In [1]:
import os
import glob
import argparse
import yaml
import sys
import math

import timm #only needed if downloading pretrained models
from datetime import datetime

sys.path.append('../')
sys.path.append('./')
sys.path.append('../../')

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd

from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau

from models.sit import SiT
from utils2.renm_utils import * #load_weights_imagenet
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open("/Users/fyzeen/FyzeenLocal/GitHub/NeuroTranslate/LocalTransformerTests/config/SiT/training/ICAd15_schfd100.yml") as f:
    config = yaml.safe_load(f)

config

{'resolution': {'ico': 6, 'sub_ico': 2},
 'data': {'data_path': '../data/{}/{}',
  'task': 'ICAd15_schfd100',
  'configuration': 'template',
  'dataset': 'HCPdb',
  'hemisphere': '1L'},
 'logging': {'folder_to_save_model': '../logs/SiT/'},
 'computation_opt': {'scale_grad_choice': False},
 'training': {'LR': 1e-05,
  'bs': 1,
  'bs_val': 1,
  'epochs': 10,
  'gpu': 0,
  'l1loss': False,
  'testing': True,
  'val_epoch': 2,
  'load_weights_ssl': False,
  'load_weights_imagenet': True,
  'save_ckpt': True,
  'finetuning': False,
  'dataset_ssl': 'hcpdb',
  'epoch_check': True},
 'weights': {'ssl_mpp': '/scratch/naranjorincon/surface-vision-transformers/logs/SiT/0509-01:57-small-1L-ICAd15_schfd100-imgnet/checkpoint.pth',
  'imagenet': 'vit_small_patch16_224'},
 'transformer': {'dim': 384,
  'depth': 12,
  'heads': 6,
  'mlp_dim': 1536,
  'pool': 'mean',
  'num_features': 15,
  'num_classes': 4950,
  'num_channels': 15,
  'dim_head': 64,
  'dropout': 0.3,
  'emb_dropout': 0.1,
  'model': '

In [3]:
train_data = np.load("/Users/fyzeen/FyzeenLocal/GitHub/NeuroTranslate/LocalTransformerTests/data/surf2mat/template/train_data.npy")
train_label = np.load("/Users/fyzeen/FyzeenLocal/GitHub/NeuroTranslate/LocalTransformerTests/data/surf2mat/template/train_labels.npy")


In [4]:
def add_start_token_torch(tensor, start_value=1):
    """
    Add a new column with a start value to the beginning of each sequence in the input tensor.
    
    :param tensor: Tensor of shape (batch_size, seq_length), input tensor
    :param start_value: int, value to add at the start of each sequence
    :return: Tensor of shape (batch_size, seq_length + 1), tensor with a new column added to the start of each sequence
    """
    batch_size, seq_length = tensor.size()
    new_column = torch.full((batch_size, 1), start_value, dtype=tensor.dtype, device=tensor.device)  # Create a new column with the start value
    out = torch.cat([new_column, tensor], dim=1)  # Concatenate the new column with the input tensor
    return out

def add_start_token_np(array, start_value=1):
    """
    Add a new column with a start value to the beginning of each sequence in the input array.
    
    :param array: Array of shape (batch_size, seq_length), input array
    :param start_value: int, value to add at the start of each sequence
    :return: Array of shape (batch_size, seq_length + 1), array with a new column added to the start of each sequence
    """
    batch_size, seq_length = array.shape
    new_column = np.full((batch_size, 1), start_value, dtype=array.dtype)  # Create a new column with the start value
    out = np.concatenate((new_column, array), axis=1)  # Concatenate the new column with the input array
    return out

train_label = add_start_token_np(train_label)


In [10]:
bs=2
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
#device = "cpu"
train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_data).float(), torch.from_numpy(train_label).float())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = bs, shuffle=True, num_workers=5)

In [11]:
#### MODIFIED SiT model from Dahan
import torch
from torch import nn

from einops import repeat
from einops.layers.torch import Rearrange

from vit_pytorch.vit import Transformer

class EncoderSiT(nn.Module):
    def __init__(self, *,
                        dim, 
                        depth,
                        heads,
                        mlp_dim,
                        num_patches = 320,
                        num_channels = 4,
                        num_vertices = 153,
                        dim_head = 64,
                        sequence_length = 1225,
                        dropout = 0.1,
                        emb_dropout = 0.1
                        ):

        super().__init__()

        patch_dim = num_channels * num_vertices

        self.sequence_length = sequence_length
        self.dim = dim

        # inputs has size = b * c * n * v where b = batch, c = channels, f = features, n=patches, v=verteces
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c n v  -> b n (v c)'),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) # See here: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

        self.linear = nn.Linear(num_patches * dim, sequence_length * dim)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        x += self.pos_embedding[:, :] # was originally sliced by [:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        # Reshape the input tensor to (batch_size, num_patches * dim)
        x_reshaped = x.view(b, -1)
        # Apply the linear layer
        output = self.linear(x_reshaped)
        # Reshape the output tensor to (batch_size, sequence_length, dim)
        output = output.view(b, self.sequence_length, self.dim)

        return output

In [12]:
model = EncoderSiT(dim=20,
                   depth=5,
                   heads=2, 
                   mlp_dim=40)

for i, data in enumerate(train_loader):
    inputs, targets = data[0].to(device), data[1].to(device).squeeze()
    model.to(device)
    output = model(inputs)
    print(output.shape)

torch.Size([2, 1225, 20])
torch.Size([2, 1225, 20])
torch.Size([2, 1225, 20])
torch.Size([2, 1225, 20])
torch.Size([2, 1225, 20])
torch.Size([2, 1225, 20])
torch.Size([2, 1225, 20])
torch.Size([2, 1225, 20])
torch.Size([2, 1225, 20])


In [13]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # Create a matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        # Create a vector of shape (d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
        # Add a batch dimension to the positional encoding
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        # Register the positional encoding as a buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)

class TransformerDecoderBlock(nn.Module):
    def __init__(self, input_dim, d_model, nhead, dim_feedforward, dropout=0.1):
        super(TransformerDecoderBlock, self).__init__()
        self.d_model = d_model
        self.input_dim = input_dim

        self.flatten_to_high_dim = nn.Linear(input_dim, input_dim * d_model)
        self.positional_encoding = PositionalEncoding(d_model=d_model, seq_len=input_dim, dropout=dropout)
        
        self.masked_multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.cross_multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
    
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        b, _ = tgt.size()

        # Project to high-dimensional space
        tgt = self.flatten_to_high_dim(tgt)
        tgt = tgt.view(b, -1, self.d_model)
        print(tgt.shape)
        
        # Apply positional encoding
        tgt = self.positional_encoding(tgt)
        

        tgt_mask = generate_subsequent_mask(self.input_dim).to(device)
        # Masked Multi-Head Attention
        tgt2, _ = self.masked_multihead_attn(tgt, tgt, tgt, attn_mask=tgt_mask)
        tgt = tgt + self.dropout1(tgt2)  # Residual connection
        tgt = self.norm1(tgt)
        print(tgt.shape)
        
        # Cross-Multi-Head Attention
        tgt2, _ = self.cross_multihead_attn(tgt, memory, memory)
        tgt = tgt + self.dropout2(tgt2)  # Residual connection
        tgt = self.norm2(tgt)
        
        # Feed Forward
        tgt2 = self.feed_forward(tgt)
        tgt = tgt + self.dropout3(tgt2)  # Residual connection
        tgt = self.norm3(tgt)
        
        return tgt

def generate_subsequent_mask(size):
    """
    Generate a mask to ensure that each position in the sequence can only attend to
    positions up to and including itself. This is a lower triangular matrix filled with ones.
    
    :param size: int, the length of the sequence
    :return: tensor of shape (size, size), where element (i, j) is False if j <= i, and True otherwise (See attn_mask option here: https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)
    """
    mask = torch.triu(torch.ones(size, size)).bool()
    mask.diagonal().fill_(False)
    return mask




In [14]:
encoder = EncoderSiT(dim=20,
                     depth=5,
                     heads=2, 
                     mlp_dim=40)

decoder = TransformerDecoderBlock(input_dim=1226,
                                  d_model=20,
                                  nhead=2,
                                  dim_feedforward=80)

for i, data in enumerate(train_loader):
    inputs, targets = data[0].to(device), data[1].to(device).squeeze()

    encoder.to(device)
    decoder.to(device)

    encoder_out = encoder(inputs)
    output = decoder(tgt = targets.to(device), memory=encoder_out)
    print(output.shape)

torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])
torch.Size([2, 1226, 20])


In [17]:
def train():
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    device = torch.device(device)

    
