In [1]:
import os, torch, torchvision, pandas, pickle
import numpy as np

from torch import nn
from torchvision.models import feature_extraction
from torch.utils.data import Dataset, DataLoader

#ridge regression
from sklearn.metrics import r2_score
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import GroupKFold

import sys
sys.path.append('/home/maellef/git/cNeuromod_encoding_2020')
from models import encoding_models as encod

#visualisation
from matplotlib import pyplot as plt
from matplotlib import colors, colormaps
#brain visualization import
from nilearn import regions, datasets, surface, plotting, image, maskers
from nilearn.plotting import plot_roi, plot_stat_map
MIST_path = '/home/maellef/DataBase/fMRI_parcellations/MIST_parcellation/Parcellations/MIST_ROI.nii.gz'

In [2]:
def surface_fig(parcel_data, vmax, threshold=0, cmap='turbo', inflate=True, colorbar=True, 
                no_background=True, symmetric_cbar=True):
    nii_data = regions.signals_to_img_labels(parcel_data, MIST_path)
    fig, ax = plotting.plot_img_on_surf(nii_data,
                              views=['lateral', 'medial'], hemispheres=['left', 'right'], inflate=inflate,
                              vmax=vmax, threshold=threshold, colorbar=colorbar, cmap=cmap, 
                                        symmetric_cbar=symmetric_cbar, cbar_tick_format="%.1f")
    return fig

In [3]:
def extend_colormap(original_colormap = 'twilight', percent_start = 0.25, percent_finish = 0.25):
    colormap = colormaps[original_colormap]
    nb_colors = colormap.N
    new_colors_range = colormap(np.linspace(0,1,nb_colors))

    n_start = round(nb_colors/(1-percent_start)) - nb_colors if percent_start != 0 else 0
    new_color_start = np.array([colormap(0)]*n_start).reshape(-1, new_colors_range.shape[1])
    n_finish = round(nb_colors/(1-percent_finish)) - nb_colors if percent_finish != 0 else 0
    new_color_finish = np.array([colormap(0)]*n_finish).reshape(-1, new_colors_range.shape[1])

    new_colors_range = np.concatenate((new_color_start,new_colors_range,new_color_finish), axis=0)
    new_colormap = colors.ListedColormap(new_colors_range)
    return new_colormap

In [4]:
def load_sub_models(sub, scale, conv, models_path, no_init=False): 
    models = {}
    #scale_path = os.path.join(models_path, sub, scale)
    for model in os.listdir(models_path):
        if '.pt' in model and conv in model and sub in model and scale in model:
            model_path = os.path.join(models_path, model)
            modeldict = torch.load(model_path, map_location=torch.device('cpu'))
            model_net = encod.SoundNetEncoding_conv(out_size=modeldict['out_size'],output_layer=modeldict['output_layer'],
                                                    kernel_size=modeldict['kernel_size'], no_init=no_init)
            if not no_init:
                model_net.load_state_dict(modeldict['checkpoint'])
            models[model] = model_net
    return models

In [5]:
class movie10_dataset(Dataset):
    def __init__(self, data, temporal_window, tr, sr):
        self.temporal_window = temporal_window
        self.tr = tr
        self.sr = sr

        data_by_temporal_window = []
        for (run_wav, run_bold) in data:  
            run_data = self.__create_temporal_segments__(run_wav, run_bold)
            data_by_temporal_window.extend(run_data)

        self.x = [wav for (wav, bold) in data_by_temporal_window]
        self.y = [bold for (wav, bold) in data_by_temporal_window]
        
    def __create_temporal_segments__(self, wav, bold):
        chunk_length = round(self.sr*self.tr)*self.temporal_window
        wav_length = len(wav)
        
        wav_starts = range(0, len(wav), chunk_length)
        bold_starts = range(0, len(bold), self.temporal_window)
        wavbold_by_temporalwindow = []
        for wav_start, bold_start in zip(wav_starts, bold_starts):
            
            if wav_start+chunk_length < len(wav):            
                wav_chunk = wav[wav_start:wav_start+chunk_length]
                bold_chunk = bold[bold_start:bold_start+self.temporal_window,:]
            else:
                wav_tr = round((len(wav) - wav_start)/(self.tr*self.sr))
                bold_tr = len(bold) - bold_start
                min_tr = wav_tr if wav_tr <= bold_tr else bold_tr
    
                wav_stop = int(wav_start+min_tr*self.sr*self.tr)
                wav_chunk = wav[wav_start:wav_stop]
                bold_chunk = bold[bold_start:bold_start+min_tr,:]  
                
            wavbold_by_temporalwindow.append((wav_chunk, bold_chunk))
    
        return wavbold_by_temporalwindow

    
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [6]:
def test(dataloader, net, epoch, mseloss, return_nodes, gpu=True):
    net.eval()
    out_p = {layer_name:[] for layer_name in return_nodes.values()}
    
    with torch.no_grad():
        for (x,y) in dataloader:
            #print(x.shape, y.shape)
            # load data
            x = torch.Tensor(x).view(1,1,-1,1)
            # Forward pass
            y_p = net(x, epoch)
            
            for key, p in y_p.items():
                #print(val.shape)
                p = p.permute(2,1,0,3).squeeze()
                #print(p.shape, y.shape)
                out_p[key].append((p.numpy(), y.numpy()))
    return out_p

In [8]:
#necessary args
dataset = 'movie10'
sub = 'sub-06'
no_init = False
conv = 'conv4' #'opt110_wb'#, 'sub-02', 'sub-03', 'sub-04', 'sub-05'
scale = 'MIST_ROI'#, 'auditory_Voxels' 
shape = 210

models_path = '/home/maellef/Results/best_models/converted' 
training_data_path = '/home/maellef/git/MuteMusic_analysis/data/training_data'

#load and extract model from dict
models = load_sub_models(sub, scale, conv, models_path, no_init=no_init)

for name, model in models.items():
    i = name.find('conv_') + len('conv_')
    temporal_size = int(name[i:i+3]) 
    model = model

print('nb of tr in temporal window: ', temporal_size)
#load data + metadata
metadata_path = os.path.join(training_data_path, f'{dataset}_{sub}_metadata.tsv')
pairbold_path = os.path.join(training_data_path, f'{dataset}_{sub}_pairWavBold')

data_df = pandas.read_csv(metadata_path, sep='\t')
with open(pairbold_path, 'rb') as f: 
    wavbold = pickle.load(f)

print(data_df)

shape of encoding matrice from last encoding layer : 1024 X 210
nb of tr in temporal window:  70
    Unnamed: 0 category       task  repetition
0            0     wolf     wolf09         NaN
1            1   bourne   bourne09         NaN
2            2  figures  figures03         1.0
3            3     wolf     wolf04         NaN
4            4     wolf     wolf03         NaN
..         ...      ...        ...         ...
56          56     life     life04         1.0
57          57     wolf     wolf05         NaN
58          58   bourne   bourne04         NaN
59          59  figures  figures08         1.0
60          60     wolf     wolf11         NaN

[61 rows x 4 columns]


In [9]:
#select data depending on metadata infos
category = 'wolf'
repetition = 'all'

cat = data_df['category'].unique() if category == 'all' else [category]
rep = data_df['repetition'].unique() if repetition == 'all' else [repetition]
i_selection = data_df.loc[(data_df['category'].isin(cat))
                            &(data_df['repetition'].isin(rep))].index.values
print(i_selection)
print(len(wavbold))
selected_wavbold = [(wav, bold) for i, (wav, bold) in enumerate(wavbold) if i in i_selection]
print(len(selected_wavbold))

[ 0  3  4 12 18 21 23 25 29 44 47 48 51 52 53 57 60]
61
17


In [11]:
def select_df_index(df, **selectors):
    '''return the rows indexes of a dataframe based on selectors
    selector argument : column name = value to select
    value can be a list of values, or a single value
    if value = 'all', all rows for this column will be selected'''
    
    conditions = True
    for column_name, val in selectors.items():
        val = val if isinstance(val, list) else [val]
        selected_items = df[column_name].unique() if val[0]=='all' else val
        condition = df[column_name].isin(selected_items)
        conditions &= (condition)
    i_selection = df.loc[conditions].index.values
    return i_selection

i = select_df_index(data_df, category='wolf', repetition='all')

print(i)

[ 0  3  4 12 18 21 23 25 29 44 47 48 51 52 53 57 60]


In [None]:
#extract X and Y data for prediction + check for empty data (WIP: move to previous later)
empty_pair = []
for i, (wav, bold) in enumerate(selected_wavbold):
    if wav.shape[0] == 0 and bold.shape[0] == 0:
        empty_pair.append(i)

correct_wavbold = [(wav, bold) for (wav, bold) in selected_wavbold if wav.shape[0] != 0]
print(empty_pair)

In [None]:
#define all possible output (WIP) + create model with extractable embeddings
train_nodes, eval_nodes = feature_extraction.get_graph_node_names(model)
#return_nodes = {layer:layer[len('soundnet.'):-2] for layer in train_nodes if layer[-1] == '2'}
print(eval_nodes)

return_nodes = {'soundnet.conv7.2':'conv7', 'encoding_fmri':'encoding_conv'}
model_feat = feature_extraction.create_feature_extractor(model, return_nodes=return_nodes)

In [None]:
#create dataset for dataloader
model_dataset = movie10_dataset(correct_wavbold, temporal_window=temporal_size, tr=1.49, sr=22050)
testloader = DataLoader(model_dataset)

In [None]:
#extract embedding from selected model
out_p = test(testloader, net=model_feat, epoch=1, 
     mseloss=nn.MSELoss(reduction='sum'), 
     return_nodes=return_nodes, gpu=False)

In [None]:
print(out_p['encoding_conv'][0][0].shape, out_p['encoding_conv'][0][1].shape)

predicted_y = [pred_y[:y.shape[1],:] for (pred_y, y) in out_p['encoding_conv']]
predicted_y = np.vstack(predicted_y)

real_y = [y.squeeze() for (pred_y, y) in out_p['encoding_conv']]
real_y = np.vstack(real_y)

print(predicted_y.shape, real_y.shape)

In [None]:
r2 = r2_score(real_y, predicted_y, multioutput='raw_values')
r2 = np.where(r2<0, 0, r2)
print(max(r2))
colormap = extend_colormap(original_colormap='turbo',
                          percent_start = 0.1, percent_finish=0)
surface_fig(r2, vmax=0.25, threshold=0.00005, cmap='turbo', symmetric_cbar=False)

savepath = f'./figures/{sub}_generalisation_{dataset}_{category}_{repetition}.png'
plt.savefig(savepath)

In [None]:
MISTinfo_path = '/home/maellef/DataBase/fMRI_parcellations/MIST_parcellation/Parcel_Information/MIST_ROI.csv'
MIST_df = pandas.read_csv(MISTinfo_path, sep=';')

min = []
imin = []
for i in range(len(r2)):
    if r2[i] < -2:
        min.append(r2[i])
        imin.append(i)
print(imin, min)
print(MIST_df.iloc[imin]['name'])
print(MIST_df.iloc[48]['name'])
print(r2[48])
print(MIST_df.iloc[49]['name'])
print(r2[49])
print(MIST_df.iloc[208]['name'])
print(r2[208])
print(MIST_df.iloc[209]['name'])
print(r2[209])