In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import pandas as pd
from transformers import AdamW, get_scheduler
from datasets import load_metric

from sklearn.preprocessing import LabelEncoder
from torch.nn.utils.rnn import pad_sequence
from saveAndLoad import *

from torch.utils.data import DataLoader, Subset, Dataset
from sklearn.model_selection import train_test_split
from torch.nn import functional as F
import math

class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-8, bias = None):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = x.norm(keepdim=True, dim=-1) * (x.size(-1) ** -0.5)
        return self.scale * (x / (norm + self.eps))

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = config.norm_fn(config.n_embd, bias=config.bias)
        self.attn = SelfAttention(config)
        self.ln_2 = config.norm_fn(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
    
class MLP(nn.Module):

    def __init__(self, config, use_dropout=True):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
        self.use_dropout = use_dropout

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        if self.use_dropout: x = self.dropout(x)
        return x

class SelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        # efficient attention using Flash Attention CUDA kernels
        y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.pooling = config.pooling

        self.transformer = nn.ModuleDict(dict(
            drop = nn.Dropout(config.dropout),
            blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = config.norm_fn(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # with weight tying when using torch.compile() some warnings get generated:
        # "UserWarning: functional_call was passed multiple values for tied weights.
        # This behavior is deprecated and will be an error in future versions"
        # not 100% sure what this is, so far seems to be harmless. TODO investigate
        # self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, targets=None):
        
        for block in self.transformer.blocks:
            x = block(x)
        x = self.transformer.ln_f(x)

        if self.pooling == 'cls':
            classifier_input = x[:, 0, :].view(-1, self.n_embd)
        elif self.pooling == 'mean':
            classifier_input = x.mean(dim=1)
        elif self.pooling == 'max':
            classifier_input, _ = x.max(dim=1)

        logits = self.lm_head(classifier_input)

        return logits

In [None]:
from custom_dataset import *

# LOAD DATA
canonical_mut_embeddings_esm2 = np.load('../aa/canonical_mut_embeddings_esm2.npy')
data_dir = '../labeled_data/'
labeled_data = os.listdir(data_dir)
for ni,i in enumerate(labeled_data):print(ni,i)
data = labeled_data[0]
print('\n',data)
data_df = pd.read_csv(data_dir+data)
data = data_df['idxs'].values
labels = torch.tensor(data_df['int_label'].values,dtype=torch.long)
nlabels = len(data_df['int_label'].unique())
device = 'cuda:1'

# Create dataset
dataset = Dataset_MutationList(data, labels, canonical_mut_embeddings_esm2,device)

# Create DataLoader
# dataloader = DataLoader(dataset, batch_size=100, shuffle=False, collate_fn=custom_collate)

# TEST/TRAIN SPLIT
test_size = .2
random_state = 42
batch_size = 1
indices = list(range(len(dataset)))

train_indices, test_indices = train_test_split(
    indices, 
    test_size=test_size, 
    random_state=random_state
)

train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

In [None]:
# majority classifier
print(len(data_df['int_label'].unique()))
sorted(data_df['int_label'].value_counts(),reverse=True)[0]/len(data_df['int_label'])

In [None]:
import torch.optim as optim
from tqdm import tqdm

class GPTConfig:
    block_size: int = 30
    vocab_size: int = 17 #n_labels
    n_layer: int = 1 #12
    n_head: int = 1 #10
    n_embd: int = 640
    dropout: float = 0.0
    norm_fn: nn.Module =  LayerNorm
    pooling: str = 'mean'
    bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

print('n labels:',nlabels)
config = GPTConfig()
config.vocab_size = nlabels
config.pooling = 'mean'
config.norm_fn = RMSNorm

model = GPT(config)
model.to('cuda:1')

num_epochs = 10
learning_rate = 0.001

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    with tqdm(enumerate(train_loader), total=len(train_loader),desc='TRAINING') as pbar:
        for batch_idx, (data, target) in pbar:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            pbar.set_postfix({'Epoch':f'{epoch+1}/{num_epochs}, Loss: {loss.item():.4f}'})
            if batch_idx % 20000 == 0:
                print('')

        # Evaluation
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for data, target in tqdm(test_loader,desc='TESTING'):
                output = model(data)
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

        accuracy = 100 * correct / total
        print(f'Test Accuracy: {accuracy:.2f}%, ({correct} of {total})')