In [1]:
import pandas, pickle, os
import numpy as np
from matplotlib import pyplot as plt

#ridge regression
from sklearn.metrics import r2_score
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import GroupKFold

from torch import Tensor
from torchvision.models import feature_extraction
from torch.utils.data import DataLoader

import encoding_utils as eu
import models_class as mc
import visualisation_utils as visu
import seaborn as sns

In [2]:
#-----------env args------------------------------------------------------------
dataset = 'mutemusic'
sub = 'sub-03'
no_init = False
tr=1.49

#absolute paths
model_path = '/home/maellef/Results/best_models/converted' 
training_data_path = '/home/maellef/git/MuteMusic_analysis/data/training_data'

#specific to soundnet/audio
model_type = 'soundnet' #'conv4'
resolution = 'MIST_ROI'# 'auditory_Voxels' 
sr=22050

#specific to one instance
stim_tracks = 'silenced'
category = 'all'
repetition = 'all'
original_sr = 48000

#visualisation
r2_max_threshold = 1

In [None]:
#specific to mutemusic---------separate silence bold from music bold
def pair_silence_with_music(track_data, metadata_df, sr=22050, bold=False):
    paired_segments = []
    for track, (i, metadata) in zip(track_data, metadata_df.iterrows()):
        timestamps_s = {'duration':[metadata['S1_duration'], metadata['S2_duration'], metadata['S3_duration'], metadata['S4_duration']],
                        'start':[metadata['S1_start'], metadata['S2_start'], metadata['S3_start'], metadata['S4_start']],
                        'stop':[metadata['S1_stop'], metadata['S2_stop'], metadata['S3_stop'], metadata['S4_stop']]}
        siltt_df = pandas.DataFrame(timestamps_s).sort_values(by='start').dropna()
        music_start = 0
        
        for i, silence_tt in siltt_df.iterrows():            
            start = silence_tt['start']/tr if bold else silence_tt['start']*sr
            silence_start = round(start)
            music_stop = silence_start

            stop = silence_tt['stop']/tr if bold else silence_tt['stop']*sr
            silence_stop = round(stop)

            paired_segments.append((track[music_start:music_stop],
                                   track[silence_start:silence_stop]))
            music_start = silence_stop
        
    return paired_segments

In [None]:
#------------load training data-------------------------------------------------------
#load data + metadata
metadata_path = os.path.join(training_data_path, f'{dataset}_{sub}_{stim_tracks}_metadata.tsv')
pairbold_path = os.path.join(training_data_path, f'{dataset}_{sub}_{stim_tracks}_pairWavBold')

data_df = pandas.read_csv(metadata_path, sep='\t')
with open(pairbold_path, 'rb') as f: 
    wavbold = pickle.load(f)

In [None]:
#----------load model + convert to extract embedding-------------------------------------
#load model (specific to soundnet model)
print(sub, resolution, model_type, category)
model_name, model = eu.load_sub_model(sub, resolution, model_type, model_path, no_init=False)
print(model_name)
i = model_name.find('conv_') + len('conv_')
temporal_size = int(model_name[i:i+3])

#create model with extractable embeddings
return_nodes = {'soundnet.conv7.2':'conv7', 'encoding_fmri':'encoding_conv'}
model_feat = feature_extraction.create_feature_extractor(model, return_nodes=return_nodes)

In [None]:
#create train/test for ridge regression

train_data, train_data_df = eu.extract_selected_data(wavbold, data_df, repetition=[1,2])
test_data, test_data_df = eu.extract_selected_data(wavbold, data_df, repetition=[3])
f_test_data, f_test_data_df = eu.extract_selected_data(wavbold, data_df, repetition=[3], groupe='F')
u_test_data, u_test_data_df = eu.extract_selected_data(wavbold, data_df, repetition=[3], groupe='U')

datasets_dict = {
    'train':(train_data, train_data_df),
    'test':(test_data, test_data_df),
    'test_F':(f_test_data, f_test_data_df),
    'test_U':(u_test_data, u_test_data_df)
}

In [None]:
#create dataset
embedding_dict = {}
for dataset_type, (data, data_df) in datasets_dict.items():
    wav_data = [x for (x,y) in data]
    wav_mussil = pair_silence_with_music(wav_data, data_df, 
                                                   sr=original_sr, bold=False)    
    bold_data = [y for (x,y) in data]
    bold_mussil = pair_silence_with_music(bold_data, data_df,
                                                     sr=original_sr, bold=True)
    sil_tr_len = [bold_s.shape[0] for bold_m, bold_s in bold_mussil]

    wav_paired = [np.concatenate([wav_m, wav_s]) for wav_m, wav_s in wav_mussil]
    bold_paired = [np.concatenate([bold_m, bold_s]) for bold_m, bold_s in bold_mussil]

    wavbold_data = [(wav, bold) for wav, bold in zip(wav_paired, bold_paired)]
    #create embedding through pretrained network
    encoding_dataset = mc.soundnet_dataset(wavbold_data, tr=tr, sr=sr)
    if original_sr != encoding_dataset.sr:
        encoding_dataset.resample_input(input_sr=original_sr)
    encoding_dataset.convert_input_to_tensor()
    
    testloader = DataLoader(encoding_dataset)
    out_p = eu.test(testloader, net=model_feat, return_nodes=return_nodes, gpu=False)
    
    Y_pred_converted, Y_real_converted = [], []
    for y_p, y_r in out_p['conv7']:
        (y_p_converted, y_r_converted) = encoding_dataset.redimension_output(y_p, y_r, cut='end')
        Y_pred_converted.append(y_p_converted)
        Y_real_converted.append(y_r_converted)
    print(len(Y_pred_converted), Y_pred_converted[0].shape,
         len(Y_real_converted), Y_real_converted[0].shape)

    embedding_dict[dataset_type] = (Y_pred_converted, Y_real_converted, sil_tr_len)


In [None]:
print(embedding_dict.keys())

In [None]:
embedding, bold, _ = embedding_dict['train']
embedding = np.vstack(embedding)
bold = np.vstack(bold)

print(embedding.shape, bold.shape)

In [None]:
alphas = np.logspace(0.1, 3, 10)
model = RidgeCV(
        alphas=alphas,
        fit_intercept=True,
        cv=10)

model.fit(embedding, bold)

In [None]:
test_x, test_y, len_sil_tr = embedding_dict['test']
print(len(test_x), len(test_y), len(len_sil_tr))

test_fullx = np.vstack(test_x)
test_fully = np.vstack(test_y)

test_musx = np.vstack([x[:-len] for x, len in zip(test_x, len_sil_tr)])
test_silx = np.vstack([x[-len:] for x, len in zip(test_x, len_sil_tr)])

test_musy = np.vstack([y[:-len] for y, len in zip(test_y, len_sil_tr)])
test_sily = np.vstack([y[-len:] for y, len in zip(test_y, len_sil_tr)])

print(test_fullx.shape, test_fully.shape,
     test_musx.shape, test_musy.shape,
     test_silx.shape, test_sily.shape)

In [None]:
testf_x, testf_y, lenf_sil_tr = embedding_dict['test_F']

testf_fullx = np.vstack(testf_x)
testf_fully = np.vstack(testf_y)

testf_musx = np.vstack([x[:-len] for x, len in zip(testf_x, lenf_sil_tr)])
testf_silx = np.vstack([x[-len:] for x, len in zip(testf_x, lenf_sil_tr)])

testf_musy = np.vstack([y[:-len] for y, len in zip(testf_y, lenf_sil_tr)])
testf_sily = np.vstack([y[-len:] for y, len in zip(testf_y, lenf_sil_tr)])

print(testf_fullx.shape, testf_fully.shape,
     testf_musx.shape, testf_musy.shape,
     testf_silx.shape, testf_sily.shape)

In [None]:
testu_x, testu_y, lenu_sil_tr = embedding_dict['test_U']

testu_fullx = np.vstack(testu_x)
testu_fully = np.vstack(testu_y)

testu_musx = np.vstack([x[:-len] for x, len in zip(testu_x, lenu_sil_tr)])
testu_silx = np.vstack([x[-len:] for x, len in zip(testu_x, lenu_sil_tr)])

testu_musy = np.vstack([y[:-len] for y, len in zip(testu_y, lenu_sil_tr)])
testu_sily = np.vstack([y[-len:] for y, len in zip(testu_y, lenu_sil_tr)])

print(testu_fullx.shape, testu_fully.shape,
     testu_musx.shape, testu_musy.shape,
     testu_silx.shape, testu_sily.shape)

In [None]:
y_p = model.predict(test_fullx)
r2 = r2_score(test_fully, y_p, multioutput='raw_values')
r2 = np.where(r2<0, 0, r2)
print(max(r2))
colormap = visu.extend_colormap(original_colormap='turbo',
                          percent_start = 0.1, percent_finish=0)
visu.surface_fig(r2, vmax=0.4, threshold=0.005, cmap='turbo', symmetric_cbar=False)

savepath = f'./figures/{sub}_{dataset}_{model_type}_paired_mussil_predict_mussil_HRF2.png'
plt.savefig(savepath)