# Install libraries 


In [1]:
import sys, os, json
import mne, sklearn, wandb
import numpy as np
import pandas as pd

from scipy.interpolate import interp1d
from nilearn import datasets, image, masking, plotting
from nilearn.input_data import NiftiLabelsMasker


# animation part
from IPython.display import HTML
import matplotlib
import matplotlib.pyplot as plt
# from celluloid import Camera   # it is convinient method to animate
from matplotlib import animation, rc
from matplotlib.animation import FuncAnimation


## torch libraries 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader, Subset

from pytorch_model_summary import summary




In [2]:
%load_ext autoreload
%autoreload 2
sys.path.insert(1, os.path.realpath(os.path.pardir))

from utils import get_datasets
from utils import preproc
from utils import torch_dataset
from utils import train_utils
from utils import inference
from utils.models_arch import autoencoder_new, autoencoder_v3_separable

# Set all hyperparameters
- Cuda and GPU.
- Parameters of dataset. 
- random seed( if necessary). 


In [3]:
import random

torch.manual_seed(0)
random.seed(0)  # python operation seed
np.random.seed(0)

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

print(torch.cuda.is_available(), torch.cuda.device_count())
torch.cuda.set_device(0)

True 4


In [4]:
config = dict(  
                dataset_name = 'CWL', # CWL
                new_fps=100, 
                freqs = np.logspace(np.log10(2), np.log10(99), 16), 
    
                n_channels = 30, # 30 
                n_roi = 6,
                
                bold_delay = 5,
                to_many = True,
                random_subsample = True,
                sample_per_epoch = 256, 
                WINDOW_SIZE = 512,
                    
                optimizer='adamW',
                lr=3e-4,
                weight_decay=1e-4, 
                batch_size=16, 
                
                preproc_type = 'dB_log',
                loss_function = 'mse', 
                model_type = 'CorrModel_AE1D_standard'
                )


hp_autoencoder = dict(n_electrodes=30,
                     n_freqs = 16,
                     n_channels_out=6,
                     corr_proj_size = 64,
                     channels = [128, 128, 128, 64], 
                     kernel_sizes=[5, 5, 3],
                     strides=[8, 8, 4], 
                     dilation=[1, 1, 1], 
                     decoder_reduce=4, 
                     window_size=config['WINDOW_SIZE'])


config = {**hp_autoencoder, **config}

params_train = {'batch_size': config['batch_size'],
                'shuffle': True,
                'num_workers': 0}

params_val = {'batch_size': config['batch_size'],
              'shuffle': False}

# Upload preprocessed dataset from np files. 
It should accelerate speed of experiments.

In [5]:
with open("../data/processed/labels_roi_6.json", 'r') as f:
    labels_roi = json.load(f)


if config['dataset_name']=='CWL':
    dataset_path = '../data/processed/CWL/trio1_100_hz_6_roi_2_99_freqs.npz'
    
elif config['dataset_name']=='NODDI':
    dataset_path = '../data/processed/NODDI/32_100_hz_6_roi_2_99_freqs.npz'
else:
    print('no such dataset')


# download data
data = np.load(dataset_path)

train_dataset_prep = (data['x_train'], data['y_train'])
test_dataset_prep = (data['x_test'], data['y_test'])


# apply time dealy corrected
train_dataset_prep = preproc.bold_time_delay_align(train_dataset_prep, 
                                                   config['new_fps'],
                                                   config['bold_delay'])
test_dataset_prep = preproc.bold_time_delay_align(test_dataset_prep, 
                                                  config['new_fps'],
                                                  config['bold_delay'])


print('Size of train dataset:', train_dataset_prep[0].shape, train_dataset_prep[1].shape)
print('Size of test dataset:', test_dataset_prep[0].shape, test_dataset_prep[1].shape)

# torch dataset creation 
torch_dataset_train = torch_dataset.CreateDataset_eeg_fmri(train_dataset_prep, 
                                                            random_sample=config['random_subsample'], 
                                                            sample_per_epoch=config['sample_per_epoch'], 
                                                            to_many=config['to_many'], 
                                                            window_size = config['WINDOW_SIZE'])

torch_dataset_test = torch_dataset.CreateDataset_eeg_fmri(test_dataset_prep, 
                                                            random_sample=False, 
                                                            sample_per_epoch=None, 
                                                            to_many=config['to_many'], 
                                                            window_size = config['WINDOW_SIZE'])

# because you do not have strid for val data. 
torch_dataset_test = Subset(torch_dataset_test, np.arange(len(torch_dataset_test))[::100])

# init dataloaders for training
train_loader = torch.utils.data.DataLoader(torch_dataset_train, **params_train)
val_loader = torch.utils.data.DataLoader(torch_dataset_test, **params_val)




Size of train dataset: (30, 16, 20690) (6, 20690)
Size of test dataset: (30, 16, 5500) (6, 5500)


## Model investigation 

In [6]:
import torch 

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader, Subset
from pytorch_model_summary import summary
import numpy as np




class ConvBlock(nn.Module):
    """
    Input is [batch, emb, time]
    simple conv block from wav2vec 2.0 
        - conv
        - layer norm by embedding axis
        - activation
    To do: 
        add res blocks.
    """
    def __init__(self, in_channels, out_channels, kernel_size, 
                 stride=1, dilation=1, p_conv_drop=0.1):
        super(ConvBlock, self).__init__()
        
        
        self.conv1d = nn.Conv1d(in_channels, out_channels, 
                                kernel_size=kernel_size, 
                                bias=False, 
                                padding='same')
        
        self.norm = nn.LayerNorm(out_channels)
        self.activation = nn.GELU()
        self.drop = nn.Dropout(p=p_conv_drop)
        
        self.downsample = nn.MaxPool1d(kernel_size=stride, stride=stride)


        
    def forward(self, x):
        """
        - conv 
        - norm 
        - activation
        - downsample 
        """
        x = self.conv1d(x)
        # norm by last axis.
        x = torch.transpose(x, -2, -1) 
        x = self.norm(x)
        x = torch.transpose(x, -2, -1) 
        
        x = self.activation(x)
        x = self.drop(x)
        
        x = self.downsample(x)
        
        return x


    
class UpConvBlock(nn.Module):
    def __init__(self, scale, **args):
        super(UpConvBlock, self).__init__()
        self.conv_block = ConvBlock(**args)
        self.upsample = nn.Upsample(scale_factor=scale, mode='linear', align_corners=False)

            
    def forward(self, x):
        
        x = self.conv_block(x)
        x = self.upsample(x)
        return x    
    
def extract_cross_corr(x, th=0.):
    """
    Calculate cross correllation between electrodes for each band 
    Return triu and reshape into one vector.
    
    Input:
        x.shape [batch, elec, n_freqs, time]
        th - threshold for removeing trash connectivity. 
    Returns: 
    x_cross_vec. shape [batch, n_freqs, elec*(elec-1)//2] 
    """
    batch, elec, n_freq, time = x.shape
        
    # cross corr features 
    x = x.transpose(1, 2)
    x = x.reshape(batch*n_freq, elec, -1)

    x_corrs = torch.stack([torch.corrcoef(x_) for x_ in x])
    x_corrs = torch.nan_to_num(x_corrs)
    x_corrs = x_corrs.reshape(batch, n_freq, elec, elec)

    x_corrs = torch.where(torch.abs(x_corrs)>th, x_corrs, torch.zeros_like(x_corrs))

    triu_idxs = torch.triu_indices(elec, elec, offset=1)
    x_corrs_vec =  x_corrs[..., triu_idxs[0], triu_idxs[1]]

    return x_corrs_vec  


    
    
class CrossCorrAutoEncoder1D(nn.Module):
    """
    This is implementation of AutoEncoder1D model for time serias regression
    
    decoder_reduce  -size of reducing parameter on decoder stage. We do not want use a lot of features here.
    
    corr_proj_size - 
    """
    def __init__(self,
                 window_size,
                 n_electrodes=30,
                 n_freqs = 16,
                 n_channels_out=21,
                 corr_proj_size = 128,
                 channels = [8, 16, 32, 32], 
                 kernel_sizes=[3, 3, 3],
                 strides=[4, 4, 4], 
                 dilation=[1, 1, 1], 
                 decoder_reduce=1,):
        
        super(CrossCorrAutoEncoder1D, self).__init__()
        

        self.n_electrodes = n_electrodes
        self.n_freqs = n_freqs
        self.n_inp_features = n_freqs*n_electrodes
        self.n_channels_out = n_channels_out
        self.window_size = window_size
        
        ## cross corr parameters. 
        cross_corr_vector_dim = int(n_freqs * n_electrodes*(n_electrodes-1)/2)
        self.project = nn.Sequential(nn.Conv1d(cross_corr_vector_dim, int(2*corr_proj_size), 1),
                                     nn.Dropout(p=0.25),
                                     nn.ReLU(),
                                     )
        
        
        # Encoder 
        self.model_depth = len(channels)-1
        channels = np.array(channels)
        
        self.spatial_reduce = ConvBlock(self.n_inp_features, channels[0], kernel_size=3)
        self.encoder = nn.Sequential(*[ConvBlock(channels[i], 
                                                 channels[i+1], 
                                                 kernel_sizes[i],
                                                 stride=strides[i], 
                                                 dilation=dilation[i], 
                                                 p_conv_drop=0.1) for i in range(self.model_depth)])
        ## Decoder 
        # Reduce number of channels ( do not touch last one. 
        channels[:-1] = channels[:-1]//decoder_reduce
        channels[-1] = channels[-1] + corr_proj_size
        print('Channels for decoder', channels)
        
        self.mapping = ConvBlock(channels[-1], channels[-1], 1)
        
        # channels
        self.decoder = nn.Sequential(*[UpConvBlock(scale=strides[i],
                                                   in_channels=channels[i+1],
                                                   out_channels=channels[i],
                                                   kernel_size=kernel_sizes[i], 
                                                   p_conv_drop=0.1) for i in range(self.model_depth-1, -1, -1)])
        
        
        self.conv1x1_one = nn.Conv1d(channels[0], 
                                     self.n_channels_out, 
                                     kernel_size=1,
                                     padding='same')

    def forward(self, x):
        """
        """
        batch, elec, n_freq, time = x.shape
        assert time == self.window_size, "PROBLEM with size "
        
        # cross corr features 
        x_corr_vec  = extract_cross_corr(x, th = 0.3)
        x_corr_vec = x_corr_vec.reshape(batch, -1, 1)
        
        x_proj = self.project(x_corr_vec) # [batch, corr_proj_size, 1]
        
        
        
        # wavelet features
        x = x.reshape(batch, -1, time)
        x = self.spatial_reduce(x)
        
        x = self.encoder(x)
        
        # aggregate features. 
        # x_proj_rep = x_proj.repeat(1, 1, x.shape[-1]) # copy to time axis.
        x_proj_rep = x_proj.reshape(batch, -1, 2)
        x = torch.cat([x, x_proj_rep], dim=1)
        x = self.mapping(x)
        
        x = self.decoder(x)
        x = self.conv1x1_one(x)
        
        return x


# Init Model, Loss, optimizers

In [8]:
model = CrossCorrAutoEncoder1D(**hp_autoencoder)

x = torch.zeros(4, 30, 16, 512)
print(model(x).shape)

print(summary(model, x, show_input=True))

Channels for decoder [ 32  32  32 128]
torch.Size([4, 6, 512])
-----------------------------------------------------------------------
      Layer (type)         Input Shape         Param #     Tr. Param #
          Conv1d-1        [4, 6960, 1]         891,008         891,008
         Dropout-2         [4, 128, 1]               0               0
            ReLU-3         [4, 128, 1]               0               0
       ConvBlock-4       [4, 480, 512]         184,576         184,576
       ConvBlock-5       [4, 128, 512]          82,176          82,176
       ConvBlock-6        [4, 128, 64]          82,176          82,176
       ConvBlock-7         [4, 128, 8]          24,704          24,704
       ConvBlock-8         [4, 128, 2]          16,640          16,640
     UpConvBlock-9         [4, 128, 2]          12,352          12,352
    UpConvBlock-10          [4, 32, 8]           5,184           5,184
    UpConvBlock-11         [4, 32, 64]           5,184           5,184
         Conv

# Model training

In [9]:
n_runs = 1

for i in range(n_runs):
    
    model = CrossCorrAutoEncoder1D(**hp_autoencoder)
    
    loss_func = train_utils.make_mse_loss()
    train_step = train_utils.train_step

    optimizer = optim.AdamW(model.parameters(), 
                       lr=config['lr'], 
                       weight_decay=config['weight_decay'])
    
    
    parameters = {
        'EPOCHS': 500,
        'model': model, 
        'train_loader': train_loader, 
        'val_loader': val_loader, 
        'loss_function': loss_func,
        'train_step': train_step,
        'optimizer': optimizer, 
        'device': 'cuda', 
        'raw_test_data': test_dataset_prep,
        'show_info': 5, 
        'num_losses': 5,
        'labels': labels_roi,
        'inference_function': inference.model_inference_function, 
        'to_many': config['to_many']
    }



    path_to_save_wandb = 'common/koval_alvi/Checkpoints/wandb_brain'
    
    
    with wandb.init(project="eeg_fmri", config=config, save_code=True):
        
        wandb.define_metric("val/corr_mean", summary="max")

        if i == 0: 
            exp_name = wandb.run.name
        
        wandb.run.name = exp_name +'_run_' + str(i)
        
        print(config)
        print(parameters['model'])
        print(summary(model, torch.zeros(4, config['n_channels'],
                                         len(config['freqs']), config['WINDOW_SIZE']), show_input=True))
        
        model = train_utils.wanb_train_regression(**parameters)
        

Channels for decoder [ 32  32  32 128]


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


{'n_electrodes': 30, 'n_freqs': 16, 'n_channels_out': 6, 'corr_proj_size': 64, 'channels': [128, 128, 128, 64], 'kernel_sizes': [5, 5, 3], 'strides': [8, 8, 4], 'dilation': [1, 1, 1], 'decoder_reduce': 4, 'window_size': 512, 'dataset_name': 'CWL', 'new_fps': 100, 'freqs': array([ 2.        ,  2.59420132,  3.36494024,  4.3646662 ,  5.6614114 ,
        7.34342046,  9.52515552, 12.3550855 , 16.02578954, 20.78706217,
       26.96291204, 34.97361097, 45.36429384, 58.84205542, 76.32406886,
       99.        ]), 'n_channels': 30, 'n_roi': 6, 'bold_delay': 5, 'to_many': True, 'random_subsample': True, 'sample_per_epoch': 256, 'WINDOW_SIZE': 512, 'optimizer': 'adamW', 'lr': 0.0003, 'weight_decay': 0.0001, 'batch_size': 16, 'preproc_type': 'dB_log', 'loss_function': 'mse', 'model_type': 'CorrModel_AE1D_standard'}
CrossCorrAutoEncoder1D(
  (project): Sequential(
    (0): Conv1d(6960, 128, kernel_size=(1,), stride=(1,))
    (1): Dropout(p=0.25, inplace=False)
    (2): ReLU()
  )
  (spatial_reduce)

VBox(children=(Label(value=' 4.83MB of 4.83MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train/loss_0,█▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss_1,▁▄▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇█████████████████████
val/corr_mean,▁▂▁▁▁▄▇▆▆▆▇▇▆█▇▇▆▆▇▇▇▇▇▇██▇█
val/loss_0,▃▃▁▂▃▂▁▁▄▂▃▂▂▁▃▃▄▃▄▃▃▄▅▆▆▇▄▆▄█▅▇▅▆▇██▆██
val/loss_1,▁▆▇▇▇▇█████████▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▅▇▇

0,1
train/loss_0,0.07262
train/loss_1,0.87282
val/loss_0,1.36539
val/loss_1,0.42608


# 