In [1]:
import glob
from pathlib import Path
from PIL import Image
from tqdm import tqdm, trange
import pickle
import wget
import zipfile

import numpy as np
from scipy import stats
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCVz
import matplotlib.pyplot as plt

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
TQDM_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}'

In [2]:
def generate_activations(input_dir, output_dir=join('temp','activations'), center_crop=False):
    """Pushes input images through our pretrained resnet18 model and saves the activations.
    Can be modified to use a different network or layer if desired, by changing 'model'.

    Parameters
    ----------
    input_dir : str
        Relative path to input directory of images to be predicted.
    output_dir : str
        Relative path to save intermediate files used in prediction.
        Defaults to 'temp/activations/'.
    center_crop : bool
        If True, crops each image to a square, from the center of the
        image, before processing. If False, images are resized to a square before processing.
        Defaults to False.
    """
    # Default input image transformations for ImageNet
    if center_crop:
        scaler = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop((224, 224))
        ])
    else:
        scaler = transforms.Resize((224, 224))
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    to_tensor = transforms.ToTensor()

    # Load our pretrained model. NOTE: this is where the model could be extended or feature 
    # extraction model changed.
    model = models.resnet18(weights='DEFAULT')
    model.eval()

    desc = 'Pushing images through CNN'
    for filename in tqdm(glob.glob(join(input_dir,'*')), bar_format=TQDM_FORMAT, desc=desc):
        if Path(filename).suffix not in ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']:
            continue

        img = Image.open(filename)
        t_img = Variable(normalize(to_tensor(scaler(img))).unsqueeze(0))

        # Create network up to last layer and push image through
        layer_extractor = torch.nn.Sequential(*list(model.children())[:-1])
        feature_vec = layer_extractor(t_img).data.numpy().squeeze()
        feature_vec = feature_vec.flatten()

        # Save image activations
        image_name = Path(filename).stem
        np.save(join(output_dir,f'{image_name}.npy'), feature_vec)
        img.close()

In [3]:
output_dir = 'cnn_activations'

In [4]:
!mkdir $output_dir

In [5]:
# Download and extract stimuli data for training models (341MB)
# output is directory 'presented_stimuli' with subdirectories for each dataset
stimuli_url = "https://figshare.com/ndownloader/files/36563031"
input_dir = 'presented_stimuli'
filename = wget.download(stimuli_url)

with zipfile.ZipFile(filename, 'r') as zip_ref:
    zip_ref.extractall()

100% [......................................................................] 341219259 / 341219259

In [6]:
# Push images through CNN
for folder in glob.glob("presented_stimuli/*"):
    generate_activations(folder, output_dir, center_crop=False)

Pushing images through CNN: 100%|██████████| 2000/2000 [02:20<00:00, 14.23it/s]                                                                                                     
Pushing images through CNN: 100%|██████████| 1916/1916 [02:18<00:00, 13.80it/s]                                                                                                     
Pushing images through CNN: 100%|██████████| 1000/1000 [01:12<00:00, 13.73it/s]                                                                                                     


In [7]:
# Loads CNN activations previously saved in specified directory
def load_activations(activations_folder):
    stimuli_list = np.load('stimuli_list.pkl', allow_pickle=True)
    activations = []
    
    num_images = len(glob.glob(f'{activations_folder}/*'))
    with tqdm(total=num_images, bar_format=TQDM_FORMAT, desc='Loading activations') as pbar:
        for image_name in stimuli_list:
            for filename in glob.glob(f'{activations_folder}/{Path(image_name).stem}.npy'):
                img_activation = np.load(filename, allow_pickle = True)
                activations.append(img_activation)
            pbar.update(1)
    
    return np.asarray(activations)

In [8]:
activations = load_activations(output_dir)

Loading activations: 100%|██████████| 4916/4916 [00:01<00:00, 3893.31it/s]                                                                                                          


In [9]:
# Download fMRI data for training models (282MB)
fmri_url = "https://figshare.com/ndownloader/files/34907763"
wget.download(fmri_url)

100% [......................................................................] 281878296 / 281878296

'bold5000_reordered_data.npy'

In [10]:
fmri_raw = np.load('bold5000_reordered_data.npy', allow_pickle=True)

In [11]:
# Rearrange fmri data to num_rois x num_subjects x num_voxels x num_samples
fmri_preprocessed = np.empty((5,3,), dtype=object)
for roi_idx in range(5):
    for sub_idx in range(3):
        sub_roi = np.vstack(fmri_raw[sub_idx][roi_idx]).T
        sub_roi = stats.zscore(sub_roi, axis=1)
        fmri_preprocessed[roi_idx, sub_idx] = np.asarray(sub_roi)

In [12]:
!mkdir models

In [13]:
# Note that although we save models for all five ROIs specified with the 
# BOLD5000 dataset, we end up only using LOC, RSC, and PPA in our analyses
roi_list = ["EarlyVis","OPA", "LOC", "RSC", "PPA"]

ridge_p_grid = {'alpha': np.logspace(1, 5, 10)}
save_location = f"models/"

for roi_idx, roi in enumerate(roi_list):
    for subj in range(3):

        X_train = activations
        y_train = fmri_preprocessed[roi_idx, subj].T

        grid = GridSearchCV(Ridge(), ridge_p_grid)
        grid.fit(X_train, y_train)

        pkl_filename = f'{save_location}subj{subj+1}_{roi_list[roi_idx]}_model.pkl'
        with open(pkl_filename, 'wb') as file:
            pickle.dump(grid.best_estimator_, file)
            
        print(f"ROI: {roi_list[roi_idx]} for subject{subj+1} saved")

ROI: EarlyVis for subject1 saved
ROI: EarlyVis for subject2 saved
ROI: EarlyVis for subject3 saved
ROI: OPA for subject1 saved
ROI: OPA for subject2 saved
ROI: OPA for subject3 saved
ROI: LOC for subject1 saved
ROI: LOC for subject2 saved
ROI: LOC for subject3 saved
ROI: RSC for subject1 saved
ROI: RSC for subject2 saved
ROI: RSC for subject3 saved
ROI: PPA for subject1 saved
ROI: PPA for subject2 saved
ROI: PPA for subject3 saved
