In [5]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import pandas as pd
from transformers import AdamW, get_scheduler
from datasets import load_metric

from sklearn.preprocessing import LabelEncoder
from torch.nn.utils.rnn import pad_sequence
from saveAndLoad import *

from torch.utils.data import DataLoader, Subset, Dataset
from sklearn.model_selection import train_test_split

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-8, bias = None):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = x.norm(keepdim=True, dim=-1) * (x.size(-1) ** -0.5)
        return self.scale * (x / (norm + self.eps))
    
class MLP(nn.Module):

    def __init__(self, config, use_dropout=True):
        super().__init__()
        self.c_fc    = nn.Linear(config.input_dim, 4 * config.input_dim, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.input_dim, config.input_dim, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
        self.use_dropout = use_dropout

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        if self.use_dropout: x = self.dropout(x)
        return x
    
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.input_dim = config.input_dim
        self.query = nn.Linear(config.input_dim, config.input_dim)
        self.key = nn.Linear(config.input_dim, config.input_dim)
        self.value = nn.Linear(config.input_dim, config.input_dim)
        self.softmax = nn.Softmax(dim=2)
        self.dense_layer = nn.Linear(config.input_dim, config.input_dim)
        self.dropout = nn.Dropout(config.dropout)
        self.config = config

    def forward(self, x):
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)
        scores = torch.bmm(queries, keys.transpose(-2, -1)) / (self.input_dim ** 0.5)
        attention = self.softmax(scores)
        y = torch.bmm(attention, values)
        return self.dense_layer(y)

class Block(nn.Module):
    def __init__(self, config, norm_fn = nn.LayerNorm):
        super().__init__()
        self.norm1 = config.norm_fn(config.input_dim)  # RMS normalization
        self.norm2 = config.norm_fn(config.input_dim)
        self.attn = Attention(config)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class Classifier(nn.Module):
    def __init__(self, config):
        super(Classifier, self).__init__()

        self.blocks = nn.ModuleList([Block(config, norm_fn = config.norm_fn) for _ in range(config.n_layer)])

        self.input_dim = config.input_dim 
        self.pooling = config.pooling
        assert self.pooling in ['cls', 'mean', 'max'], 'pooling should be either cls, mean, or max'
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.input_dim))  # Learnable CLS token

        # self.emb_transform = nn.Linear(config.input_dim, config.input_dim)  
        # self.emb_transform = MLP(config,use_dropout=False)  

        self.num_labels = config.n_labels # number of labels for classifier
        self.classifier = nn.Linear(config.input_dim, config.n_labels) # FC Layer
        self.loss_func = nn.CrossEntropyLoss() # Change this if it becomes more than binary classification
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, x):
        # is_pad = x == float('-inf')
        # pad_rows = is_pad.all(dim=2)
        # x = self.emb_transform(x)

        if self.pooling == 'cls':
            batch_size = x.size(0)
            cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # Expand CLS token for each sequence in the batch
            x = torch.cat((cls_tokens, x), dim=1)
        
        for block in self.blocks:
            x = block(x)

        if self.pooling == 'cls':
            classifier_input = x[:, 0, :].view(-1, self.input_dim)
        elif self.pooling == 'mean':
            classifier_input = x.mean(dim=1)
        elif self.pooling == 'max':
            classifier_input, _ = x.max(dim=1)

        logits = self.classifier(classifier_input)
        return logits

In [10]:
from custom_dataset import custom_collate, Dataset_MutationList

# LOAD DATA
canonical_mut_embeddings_esm2 = np.load('../aa/canonical_mut_embeddings_esm2.npy')
data_dir = '../labeled_data/'
labeled_data = os.listdir(data_dir)
for ni,i in enumerate(labeled_data):print(ni,i)
data = labeled_data[0]
print('\n',data)
data_df = pd.read_csv(data_dir+data)
data = data_df['idxs'].values
labels = torch.tensor(data_df['int_label'].values,dtype=torch.long)
nlabels = len(data_df['int_label'].unique())
device = 'cuda:1'

# Create dataset
dataset = Dataset_MutationList(data, labels, canonical_mut_embeddings_esm2,device)

# Create DataLoader
# dataloader = DataLoader(dataset, batch_size=100, shuffle=False, collate_fn=custom_collate)

## TEST/TRAIN SPLIT
test_size = .2
random_state = 42
batch_size = 1
indices = list(range(len(dataset)))

train_indices, test_indices = train_test_split(
    indices, 
    test_size=test_size, 
    random_state=random_state
)

train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

0 BINARYdata_CANCER_TYPE_3MinMutations_1696MinCancerType.csv
1 data_CANCER_TYPE_DETAILED_3MinMutations_1696MinCancerType.csv
2 BINARYdata_CANCER_TYPE_DETAILED_3MinMutations_1696MinCancerType.csv
3 data_CANCER_TYPE_DETAILED_0MinMutations_1696MinCancerType.csv
4 BINARYdata_CANCER_TYPE_DETAILED_0MinMutations_1696MinCancerType.csv
5 data_CANCER_TYPE_3MinMutations_169MinCancerType.csv
6 data_CANCER_TYPE_0MinMutations_1696MinCancerType.csv
7 BINARYdata_CANCER_TYPE_0MinMutations_169MinCancerType.csv
8 data_CANCER_TYPE_0MinMutations_169MinCancerType.csv
9 BINARYdata_CANCER_TYPE_DETAILED_0MinMutations_169MinCancerType.csv
10 data_CANCER_TYPE_DETAILED_0MinMutations_169MinCancerType.csv
11 BINARYdata_CANCER_TYPE_DETAILED_3MinMutations_169MinCancerType.csv
12 BINARYdata_CANCER_TYPE_3MinMutations_169MinCancerType.csv
13 BINARYdata_CANCER_TYPE_0MinMutations_1696MinCancerType.csv
14 data_CANCER_TYPE_3MinMutations_1696MinCancerType.csv
15 data_CANCER_TYPE_DETAILED_3MinMutations_169MinCancerType.csv

 

In [3]:
# majority classifier
print(len(data_df['int_label'].unique()))
sorted(data_df['int_label'].value_counts(),reverse=True)[0]/len(data_df['int_label'])

17


0.22298445159933342

In [6]:
import torch.optim as optim
from tqdm import tqdm

class Config:
    n_layer: int = 3
    input_dim: int = 640
    dropout: float = 0.0
    bias: bool = False
    n_labels: int = 17
    pooling : str = 'mean'
    norm_fn: nn.Module = nn.LayerNorm

print('n labels:',nlabels)
config = Config()
config.n_labels = nlabels

model = Classifier(config)
model.to(device)

num_epochs = 3
learning_rate = 0.001

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    with tqdm(enumerate(train_loader), total=len(train_loader),desc='TRAINING') as pbar:
        for batch_idx, (data, target) in pbar:
            optimizer.zero_grad()
            output = model(data)
            # assert False
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            pbar.set_postfix({'Epoch':f'{epoch+1}/{num_epochs}, Loss: {loss.item():.4f}'})
            if batch_idx % 20000 == 0:
                print('')

        # Evaluation
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for data, target in tqdm(test_loader,desc='TESTING'):
                output = model(data)
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

        accuracy = 100 * correct / total
        print(f'Test Accuracy: {accuracy:.2f}%, ({correct} of {total})')

n labels: 17


TRAINING:   0%|          | 20/49445 [00:00<09:59, 82.38it/s, Epoch=1/3, Loss: 1.7225]  




TRAINING:  41%|████      | 20029/49445 [01:32<02:54, 168.98it/s, Epoch=1/3, Loss: 1.3050]  




TRAINING:  81%|████████  | 40020/49445 [04:06<01:16, 122.43it/s, Epoch=1/3, Loss: 1.0588]  




TRAINING: 100%|██████████| 49445/49445 [05:22<00:00, 153.24it/s, Epoch=1/3, Loss: 7.1183]   
TESTING: 100%|██████████| 12362/12362 [00:10<00:00, 1144.27it/s]


Test Accuracy: 22.30%, (2757 of 12362)


TRAINING:   0%|          | 15/49445 [00:00<05:41, 144.66it/s, Epoch=2/3, Loss: 1.8787]




TRAINING:  41%|████      | 20036/49445 [02:23<02:05, 235.26it/s, Epoch=2/3, Loss: 1.5146] 




TRAINING:  81%|████████  | 40015/49445 [04:58<01:22, 114.69it/s, Epoch=2/3, Loss: 3.8572] 




TRAINING: 100%|██████████| 49445/49445 [06:01<00:00, 136.72it/s, Epoch=2/3, Loss: 2.7306] 
TESTING: 100%|██████████| 12362/12362 [00:10<00:00, 1134.48it/s]


Test Accuracy: 38.05%, (4704 of 12362)


TRAINING:   0%|          | 17/49445 [00:00<05:04, 162.46it/s, Epoch=3/3, Loss: 3.9766]




TRAINING:  40%|████      | 20017/49445 [02:14<04:12, 116.61it/s, Epoch=3/3, Loss: 1.6758]   




TRAINING:  81%|████████  | 40023/49445 [04:22<00:43, 218.64it/s, Epoch=3/3, Loss: 2.1613]  




TRAINING: 100%|██████████| 49445/49445 [05:10<00:00, 159.12it/s, Epoch=3/3, Loss: 4.4619] 
TESTING: 100%|██████████| 12362/12362 [00:10<00:00, 1138.38it/s]

Test Accuracy: 40.17%, (4966 of 12362)



