# Install libraries 


In [1]:
import sys, os, json
import mne, sklearn, wandb
import numpy as np
import pandas as pd

from scipy.interpolate import interp1d
from nilearn import datasets, image, masking, plotting
from nilearn.input_data import NiftiLabelsMasker


# animation part
from IPython.display import HTML
import matplotlib
import matplotlib.pyplot as plt
# from celluloid import Camera   # it is convinient method to animate
from matplotlib import animation, rc
from matplotlib.animation import FuncAnimation


## torch libraries 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader, Subset

from pytorch_model_summary import summary




In [2]:
%load_ext autoreload
%autoreload 2
sys.path.insert(1, os.path.realpath(os.path.pardir))

from utils import get_datasets
from utils import preproc
from utils import torch_dataset
from utils import train_utils
from utils import inference
from utils.models_arch import autoencoder_new, autoencoder_v3_separable

# Set all hyperparameters
- Cuda and GPU.
- Parameters of dataset. 
- random seed( if necessary). 


In [3]:
import random

torch.manual_seed(0)
random.seed(0)  # python operation seed
np.random.seed(0)

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

print(torch.cuda.is_available(), torch.cuda.device_count())
torch.cuda.set_device(0)

True 4


In [4]:
config = dict(  
                dataset_name = 'CWL', # CWL
                new_fps=100, 
                freqs = np.logspace(np.log10(2), np.log10(99), 16), 
    
                n_channels = 30, # 30 
                n_roi = 6,
                
                bold_delay = 5,
                to_many = True,
                random_subsample = True,
                sample_per_epoch = 128, 
                WINDOW_SIZE = 512,
                    
                optimizer='adam',
                lr=1e-3,
                weight_decay=0, 
                batch_size=32, 
                
                preproc_type = 'dB_log',
                loss_function = 'mse', 
                model_type = 'CorrModel'
                )


hp_autoencoder = dict(n_electrodes=config['n_channels'],
                      n_freqs = len(config['freqs']),
                      n_channels_out = config['n_roi'],

                     channels=[128, 64, 64, 32], 
                     kernel_sizes=[15, 11, 5],
                     strides=[8, 4, 2], 
                     # dilation=[1, 1, 1], 
                     # decoder_reduce=4
                     )


config = {**hp_autoencoder, **config}

params_train = {'batch_size': config['batch_size'],
                'shuffle': True,
                'num_workers': 0}

params_val = {'batch_size': config['batch_size'],
              'shuffle': False}

# Upload preprocessed dataset from np files. 
It should accelerate speed of experiments.

In [5]:
with open("../data/processed/labels_roi_6.json", 'r') as f:
    labels_roi = json.load(f)


if config['dataset_name']=='CWL':
    dataset_path = '../data/processed/CWL/trio1_100_hz_6_roi_2_99_freqs.npz'
    
elif config['dataset_name']=='NODDI':
    dataset_path = '../data/processed/NODDI/32_100_hz_6_roi_2_99_freqs.npz'
else:
    print('no such dataset')


# download data
data = np.load(dataset_path)

train_dataset_prep = (data['x_train'], data['y_train'])
test_dataset_prep = (data['x_test'], data['y_test'])


# apply time dealy corrected
train_dataset_prep = preproc.bold_time_delay_align(train_dataset_prep, 
                                                   config['new_fps'],
                                                   config['bold_delay'])
test_dataset_prep = preproc.bold_time_delay_align(test_dataset_prep, 
                                                  config['new_fps'],
                                                  config['bold_delay'])


print('Size of train dataset:', train_dataset_prep[0].shape, train_dataset_prep[1].shape)
print('Size of test dataset:', test_dataset_prep[0].shape, test_dataset_prep[1].shape)

# torch dataset creation 
torch_dataset_train = torch_dataset.CreateDataset_eeg_fmri(train_dataset_prep, 
                                                            random_sample=config['random_subsample'], 
                                                            sample_per_epoch=config['sample_per_epoch'], 
                                                            to_many=config['to_many'], 
                                                            window_size = config['WINDOW_SIZE'])

torch_dataset_test = torch_dataset.CreateDataset_eeg_fmri(test_dataset_prep, 
                                                            random_sample=False, 
                                                            sample_per_epoch=None, 
                                                            to_many=config['to_many'], 
                                                            window_size = config['WINDOW_SIZE'])

# because you do not have strid for val data. 
torch_dataset_test = Subset(torch_dataset_test, np.arange(len(torch_dataset_test))[::100])

# init dataloaders for training
train_loader = torch.utils.data.DataLoader(torch_dataset_train, **params_train)
val_loader = torch.utils.data.DataLoader(torch_dataset_test, **params_val)




Size of train dataset: (30, 16, 20690) (6, 20690)
Size of test dataset: (30, 16, 5500) (6, 5500)


## Model investigation 

In [6]:
   
class UpsampleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, scale=2):
        super(UpsampleConvBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=scale, mode='linear')
        self.conv_block = nn.Sequential(nn.Conv1d(in_channels, out_channels, kernel_size, padding='same'),
                                        nn.ReLU(),
                                        nn.Dropout(0.1)
                                        # nn.Conv1d(out_channels, out_channels, kernel_size, padding='same'),
                                        # nn.ReLU())
                                       )
                                        
        
    def forward(self, x):
        x = self.upsample(x)
        x = self.conv_block(x)
        
        return x

class CrossCorrModel(nn.Module):
    def __init__(self, n_electrodes=30,
                 n_freqs = 16,
                 n_channels_out=21, 
                 window_size = 512):
        
        super(CrossCorrModel, self).__init__()

        self.n_electrodes = n_electrodes
        self.n_freqs = n_freqs
        self.window_size = window_size
        
        inp_channels = int(n_freqs * n_electrodes*(n_electrodes-1)/2)
        self.project = nn.Sequential(nn.Conv1d(inp_channels, 128, 1),
                                     nn.Dropout(p=0.2),
                                     nn.ReLU(),
                                     nn.Conv1d(128, 256, 1), 
                                     nn.Dropout(p=0.2),
                                     nn.ReLU())
        
        self.upsample = nn.Sequential(UpsampleConvBlock(128, 128, kernel_size=3, scale=4),
                                      UpsampleConvBlock(128, 128, kernel_size=7, scale=8),
                                      UpsampleConvBlock(128, 64, kernel_size=7, scale=8))

        self.last = nn.Conv1d(64, 6, 1)

        
    def forward(self, x):
        batch, elec, n_freq, time = x.shape
        x = x.transpose(1, 2)
        x = x.reshape(batch*n_freq, elec, -1)

        x_corrs = torch.stack([torch.corrcoef(x_) for x_ in x])
        x_corrs = x_corrs.reshape(batch, n_freq, elec, elec)
        
        x_corrs = torch.where(torch.abs(x_corrs)>0.3, x_corrs, torch.zeros_like(x_corrs))
        
        triu_idxs = torch.triu_indices(elec, elec, offset=1)
        x_corrs_vec =  x_corrs[..., triu_idxs[0], triu_idxs[1]]
        
        x_corrs_vec = x_corrs_vec.reshape(batch, -1, 1)
        x_corrs_vec = torch.nan_to_num(x_corrs_vec)
        
        # [bathc, features, 1] -> [[bathc, hidden//4, 4]] reshapping 
        x_proj = self.project(x_corrs_vec)
        x_proj = x_proj.reshape(batch, -1 , 2)
        
        # generator
        x =  self.upsample(x_proj)
        x = self.last(x)
        
        return x

# Init Model, Loss, optimizers

In [7]:
model = CrossCorrModel(n_electrodes=config['n_electrodes'],
                 n_freqs = config['n_freqs'],
                 n_channels_out=6, 
                 window_size=512)
print(summary(model, torch.zeros(4, config['n_channels'], 
                                 len(config['freqs']),
                                 config['WINDOW_SIZE']).float(), show_input=True))


---------------------------------------------------------------------------
          Layer (type)         Input Shape         Param #     Tr. Param #
              Conv1d-1        [4, 6960, 1]         891,008         891,008
             Dropout-2         [4, 128, 1]               0               0
                ReLU-3         [4, 128, 1]               0               0
              Conv1d-4         [4, 128, 1]          33,024          33,024
             Dropout-5         [4, 256, 1]               0               0
                ReLU-6         [4, 256, 1]               0               0
   UpsampleConvBlock-7         [4, 128, 2]          49,280          49,280
   UpsampleConvBlock-8         [4, 128, 8]         114,816         114,816
   UpsampleConvBlock-9        [4, 128, 64]          57,408          57,408
             Conv1d-10        [4, 64, 512]             390             390
Total params: 1,145,926
Trainable params: 1,145,926
Non-trainable params: 0
-----------------------



# Model training

In [8]:
n_runs = 1

for i in range(n_runs):
    
    model = CrossCorrModel(n_electrodes=config['n_electrodes'],
                     n_freqs = config['n_freqs'],
                     n_channels_out=6,
                     window_size=512)
    
    loss_func = train_utils.make_mse_loss()
    train_step = train_utils.train_step

    optimizer = optim.Adam(model.parameters(), 
                       lr=config['lr'], 
                       weight_decay=config['weight_decay'])
    
    
    parameters = {
        'EPOCHS': 500,
        'model': model, 
        'train_loader': train_loader, 
        'val_loader': val_loader, 
        'loss_function': loss_func,
        'train_step': train_step,
        'optimizer': optimizer, 
        'device': 'cuda', 
        'raw_test_data': test_dataset_prep,
        'show_info': 5, 
        'num_losses': 5,
        'labels': labels_roi,
        'inference_function': inference.model_inference_function, 
        'to_many': config['to_many']
    }



    path_to_save_wandb = 'common/koval_alvi/Checkpoints/wandb_brain'
    
    
    with wandb.init(project="eeg_fmri", config=config, save_code=True):
        
        wandb.define_metric("val/corr_mean", summary="max")

        if i == 0: 
            exp_name = wandb.run.name
        
        wandb.run.name = exp_name +'_run_' + str(i)
        
        print(config)
        print(parameters['model'])
        print(summary(model, torch.zeros(4, config['n_channels'],
                                         len(config['freqs']), config['WINDOW_SIZE']), show_input=True))
        
        model = train_utils.wanb_train_regression(**parameters)
        

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


{'n_electrodes': 30, 'n_freqs': 16, 'n_channels_out': 6, 'channels': [128, 64, 64, 32], 'kernel_sizes': [15, 11, 5], 'strides': [8, 4, 2], 'dataset_name': 'CWL', 'new_fps': 100, 'freqs': array([ 2.        ,  2.59420132,  3.36494024,  4.3646662 ,  5.6614114 ,
        7.34342046,  9.52515552, 12.3550855 , 16.02578954, 20.78706217,
       26.96291204, 34.97361097, 45.36429384, 58.84205542, 76.32406886,
       99.        ]), 'n_channels': 30, 'n_roi': 6, 'bold_delay': 5, 'to_many': True, 'random_subsample': True, 'sample_per_epoch': 128, 'WINDOW_SIZE': 512, 'optimizer': 'adam', 'lr': 0.001, 'weight_decay': 0, 'batch_size': 32, 'preproc_type': 'dB_log', 'loss_function': 'mse', 'model_type': 'CorrModel'}
CrossCorrModel(
  (project): Sequential(
    (0): Conv1d(6960, 128, kernel_size=(1,), stride=(1,))
    (1): Dropout(p=0.2, inplace=False)
    (2): ReLU()
    (3): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
    (4): Dropout(p=0.2, inplace=False)
    (5): ReLU()
  )
  (upsample): Sequenti



---------------------------------------------------------------------------
          Layer (type)         Input Shape         Param #     Tr. Param #
              Conv1d-1        [4, 6960, 1]         891,008         891,008
             Dropout-2         [4, 128, 1]               0               0
                ReLU-3         [4, 128, 1]               0               0
              Conv1d-4         [4, 128, 1]          33,024          33,024
             Dropout-5         [4, 256, 1]               0               0
                ReLU-6         [4, 256, 1]               0               0
   UpsampleConvBlock-7         [4, 128, 2]          49,280          49,280
   UpsampleConvBlock-8         [4, 128, 8]         114,816         114,816
   UpsampleConvBlock-9        [4, 128, 64]          57,408          57,408
             Conv1d-10        [4, 64, 512]             390             390
Total params: 1,145,926
Trainable params: 1,145,926
Non-trainable params: 0
-----------------------



........



....



...
Epoch 5 train loss_0 : 0.975 val loss_0 : 1.32 train loss_1 : 0.00383 val loss_1 : 0.0109 
.........



....



.......
Epoch 10 train loss_0 : 0.928 val loss_0 : 1.33 train loss_1 : 0.0158 val loss_1 : 0.0754 
....................
Epoch 15 train loss_0 : 0.849 val loss_0 : 1.35 train loss_1 : 0.0438 val loss_1 : 0.0917 
.



...................
Epoch 20 train loss_0 : 0.818 val loss_0 : 1.28 train loss_1 : 0.105 val loss_1 : 0.0933 
.



....



....



...........
Epoch 25 train loss_0 : 0.729 val loss_0 : 1.27 train loss_1 : 0.188 val loss_1 : 0.24 
.



................



...
Epoch 30 train loss_0 : 0.65 val loss_0 : 1.26 train loss_1 : 0.31 val loss_1 : 0.256 
.....



...............
Epoch 35 train loss_0 : 0.58 val loss_0 : 1.26 train loss_1 : 0.347 val loss_1 : 0.156 
....................
Epoch 40 train loss_0 : 0.556 val loss_0 : 1.22 train loss_1 : 0.353 val loss_1 : 0.202 
....................
Epoch 45 train loss_0 : 0.524 val loss_0 : 1.24 train loss_1 : 0.388 val loss_1 : 0.156 
.............



.......
Epoch 50 train loss_0 : 0.492 val loss_0 : 1.15 train loss_1 : 0.453 val loss_1 : 0.338 
.



............



.......
Epoch 55 train loss_0 : 0.447 val loss_0 : 1.23 train loss_1 : 0.515 val loss_1 : 0.248 
....................
Epoch 60 train loss_0 : 0.468 val loss_0 : 1.18 train loss_1 : 0.441 val loss_1 : 0.24 
.....



...............
Epoch 65 train loss_0 : 0.406 val loss_0 : 1.16 train loss_1 : 0.509 val loss_1 : 0.226 
....................
Epoch 70 train loss_0 : 0.461 val loss_0 : 1.23 train loss_1 : 0.429 val loss_1 : 0.235 
.................



...
Epoch 75 train loss_0 : 0.482 val loss_0 : 1.34 train loss_1 : 0.506 val loss_1 : 0.136 
....................
Epoch 80 train loss_0 : 0.457 val loss_0 : 1.2 train loss_1 : 0.439 val loss_1 : 0.241 
....................
Epoch 85 train loss_0 : 0.449 val loss_0 : 1.19 train loss_1 : 0.501 val loss_1 : 0.208 
....................
Epoch 90 train loss_0 : 0.387 val loss_0 : 1.24 train loss_1 : 0.514 val loss_1 : 0.308 
....................
Epoch 95 train loss_0 : 0.401 val loss_0 : 1.18 train loss_1 : 0.512 val loss_1 : 0.192 
....................
Epoch 100 train loss_0 : 0.397 val loss_0 : 1.15 train loss_1 : 0.554 val loss_1 : 0.283 
....................
Epoch 105 train loss_0 : 0.397 val loss_0 : 1.15 train loss_1 : 0.534 val loss_1 : 0.264 
....................
Epoch 110 train loss_0 : 0.372 val loss_0 : 1.16 train loss_1 : 0.535 val loss_1 : 0.322 
....................
Epoch 115 train loss_0 : 0.373 val loss_0 : 1.17 train loss_1 : 0.524 val loss_1 : 0.286 
....................
Epo

VBox(children=(Label(value=' 3.18MB of 3.18MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train/loss_0,█▆▅▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▂▁▁▁▁▁▁▁▁▂▁▁
train/loss_1,▁▂▃▄▅▆▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇██▇█▇▇▇▇███▇██
val/corr_mean,▂▃▃▄▃▁▃▃▂▄▄▄▇▆▇██
val/loss_0,▆█▆▃▅▂▃▂▃▂▄▁▃▁▅▆▄▃▅▃▄▅▃▅▃▅▃▃▂▂▄▅▄▄▅▄▅▅▄▄
val/loss_1,▁▁▅▄▃█▅▅▅▅▆▇▆▆▆▇▆▆▆▇▅▆▅▆▆▅▇▆▅▆▆▅▅▅▆▆▆▆▆▆

0,1
train/loss_0,0.25174
train/loss_1,0.69885
val/loss_0,1.22658
val/loss_1,0.28057


# 