In [1]:
import glob
from pathlib import Path
from PIL import Image
from tqdm import tqdm, trange
import pickle

import numpy as np
from scipy import stats
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
import matplotlib.pyplot as plt

import nibabel as nib
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable

In [14]:
# TODO:
# - comments
# - directory handling?
# - uploading stimuli

In [None]:
# NOTE: This will only work once pip'd
from imgtofmri import generate_activations

In [2]:
def generate_activations(input_dir, output_dir=""):
    if output_dir == "": 
        output_dir = f"temp/activations/"

    # Default input image transformations for ImageNet
    scaler = transforms.Resize((224, 224))
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    to_tensor = transforms.ToTensor()

    # Load the pretrained model, set to eval mode
    model = models.resnet18(pretrained=True)
    model.eval()

    for filename in tqdm(glob.glob(f"{input_dir}/*"), desc='Pushing images through CNN'):
        # TODO need to have a check for non jpg/pngs... should just have a try except probs
        if Path(filename).suffix not in [".jpg", '.JPG', '.jpeg', '.JPEG', ".png", '.PNG']:
            print(f"skipping {filename} with suffix: {Path(filename).suffix}")
            continue

        img = Image.open(filename)
        t_img = Variable(normalize(to_tensor(scaler(img))).unsqueeze(0))

        # Create network up to last layer, push image through, flatten
        layer_extractor = torch.nn.Sequential(*list(model.children())[:-1])
        feature_vec = layer_extractor(t_img).data.numpy().squeeze()
        feature_vec = feature_vec.flatten()

        # Save image
        image_name = Path(filename).stem
        np.save(f"{output_dir}/{image_name}.npy", feature_vec)
        img.close()

In [3]:
output_dir = 'cnn_activations'

In [4]:
for folder in glob.glob("presented_stimuli/*"):
    generate_activations(folder, output_dir)

Pushing images through CNN: 100%|██████████| 2000/2000 [02:57<00:00, 11.27it/s]
Pushing images through CNN: 100%|██████████| 1916/1916 [02:50<00:00, 11.21it/s]
Pushing images through CNN: 100%|██████████| 1000/1000 [01:26<00:00, 11.58it/s]


In [5]:
def load_activations(activations_folder):
    stimuli_list = np.load('stimuli_list.pkl', allow_pickle=True)
    activations = []
    
    num_images = len(glob.glob(f'{activations_folder}/*'))
    with tqdm(total=num_images, desc='Loading activations') as pbar:
        for image_name in stimuli_list:
            for filename in glob.glob(f'{activations_folder}/{Path(image_name).stem}.npy'):
                img_activation = np.load(filename, allow_pickle = True)
                activations.append(img_activation)
            pbar.update(1)
    
    return np.asarray(activations)

In [6]:
activations = load_activations(output_dir)

Loading activations: 100%|██████████| 4916/4916 [00:03<00:00, 1490.62it/s]


In [7]:
fmri_raw = np.load('bold5000_reordered_data.pkl', allow_pickle=True)

In [12]:
fmri_preprocessed = np.empty((5,3,), dtype=object)
for roi_idx in range(5):
    for sub_idx in range(3):
        sub_roi = np.vstack(fmri_raw[sub_idx][roi_idx])
        sub_roi = np.transpose(sub_roi)
        sub_roi = stats.zscore(sub_roi, axis=1)
        fmri_preprocessed[roi_idx, sub_idx] = np.asarray(sub_roi)

In [9]:
# Data is now in shape: num_rois x num_subjects x num_voxels x num_samples

In [10]:
!mkdir models

In [13]:
roi_list = ["EarlyVis","OPA", "LOC", "RSC", "PPA"]

ridge_p_grid = {'alpha': np.logspace(1, 5, 10)}
save_location = f"models/"

for roi_idx, roi in enumerate(roi_list):
    for subj in range(3):

        X_train = activations
#         y_train = np.vstack(fmri_preprocessed[roi_idx, subj]).T
        y_train = fmri_preprocessed[roi_idx, subj].T

        grid = GridSearchCV(Ridge(), ridge_p_grid)
        grid.fit(X_train, y_train)

        pkl_filename = f'{save_location}subj{subj+1}_{roi_list[roi_idx]}_model.pkl'
        with open(pkl_filename, 'wb') as file:
            pickle.dump(grid.best_estimator_, file)
        print(f"ROI: {roi_list[roi_idx]} for subject{subj+1} saved")

ROI: EarlyVis for subject1 saved
ROI: EarlyVis for subject2 saved
ROI: EarlyVis for subject3 saved
ROI: OPA for subject1 saved
ROI: OPA for subject2 saved
ROI: OPA for subject3 saved
ROI: LOC for subject1 saved
ROI: LOC for subject2 saved
ROI: LOC for subject3 saved
ROI: RSC for subject1 saved
ROI: RSC for subject2 saved
ROI: RSC for subject3 saved
ROI: PPA for subject1 saved
ROI: PPA for subject2 saved
ROI: PPA for subject3 saved
