In [1]:
from models import *

import os
import glob
import argparse
import yaml
import sys
import math

import matplotlib.pyplot as plt

import timm #only needed if downloading pretrained models
from datetime import datetime

sys.path.append('../')
sys.path.append('./')
sys.path.append('../../')

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd

from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau

from utils2.renm_utils import * #load_weights_imagenet
import random

from einops import repeat
from einops.layers.torch import Rearrange

from vit_pytorch.vit import Transformer


  from .autonotebook import tqdm as notebook_tqdm


In [15]:
train_data = np.load("/Users/fyzeen/FyzeenLocal/GitHub/NeuroTranslate/LocalTransformerTests/data/surf2mat/template/train_data.npy")
train_label = np.load("/Users/fyzeen/FyzeenLocal/GitHub/NeuroTranslate/LocalTransformerTests/data/surf2mat/template/train_labels.npy")



def add_start_token_torch(tensor, start_value=1):
    """
    Add a new column with a start value to the beginning of each sequence in the input tensor.
    
    :param tensor: Tensor of shape (batch_size, seq_length), input tensor
    :param start_value: int, value to add at the start of each sequence
    :return: Tensor of shape (batch_size, seq_length + 1), tensor with a new column added to the start of each sequence
    """
    batch_size, seq_length = tensor.size()
    new_column = torch.full((batch_size, 1), start_value, dtype=tensor.dtype, device=tensor.device)  # Create a new column with the start value
    out = torch.cat([new_column, tensor], dim=1)  # Concatenate the new column with the input tensor
    return out

def add_start_token_np(array, start_value=1):
    """
    Add a new column with a start value to the beginning of each sequence in the input array.
    
    :param array: Array of shape (batch_size, seq_length), input array
    :param start_value: int, value to add at the start of each sequence
    :return: Array of shape (batch_size, seq_length + 1), array with a new column added to the start of each sequence
    """
    batch_size, seq_length = array.shape
    new_column = np.full((batch_size, 1), start_value, dtype=array.dtype)  # Create a new column with the start value
    out = np.concatenate((new_column, array), axis=1)  # Concatenate the new column with the input array
    return out

# taking only first two subjects and first 100 cells in the parcellation
train_label = train_label[:1, :100]
train_data = train_data[:1, :, :, :]

train_label = add_start_token_np(train_label)


In [16]:
bs=1
#device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
device = "cpu"
train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_data).float(), torch.from_numpy(train_label).float())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = bs, shuffle=True, num_workers=5)

In [4]:
def generate_subsequent_mask(size):
    """
    Generate a mask to ensure that each position in the sequence can only attend to
    positions up to and including itself. This is a lower triangular matrix filled with ones.
    
    :param size: int, the length of the sequence
    :return: tensor of shape (size, size), where element (i, j) is False if j <= i, and True otherwise (See attn_mask option here: https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)
    """
    mask = torch.triu(torch.ones(size, size)).bool()
    mask.diagonal().fill_(False)
    return mask

In [5]:
def greedy_decode(model, source, input_dim, device, b=2): #b=batch size
    '''
    Greedy decode algorithm for a full encoder-decoder architecture (inference).

    Implements using ALL options (initialization with torch.ones vs zeros vs randn; generating ONLY i+1th val VS :i+1 tokens on each iteration)
    '''
    encoder_output = model.encode(source)
    initialization_list = [torch.zeros(b,input_dim).to(device), torch.ones(b,input_dim).to(device), torch.randn(b,input_dim).to(device)]
    
    # build target mask
    decoder_mask = generate_subsequent_mask(initialization_list[0].size(1)).to(device)

    out_list = []

    for i, decoder_input in enumerate(initialization_list):
        decoder_input_copy = decoder_input.clone()
        for i in range(input_dim-1):
            # compute next output
            out = model.decode(encoder_out=encoder_output, tgt=decoder_input, tgt_mask=decoder_mask)
            decoder_input[:, i+1] = out.squeeze(1)[:, i+1]

        out_list.append(decoder_input.squeeze(0))

        for i in range(input_dim-1):
            out = model.decode(encoder_out=encoder_output, tgt=decoder_input_copy, tgt_mask=decoder_mask)
            decoder_input_copy[:, :i+1] = out.squeeze(1)[:, :i+1] 
        
        out_list.append(decoder_input_copy.squeeze(0))

    return out_list

In [6]:
def plot_mses(mses):
    fig, ax = plt.subplots()
    ax.set_title('Mean Squared Errors (MSEs) Over Iterations')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('MSE')
    ax.grid(True)

    plt.plot(mses)
    
    plt.show()

In [7]:
model = FullTransformer(dim_model=20, 
                        encoder_depth=5,
                        nhead = 2,
                        encoder_mlp_dim=80,
                        decoder_input_dim=101,
                        decoder_dim_feedforward=80,
                        decoder_depth=5,
                        dropout=0.1)

In [36]:
# like a train function
model.to(device)
#model._reset_parameters() reset if you wanna restart

trainlosses = []

optimizer = torch.optim.Adam(model.parameters(), lr=0.00001, eps=1e-9)
global_step = 0

loss_fn = nn.L1Loss()

for epoch in range(300):
    model.train()

    for i, data in enumerate(train_loader):
        inputs, targets = data[0].to(device), data[1].to(device).squeeze().unsqueeze(0) # USE THIS unsqueeze(0) ONLY if batch size = 1

        print(inputs.shape)
        print(targets.shape)
        
        pred = model(src=inputs, tgt=targets, tgt_mask=generate_subsequent_mask(101).to(device))

        loss = loss_fn(pred, targets)

        print(f"Train Loss, Epoch {epoch}: {loss}")
        trainlosses.append(loss.detach().cpu().numpy())

        loss.backward()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        global_step+=1
    
    if (epoch % 25 == 0) and epoch>0:
        #plot_mses(trainlosses)
        break

    #torch.save(model.state_dict(), "/Users/fyzeen/FyzeenLocal/GitHub/NeuroTranslate/LocalTransformerTests/_FyzTests/torchsave.pt")



    

torch.Size([1, 4, 320, 153])
torch.Size([1, 101])
Train Loss, Epoch 0: 0.18584363162517548
torch.Size([1, 4, 320, 153])
torch.Size([1, 101])
Train Loss, Epoch 0: 0.19281522929668427
torch.Size([1, 4, 320, 153])
torch.Size([1, 101])
Train Loss, Epoch 1: 0.19106265902519226
torch.Size([1, 4, 320, 153])
torch.Size([1, 101])
Train Loss, Epoch 1: 0.1921955645084381
torch.Size([1, 4, 320, 153])
torch.Size([1, 101])
Train Loss, Epoch 2: 0.23258045315742493
torch.Size([1, 4, 320, 153])
torch.Size([1, 101])
Train Loss, Epoch 2: 0.20905910432338715
torch.Size([1, 4, 320, 153])
torch.Size([1, 101])
Train Loss, Epoch 3: 0.22530755400657654
torch.Size([1, 4, 320, 153])
torch.Size([1, 101])
Train Loss, Epoch 3: 0.20205941796302795
torch.Size([1, 4, 320, 153])
torch.Size([1, 101])
Train Loss, Epoch 4: 0.2379966378211975
torch.Size([1, 4, 320, 153])
torch.Size([1, 101])
Train Loss, Epoch 4: 0.1964716911315918
torch.Size([1, 4, 320, 153])
torch.Size([1, 101])
Train Loss, Epoch 5: 0.23316511511802673
to

KeyboardInterrupt: 

In [17]:
model.load_state_dict(torch.load("/Users/fyzeen/FyzeenLocal/GitHub/NeuroTranslate/LocalTransformerTests/_FyzTests/torchsave.pt"))
model.eval()
model.to(device)

FullTransformer(
  (flatten_to_high_dim): Linear(in_features=101, out_features=2020, bias=True)
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): EncoderSiT(
    (to_patch_embedding): Sequential(
      (0): Rearrange('b c n v  -> b n (v c)')
      (1): Linear(in_features=612, out_features=20, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (transformer): Transformer(
      (norm): LayerNorm((20,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-4): 5 x ModuleList(
          (0): Attention(
            (norm): LayerNorm((20,), eps=1e-05, elementwise_affine=True)
            (attend): Softmax(dim=-1)
            (dropout): Dropout(p=0.1, inplace=False)
            (to_qkv): Linear(in_features=20, out_features=384, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=128, out_features=20, bias=True)
              (1): Dropout(p=0.1, inplace=False)
          

In [34]:
targets_ = []
preds_ = []
with torch.no_grad():
    for i, data in enumerate(train_loader):
        inputs, targets = data[0].to(device), data[1].to(device).squeeze().unsqueeze(0)
        print(targets.shape)

        predictions = greedy_decode(model=model, source=inputs, input_dim=101, device=device, b=1)

        targets_.append(targets.cpu().numpy())
        preds_.append(predictions)

torch.Size([1, 101])


  return torch._native_multi_head_attention(


In [24]:
preds_

[[tensor([ 0.0000, -0.0012, -0.2333,  0.0283, -0.0626, -0.0278,  0.0190,  0.0650,
           0.0293, -0.1221, -0.0468, -0.0365,  0.0700, -0.1215, -0.0069, -0.0789,
           0.0489, -0.0974,  0.0722, -0.0711,  0.0138, -0.0056,  0.0720,  0.1411,
          -0.0706, -0.0844, -0.0241, -0.0993, -0.0778, -0.1968, -0.0922, -0.0545,
           0.0340, -0.0179, -0.0769,  0.0472,  0.0780,  0.1329,  0.0369,  0.0494,
           0.1386, -0.0062,  0.1717, -0.0839,  0.0495, -0.0064, -0.0580,  0.0181,
          -0.1644, -0.1207,  0.0155, -0.1651, -0.0338, -0.0375, -0.0615,  0.0658,
          -0.0311, -0.0154, -0.0289, -0.0670, -0.0782, -0.0160,  0.1308, -0.0201,
           0.0133, -0.0255, -0.0775,  0.0808,  0.0275,  0.1124, -0.0071, -0.0862,
          -0.0218, -0.0736, -0.0714,  0.0319,  0.0111,  0.1312,  0.1367, -0.0236,
          -0.0871,  0.0510,  0.1205, -0.0584, -0.1066,  0.0437,  0.0313,  0.0409,
           0.0020,  0.0817, -0.1449,  0.1471, -0.0926,  0.1592,  0.0373,  0.0681,
          -0.015

In [51]:
mae_epoch = np.mean(np.abs(np.concatenate(targets_[0]) - np.concatenate(preds_[0][1].cpu().numpy())))
mae_epoch

ValueError: zero-dimensional arrays cannot be concatenated

In [22]:
loss_fn = nn.L1Loss()
loss_fn(pred, targets)

tensor(0.1975, device='mps:0', grad_fn=<MeanBackward0>)

In [43]:
print(targets_[0] - preds_[0][0].cpu().numpy())

[[ 1.00000000e+00  5.47885336e-02  1.11828014e-01 -7.64793307e-02
  -3.72464880e-02  6.47229403e-02  8.78072307e-02 -1.02245107e-01
  -3.53789702e-02  1.31891251e-01 -4.49974351e-02  1.91501789e-02
  -6.17876165e-02  2.79100388e-02 -2.18449114e-03 -2.35218555e-02
  -1.32850170e-01  1.09109387e-01  3.00634652e-03  2.81155631e-02
  -3.93006466e-02  1.02485843e-01 -2.96451151e-02 -3.85278836e-02
  -1.37200058e-02 -4.43901941e-02 -4.68282700e-02  5.52757978e-02
   2.55711637e-02  6.91241026e-02  8.09423774e-02 -1.08139217e-03
   2.63378620e-02  4.06384747e-03  2.13998556e-02 -5.50149493e-02
  -1.24481253e-01 -8.41911137e-02  6.76657632e-03 -8.53681192e-03
  -4.32314202e-02  7.63071477e-02 -9.01077166e-02  7.78756812e-02
  -2.94941068e-02  2.01291982e-02 -4.88891527e-02  6.59624115e-03
   1.12655088e-01  2.78413221e-02 -2.35273652e-02  1.23063624e-02
   8.73373002e-02  5.58957160e-02  7.24873170e-02 -9.17304009e-02
  -6.25871494e-03  8.31624866e-02  4.20214459e-02  9.85423028e-02
   2.25876

In [41]:
print(targets_[0])

[[ 1.         0.053635  -0.12151   -0.04817   -0.099836   0.036915
   0.10685   -0.037254  -0.0060648  0.009801  -0.091813  -0.017395
   0.0082223 -0.093588  -0.0091054 -0.10244   -0.083986   0.011728
   0.075165  -0.042984  -0.025545   0.096871   0.042341   0.10258
  -0.084292  -0.12884   -0.070937  -0.043987  -0.052232  -0.12772
  -0.011292  -0.055536   0.06038   -0.013792  -0.055451  -0.0078369
  -0.046508   0.048743   0.043623   0.04084    0.095379   0.070152
   0.081631  -0.0060417  0.019975   0.013762  -0.10691    0.024739
  -0.051731  -0.092834  -0.0079792 -0.15281    0.053569   0.018373
   0.011028  -0.02589   -0.037309   0.067792   0.013131   0.031553
  -0.055607   0.057753   0.2424     0.023996   0.018646   0.095537
  -0.041546   0.072851  -0.028209   0.036517  -0.090177  -0.050468
  -0.043845   0.026591  -0.11971   -0.0052818  0.025643  -0.04022
   0.028181   0.036962   0.01726   -0.025386   0.095228   0.014827
  -0.054002   0.0757     0.050593  -0.051038   0.0019048  0.0245

In [42]:
print(preds_[0][0])

tensor([ 0.0000, -0.0012, -0.2333,  0.0283, -0.0626, -0.0278,  0.0190,  0.0650,
         0.0293, -0.1221, -0.0468, -0.0365,  0.0700, -0.1215, -0.0069, -0.0789,
         0.0489, -0.0974,  0.0722, -0.0711,  0.0138, -0.0056,  0.0720,  0.1411,
        -0.0706, -0.0844, -0.0241, -0.0993, -0.0778, -0.1968, -0.0922, -0.0545,
         0.0340, -0.0179, -0.0769,  0.0472,  0.0780,  0.1329,  0.0369,  0.0494,
         0.1386, -0.0062,  0.1717, -0.0839,  0.0495, -0.0064, -0.0580,  0.0181,
        -0.1644, -0.1207,  0.0155, -0.1651, -0.0338, -0.0375, -0.0615,  0.0658,
        -0.0311, -0.0154, -0.0289, -0.0670, -0.0782, -0.0160,  0.1308, -0.0201,
         0.0133, -0.0255, -0.0775,  0.0808,  0.0275,  0.1124, -0.0071, -0.0862,
        -0.0218, -0.0736, -0.0714,  0.0319,  0.0111,  0.1312,  0.1367, -0.0236,
        -0.0871,  0.0510,  0.1205, -0.0584, -0.1066,  0.0437,  0.0313,  0.0409,
         0.0020,  0.0817, -0.1449,  0.1471, -0.0926,  0.1592,  0.0373,  0.0681,
        -0.0157, -0.0181,  0.1086, -0.12

In [31]:
np.corrcoef(pred[1].detach().cpu().numpy(), targets[1].detach().cpu().numpy())

array([[1.        , 0.44203997],
       [0.44203997, 1.        ]])

In [34]:
class ProjectionConv(nn.Module):
  def __init__(self):
    super(ProjectionConv, self).__init__()
    self.projection_conv = nn.Conv1d(in_channels=100, out_channels=360, kernel_size=1, groups=10)

  
  def forward(self, input):
    B, C = input.shape
    output = self.projection_conv(input.unsqueeze(-1))
    return output.squeeze()

In [41]:
model = ProjectionConv()



#model(torch.ones(1, 100))

model.projection_conv.weight.shape

torch.Size([360, 10, 1])