In [1]:
import os
import sys
import torch
import glob
import pandas as pd
import numpy as np
import nilearn.connectome
import matplotlib.pyplot as plt
from nilearn.maskers import NiftiMasker
from nilearn.input_data import NiftiLabelsMasker, NiftiMapsMasker
from csv import writer
from load_confounds import Params9

from dypac.masker import LabelsMasker, MapsMasker
from nilearn.plotting import plot_roi, plot_stat_map
from nilearn import datasets
from nilearn.interfaces.fmriprep import load_confounds_strategy

sys.path.append(os.path.join(".."))
import time_windows_dataset
import graph_construction
import gcn_model



In [None]:
print("torch v{}".format(torch.__version__))

In [None]:
print("nilearn v{}".format(nilearn.__version__))

# Initial parameters

In [None]:
TR = 1.49
random_seed = 0

subject = 'sub-02'
region_approach = 'dypac'
resolution = 1024
window_length = 1
modality = 'all_mod' #'motor'
HRFlag_processes = '3volumes'

# Fetch data

In [None]:
# paths
data_dir = os.path.join('..', '..', '..', 'data')

concat_data_dir = os.path.join(data_dir, 'concat_data', region_approach, str(resolution), subject)
processed_bold_files = sorted(glob.glob(concat_data_dir + '/*.npy'))

conn_dir = os.path.join(data_dir, 'connectomes')
conn_files = sorted(glob.glob(conn_dir + '/conn_friends_{}_{}{}.npy'.format(subject,
                                                                            region_approach,
                                                                            resolution)))

split_dir = os.path.join(data_dir, 'split_win_data')
out_csv = os.path.join(split_dir, 'labels.csv')
out_file = os.path.join(split_dir, '{}_{:04d}.npy')

result_dir = os.path.join('../results')
result_csv = os.path.join(result_dir, 'result_df.csv')
model_path = os.path.join('../models', 'gcn_test.pt')

if not os.path.exists(split_dir):
    os.makedirs(split_dir)
if not os.path.exists(conn_dir):
    os.makedirs(conn_dir)    
if not os.path.exists(result_dir):
    os.makedirs(result_dir)
if not os.path.exists(result_csv):
    result_df = pd.DataFrame(columns=['subject','modality','window_length','region_approach',
                                      'average_loss','average_accuracy', 'time_window'])
    result_df.to_csv(result_csv, index=False)
    
# remove previous content
if os.path.exists(split_dir):
    files = glob.glob(os.path.join(split_dir, "*"))
    for f in files:
        os.remove(f)

# Generating connectomes

In [None]:
# فقط یک بار نیاز است که اجرا شود و فایل ذخیره شده
# Generates connectome from friends for GCN
bold_suffix = '_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz'
path_cneuromod = '/data/neuromod/projects/ml_models_tutorial/data/friends/raw_data'
file_epi = os.path.join(path_cneuromod, 
                        '{}_ses-003_task-s01e06a_space-MNI152NLin2009cAsym_desc-' \
                        'preproc_bold.nii.gz'.format(subject))

conf = load_confounds_strategy(file_epi, denoise_strategy='simple', global_signal='basic')

path_dypac = '/data/cisl/pbellec/models'
file_mask = os.path.join(path_dypac, '{}_space-MNI152NLin2009cAsym_label-GM_mask.nii.gz'.format(subject))
file_dypac = os.path.join(path_dypac, '{}_space-MNI152NLin2009cAsym_desc-dypac{}_' \
                          'components.nii.gz'.format(subject,resolution))

masker = NiftiMasker(standardize=True, detrend=False, smoothing_fwhm=5, mask_img=file_mask)
masker.fit(file_epi)

maps_masker = MapsMasker(masker=masker, maps_img=file_dypac)

sample_ts= maps_masker.transform(img=file_epi, confound=conf[0])
sample_ts.shape

# Estimating connectomes
corr_measure = nilearn.connectome.ConnectivityMeasure(kind="correlation")
conn = corr_measure.fit_transform([sample_ts])[0]
np.save(os.path.join(conn_dir, 'conn_friends_{}_{}{}.npy'.format(subject, region_approach, resolution)), conn)

# # conn_file = os.path.join(conn_dir, 'conn_friends_{}_{}{}.npy'.format(subject, region_approach, resolution))
# # print(conn_file)
# # a = np.load(conn_file)
# # print(np.shape(a))

# Split timeseries & generate label file

In [None]:
dic_labels = {'body0b':0,'body2b':1,'face0b':2,'face2b':3,'fear':4,'footL':5,'footR':6,
              'handL':7,'handR':8,'match':9,'math':10,'mental':11,'place0b':12,'place2b':13, 
              'random':14,'relational':15,'shape':16,'story':17,'tongue':18,'tool0b':19,'tool2b':20}


label_df = pd.DataFrame(columns=['label', 'filename'])
# print(len(processed_bold_files))

for proc_bold in processed_bold_files:
    
    ts_data = np.load(proc_bold)
#     print(ts_data)
    
    ts_duration = len(ts_data)
#     print(ts_duration)
    
    ts_filename = os.path.basename(proc_bold)
    ts_filename = "".join(ts_filename.split(".")[:-1])
    print(ts_filename)
    
    ts_label = ts_filename.split(subject+'_', 1)[1].split('_'+HRFlag_processes, 1)[0]
    print('ts_label:', ts_label)
    
    valid_label = dic_labels[ts_label]
    
    # Split the timeseries
    rem = ts_duration % window_length
    n_splits = int(np.floor(ts_duration / window_length))
    print('n_splits:', n_splits)
    ts_data = ts_data[:(ts_duration-rem), :]
    print('ts_data shape after removing rem:', np.shape(ts_data), '\n')    
    
    for j, split_ts in enumerate(np.split(ts_data, n_splits)):
        ts_output_file_name = out_file.format(ts_filename, j)
        print('ts_output_file_name:', ts_output_file_name)        
#         print('shape split_ts:', np.shape(split_ts))
#         print('split_ts:', split_ts)

        split_ts = np.swapaxes(split_ts, 0, 1)
        np.save(ts_output_file_name, split_ts)
        curr_label = {'label': valid_label, 'filename': os.path.basename(ts_output_file_name)}
        label_df = label_df.append(curr_label, ignore_index=True)
    print('------------------------------------------------------------------------------------------------')
    
label_df.to_csv(out_csv, index=False)        

In [None]:
# Pytorch dataset: generates items from the current data directory
train_dataset = time_windows_dataset.TimeWindowsDataset(
    data_dir=split_dir
    , partition="train"
    , random_seed=random_seed
    , pin_memory=True
    , normalize=True,shuffle = True)
valid_dataset = time_windows_dataset.TimeWindowsDataset(
    data_dir=split_dir
    , partition="valid"
    , random_seed=random_seed
    , pin_memory=True
    , normalize=True, shuffle = True)
test_dataset = time_windows_dataset.TimeWindowsDataset(
    data_dir=split_dir
    , partition="test"
    , random_seed=random_seed
    , pin_memory=True
    , normalize=True, shuffle = True)
print("train dataset: {}".format(train_dataset))
print("valid dataset: {}".format(valid_dataset))
print("test dataset: {}".format(test_dataset))

In [None]:
# Pytorch dataloader: wraps an iterable around pytorch dataset to shuffle & generate (in parrallel) minibatches
#setting pytoch seed for reproducible torch.utils.data.DataLoader
torch.manual_seed(random_seed)
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True)
valid_generator = torch.utils.data.DataLoader(valid_dataset, batch_size=10, shuffle=True)
test_generator = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=True)
train_features, train_labels = next(iter(train_generator))
print(f"Feature batch shape: {train_features.size()}; mean {torch.mean(train_features)}")
print(f"Labels batch shape: {train_labels.size()}; mean {torch.mean(torch.Tensor.float(train_labels))}")

# Model definition

In [None]:
## model definition
## get average connectome with its k-nearest neighbors
# connectomes = np.load(conn_file)

connectomes = []
for conn_file in conn_files:
    print(conn_file)
    connectomes += [np.load(conn_file)]


# connectomes = RawDataLoad.get_valid_connectomes()
graph = graph_construction.make_group_graph(connectomes, k=8, self_loops=False, symmetric=True)

## Create model
gcn = gcn_model.GCN(graph.edge_index, graph.edge_attr, 
                    n_timepoints=window_length, resolution=resolution)
gcn

In [None]:
print(connectomes)
# print(connectomes.type)
print(np.shape(np.load('../../data/connectomes/conn_friends_sub-02_dypac1024.npy')))

# Training

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)    
    model = model.double() #shima    
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X.double())# shima
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss, current = loss.item(), batch * dataloader.batch_size #Loic

        correct = (pred.argmax(1) == y).type(torch.float).sum().item()
        correct /= X.shape[0]
        print(f"#{batch:>5};\ttrain_loss: {loss:>0.3f};\ttrain_accuracy:{(100*correct):>5.1f}%\t\t[{current:>5d}/{size:>5d}]")

def valid_test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model.forward(X)
            loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    loss /= size
    correct /= size

    return loss, correct

In [None]:
# Train and evaluate the model
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gcn.parameters(), lr=1e-4, weight_decay=5e-4)

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}/{epochs}\n-------------------------------")
    train_loop(train_generator, gcn, loss_fn, optimizer)
    loss, correct = valid_test_loop(valid_generator, gcn, loss_fn)
    print(f"Valid metrics:\n\t avg_loss: {loss:>8f};\t avg_accuracy: {(100*correct):>0.1f}%")
print("Done!")

# Test

In [None]:
loss, correct = valid_test_loop(test_generator, gcn, loss_fn) 
print(f"Test metrics:\n\t avg_loss: {loss:>8f};\t avg_accuracy: {(100*correct):>0.1f}%")

In [None]:
average_loss = 100*float("{:.2f}".format(loss))
correct = 100*float("{:.4f}".format(correct))

print(average_loss)
print(correct)

# Saving the results

In [None]:
def append_list_as_row(file_name, list_of_elem):
    with open(file_name, 'a+', newline='') as write_obj:
        csv_writer = writer(write_obj)
        csv_writer.writerow(list_of_elem)
        
time_window = window_length*1.49         

row_contents = [subject, modality, window_length, region_approach, 
               average_loss, correct, time_window]
append_list_as_row(result_csv, row_contents)

In [None]:
# results = {'subject': subject,'modality': modality,
#            'window_length': window_length,'region_approach': region_approach,
#            'average_loss': average_loss,'average_accuracy': correct}

# result_df = result_df.append(results, ignore_index=True)
# result_df.to_csv(result_csv, index=False)

# Checking

In [None]:
for X,y in test_generator:
    print('X:', X.shape)
    print('y:', y.shape)
    print(X.mean())
    print(y)    

# Model saving

In [None]:
torch.save(gcn.state_dict(), model_path)

In [None]:
print(result_csv)

In [None]:
print(np.__version__)

# Visualizations

In [None]:
df = pd.read_csv(result_csv, sep=',')

In [None]:
# multiple line plots

plt.plot( 'time_window', 'average_accuracy', data=df, marker='', 
         color='#4a996f', linewidth=2, label='Average accuracy')
plt.plot( 'time_window', 'average_loss', data=df, marker='', color='#d1cd10', 
         linewidth=2, linestyle='dashed', label='Average loss')
plt.title(subject, 'all 21 conditions')
plt.xlabel("Window time(sec)")

# show legend
plt.legend()

# show graph
plt.show()