In [2]:
from pathlib import Path
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings("ignore")

In [3]:
data_base = Path("shared/NSD")
curr_subj = "subj01"
subj_behav_file = data_base / f"nsddata/ppdata/{curr_subj}/behav/responses.tsv"
subj_behav = pd.read_csv(subj_behav_file, delimiter="\t")

In [4]:
# add within session iteration indices
for sess_id in range(1, 40+1):
    subj_behav.loc[subj_behav["SESSION"]==sess_id, "INSESSIONIDX"] = list(range(1, 750+1))

In [5]:
subj_behav_sel = subj_behav.dropna(subset=["BUTTON"])
subj_behav_sel = subj_behav_sel[~subj_behav_sel["SESSION"].isin([38, 39, 40])]

In [7]:
# take the first saved data
betas = np.load("./nsd-data/MTL-1_full.npy")

In [8]:
X = betas[subj_behav_sel.index, :] / 300
y = subj_behav_sel["ISOLDCURRENT"].to_numpy()

In [9]:
from sklearn.model_selection import train_test_split
from sklearn.svm import LinearSVC
from sklearn.metrics import confusion_matrix, classification_report

In [10]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

In [11]:
svc = LinearSVC(class_weight="balanced", verbose=10)
svc.fit(X_train, y_train)
y_pred = svc.predict(X_test)

[LibLinear]....................................................................................................
optimization finished, #iter = 1000

Using -s 2 may be faster (also see FAQ)

Objective value = -1179.217130
nSV = 18539


In [12]:
conf_mat = confusion_matrix(y_test, y_pred)
conf_mat

array([[5217, 1779],
       [1510,  648]])

In [13]:
classification_report(y_test, y_pred, output_dict=True)

{'0': {'precision': 0.7755314404638026,
  'recall': 0.7457118353344768,
  'f1-score': 0.7603293740435765,
  'support': 6996},
 '1': {'precision': 0.2669962917181706,
  'recall': 0.3002780352177943,
  'f1-score': 0.2826608505997819,
  'support': 2158},
 'accuracy': 0.6407035175879398,
 'macro avg': {'precision': 0.5212638660909866,
  'recall': 0.5229949352761356,
  'f1-score': 0.5214951123216792,
  'support': 9154},
 'weighted avg': {'precision': 0.6556473623566282,
  'recall': 0.6407035175879398,
  'f1-score': 0.6477219157093281,
  'support': 9154}}