# CS 229B Project

Resource that explains time series forecasting with RNNs (will explain how to get X,y, etc.) :
https://www.geeksforgeeks.org/time-series-forecasting-using-recurrent-neural-networks-rnn-in-tensorflow/

## Upload Dataset & Install packages

The FS Peptide dataset contains 28 trajectories, each with 10000 frames.

We need to install packages for data preprocessing (MDTraj) and for the models we are running (PyEMMA, PyTorch).

In [None]:
!wget https://ndownloader.figshare.com/articles/1030363/versions/1 -O fs_peptide.zip
!unzip -o fs_peptide.zip

In [None]:
# !pip install --pre torch
!pip install torch
!pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118
# !pip install xformers
!pip install mdtraj pyemma
!pip install mdshare
!pip install torchprofile
!pip install deepspeed

Load dataset into files (including corresponding pdb file)

# Preprocessing Data
Please run the dataloader prior to running any of the following models

In [None]:
import time
import warnings
import math
import pandas as pd
import matplotlib.pyplot as plt
import json

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from pathlib import Path
import pickle

import pyemma
from pyemma.util.contexts import settings
import mdtraj as md
import mdshare


class CustomDataset(Dataset):
    def __init__(self, pdb, files, chunk_size = 1000, mode='chunk'):
        self.pdb = pdb
        self.files = files
        with warnings.catch_warnings():
          warnings.simplefilter("ignore")
          self.topology = md.load(pdb).topology
          self.ref_traj = md.load_xtc(files[0], top=pdb)
        self.chunk_size = chunk_size
        self.mode = mode

    def __len__(self):
        return len(self.files)


    def set_mode(self, mode):
        self.mode = mode
        print("updated mode to: ", mode)


    def set_chunk_size(self, chunk_size):
        self.chunk_size = chunk_size
        print("updated chunk_size to: ", chunk_size)

    def __getitem__(self, idx):
        assert self.mode in ['chunk', 'truncate']
        with warnings.catch_warnings():
          warnings.simplefilter("ignore")
          traj = md.load_xtc(self.files[idx], top=self.pdb)

        traj.superpose(self.ref_traj, frame=0)
        traj.center_coordinates()

        backbone = traj.atom_slice(self.topology.select("protein and backbone"))
        coords = backbone.xyz

        traj_coords = coords.reshape(-1, backbone.n_atoms*3)

        traj_coords = torch.from_numpy(traj_coords) # shape 264 x 10000 --> 10000 x 264

        if self.mode == 'chunk':
            traj_coords = traj_coords.reshape(-1,self.chunk_size, 264)
        else:
            traj_coords = traj_coords[:self.chunk_size, :].unsqueeze(0)
        X = traj_coords[:,:-1, :]
        Y = traj_coords[:,1:, :]

        return X, Y

def load_files(source = '.'):
    files = [Path(source).resolve() / f'trajectory-{i}.xtc' for i in range(1,29)]
    pdb = '100-fs-peptide-400K.pdb'
    return files, pdb

def build_loaders(source = '.',
                  train_batch_size = 4,
                  test_batch_size = 8,
                  num_workers = 2,
                  seed = 1, test_frac = 0.2,train_drop = 0):
    files, pdb = load_files(source)
    np.random.seed(seed)
    train_files = np.random.choice(files, size=int(len(files)* (1 - test_frac)), replace=False)

    test_files = np.setdiff1d(files, train_files)
    train_files = np.random.choice(train_files,size=int(len(train_files)* (1 - train_drop)), replace=False)

    train_dataset = CustomDataset(pdb, train_files)
    test_dataset = CustomDataset(pdb, test_files)

    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=num_workers)
    print(f'Number of training samples: {len(train_dataset)}')
    print(f'Number of testing samples: {len(test_dataset)}')
    return train_loader, test_loader


def save_pkl(filename, save_object):
    writer = open(filename,'wb')
    pickle.dump(save_object, writer)
    writer.close()



In [None]:
INPUT_SIZE = 264
OUTPUT_SIZE = 264
SEQ_LEN = 999

# Define an RNN model
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size,
                 nonlinearity='tanh', dropout=0.):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, num_layers,
                          batch_first=True, nonlinearity=nonlinearity,
                          dropout=dropout)
        self.fc = nn.Linear(hidden_size, output_size)

    def init_hidden(self, device, batch_size=1):
        return torch.zeros(batch_size, self.hidden_size, device=device)


    def forward(self, x, hidden = None,return_hidden = False):
        B, T, d = x.size()
        if hidden is None:
            hidden = self.init_hidden(x.device, batch_size=B)
        out, h = self.rnn(x)
        out = self.fc(out)
        if return_hidden:
            return out, h
        return out

    def forward_autoregressive(self, x0, max_len, hidden = None):
        if hidden is None:
            hidden = self.init_hidden(x0.device, batch_size=x0.size(0))

        all_out = torch.zeros(x0.size(0), max_len+1, x0.size(2)).to(x0.device)
        all_out[:,0,:] = x0[:,0,:]
        xt = x0
        for t in range(max_len):
            out, hidden = self.forward(xt,hidden = hidden,return_hidden = True)
            all_out[:,t+1,:] = out.squeeze(1)
            xt = out

        return all_out[:, 1:, :]

import xformers
from xformers.components.feedforward import MLP
from xformers.components.attention import NystromAttention

class XformerBlock(nn.Module):
    def __init__(self, num_heads = 8,
                 embed_size = 32,
                 num_landmarks = 64, causal = True):
        super().__init__()
        self.attn = NystromAttention(
            dropout = 0.,
            num_heads = num_heads,
            num_landmarks = num_landmarks,
            causal = causal
        )
        self.ff = MLP(dim_model=embed_size,
          hidden_layer_multiplier=4,
          activation="gelu",
          dropout=0.)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

    def forward(self, x):
        x = self.attn(k=x, q=x, v=x) + x
        x = self.norm1(x)
        x = self.ff(x) + x
        x = self.norm2(x)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, emb_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, emb_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * (-math.log(10000.0) / emb_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

# Implement causal masking
def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
    return mask

class TransformerModel(nn.Module):
    def __init__(self, input_size, embed_size, num_heads, num_layers, output_size,seq_len):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Linear(input_size, embed_size)
        self.pos_encoder = PositionalEncoding(embed_size, max_len=SEQ_LEN)
        self.max_len = SEQ_LEN
        self.embed_size = embed_size
        transformer_layers = []
        for _ in range(num_layers):
            transformer_layers.append(XformerBlock(num_heads, embed_size, num_landmarks = 64, causal = True))
        self.transformer = nn.Sequential(*transformer_layers)
        self.fc = nn.Linear(embed_size, output_size)

    def adjust_pe(self, new_seq_len):
        if self.max_len < new_seq_len:
            self.pos_encoder = PositionalEncoding(self.embed_size, max_len=new_seq_len)

    def forward(self, x):

        B, T, d = x.size()
        x = self.embedding(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = self.fc(x)
        return x

    def forward_autoregressive(self, x0, max_len):
        all_out = torch.zeros(x0.size(0), max_len+1, x0.size(2)).to(x0.device) # preallocate output
        all_out[:,0,:] = x0[:,0,:]
        for t in tqdm(range(max_len)):
            xt = all_out[:, :t+1, :]
            next_output = self.forward(xt)[:,-1,:].unsqueeze(1)
            all_out[:,t+1,:] = next_output.squeeze(1)
        return all_out[:, 1:, :]



# Implement Transformer Native Model
class TransformerModel_NN(nn.Module):
    def __init__(self, input_size, embed_size, num_heads, num_layers, output_size,seq_len):
        super(TransformerModel_NN, self).__init__()
        self.embedding = nn.Linear(input_size, embed_size)
        self.pos_encoder = PositionalEncoding(embed_size, max_len=SEQ_LEN)
        self.max_len = SEQ_LEN
        self.embed_size = embed_size
        transformer_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads, batch_first=True,
                                                       dim_feedforward = 4 * embed_size, dropout = 0)
        self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_size, output_size)
        self.register_buffer("mask", generate_square_subsequent_mask(100))

    def adjust_pe(self, new_seq_len):
        if self.max_len < new_seq_len:
            self.pos_encoder = PositionalEncoding(self.embed_size, max_len=new_seq_len)

    def forward(self, x):

        B, T, d = x.size()
        x = self.embedding(x)

        if self.mask.size(0) < T:
            self.register_buffer("mask", generate_square_subsequent_mask(T).to(x.device))
        mask = self.mask[:T, :T]

        x = self.pos_encoder(x)
        x = self.transformer(x, mask=mask)
        x = self.fc(x)
        return x

    def forward_autoregressive(self, x0, max_len):

        all_out = torch.zeros(x0.size(0), max_len+1, x0.size(2)).to(x0.device) # preallocate output
        all_out[:,0,:] = x0[:,0,:]
        for t in tqdm(range(max_len)):
            xt = all_out[:, :t+1, :]
            next_output = self.forward(xt)[:,-1,:].unsqueeze(1)
            all_out[:,t+1,:] = next_output.squeeze(1)
        return all_out[:, 1:, :]

from torch.nn.utils.parametrizations import weight_norm
# Apply convolution with causal padding: following https://discuss.pytorch.org/t/causal-convolution/3456/4
class CausalConv1d(nn.Conv1d):
    def __init__(self, input_size, output_size, kernel_size, stride=1, dilation=1):
        padding = (kernel_size - 1) * dilation
        super(CausalConv1d, self).__init__(input_size, output_size, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation)

    def forward(self, x):
        x = super(CausalConv1d, self).forward(x)
        if self.padding[0] != 0:
            return x[:, :, :-self.padding[0]]
        return x

class ResidTempBlock(nn.Module):
    def __init__(self, input_size, channel_size, output_size, kernel_size, stride=1, dilation=1, dropout=0):
        super(ResidTempBlock, self).__init__()

        # do causal convolution here
        self.conv1 = weight_norm(CausalConv1d(input_size,
                                                             channel_size,
                                                             kernel_size,
                                                             stride=stride,
                                                             dilation=dilation))
        self.relu1 = nn.ReLU()

        self.dropout1 = nn.Dropout(dropout)

        # do causal convolution here
        self.conv2 = weight_norm(CausalConv1d(channel_size,
                                                             output_size,
                                                             kernel_size,
                                                             stride=stride,
                                                             dilation=dilation))

        self.relu2 = nn.ReLU()

        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.relu1, self.dropout1,
                                     self.conv2, self.relu2, self.dropout2)

        self.reshape = nn.Conv1d(input_size, output_size, 1) if input_size != output_size else None

        self.relu = nn.ReLU()

        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.reshape is not None:
          self.reshape.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        residual = x if self.reshape is None else self.reshape(x)
        out = self.relu(out + residual)
        return out


class TCNModel(nn.Module):
    def __init__(self, input_size, channel_size, input_length,
                 kernel_size, stride=1, dropout=0):
        super(TCNModel, self).__init__()

        self.input_length = input_length
        self.input_size = input_size
        num_layers = len(channel_size)

        layers=[]
        for i in range(num_layers):
            dilation_size = 2**i
            in_channels = input_size if i == 0 else channel_size[i-1]
            out_channels = input_size if i == (num_layers-1) else channel_size[i]

            layers.append(ResidTempBlock(in_channels, channel_size[i-1], out_channels,
                                    kernel_size, stride=stride,
                                    dilation=dilation_size))

        self.tcn = nn.Sequential(*layers)

    def forward(self, x):
        # tcn needs data to be (batch_size, input_size, input_length) so we need to swap last 2 dims
        x = torch.transpose(x, 1,2)

        out = self.tcn(x)

        # swap data back so its the correct shape for loss/ etc.
        out = torch.transpose(out, 1, 2)
        #out = out.view(batch_size, input_length, input_size)
        return out

    def forward_autoregressive(self, x0, max_len):

        all_out = torch.zeros(x0.size(0), max_len+1, x0.size(2)).to(x0.device) # preallocate output
        all_out[:,0,:] = x0[:,0,:]
        for t in tqdm(range(max_len)):
            xt = all_out[:, :t+1, :]
            next_output = self.forward(xt)[:,-1,:].unsqueeze(1)
            all_out[:,t+1,:] = next_output.squeeze(1)
        return all_out[:, 1:, :]

In [None]:
from tqdm import tqdm
import random

def merge_dict(main_dict, new_dict, value_fn = None):
    """
    Merge new_dict into main_dict. If a key exists in both dicts, the values are appended.
    Else, the key-value pair is added.
    If value_fn is not None, it is applied to each item in each value in new_dict before merging.
    Args:
        main_dict: main dict
        new_dict: new dict
        value_fn: function to apply to each item in each value in new_dict before merging
    """
    if value_fn is None:
        value_fn = lambda x: x
    for key, value in new_dict.items():
        if not isinstance(value, list):
            value = [value]
        value = [value_fn(v) for v in value]
        if key in main_dict:
            main_dict[key] = main_dict[key] + value
        else:
            main_dict[key] = value
    return main_dict

def train(model, optimizer, train_loader, test_loader, criterion, device, num_epochs = 10, reg_coef=0.05):
    outputs = {'train_loss': [],
               'train_loss_over_time': [],
               'test_loss': [],
               'test_loss_over_time': [],
               'test_loss_autoregressive': [],
               'test_loss_autoregressive_over_time': []}
    for epoch in range(num_epochs):
        #### Training
        train_out = train_loop(model, optimizer, train_loader, criterion, device,reg_coef = reg_coef)
        outputs = merge_dict(outputs, train_out)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_out["train_loss"]}')
        if epoch == (num_epochs - 1):

          #### Eval
          test_out = test_loop(model, test_loader, criterion, device)
          outputs = merge_dict(outputs, test_out)

          print(f'Epoch [{epoch + 1}/{num_epochs}], Test Loss: {test_out["test_loss"]}')
          print(f'Epoch [{epoch + 1}/{num_epochs}], Test Loss Autoregressive: {test_out["test_loss_autoregressive"]}')

    torch.save(test_out["test_preds"], f'test_preds_{epoch}_no_reg.pt')
    torch.save(test_out["test_autoregressive_preds"], f'test_autoregressive_preds_{epoch}_no_reg.pt')
    # torch.save(model, 'model_no_reg.pth')

    return outputs

def random_seed(seed=42, rank=0):
    torch.manual_seed(seed + rank)
    np.random.seed(seed + rank)
    random.seed(seed + rank)

def train_loop(model, optimizer, train_loader, criterion, device, reg_coef=0.05):
    model.train()  # Set the model to training mode
    train_loss = 0
    train_loss_over_time = np.zeros((SEQ_LEN,))
    for i, (X_train, Y_train) in enumerate(tqdm(train_loader)):
        B, num_chunks, T, d = X_train.shape
        X_train = X_train.reshape(B*num_chunks, T, d).to(device)
        Y_train = Y_train.reshape(B*num_chunks, T, d).to(device)

        # Forward pass
        outputs = model(X_train)
        loss = criterion(outputs, Y_train)
        train_loss_over_time += torch.mean((outputs - Y_train) ** 2, dim=(0,2)).detach().cpu().numpy()
        if reg_coef > 0:
            n_atoms = d // 3
            outputs = outputs.view(B * num_chunks, T, n_atoms, 3)

            # Calculate the centroid for each configuration
            centroid = outputs.mean(dim=2, keepdim=True)  # Size: [B*num_chunks, T, 1, 3]

            # Calculate the distance of each atom from its centroid
            distances = outputs - centroid  # Broadcasting the centroid across all atoms

            # Calculate the Euclidean distance
            distances = distances.norm(dim=-1)  # Size: [B*num_chunks, T, n_atoms]

            # Calculate the average distance to centroid for each configuration
            avg_dist_to_centroid = distances.mean(dim=-1)  # Size: [B*num_chunks, T]

            # You can then use avg_dist_to_centroid as a regularization term
            # For example, add it to your loss function multiplied by the regularization coefficient
            reg_loss = reg_coef * avg_dist_to_centroid.mean()

            loss += -reg_loss
            # print(-reg_loss)

        train_loss += loss.item()


        # Backward and optimize
        optimizer.zero_grad()  # Clear gradients
        loss.backward()  # Compute gradient
        optimizer.step()  # Update weights
    return {'train_loss': train_loss / len(train_loader),
            'train_loss_over_time': train_loss_over_time / len(train_loader)}

def test_loop(model, test_loader, criterion, device, seq_len = SEQ_LEN):
    model.eval()
    test_preds = []
    test_autoregressive_preds = []
    test_loss = 0
    test_loss_over_time = np.zeros((seq_len,))
    test_loss_autoregressive = 0
    test_loss_autoregressive_over_time = np.zeros((seq_len,))
    for i, (X_test, Y_test) in enumerate(test_loader):
        B, num_chunks, T, d = X_test.shape
        X_test = X_test.reshape(B*num_chunks, T, d).to(device)
        Y_test = Y_test.reshape(B*num_chunks, T, d).to(device)
        # Forward pass
        with torch.no_grad():
            outputs = model(X_test)
        loss = criterion(outputs, Y_test)
        test_loss += loss.item()

        torch.cuda.empty_cache()

        # Forward pass for autorergressive
        with torch.no_grad():
            outputs_autoregressive = model.forward_autoregressive(X_test[:,0,:].unsqueeze(1), max_len = X_test.size(1))
        loss_autoregressive = criterion(outputs_autoregressive, Y_test)
        test_loss_autoregressive += loss_autoregressive.item()

        test_preds.append(outputs.detach().cpu().numpy())
        test_autoregressive_preds.append(outputs_autoregressive.detach().cpu().numpy())

        test_loss_over_time += torch.mean((outputs - Y_test) ** 2, dim=(0,2)).detach().cpu().numpy()
        test_loss_autoregressive_over_time += torch.mean((outputs_autoregressive - Y_test) ** 2, dim=(0,2)).detach().cpu().numpy()
    return {'test_preds': np.concatenate(test_preds, axis=0),
            'test_autoregressive_preds': np.concatenate(test_autoregressive_preds, axis=0),
            'test_loss': test_loss / len(test_loader),
            'test_loss_over_time': test_loss_over_time / len(test_loader),
            'test_loss_autoregressive': test_loss_autoregressive / len(test_loader),
            'test_loss_autoregressive_over_time': test_loss_autoregressive_over_time / len(test_loader)}



In [None]:
VALID_MODEL_TYPES = ['RNN', 'Xformer','Transformer','TCN']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set data parameters
train_batch_size = 4
test_batch_size = 8
num_workers = 2
seed = 7



# training parameters
lr = 1e-4
epochs = 10



def model_generator(model_type,INPUT_SIZE,OUTPUT_SIZE,SEQ_LEN,params):

    # set up default parameters for the models
    default_params = {
    # rnn parameters
    'RNN':{'input_size' :INPUT_SIZE,
                            'hidden_size':32,
                            'output_size': OUTPUT_SIZE,
                            'num_layers':1},


    # xformer parameters
    'Xformers': {'input_size':INPUT_SIZE,
                              'output_size':OUTPUT_SIZE,
                              'embed_size': 32,
                              'num_layers': 2,
                              'num_heads': 8,
                              'seq_len':SEQ_LEN},

    # transformer parameters
    'Transformer':{'input_size' :INPUT_SIZE,
                              'output_size' : OUTPUT_SIZE,
                              'embed_size' : 96,
                              'num_layers' : 6,
                              'num_heads' : 16,
                              'seq_len' : SEQ_LEN},

    #TCN parameters
    'TCN':{'input_size' : INPUT_SIZE,
                          'channel_size' : [6]*(int(np.ceil(np.log2(((SEQ_LEN-1)*(2-1))/(2*(8-1)) +1 ))) - 2),
                          'input_length' : SEQ_LEN,
                          'kernel_size' : 8}
                     }
    if params is not None:
        model_params = { **params}
    else:
        model_params = default_params[model_type]

    # Model instantiation
    assert model_type in VALID_MODEL_TYPES, f'model_type must be one of {VALID_MODEL_TYPES}'
    if model_type == 'RNN':
        # Model instantiation
            model = RNNModel(**model_params)
    elif model_type == 'Xformer':
            model = TransformerModel(**model_params)
    elif model_type == 'Transformer':
            model = TransformerModel_NN(**model_params)
    elif model_type == 'TCN':
            model = TCNModel(**model_params)
    else:
        raise ValueError(f'Invalid model_type {model_type}')
    return model


# Tuned learning rate

In [None]:
learning_rate_dict = {
    ('RNN', 'small'): 0.0001,
    ('RNN', 'medium'): 0.0001,
    ('RNN', 'large'): 0.0001,
    ('TCN', 'small'): 0.01,
    ('TCN', 'medium'): 0.01,
    ('TCN', 'large'): 0.001,
    ('Xformer', 'small'): 0.01,
    ('Xformer', 'medium'): 0.001,
    ('Xformer', 'large'): 0.001,
    ('Transformer', 'small'): 0.001,
    ('Transformer', 'medium'): 0.001,
    ('Transformer', 'large'): 0.0001,
}

# Run specific model and save outputs:

In [None]:
random_seed(seed=seed)
train_loader, test_loader = build_loaders(train_batch_size=train_batch_size,
                                          test_batch_size=test_batch_size,
                                          num_workers=num_workers,
                                          seed=seed)

model_type = 'Transformer'
model = model_generator(model_type,INPUT_SIZE,OUTPUT_SIZE,SEQ_LEN,params = None)

In [None]:
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
model.to(device)

outputs = train(model, optimizer, train_loader, test_loader, criterion, device, num_epochs = epochs,reg_coef=0)

In [None]:
random_seed(seed=seed)
train_loader, test_loader = build_loaders(train_batch_size=train_batch_size,
                                          test_batch_size=test_batch_size,
                                          num_workers=num_workers,
                                          seed=seed)

In [None]:
model.eval()
min_loss = np.inf
for i, (X_test, Y_test) in enumerate(test_loader):

  B, num_chunks, T, d = X_test.shape
  X_test = X_test.reshape(B*num_chunks, T, d).to(device)
  Y_test = Y_test.reshape(B*num_chunks, T, d).to(device)

  with torch.no_grad():
    outputs = model(X_test)

  loss = nn.functional.mse_loss(Y_test, outputs,reduction = 'none').mean(dim=-1)
  min_val, min_idx_cols = torch.min(loss, dim=0)
  overall_min_val, min_idx_row = torch.min(min_val, dim=0)
  print
  if overall_min_val < min_loss:
    min_idx = (min_idx_row.item(), min_idx_cols[min_idx_row].item())
    sample_test = Y_test[min_idx[1],min_idx[0],:].detach().cpu().numpy()
    sample_output = outputs[min_idx[1],min_idx[0],:].detach().cpu().numpy()
    min_loss = overall_min_val
    print(overall_min_val)


  sample_test = sample_test.reshape(-1,3)
  sample_output = sample_output.reshape(-1,3)

In [None]:
traj0 = md.load_xtc('trajectory-1.xtc', top='100-fs-peptide-400K.pdb')
topology = md.load('100-fs-peptide-400K.pdb').topology
backbone = traj0.atom_slice(topology.select("protein and backbone"))
bb_slice = backbone[0]

bb_slice.xyz = sample_test
bb_slice.save('TEST_BACKBONE.pdb')

bb_slice.xyz = sample_output
bb_slice.save('OUTPUT_BACKBONE.pdb')

In [None]:
model.eval()

min_loss = np.inf

for i, (X_train, Y_train) in enumerate(train_loader):

  B, num_chunks, T, d = X_train.shape
  X_train = X_train.reshape(B*num_chunks, T, d).to(device)
  Y_train = Y_train.reshape(B*num_chunks, T, d).to(device)

  with torch.no_grad():
    outputs = model(X_train)

  loss = nn.functional.mse_loss(Y_train, outputs,reduction = 'none').mean(dim=-1)
  min_val, min_idx_cols = torch.min(loss, dim=0)
  overall_min_val, min_idx_row = torch.min(min_val, dim=0)
  print
  if overall_min_val < min_loss:
    min_idx = (min_idx_row.item(), min_idx_cols[min_idx_row].item())
    sample_train = Y_train[min_idx[1],min_idx[0],:].detach().cpu().numpy()
    sample_output = outputs[min_idx[1],min_idx[0],:].detach().cpu().numpy()
    min_loss = overall_min_val
    print(overall_min_val)


  sample_train = sample_train.reshape(-1,3)
  sample_output = sample_output.reshape(-1,3)

In [None]:
Y_train.shape

In [None]:
sample_train.shape

In [None]:
traj0 = md.load_xtc('trajectory-1.xtc', top='100-fs-peptide-400K.pdb')
topology = md.load('100-fs-peptide-400K.pdb').topology
backbone = traj0.atom_slice(topology.select("protein and backbone"))
bb_slice = backbone[0]

bb_slice.xyz = sample_train
bb_slice.save('TRAIN_BACKBONE.pdb')

bb_slice.xyz = sample_output
bb_slice.save('TRAIN_OUTPUT_BACKBONE.pdb')

In [None]:
output_dir = Path(f'outputs_{model_type}')
output_dir.mkdir(exist_ok=True)

stack_keys = ['train_loss_over_time', 'test_loss_over_time', 'test_loss_autoregressive_over_time']
for key in stack_keys:
    if not isinstance(outputs[key], np.ndarray):
        outputs[key] = np.stack(outputs[key])

all_train_loss = outputs['train_loss']
all_test_loss = outputs['test_loss']
all_test_loss_autoregressive = outputs['test_loss_autoregressive']
all_train_loss_over_time = outputs['train_loss_over_time']
all_test_loss_over_time = outputs['test_loss_over_time']
all_test_loss_autoregressive_over_time = outputs['test_loss_autoregressive_over_time']

print(f'Final Train Loss: {all_train_loss[-1]}')
print(f'Final Test Loss: {all_test_loss[-1]}')
print(f'Final Test Loss Autoregressive: {all_test_loss_autoregressive[-1]}')

# Plot the train and test loss progression per epoch

import matplotlib.pyplot as plt


plt.plot(all_train_loss, label='train')
# plt.plot(all_test_loss, label='test')
plt.plot(all_test_loss_autoregressive, label='test autoregressive')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.savefig(Path(output_dir) / 'loss.png')
plt.show()

for epoch in range(epochs):
    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(5, 4)
    plt.plot(all_train_loss_over_time[epoch,:], label=f'train {epoch}')
    plt.plot(all_test_loss_autoregressive_over_time[epoch,:], label=f'test AR {epoch}')
    plt.ylabel('MSE')
    plt.legend()

    plt.savefig(Path(output_dir) / f'loss_epoch{epoch}_ar_only.png')
    plt.show()

save_pkl(Path(output_dir) / 'outputs.pkl', outputs)

# Establish model comparison criteria using MACS

In [None]:
# Define function to assess computational complexity of model
from deepspeed.profiling.flops_profiler import get_model_profile

def get_comp(model,batchsize=1, sequence_length = 999, input_size = 264, x=None):
    if x is None:
          x = torch.rand(batchsize, sequence_length, input_size)
    else:
          batchsize = x.shape[0]
    # model.eval()
    flops, macs, params = get_model_profile(model=model, # model
                                    input_shape=(batchsize, sequence_length, input_size),  # Adjust the input shape based on your model
                                    args=(x,),  # Passing the input as part of args
                                    print_profile=False,  # Set to True if you want to print the profile
                                    detailed=False,  # Set to True for a detailed profile
                                    module_depth=-1,  # Adjust as needed
                                    top_modules=1,  # Adjust as needed
                                    warm_up=10,  # Number of warm-ups
                                    as_string=False,  # Set to True if you want human-readable strings
                                    output_file=None,  # Set a file path to save the profile
                                    ignore_modules=None  # List any modules to ignore
                                    )
    return flops, macs, params

# Define function that assesses the computational complexity of the model
from torchprofile import profile_macs
def get_macs(model,batchsize=1, sequence_length = 999, input_size = 264, x=None):
    if x is None:
          x = torch.rand(batchsize, sequence_length, input_size)
    else:
          batchsize = x.shape[0]
    # model.eval()
    macs = profile_macs(model, x)
    return macs

In [None]:
# Define function to fetch the correct parameters based on the chosen model size
def get_model_parameters(model_type, size, input_size, output_size, seq_len):
    """
    Get parameters for a specified model type and size.

    :param model_type: Type of the model ('RNN', 'Xformer', 'Transformer', 'TCN').
    :param size: Size of the model ('small', 'medium', 'large').
    :param input_size: Size of the input.
    :param output_size: Size of the output.
    :param seq_len: Sequence length.
    :return: Dictionary of model parameters.
    """

    if model_type == 'RNN':
        hidden_sizes = {'small': 256, 'medium': 128*5, 'large': 256*10}
        num_layers = {'small': 1, 'medium': 2, 'large': 6}
        return {
            'input_size': input_size,
            'hidden_size': hidden_sizes[size],
            'output_size': output_size,
            'num_layers': num_layers[size]
        }

    elif model_type == 'Xformer' or model_type == 'Transformer':
        embed_sizes = {'small': 32, 'medium': 64, 'large': 96}
        num_layers = {'small': 2, 'medium': 4, 'large': 6}
        num_heads = {'small': 4, 'medium': 8, 'large': 16}
        return {
            'input_size': input_size,
            'output_size': output_size,
            'embed_size': embed_sizes[size],
            'num_layers': num_layers[size],
            'num_heads': num_heads[size],
            'seq_len': seq_len
        }

    elif model_type == 'TCN':
        channel_sizes = {'small': [8] * max(1,(int(np.ceil(np.log2((seq_len/10 - 1) * (2 - 1) / (2 * (8 - 1)) + 1))) - 2)),
                         'medium': [16] * max(1,(int(np.ceil(np.log2((seq_len - 1) * (2 - 1) / (2 * (8 - 1)) + 1))) - 2)),
                         'large': [48] * max(1,(int(np.ceil(np.log2((seq_len - 1) * (2 - 1) / (2 * (8 - 1)) + 1))) - 2))}
        kernel_size = 8
        return {
            'input_size': input_size,
            'channel_size': channel_sizes[size],
            'input_length': seq_len,
            'kernel_size': kernel_size
        }

    else:
        raise ValueError("Invalid model type. Choose from 'rnn', 'xformer', 'transformer', 'tcn'.")


In [None]:
# Define function to output the computational complexity and test throughput for all the model sizes for each model of concern
def evaluate_models(input_size, output_size, seq_len, batchsize=64, ntest=100):
    model_types = ['RNN', 'Xformer', 'Transformer', 'TCN']
    sizes = ['small', 'medium', 'large']
    results = []

    for model_type in model_types:
        for size in sizes:
            # Get model parameters
            params = get_model_parameters(model_type, size, input_size, output_size, seq_len)

            # Create model instance
            model = model_generator(model_type,input_size, output_size, seq_len, params).cuda()

            # Get FLOPs, MACs, and Parameters
            flops, macs, params = get_comp(model, input_size=input_size, sequence_length=seq_len)

            # Store results for table
            results.append({
                'model_type': model_type,
                'size': size,
                'flops': flops,
                'macs': macs,
                'parameters': params,
            })

    # Create DataFrame for table
    results_df = pd.DataFrame(results)

    return results_df


In [None]:
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

In [None]:
comp_complex = evaluate_models(input_size=INPUT_SIZE, output_size=OUTPUT_SIZE, seq_len=SEQ_LEN)
comp_complex = comp_complex.sort_values(by='size', ascending=False)

In [None]:
# Separate plots for FLOPs, MACs, and Parameters
criteria = ['flops', 'macs', 'parameters']
for criterion in criteria:
    # Pivot the DataFrame to get 'model_type' as index and 'size' as columns
    pivot_df = comp_complex.pivot(index='model_type', columns='size', values=criterion)

    # Plotting
    ax = pivot_df.plot(kind='bar', figsize=(5, 5))
    # ax.set_ylabel(criterion.title())
    ax.set_ylabel("MACS",fontsize = 14)
    ax.set_title(f'{criterion.upper()} for Different Models and Sizes')
    ax.set_xlabel('Model Type',fontsize = 14)
    plt.xticks(rotation=0)  # Rotate the x-axis labels to show them horizontally

    ax.tick_params(axis='y', labelsize=13)  # Font size for y-axis

    # Get handles and labels
    handles, labels = ax.get_legend_handles_labels()
    new_order = [2,1, 0]  # Adjust based on your desired order
    ordered_handles = [handles[idx] for idx in new_order]
    ordered_labels = [labels[idx] for idx in new_order]

    # Create the new legend
    ax.legend(ordered_handles, ordered_labels, fontsize=10,title='Size')
    # plt.legend(title='Size')
    plt.tight_layout()  # Adjust the layout to fit all elements
    plt.show()

# Model Comparison: test throughput - DONE

In [None]:
# Define function that fetch the test throughput for the model
def speed_test(model, ntest=100, batchsize=1, sequence_length = SEQ_LEN, input_size = INPUT_SIZE, x=None):
    if x is None:
        x = torch.rand(batchsize, sequence_length, input_size).cuda()
    else:
        batchsize = x.shape[0]
    # model.eval()

    # Warmup
    for i in range(10):
        model(x)

    # Test
    start = time.time()
    for i in range(ntest):
        with torch.no_grad():
              model.forward_autoregressive(x[:,0,:].unsqueeze(1), max_len = x.size(1))
    torch.cuda.synchronize()
    end = time.time()

    elapse = end - start
    speed = batchsize * ntest / elapse
    # speed = torch.tensor(speed, device=x.device)
    # torch.distributed.broadcast(speed, src=0, async_op=False)
    # speed = speed.item()
    return speed

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import pickle
import os

def fetch_performance_metrics_test_thruput(seed = seed, train_batch_size = train_batch_size, test_batch_size = test_batch_size,
                              num_workers = num_workers,device = device,learning_rate_dict = learning_rate_dict, num_epochs = epochs,
                              input_size=INPUT_SIZE, output_size=OUTPUT_SIZE, seq_len=SEQ_LEN):
    performance_metrics = []

    # Initialize a dictionary to hold all results
    all_results = {}

    model_types = ['RNN', 'Xformer', 'Transformer', 'TCN']
    sizes = ['small', 'medium', 'large']

    for model_type in model_types:
        all_results[model_type] = {}
        for size in sizes:
            print('Current model is: ',model_type,'and size is: ',size)

            file_name = f'{model_type}_{size}_test_thruput.pkl'
            file_path = os.path.join('/content/drive/My Drive', file_name)

            if os.path.exists(file_path):
                with open(file_path, 'rb') as f:
                    results = pickle.load(f)
                    all_results[model_type][size] = results
                print(f'{file_path} found')
            else:

                # Get model parameters
                params = get_model_parameters(model_type, size, input_size, output_size, seq_len)

                # Create model instance
                model = model_generator(model_type,input_size, output_size, seq_len, params).cuda()


                # Evaluate accuracy
                random_seed(seed=seed)
                train_loader, test_loader = build_loaders(train_batch_size=train_batch_size,
                                                          test_batch_size=test_batch_size,
                                                          num_workers=num_workers,
                                                          seed=seed)
                # Loss and optimizer
                criterion = nn.MSELoss()
                optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate_dict[(model_type,size)])
                model.to(device)

                outputs = train(model, optimizer, train_loader, test_loader, criterion, device, num_epochs = epochs)

                # output_dir = Path(f'outputs_{model_type}')
                # output_dir.mkdir(exist_ok=True)

                stack_keys = ['train_loss_over_time', 'test_loss_over_time', 'test_loss_autoregressive_over_time']
                for key in stack_keys:
                    if not isinstance(outputs[key], np.ndarray):
                        outputs[key] = np.stack(outputs[key])

                # Measure throughput time
                throughput_time = speed_test(model)

                # Organize the outputs
                results  = {
                    'train_loss': outputs['train_loss'],
                    'test_loss': outputs['test_loss'],
                    'test_loss_autoregressive': outputs['test_loss_autoregressive'],
                    'train_loss_over_time': outputs['train_loss_over_time'],
                    'test_loss_over_time': outputs['test_loss_over_time'],
                    'test_loss_autoregressive_over_time': outputs['test_loss_autoregressive_over_time'],
                    'throughput_time': throughput_time
                }
                all_results[model_type][size] = results


            # Store the results
            performance_metrics.append({
                'model_type': model_type,
                'size': size,
                'test_loss_autoregressive': all_results[model_type][size]['test_loss_autoregressive'][-1],
                'throughput_time': all_results[model_type][size]['throughput_time']
            })

            # # Flatten the nested dictionary into a table
            # rows_list = []
            # for model_type, sizes_dict in all_results.items():
            #     for size, metrics in sizes_dict.items():
            #         row = {'model_type': model_type, 'size': size}
            #         row.update(metrics)
            #         rows_list.append(row)

            # # Create a DataFrame and save to CSV
            # results_df = pd.DataFrame(rows_list)
            # results_df.to_csv('/content/drive/My Drive/model_results.csv', index=False)

            # # Save the results to a JSON file

            with open(file_path, 'wb') as f:
                 pickle.dump(results, f)

    return performance_metrics

def plot_performance_metrics_test_thruput(performance_metrics,markers,colors):
    # Convert to DataFrame for easier plotting
    df = pd.DataFrame(performance_metrics)

    # Sort the DataFrame by learning_rate to ensure the lines are plotted correctly
    df.sort_values('size', inplace=True)

    # Create a figure and a set of subplots
    fig, ax = plt.subplots(figsize=(10, 6))


    ## Iterate over each model type
    for model_type in df['model_type'].unique():
        model_data = df[df['model_type'] == model_type]
        model_data.sort_values('size', inplace=True)

        # Use the color for the model type
        color = colors.get(model_type, 'black')  # Default color if model type not in dict

        # Plot the line for this model type
        ax.plot(model_data['throughput_time'], model_data['test_loss_autoregressive'],
                label=model_type, linestyle='-', color=color)

        # Plot the markers for each size
        for size in model_data['size'].unique():
            size_data = model_data[model_data['size'] == size]
            marker = markers.get(size, 'o')  # Default marker if size not in dict

            ax.scatter(size_data['throughput_time'], size_data['test_loss_autoregressive'],
                      marker=marker, color=color, s=100)  # Set markersize here with 's' parameter


    # Set labels and legend

    # Set font size of the x and y axis tick labels
    ax.tick_params(axis='x', labelsize=15)  # Font size for x-axis
    ax.tick_params(axis='y', labelsize=15)  # Font size for y-axis
    ax.set_xlabel('Throughput Speed (Num. Seq./s)', fontsize=20)
    ax.set_ylabel('Test Loss', fontsize=20)

    # Get handles and labels
    handles, labels = ax.get_legend_handles_labels()

    # Modify the order of your legend here
    # Example: if you want the second item to appear first and the first item second
    new_order = [0,3,2, 1]  # Adjust based on your desired order
    ordered_handles = [handles[idx] for idx in new_order]
    ordered_labels = [labels[idx] for idx in new_order]

    # Create the new legend
    ax.legend(ordered_handles, ordered_labels, fontsize=20)
    # ax.legend(fontsize=20)

    # Show the plot
    plt.show()

# Model Comparison: train data amount - DONE

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import pickle
import os
# Examine the test autoregressive loss with varying amount of train data available
def fetch_performance_metrics_train_data_amount(seed = seed, train_batch_size = train_batch_size, test_batch_size = test_batch_size,
                              num_workers = num_workers,device = device,learning_rate_dict = learning_rate_dict, num_epochs = epochs,
                              input_size=INPUT_SIZE, output_size=OUTPUT_SIZE, seq_len=SEQ_LEN):
    performance_metrics = []

    # Initialize a dictionary to hold all results
    all_results = {}

    model_types = ['RNN', 'Xformer', 'Transformer', 'TCN']
    train_drop = [0.9,0.7,0.5,0.3]
    size = 'small'

    for model_type in model_types:
        all_results[model_type] = {}
        for percent in train_drop:
            print('Current model is: ',model_type,'and train drop amount is: ',percent)

            file_name = f'{model_type}_{percent}_train_amt.pkl'
            file_path = os.path.join('/content/drive/My Drive', file_name)

            if os.path.exists(file_path):
                with open(file_path, 'rb') as f:
                    results = pickle.load(f)
                    all_results[model_type][percent] = results
                print(f'{file_path} found')
            else:

                # Get model parameters
                params = get_model_parameters(model_type, size, input_size, output_size, seq_len)

                # Create model instance
                model = model_generator(model_type,input_size, output_size, seq_len, params).cuda()


                # Evaluate accuracy
                random_seed(seed=seed)
                train_loader, test_loader = build_loaders(train_batch_size=train_batch_size,
                                                          test_batch_size=test_batch_size,
                                                          num_workers=num_workers,
                                                          seed=seed,train_drop = percent)
                # Loss and optimizer
                criterion = nn.MSELoss()
                optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate_dict[(model_type,size)])
                model.to(device)

                outputs = train(model, optimizer, train_loader, test_loader, criterion, device, num_epochs = epochs)

                # output_dir = Path(f'outputs_{model_type}')
                # output_dir.mkdir(exist_ok=True)

                stack_keys = ['train_loss_over_time', 'test_loss_over_time', 'test_loss_autoregressive_over_time']
                for key in stack_keys:
                    if not isinstance(outputs[key], np.ndarray):
                        outputs[key] = np.stack(outputs[key])

                # # Measure throughput time
                # throughput_time = speed_test(model)

                # Organize the outputs
                results  = {
                    'train_loss': outputs['train_loss'],
                    'test_loss': outputs['test_loss'],
                    'test_loss_autoregressive': outputs['test_loss_autoregressive'],
                    'train_loss_over_time': outputs['train_loss_over_time'],
                    'test_loss_over_time': outputs['test_loss_over_time'],
                    'test_loss_autoregressive_over_time': outputs['test_loss_autoregressive_over_time']
                }
                all_results[model_type][percent] = results


            # Store the results
            performance_metrics.append({
                'model_type': model_type,
                'train_drop_percent': percent,
                'test_loss_autoregressive': all_results[model_type][percent]['test_loss_autoregressive'][-1]
            })

            with open(file_path, 'wb') as f:
                 pickle.dump(results, f)

    return performance_metrics

def plot_performance_metrics_train_data_amt(performance_metrics,markers,colors):
    # Convert to DataFrame for easier plotting
    df = pd.DataFrame(performance_metrics)

    df['train_data_remaining'] = 100 - df['train_drop_percent']*100

    # Sort the DataFrame by the new 'train_data_remaining' column
    df.sort_values('train_data_remaining', inplace=True)

    # Create a figure and a set of subplots
    fig, ax = plt.subplots(figsize=(10, 6))

    for model_type in df['model_type'].unique():
        model_data = df[df['model_type'] == model_type]
        ax.plot(model_data['train_data_remaining'], model_data['test_loss_autoregressive'],
                label=model_type, marker=markers[model_type], linestyle='-', markersize=8, color=colors[model_type])

    # Set labels and legend
    ax.tick_params(axis='x', labelsize=15)
    ax.tick_params(axis='y', labelsize=15)
    ax.set_xlabel('Train Data Amount (%)', fontsize=20)
    ax.set_ylabel('Test Loss', fontsize=20)

    # Get handles and labels for the legend
    handles, labels = ax.get_legend_handles_labels()
    new_order = [0, 3, 2, 1]  # Adjust based on your desired order
    ordered_handles = [handles[idx] for idx in new_order]
    ordered_labels = [labels[idx] for idx in new_order]

    # Create the new legend
    ax.legend(ordered_handles, ordered_labels, fontsize=20)

    # Show the plot
    plt.show()

# Model Comparison: learning rate tuning - DONE

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import pickle
import os
# Examine the test autoregressive loss with varying amount of train data available
def fetch_performance_metrics_lr(seed = seed, train_batch_size = train_batch_size, test_batch_size = test_batch_size,
                              num_workers = num_workers,device = device, num_epochs = epochs,
                              input_size=INPUT_SIZE, output_size=OUTPUT_SIZE, seq_len=SEQ_LEN):
    performance_metrics = []

    # Initialize a dictionary to hold all results
    all_results = {}

    model_types = ['RNN', 'Xformer', 'Transformer', 'TCN']
    lrs = [1e-2,1e-3,1e-4,1e-5,1e-6]
    size = 'small'

    for model_type in model_types:
        all_results[model_type] = {}
        for lr in lrs:
            print('Current model is: ',model_type,'and learning rate is: ',lr)

            file_name = f'{size}_{model_type}_{lr}.pkl'
            file_path = os.path.join('/content/drive/My Drive', file_name)

            if os.path.exists(file_path):
                with open(file_path, 'rb') as f:
                    results = pickle.load(f)
                    all_results[model_type][lr] = results
                print(f'{file_path} found')
            else:

                # Get model parameters
                params = get_model_parameters(model_type, size, input_size, output_size, seq_len)

                # Create model instance
                model = model_generator(model_type,input_size, output_size, seq_len, params).cuda()


                # Evaluate accuracy
                random_seed(seed=seed)
                train_loader, test_loader = build_loaders(train_batch_size=train_batch_size,
                                                          test_batch_size=test_batch_size,
                                                          num_workers=num_workers,
                                                          seed=seed)
                # Loss and optimizer
                criterion = nn.MSELoss()
                optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
                model.to(device)

                outputs = train(model, optimizer, train_loader, test_loader, criterion, device, num_epochs = epochs)

                torch.save(model.state_dict(), f'{model_type}_{lr}.pth')

                stack_keys = ['train_loss_over_time', 'test_loss_over_time', 'test_loss_autoregressive_over_time']
                for key in stack_keys:
                    if not isinstance(outputs[key], np.ndarray):
                        outputs[key] = np.stack(outputs[key])

                # Organize the outputs
                results  = {
                    'train_loss': outputs['train_loss'],
                    'test_loss': outputs['test_loss'],
                    'test_loss_autoregressive': outputs['test_loss_autoregressive'],
                    'train_loss_over_time': outputs['train_loss_over_time'],
                    'test_loss_over_time': outputs['test_loss_over_time'],
                    'test_loss_autoregressive_over_time': outputs['test_loss_autoregressive_over_time']
                }
                all_results[model_type][lr] = results


            # Store the results
            performance_metrics.append({
                'model_type': model_type,
                'learning_rate': lr,
                'test_loss_autoregressive': all_results[model_type][lr]['test_loss_autoregressive'][-1]
            })

            with open(file_path, 'wb') as f:
                 pickle.dump(results, f)



    return performance_metrics

def plot_performance_metrics_lr(performance_metrics, markers,colors):

    # Convert to DataFrame for easier plotting
    df = pd.DataFrame(performance_metrics)

    # Sort the DataFrame by learning_rate to ensure the lines are plotted correctly
    df.sort_values('learning_rate', inplace=True)

    # Create a figure and a set of subplots
    fig, ax = plt.subplots(figsize=(10, 6))

    # Plot lines for each model type and annotate points
    for model_type in df['model_type'].unique():
        model_data = df[df['model_type'] == model_type]
        ax.plot(model_data['learning_rate'], model_data['test_loss_autoregressive'],
                label=model_type, marker=markers[model_type], linestyle='-', markersize=8, color=colors[model_type])

    # Set labels and legend

    # Set font size of the x and y axis tick labels
    ax.tick_params(axis='x', labelsize=15)  # Font size for x-axis
    ax.tick_params(axis='y', labelsize=15)  # Font size for y-axis
    ax.set_xlabel('Learning Rate', fontsize=20)
    ax.set_xscale('log')
    ax.set_ylabel('Test Loss', fontsize=20)

    # Get handles and labels
    handles, labels = ax.get_legend_handles_labels()

    # Modify the order of your legend here
    # Example: if you want the second item to appear first and the first item second
    new_order = [2,3,1,0]  # Adjust based on your desired order
    ordered_handles = [handles[idx] for idx in new_order]
    ordered_labels = [labels[idx] for idx in new_order]

    # Create the new legend
    ax.legend(ordered_handles, ordered_labels, fontsize=20)
    # ax.legend(fontsize=20)

    # Show the plot
    plt.show()

# Model Comparison: test on different seq len

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import pickle
import os
# Examine the test autoregressive loss with varying amount of train data available
def fetch_performance_metrics_test_different_seq_lens(seed = seed, train_batch_size = train_batch_size, test_batch_size = test_batch_size,
                              num_workers = num_workers,device = device,learning_rate_dict = learning_rate_dict, num_epochs = epochs,
                              input_size=INPUT_SIZE, output_size=OUTPUT_SIZE, seq_len=SEQ_LEN):
    performance_metrics = []

    # Initialize a dictionary to hold all results
    all_results = {}

    model_types = ['RNN', 'Xformer', 'Transformer', 'TCN']
    new_seq_lens = [3000, 5000,7000,9000]
    size = 'small'

    for model_type in model_types:
        all_results[model_type] = {}
        print('Current model is: ', model_type)

        file_name = f'{model_type}_test_different_seq_lens.pkl'
        file_path = os.path.join('/content/drive/My Drive', file_name)

        if os.path.exists(file_path):
            with open(file_path, 'rb') as f:
                results = pickle.load(f)
                all_results = results
            print(f'{file_path} found')
        else:

            # Get model parameters
            params = get_model_parameters(model_type, size, input_size, output_size, seq_len)

            # Create model instance
            model = model_generator(model_type,input_size, output_size, seq_len, params).cuda()


            # Evaluate accuracy
            random_seed(seed=seed)
            train_loader, test_loader = build_loaders(train_batch_size=train_batch_size,
                                                        test_batch_size=test_batch_size,
                                                        num_workers=num_workers,
                                                        seed=seed)
            # Loss and optimizer
            criterion = nn.MSELoss()
            optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate_dict[(model_type,size)])
            model.to(device)

            outputs = train(model, optimizer, train_loader, test_loader, criterion, device, num_epochs = num_epochs)

            # output_dir = Path(f'outputs_{model_type}')
            # output_dir.mkdir(exist_ok=True)

            stack_keys = ['train_loss_over_time', 'test_loss_over_time', 'test_loss_autoregressive_over_time']
            for key in stack_keys:
                if not isinstance(outputs[key], np.ndarray):
                    outputs[key] = np.stack(outputs[key])

            # # Measure throughput time
            # throughput_time = speed_test(model)

            # Organize the outputs
            results  = {
                'train_loss': outputs['train_loss'],
                'test_loss': outputs['test_loss'],
                'test_loss_autoregressive': outputs['test_loss_autoregressive'],
                'train_loss_over_time': outputs['train_loss_over_time'],
                'test_loss_over_time': outputs['test_loss_over_time'],
                'test_loss_autoregressive_over_time': outputs['test_loss_autoregressive_over_time']
            }
            all_results[model_type]['train'] = results

            for new_seq_len in new_seq_lens:
                test_loader.dataset.set_chunk_size(new_seq_len)
                test_loader.dataset.set_mode("truncate")
                if model_type in ['Transformer', 'Xformer']:
                    model.adjust_pe(new_seq_len)
                    model.pos_encoder = model.pos_encoder.cuda()
                test_outputs = test_loop(model, test_loader, criterion, device, seq_len=new_seq_len - 1)
                results  = {
                    'test_preds': test_outputs['test_preds'],
                    'test_loss': test_outputs['test_loss'],
                    'test_loss_autoregressive': test_outputs['test_loss_autoregressive'],
                }
                all_results[model_type][new_seq_len] = results
                print(len(test_outputs['test_preds']))



        # Store the results

        performance_metrics.append({
            'model_type': model_type,
            'seq_len': 1000,
            'test_loss_autoregressive': all_results[model_type]['train']['test_loss_autoregressive'][-1],

        })


        for new_seq_len in new_seq_lens:
            print(all_results[model_type][new_seq_len]['test_loss_autoregressive'])
            performance_metrics.append({
                'model_type': model_type,
                'seq_len': new_seq_len,
                'test_loss_autoregressive':all_results[model_type][new_seq_len]['test_loss_autoregressive']})

        with open(file_path, 'wb') as f:
              pickle.dump(all_results, f)

    return performance_metrics

def plot_performance_metrics_new_seq_lens(performance_metrics,markers,colors):
    # Convert to DataFrame for easier plotting
    df = pd.DataFrame(performance_metrics)

    # Sort the DataFrame by learning_rate to ensure the lines are plotted correctly
    df.sort_values('seq_len', inplace=True)

    # Create a figure and a set of subplots
    fig, ax = plt.subplots(figsize=(10, 6))

    # Plot lines for each model type and annotate points
    for model_type in df['model_type'].unique():
        model_data = df[df['model_type'] == model_type]
        ax.plot(model_data['seq_len'], model_data['test_loss_autoregressive'],
                label=model_type, marker=markers[model_type], linestyle='-', markersize=8, color=colors[model_type])

    # Set labels and legend

    # Set font size of the x and y axis tick labels
    ax.tick_params(axis='x', labelsize=15)  # Font size for x-axis
    ax.tick_params(axis='y', labelsize=15)  # Font size for y-axis
    ax.set_xlabel('Sequence Length', fontsize=20)
    ax.set_ylabel('Test Loss', fontsize=20)

    # Get handles and labels
    handles, labels = ax.get_legend_handles_labels()

    # Modify the order of your legend here
    # Example: if you want the second item to appear first and the first item second
    new_order = [0,3,2,1]  # Adjust based on your desired order
    ordered_handles = [handles[idx] for idx in new_order]
    ordered_labels = [labels[idx] for idx in new_order]

    # Create the new legend
    ax.legend(ordered_handles, ordered_labels, fontsize=20)
    # ax.legend(fontsize=20)

    # Show the plot
    plt.show()