## In the following cell, please change the various path to the ones corresponding to your system

In [None]:
#### CHANGE THE PATH TO CORRESPOND TO THE PATH ON YOUR SYSTEM
### The path that should be changed are preceded with <--- CHANGE --->

def initialize(movie):
    # Generate folders for organized storage
    folders = ["fmri_mean_{}", "wards_{}", "fmri_ready_{}", "results"]
    for folder in folders :
        folder_name = folder.format(movie)
        try :
            os.mkdir(folder_name)
        except OSError :
            print("Folder already exists, skipping creation ({})".format(folder_name))
        else :
            print("Folder created  ({})".format(folder_name))

    result_folders = ["results/{}_conv{}".format(movie, id_layer) for id_layer in range(1,8)]
    for result_folder in result_folders :
        try :
            os.mkdir(result_folder)
        except OSError :
            print("Folder already exists, skipping creation ({})".format(result_folder))
        else :
            print("Folder created ({})".format(result_folder))
    
    # Set the differents path to data and folders

    # path to dataset of Sherlock or Merlin or TwilightZone
    if movie in ("sherlock","merlin"):
        # fmri data location
        
        # Location of the dataset
        # <--- CHANGE --->
        local_movie_path = "/home/brain/datasets/SherlockMerlin_ds001110/"
        
        # mask name
        movie_mask =  "{}Movie_bold_space-T1w_brainmask.nii.gz".format(movie.capitalize())
        
        # Generic name of the brain masks (non MNI) from the dataset
        # <--- CHANGE --->
        generic_mask_name = "/home/brain/datasets/SherlockMerlin_ds001110/sub-{:02d}/func/sub-{:02d}_task-" + movie_mask
        
        # fmri file name
        generic_filename = "sub-{:02d}_task-" + "{}Movie_bold_space-T1w_preproc.nii.gz".format(movie.capitalize())
        
        # anat file name
        
        # <--- CHANGE --->
        anat_filename = "/home/brain/victor/datasets/fmri_anat_{}/".format(movie) + "sub-{:02d}.nii.gz"
        
    elif movie == "twilight-zone" :
        # fmri data location
        # <--- CHANGE --->
        local_movie_path = "/home/brain/victor/datasets/twilight-zone"
        
        # there are no pre-existing masks
        generic_mask_name = ""
        
        # fmri file name
        generic_filename = "sub-{:02d}_task-watchmovie_bold.nii.gz"
    else :
        raise ValueError("The movie name has to be 'twilight-zone', 'merlin' or 'sherlock'")

        
    # locate the folder containing feature vectors extracted from soundnet for the corresponding movie (merlin_pytorch or sherlock_pytorch
    # <--- CHANGE --->
    feature_folder = "soundnet_features/{}_pytorch/".format(movie)

    # folder for storing the resulting r2 brain maps
    folder_name = "results/{}_".format(movie)
    result_folder = folder_name + "conv{}/"
    
    # Configure the subjects corresponding to the movie : 
    # Sherlock : [1,18]\{5} 
    # Merlin : [19,37]\{25}
    # Twilight-zone : [1,25]
    id_subjects = {
    "sherlock": (1,18,5),
    "merlin": (19,37,25),
    "twilight-zone": (1,25,0)
    }

    sub_values = id_subjects[movie]
    
    return local_movie_path, generic_mask_name, generic_filename, anat_filename, feature_folder, result_folder, sub_values


In [None]:
# nilearn imports
from nilearn.plotting import plot_roi, plot_stat_map, plot_anat, view_img, cm
from nilearn.image import mean_img

In [None]:
### Remove useless warnings from scikit learn

def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [None]:
# torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data

In [None]:
# sklearn imports
from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# utility imports
import numpy as np
from matplotlib import pyplot as plt
import os.path
from joblib import dump, load
import tqdm

In [None]:
class Dataset(data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, labels, fv_vector):
        'Initialization'
        self.labels = labels
        self.vectors = fv_vector

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.vectors)


  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        # Load data and get label
        X = self.vectors[index]
        y = self.labels[index]

        return X, y


In [None]:
from nilearn.regions import Parcellations
from nilearn.input_data import NiftiMasker
from nilearn.masking import compute_background_mask
import os

def parcellate(id_subject, n_frames, compute_mean = True, compute_ward = True, compute_ready = True):
    # Compute the ward parcellation and other fmri data of a given subject and save it
    
    folder_name = "sub-{:02d}/func".format(id_subject)
    subject_filename = generic_filename.format(id_subject)
    fmri_file = os.path.join(local_movie_path,folder_name, subject_filename)

    # Compute the mean of fmri accross time
    if  compute_mean or not os.path.isfile("fmri_mean_{}/sub-{:02d}.nii.gz".format(movie, id_subject)):
        fmri_mean = mean_img(fmri_file)
        fmri_mean.to_filename("fmri_mean_{}/sub-{:02d}.nii.gz".format(movie, id_subject))
        print("Saved mean fmri for subject {}".format(id_subject))
    else :
        print("Mean fmri already exists for subject {}".format(id_subject))
        fmri_mean = "fmri_mean_{}/sub-{:02d}.nii.gz".format(movie, id_subject)
 
    # Compute mask if the movie is twilight-zone 
    if movie == "twilight-zone":
        print("Computing background mask")
        mask_img = compute_background_mask(fmri_file)
    else :
        print("Loading pre-generated mask")
        mask_img = generic_mask_name.format(id_subject,id_subject)
    
    # Compute ward parcellation
    if compute_ward or not os.path.isfile("wards_{}/sub-{:02d}.nii.gz".format(movie,id_subject)):
        masker = NiftiMasker(mask_img=mask_img, detrend=True,standardize=True)
        masker.fit()
        ward = Parcellations(method='ward',mask=masker,standardize=True,smoothing_fwhm=None,n_parcels=500)
        ward.fit(fmri_file)
        dump(ward, "wards_{}/sub-{:02d}.nii.gz".format(movie,id_subject))
        print("Saved ward mask for subject {}".format(id_subject))
    else :
        ward = load("wards_{}/sub-{:02d}.nii.gz".format(movie,id_subject))
        print("Ward mask exists for subject {}".format(id_subject))

     # Compute fmri_ready       
    if compute_ready or not os.path.isfile("fmri_ready_{}/sub-{:02d}.npy".format(movie, id_subject)):
        print("fmri_file : ", fmri_file)
        fmri_data = ward.transform(fmri_file)
        # Truncate the data because of an offset in the fmri (see dataset description)
        if movie in ("merlin, sherlock"):
            # 25 seconds of offset, and 1.5s per frame
            fmri_ready = fmri_data[17:-(fmri_data.shape[0]-17-n_frames)]
        else :
            # 15 TR of offset, and 1.5s per frame
            fmri_ready = fmri_data[15:-(fmri_data.shape[0]-15-n_frames)]
        np.save("fmri_ready_{}/sub-{:02d}".format(movie, id_subject), fmri_ready)    
        print("Saved fmri_ready for subject {}".format(id_subject))
    else :
        print("fmri_ready exists for subject {} and movie {}".format(id_subject, movie))

In [None]:
def load_feature_vector(id_layer) :
    filename = "conv{}.npz".format(id_layer)
    file_fv = os.path.join(feature_folder, filename)
    fv = np.load(file_fv)['fv'][1:]
    # Check the size
    n_frames = fv.shape[0]
    fv_normalized = StandardScaler().fit_transform(fv)
    print("layer {}, {} frames, FV dimension is {}".format(id_layer,n_frames, fv.shape[1]))
    return fv_normalized, n_frames

In [None]:
def load_fmri_data(id_subject):
    fmri_ready_path = "fmri_ready_{}/sub-{:02d}.npy".format(movie, id_subject)
    fmri_ready = np.load(fmri_ready_path)[1:]
    print("Loaded {}, with length : {}".format(fmri_ready_path, len(fmri_ready)))
    
    fmri_mean_path = "fmri_mean_{}/sub-{:02d}.nii.gz".format(movie, id_subject)
    print("Loaded {}".format(fmri_mean_path))
    
    ward_path = "wards_{}/sub-{:02d}.nii.gz".format(movie, id_subject)
    ward = load(ward_path)
    print("Loaded {}".format(ward_path))
    
    return fmri_ready, fmri_mean_path, ward

In [None]:
def load_data(id_layer,id_subject): 
    X, _ = load_feature_vector(id_layer)
    X_tensor = torch.tensor(X, dtype = torch.float, requires_grad = True)
    y, fmri_mean, ward = load_fmri_data(id_subject)
    y_tensor = torch.tensor(y, dtype = torch.float, requires_grad = True)
    return X, y, fmri_mean, ward

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(1024, 1000)  
        self.fc2 = nn.Linear(1000, 500)

    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x.cpu()
    
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


In [None]:
def train(id_subject, tol = 1e-4, n_iter_no_change=10) :
    """tol : float, optional, default 1e-4
        Tolerance for the optimization. When the loss or score is not improving
        by at least ``tol`` for ``n_iter_no_change`` consecutive iterations,
        convergence is considered to be reached and training stops.
        
        n_iter_no_change : int, optional, default 10
        Maximum number of epochs to not meet ``tol`` improvement.
        """
    X,y,fmri_mean, ward = load_data(id_layer = 7, id_subject = id_subject)
    result_folder_layer = result_folder.format(7)
    fold_number = 1
    views = []

    for train_index, test_index in cv.split(X):
        ### Create a NEW neural network
        net = Net()
        net.to(device)
        # Generate data
        X_train, y_train = X[train_index], y[train_index]
        X_test, y_test = X[test_index], y[test_index]
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size = 0.1)
        train_length = X_train.shape[0]
        validation_length = X_val.shape[0]
        testing_length = X_test.shape[0]
        training_set = Dataset(y_train, X_train)
        validation_set = Dataset(y_val, X_val)
        testing_set = Dataset(y_test, X_test)
        training_generator = data.DataLoader(training_set, **params)
        validation_generator = data.DataLoader(validation_set, **params)
        testing_generator = data.DataLoader(testing_set, **params)
        # Create optimizer
        optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay = 0.0001)
        criterion = nn.MSELoss()
        
        # Prepare logging
        train_losses = []
        validation_losses = []
        best_validation_loss = 1e10
        no_improvement_count = 0
        
        for epoch in range(1,max_epochs+1) :
            # Training
            minibatch_number = 1
            loss_sum = 0
            for local_batch, local_labels in training_generator:
                optimizer.zero_grad()   # zero the gradient buffers
                output = net(local_batch)
                loss = criterion(output, local_labels.float())
                loss.backward()
                optimizer.step() # Does the update
                loss_sum += loss.item()
                minibatch_number += 1
            train_loss = loss_sum / train_length
            train_losses.append(train_loss)
            
            # Validation
            minibatch_number = 1
            loss_sum = 0
            with torch.set_grad_enabled(False):
                for local_batch, local_labels in validation_generator:
                    output = net(local_batch)
                    loss = criterion(output, local_labels.float())
                    loss_sum += loss.item()
                    minibatch_number += 1
            validation_loss = loss_sum / validation_length
            validation_losses.append(validation_loss)
            
            # Check for early-stopping
            if validation_loss > (best_validation_loss - tol):
                no_improvement_count += 1
            else :
                no_improvement_count = 0
            if validation_loss < best_validation_loss :
                best_validation_loss = validation_loss
            if epoch % 5 == 0 :
                print("Fold {} epoch {}".format(fold_number, epoch))
            if no_improvement_count > n_iter_no_change :
                # The loss hasn't improved by 'tol' for more than 'n_iter_no_change' epochs
                print("Early stopping at epoch {}".format(epoch))
                break

        # Testing
        minibatch_number = 1
        loss_sum = 0
        with torch.set_grad_enabled(False):
            for local_batch, local_label in testing_generator:
                output = net(local_batch)
                loss = criterion(output, local_label.float())
                #predictions.append(output.detach().numpy())
                #local_labels.append(local_label)
                loss_sum += loss.item()
                minibatch_number += 1
            test_loss = loss_sum / testing_length
        print("Fold {} testing loss : {}".format(fold_number, test_loss))
        
        # Represent R2 scores
        predictions_tensor = net(torch.tensor(X_test, dtype=torch.float))
        predictions = predictions_tensor.detach().numpy()
        scores_mlp = r2_score(y_test, predictions, multioutput='raw_values')
        scores_mlp[scores_mlp < 0] = 0
        scores_img = ward.inverse_transform(scores_mlp.reshape((1,-1)))
        scores_img.to_filename("./{}/fold{}_sub{:02d}.nii.gz".format(result_folder_layer ,fold_number, id_subject))
        fmri_anat = anat_filename.format(id_subject)
        plot_stat_map(scores_img, bg_img = fmri_anat, dim = -0.5)
        plt.savefig("./{}/fold{}_sub{:02d}.png".format(result_folder_layer ,fold_number, id_subject))
        plt.show()
                               
        # Plot the loss graph
        x_axis = [i  for i in range(len(train_losses))]
        plt.plot(x_axis, train_losses,)
        plt.plot(x_axis, validation_losses,)
        plt.plot(x_axis, [test_loss for i in x_axis])
        plt.legend(["train loss" ,"validation loss", "test loss"])
        plt.show()

        fold_number += 1
    

In [None]:
### SHERLOCK 

# Movie has to be 'sherlock', 'merlin' or 'twilight-zone'
movie = 'sherlock'
device = torch.device("cuda:0")
local_movie_path, generic_mask_name, generic_filename, anat_filename, feature_folder, result_folder, sub_values = initialize(movie)
min_sub, max_sub, null_sub = sub_values

# Compute parcellation if needed
_, n_frames = load_feature_vector(7)
for id_subject in range(min_sub, max_sub):
    if id_subject != null_sub :
        parcellate(id_subject, n_frames, False, False, False)

# Train the network
import torch.optim as optim
n_folds = 4
cv = KFold(n_splits = n_folds)


# Add labels
max_epochs = 1000
params = { 
    'batch_size': 32,
    'shuffle': False,
    'num_workers': 6
}
for id_subject in range(min_sub, max_sub):
    if id_subject != null_sub :
        train(id_subject)

In [None]:
### MERLIN 

# Movie has to be 'sherlock', 'merlin' or 'twilight-zone'
movie = 'merlin'
device = torch.device("cuda:0")
local_movie_path, generic_mask_name, generic_filename, anat_filename, feature_folder, result_folder, sub_values = initialize(movie)
min_sub, max_sub, null_sub = sub_values

# Compute parcellation if needed
_, n_frames = load_feature_vector(7)
for id_subject in range(min_sub, max_sub):
    if id_subject != null_sub :
        parcellate(id_subject, n_frames, False, False, False)

# Train the network
import torch.optim as optim
n_folds = 4
cv = KFold(n_splits = n_folds)


# Add labels
max_epochs = 1000
params = { 
    'batch_size': 32,
    'shuffle': False,
    'num_workers': 6
}
for id_subject in range(min_sub, max_sub):
    if id_subject != null_sub :
        train(id_subject)