# Interpretable Transformer-based architecture.

Pretrain via MAE: -> https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/mae.py

In [1]:
import sys, os


import sklearn
import mne
import wandb
sys.path.insert(1, os.path.realpath(os.path.pardir))



from nilearn import image
from nilearn import plotting

import numpy as np
import matplotlib.pyplot as plt

import mne
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

from scipy.interpolate import interp1d
from nilearn import datasets, image, masking
from nilearn.input_data import NiftiLabelsMasker


# animation part
from IPython.display import HTML
# from celluloid import Camera   # it is convinient method to animate
from matplotlib import animation, rc
from matplotlib.animation import FuncAnimation



## torch libraries 

import torch

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader, Subset

from pytorch_model_summary import summary


  warn("Fetchers from the nilearn.datasets module will be "


# Feature extractor.
- ConvNext and wav2vec inspired

In [2]:
class Conv_block(nn.Module):
    """
    Input is [batch, emb, time]
    simple conv block from wav2vec 2.0 
        - conv
        - layer norm by embedding axis
        - activation
    To do: 
        add res blocks.
    """
    def __init__(self, in_channels, out_channels, kernel_size=9, stride=1, dilation=1):
        super(Conv_block, self).__init__()
        
        # use it instead stride. 
        self.downsample = nn.AvgPool1d(kernel_size=stride, stride=stride)
        
        self.conv1d = nn.Conv1d(in_channels, out_channels, 
                                kernel_size=kernel_size, 
                                bias=False, 
                                padding='same')
        
        self.norm = nn.LayerNorm(out_channels)
        self.activation = nn.GELU()

        
        
    def forward(self, x):
        """
        - downsample 
        - conv 
        - norm 
        - activation
        """
        
        x = self.downsample(x)
        
        x = self.conv1d(x)
        
        # norm by last axis.
        x = torch.transpose(x, -2, -1) 
        x = self.norm(x)
        x = torch.transpose(x, -2, -1) 
        
        x = self.activation(x)
        
        return x
    
class wav2vec_conv(nn.Module):
    """
    Extract some features from one time serias as raw speech.
    To do make it possibl;e to work with raw EEG electrode recording. 

    """
    def __init__(self,
                 n_inp_features=30,
                 channels = [32, 32, 32], 
                 kernel_sizes=[3, 3, 3],
                 strides=[2, 2, 2], 
                 dilations = [1, 1, 1]):
        
        super(wav2vec_conv, self).__init__()
        
        # freqs-electrodes  reducing to channels[0].
        # add additional layer
        channels = [n_inp_features] + channels
        
        self.model_depth = len(channels)-1
        # create downsample blcoks in Sequentional manner.
        self.downsample_blocks = nn.ModuleList([Conv_block(channels[i], 
                                                        channels[i+1], 
                                                        kernel_sizes[i],
                                                        stride=strides[i], 
                                                        dilation=dilations[i]) for i in range(self.model_depth)])
        
    def forward(self, x):
        """
        1. Encode  
        """
        batch, n_freq, time = x.shape
    
        # encode information
        for block  in self.downsample_blocks:
            x = block(x)
        return x


In [13]:
class Wav2Vec_aggregate(nn.Module):
    """
    Inpyt should be. 
    batch, n_ch, n_freq, time = x.shape
    
    model_embedding - should be work with [batch, n_freqs, time]
    
    Return [batch, n_ch, emb, time//stride] 
    """
    def __init__(self, model_embedding):
        
        super(Wav2Vec_aggregate, self).__init__()
        
        # create downsample blcoks in Sequentional manner.
        self.embedding = model_embedding
        
    def forward(self, x):
        """
        1. Apply for each channnels.  
        """
        batch, n_ch, n_freq, time = x.shape
        emb_list = [self.embedding(x[:, ch]) for ch in range(n_ch)]
        emb_list = torch.stack(emb_list, dim = 1)
        return emb_list

# Extract features from EEG 
- Raw signal 
    - apply wav2net encoder for each electrode separately 
- Wavelet features 
    - Apply wav2net separately for better features extraction 
    - Simple downsampling temporal part.

In [79]:
class Vanilla_transformer(nn.Module):
    """
    Vanilla transformer aggregate all 
    [batch, n_electrodes, embed_dim, time]
    
    Input of transformer is -> [batch, sequence_length, embed_dim]
    
    Return 
        [batch, 21]
    """
    def __init__(self,
                 n_electrodes = 64,
                 sequence_length = 128,
                 embed_dim=256,
                 n_roi=21,
                 num_heads=4,
                 mlp_ratio=2,
                 attn_dropout=0.1,
                 num_layers=1,                
                 mlp_dropout=0.1,
                ):
        super(Vanilla_transformer, self).__init__()
        self.n_electrodes = n_electrodes
        self.embed_dim = embed_dim
        self.sequence_length = sequence_length

        self.class_embed = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)
        
        # just add vector 
        self.pe = nn.Parameter(torch.zeros(1, n_electrodes*sequence_length + 1, embed_dim), requires_grad=True)
        
        # less parameters. factorises pe.

        self.pe_spatial = nn.Parameter(torch.zeros(1, n_electrodes, 1, embed_dim), requires_grad=True)
        self.pe_temporal = nn.Parameter(torch.zeros(1, 1, sequence_length , embed_dim), requires_grad=True)
        self.pe_cls =  nn.Parameter(torch.zeros(1, 1, self.embed_dim),  requires_grad=True)


        
        transformer_layer = nn.TransformerEncoderLayer(d_model=embed_dim, 
                                                            nhead=num_heads, 
                                                            dim_feedforward=embed_dim*mlp_ratio, 
                                                            dropout=attn_dropout, 
                                                            activation='relu', 
                                                            batch_first=True)
        
        self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=num_layers)


        self.norm = nn.LayerNorm(embed_dim)
        # take our vector and make 21 prediction. 
        self.mlp_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*mlp_ratio),
            nn.Dropout(p=mlp_dropout), 
            nn.GELU(),
            nn.Linear(embed_dim*mlp_ratio, n_roi))
            
        
    def forward(self, x):
        """
        x.shape = > [batch, n_electrodes, embed_dim, time]
        
        """
        print('Input size: ', x.shape)
        batch = x.shape[0]
        x = x.transpose(2, 3)
        # print('After reshape: ', x.shape)
        # print(self.sequence_length , self.embed_dim, self.embed_dim)
        x = x.reshape(batch, self.sequence_length * self.n_electrodes, self.embed_dim)
        print('After reshape: ', x.shape)
        
        # repeating like batch size and add seg lenght. 
        class_token = self.class_embed.expand(batch, -1, -1)
        x = torch.cat((class_token, x), dim=1)
        
        # add positional embeddings [batch, sequence_length, embed_dim]
        self.merge_pos_encoding()
        
        x += self.pe_merged
        x = self.transformer(x)
        
        # take first vector -> normalize -> mlp into 
        x_cls = x[:, 0]
        x_cls = self.mlp_head(self.norm(x_cls))
        return x_cls
    
    def merge_pos_encoding(self):
        """
        for calcualation.
        """
        self.pe_merged = self.pe_spatial + self.pe_temporal
        self.pe_merged = self.pe_merged.reshape(1, self.n_electrodes*self.sequence_length, self.embed_dim)
        self.pe_merged = torch.cat([self.pe_cls, self.pe_merged], dim= 1)
        
        
    
    


In [80]:
class Factorized_transformer(nn.Module):
    """
    Vanilla transformer aggregate all 
    [batch, n_electrodes, embed_dim, time]
    
    Input of transformer is -> [batch, sequence_length, embed_dim]
    
    
    
    Spatial transformer ( electrode aggregation ):
        [batch, n_electrodes, embed_dim, time] -> [batch, n_roi, embed_dim, time]
    Temporal transformer: (time aggregation) 
        [batch, n_roi, embed_dim, time] -> [batch, n_roi, embed_dim]
    Prediction
         [batch, n_roi, embed_dim] -> [batch, n_roi]
        
    
    
    
    Return 
        [batch, 21]
    """
    def __init__(self,
                 n_electrodes = 64,
                 sequence_length = 128,
                 embed_dim=256,
                 n_roi=21,
                 num_heads=4,
                 mlp_ratio=2,
                 attn_dropout=0.1,
                 num_layers_spatial=1,
                 num_layers_temporal=1,
                 mlp_dropout=0.1,
                ):
        super(Factorized_transformer, self).__init__()
        self.n_electrodes = n_electrodes
        self.embed_dim = embed_dim
        self.sequence_length = sequence_length
        self.n_roi = n_roi
        
                
        self.roi_tokens = nn.Parameter(torch.zeros(1, n_roi, embed_dim), requires_grad=True)
        self.time_tokens = nn.Parameter(torch.zeros(1, n_roi, 1, embed_dim), requires_grad=True) # use different reg tokent for each roi.
   

        self.pe_spatial = nn.Parameter(torch.zeros(1, n_roi + n_electrodes, embed_dim), requires_grad=True)
        self.pe_temporal = nn.Parameter(torch.zeros(1, 1 + sequence_length, embed_dim), requires_grad=True)
        
        
        spatial_layer = nn.TransformerEncoderLayer(d_model=embed_dim, 
                                                                nhead=num_heads, 
                                                                dim_feedforward=embed_dim*mlp_ratio, 
                                                                dropout=attn_dropout, 
                                                                activation='relu', 
                                                                batch_first=True)
        
        temporal_layer = nn.TransformerEncoderLayer(d_model=embed_dim, 
                                                            nhead=num_heads, 
                                                            dim_feedforward=embed_dim*mlp_ratio, 
                                                            dropout=attn_dropout, 
                                                            activation='relu', 
                                                            batch_first=True)
        
        self.spatial_transformer = nn.TransformerEncoder(spatial_layer, num_layers=num_layers_spatial)
        self.temporal_transformer = nn.TransformerEncoder(temporal_layer, num_layers=num_layers_temporal)



#         self.mlp_heads = nn.ModuleList([nn.Sequential(
#                                         nn.LayerNorm(embed_dim),
#                                         nn.Linear(embed_dim, embed_dim*mlp_ratio),
#                                         nn.Dropout(p=mlp_dropout), 
#                                         nn.GELU(),
#                                         nn.Linear(embed_dim*mlp_ratio, 1)
#                                         ) for i in range(self.n_roi)])
        
        self.mlp_heads = nn.ModuleList([nn.Sequential(
                                        nn.LayerNorm(embed_dim),
                                        nn.Linear(embed_dim, embed_dim//2),
                                        nn.Dropout(p=mlp_dropout), 
                                        nn.GELU(),
                                        nn.Linear(embed_dim//2, 1)
                                        ) for i in range(self.n_roi)])

        
    def forward(self, x):
        """
        x.shape = > [batch, n_electrodes, embed_dim, sequence_length]
        
        """
        
        batch = x.shape[0]
        x = x.permute(0, 3, 1, 2)
        
        # [batch, sequence_length, n_electrodes, embed_dim]
        # spatial transformer.
        
        x_spatial = x.reshape(batch*self.sequence_length, self.n_electrodes, self.embed_dim)
        
        roi_tokens = self.roi_tokens.expand(batch*self.sequence_length, -1, -1)
        x_spatial = torch.cat((roi_tokens, x_spatial), dim=1)
        x_spatial += self.pe_spatial
        x_spatial_transformed = self.spatial_transformer(x_spatial)
        
        roi_tokens = x_spatial_transformed[:, :self.n_roi]
        
        # OUTPUT: [batch*sequence_length, n_roi, embed_dim]
        ## temporal transfromer.
        # take only roi tokens 
        
        roi_tokens = roi_tokens.reshape(batch, self.sequence_length, self.n_roi, self.embed_dim)
        roi_tokens = roi_tokens.permute(0, 2, 1, 3)
        # [batch, n_roi, sequence_length, embed_dim]
        
        time_tokens = self.time_tokens.expand(batch, -1, -1, -1)
        roi_tokens = torch.cat((time_tokens, roi_tokens), dim=2)  # [batch, n_roi, 1+sequence_length, embed_dim]
        roi_tokens = roi_tokens.reshape(batch*self.n_roi, 1 + self.sequence_length, self.embed_dim)
        # [batch*n_roi, 1+sequence_length, embed_dim]
        
        
        roi_tokens += self.pe_temporal
        roi_tokens = self.temporal_transformer(roi_tokens)
        roi_tokens = roi_tokens[:, 0]
        roi_tokens = roi_tokens.reshape(batch, self.n_roi, self.embed_dim)
        
        # prediction 
        preds = []
        for roi in range(self.n_roi):
            res = self.mlp_heads[roi](roi_tokens[:, roi])
            preds.append(res)
        preds= torch.cat(preds, dim=1)
        
        return preds
    
    


In [81]:
def get_strided_func(stride):
    """
    stride parameters pooling 
    pool by last axis .
    """
    def get_strided_input(x):
        return x[..., ::stride]
    return get_strided_input


class Super_model(nn.Module):
    """
    Inpyt should be. 
    batch, n_ch, n_freq, time = x.shape
    
    
    Return [batch, n_ch_out, emb, time//stride] 
    """
    def __init__(self, feature_extractor, transformer):
        
        super(Super_model, self).__init__()
        
        # create downsample blcoks in Sequentional manner.
        
        self.feature_extractor = feature_extractor
        self.transformer = transformer
        
    def forward(self, x):
        """
        1. Apply for each channnels.  
        """
        batch, n_ch, n_freq, time = x.shape
        
        x_features = self.feature_extractor(x)
        print('Size', x_features.shape)
        x_out = self.transformer(x_features)
        
        return x_out

In [82]:
n_electrodes = 64
n_features = 16
window_size = 4096
n_time_points = 32 
stride = window_size//n_time_points

eeg = torch.ones([2, n_electrodes, n_features, window_size])
print(eeg.shape)

torch.Size([2, 64, 16, 4096])


In [83]:
n_electrodes = 64
n_features = 16
window_size = 4096
n_time_points = 32 
stride = window_size//n_time_points

eeg = torch.ones([2, n_electrodes, n_features, window_size])
print(eeg.shape)


config_feature_extractor = dict(n_inp_features=n_features,
                                 channels = [32, 32, 32, 32, 32], 
                                 kernel_sizes=[9, 5, 3, 3, 3],
                                 strides=[4, 4, 2, 2, 2], 
                                 dilations = [1, 1, 1, 1, 1])


config_vanilla_transformer = dict(n_electrodes = n_electrodes,
                                 sequence_length = n_time_points,
                                 embed_dim=n_features,
                                 n_roi=21,
                                 num_heads=4,
                                 mlp_ratio=4,
                                 num_layers=2,
                                 attn_dropout=0.1,                
                                 mlp_dropout=0.5)

config_factorized_transformer = dict(n_electrodes = n_electrodes,
                                 sequence_length = n_time_points,
                                 embed_dim=32,
                                 n_roi=21,
                                 num_heads=4,
                                 mlp_ratio=4,
                                 num_layers_spatial=2,
                                 num_layers_temporal=2,
                                 attn_dropout=0.1,                
                                 mlp_dropout=0.5)

conv_wav2net = wav2vec_conv(**config_feature_extractor)
model_wav2net = Wav2Vec_aggregate(conv_wav2net)



vit = Vanilla_transformer(**config_vanilla_transformer)
factorized_vit = Factorized_transformer(**config_factorized_transformer)


model_v1 = Super_model(feature_extractor = get_strided_func(128), 
                       transformer = vit)


model_v2 = Super_model(feature_extractor = model_wav2net, 
                       transformer = factorized_vit)



y_hat = model_v1(eeg) 
y_hat_2 = model_v2(eeg) 

print('Input size', eeg.shape)
print('Prediction size ', y_hat.shape)
print('Prediction size 2', y_hat_2.shape)


# print('Output size complex', eeg_complex_features.shape)
# print('Output size simple ', eeg_simple_features.shape)



# y_hat = vit(eeg_complex_features)
# y_hat_2 = factorizes_vit(eeg_complex_features)






print(summary(model_v1, eeg, show_input=False))
print(summary(model_v2, eeg, show_input=False))
# print(summary(factorizes_vit, eeg_complex_features, show_input=False))


torch.Size([2, 64, 16, 4096])
Size torch.Size([2, 64, 16, 32])
Input size:  torch.Size([2, 64, 16, 32])
After reshape:  torch.Size([2, 2048, 16])
Size torch.Size([2, 64, 32, 32])
Input size torch.Size([2, 64, 16, 4096])
Prediction size  torch.Size([2, 21])
Prediction size 2 torch.Size([2, 21])
Size torch.Size([2, 64, 16, 32])
Input size:  torch.Size([2, 64, 16, 32])
After reshape:  torch.Size([2, 2048, 16])
-----------------------------------------------------------------------------
            Layer (type)        Output Shape         Param #     Tr. Param #
   Vanilla_transformer-1             [2, 21]          43,397          43,397
Total params: 43,397
Trainable params: 43,397
Non-trainable params: 0
-----------------------------------------------------------------------------
Size torch.Size([2, 64, 32, 32])
--------------------------------------------------------------------------------
               Layer (type)        Output Shape         Param #     Tr. Param #
        Wav2Vec

In [84]:
model_v1 = model_v1.cuda()
print(summary(model_v1, eeg.cuda(), show_input=False))


Size torch.Size([2, 64, 16, 32])
Input size:  torch.Size([2, 64, 16, 32])
After reshape:  torch.Size([2, 2048, 16])
-----------------------------------------------------------------------------
            Layer (type)        Output Shape         Param #     Tr. Param #
   Vanilla_transformer-1             [2, 21]          43,397          43,397
Total params: 43,397
Trainable params: 43,397
Non-trainable params: 0
-----------------------------------------------------------------------------


In [85]:
func = get_strided_func(128)
func(eeg).shape

torch.Size([2, 64, 16, 32])

# Following experiments 

1. Vanilla Transformer v1/v2. 
    - strided input / Extracted features.
    - vanilla transformer.
    

2. Factorized Transformer v1/v2. 
    - strided input / Extracted features.
    - Factorized transformer.
    
    
3. SSL transforemr.
    - strided input + Masked tokens 
    - Vanilla transformer learn to predict MASK.
    - EEG pretrain 
    
    


## Diffusion models 
- Create universal 1D UNet model( simple )
1. Generate fMRI samples from noise Normal  
2. Generate fMRI sample from noise based on EEG data. 
3. Profit 



## Normalizing flows
NF - change of variables. Transform function we parametrize with NN. 
So we can build two different NF: for EEG and for fMRI. 


EEG --> Z --> fMRI   
EEG <-- Z <-- fMRI 

1. Learn generate fmRI. 
2. Learn generate EEG.
3. Get Z for each EEG and for each fMRI. 
4. Learn connection between different Z. 
   
 



