In [1]:
import torch
import torch.nn as nn

from tqdm import tqdm
import wandb
import numpy as np
import pandas as pd

from models.transformer import Transformer
from datasets.npy_dataset import NpyDataset

Try balanced loss to account for class imbalance
log accuracy
balanced accuracy?

In [2]:
batch_size=8
train_dataset = NpyDataset('../data/Xtrain_base.npy', '../data/Ytrain_base.npy')
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True, num_workers=8)
val_dataset = NpyDataset('../data/Xval_base.npy', '../data/Yval_base.npy')
val_dataloader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False, num_workers=8)

vocab = pd.read_csv('../data/vocab.csv', header=None, index_col=0)
vocab_size = vocab.shape[0]
cat_label_mapping = pd.read_csv('../data/cat_label_mapping.csv', header=None)
num_classes = cat_label_mapping.shape[0]
context_length = train_dataset.x.shape[1]


device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

net = Transformer(num_layers=4,
                  vocab_size=vocab_size,
                  d_model=16,
                  d_q_k_v=16,
                  num_heads=4,
                  num_classes=num_classes,
                  hidden_dim=16)

net.to(device=device)
optim = torch.optim.AdamW(params = net.parameters())
loss_fn = nn.CrossEntropyLoss()
epochs=100

print(net.named_modules)

<bound method Module.named_modules of Transformer(
  (embed): Embedding(27038, 16)
  (attn_ops): ModuleList(
    (0-3): 4 x SelfAttention(
      (query): Linear(in_features=16, out_features=64, bias=True)
      (key): Linear(in_features=16, out_features=64, bias=True)
      (value): Linear(in_features=16, out_features=64, bias=True)
      (softmax): Softmax(dim=-1)
      (output): Linear(in_features=64, out_features=16, bias=True)
    )
  )
  (first_layernorms): ModuleList(
    (0-3): 4 x LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  )
  (feed_fwd_layers): ModuleList(
    (0-3): 4 x Linear(in_features=16, out_features=16, bias=True)
  )
  (feed_fwd_activations): ModuleList(
    (0-3): 4 x GELU(approximate='none')
  )
  (second_layernorms): ModuleList(
    (0-3): 4 x LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  )
  (final_mlp): MLP(
    (feed_fwd): Linear(in_features=16, out_features=16, bias=True)
    (sigmoid): GELU(approximate='none')
    (feed_fwd_2): Linear(in_fe

In [3]:
np.sum([param.numel() for param in net.parameters()])

np.int64(452239)

In [4]:
def accuracy(preds, labels):
    pred_classes = preds.argmax(dim=1)
    correct = (pred_classes == labels).sum().item()
    total = labels.size(0)
    return correct, total

In [5]:
wandb.init(project="cell_type_classification",config={})

for e in range(epochs):
    net.train()
    print("Training model")
    for i, (batch_x, batch_y) in tqdm(enumerate(train_dataloader), total=len(train_dataset)):
        batch_x = batch_x.to(device=device).to(torch.long)
        batch_y = batch_y.to(device=device)
        optim.zero_grad()
        pred = net(batch_x)
        loss = loss_fn(pred,batch_y)
        loss.backward()
        optim.step()
        
    net.eval()
    print("Computing validation loss")
    val_loss = 0.0
    num_batches = 0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for batch_x,batch_y in tqdm(val_dataloader, total=len(val_dataset)):
            batch_x = batch_x.to(device=device).to(torch.long)
            batch_y = batch_y.to(device=device)
            pred = net(batch_x)
            loss = loss_fn(pred,batch_y).item()

            correct, total = accuracy(pred, batch_y)
            val_correct += correct
            val_total += total

            val_loss+=loss
            num_batches+=1
    
    val_loss/=num_batches
    val_acc = val_correct / val_total

    train_loss = 0.0
    num_batches = 0
    train_correct = 0
    train_total = 0
    with torch.no_grad():
        for batch_x,batch_y in tqdm(train_dataloader, total=len(train_dataset)):
            batch_x = batch_x.to(device=device).to(torch.long)
            batch_y = batch_y.to(device=device)
            pred = net(batch_x)
            loss = loss_fn(pred,batch_y).item()

            correct, total = accuracy(pred, batch_y)
            train_correct += correct
            train_total += total
            
            train_loss+=loss
            num_batches+=1
    
    train_loss/=num_batches
    train_acc = train_correct / train_total

    print({"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc})
    wandb.log({"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc})


[34m[1mwandb[0m: Currently logged in as: [33measwaran[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training model


  1%|          | 66/7000 [00:23<41:23,  2.79it/s] 


KeyboardInterrupt: 