In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
import wandb
import numpy as np
import pandas as pd

from models.transformer import Transformer

In [2]:
class NpyDataset(torch.utils.data.Dataset):
    def __init__(self, X_path, Y_path):
        self.x = np.load(X_path, mmap_mode='r')
        self.y = np.load(Y_path, mmap_mode='r')

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]
        return x.astype(int), y.astype(int)

In [3]:
batch_size=16428
train_dataset = NpyDataset('../data/Xtrain_base.npy', '../data/Ytrain_base.npy')
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_dataset = NpyDataset('../data/Xtest_base.npy', '../data/Ytest_base.npy')
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True)

vocab = pd.read_csv('../data/vocab.csv', header=None, index_col=0)
vocab_size = vocab.shape[0]
cat_label_mapping = pd.read_csv('../data/cat_label_mapping.csv', header=None)
num_classes = cat_label_mapping.shape[0]
context_length = train_dataset.x.shape[1]


device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
net = Transformer(num_layers=4, vocab_size=vocab_size, num_tokens=context_length, d_model=16, d_q_k_v=16, num_heads=4, num_classes=num_classes, hidden_dim=16)
net.to(device=device)
optim = torch.optim.AdamW(params = net.parameters())
loss_fn = nn.CrossEntropyLoss()
epochs=100


print(net.parameters)

<bound method Module.parameters of Transformer(
  (embed): Embedding(27038, 16)
  (attn_ops): ModuleList(
    (0-3): 4 x SelfAttention(
      (query): Linear(in_features=16, out_features=64, bias=True)
      (key): Linear(in_features=16, out_features=64, bias=True)
      (value): Linear(in_features=16, out_features=64, bias=True)
      (softmax): Softmax(dim=2)
      (output): Linear(in_features=64, out_features=16, bias=True)
    )
  )
  (first_layernorms): ModuleList(
    (0-3): 4 x LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  )
  (feed_fwd_layers): ModuleList(
    (0-3): 4 x Linear(in_features=16, out_features=16, bias=True)
  )
  (feed_fwd_activations): ModuleList(
    (0-3): 4 x Sigmoid()
  )
  (second_layernorms): ModuleList(
    (0-3): 4 x LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  )
  (final_mlp): MLP(
    (feed_fwd): Linear(in_features=16, out_features=16, bias=True)
    (sigmoid): Sigmoid()
    (feed_fwd_2): Linear(in_features=16, out_features=47, bias=T

In [4]:
np.sum([param.numel() for param in net.parameters()])

np.int64(452239)

In [5]:
# wandb.init(project="cell_type_classification",config={})

for e in range(epochs):
    net.eval()
    print("Computing training loss")
    lossTrain = 0.0
    num_batches = 0
    for batch_x,batch_y in tqdm(train_dataloader):
        batch_x = batch_x.to(device=device).to(torch.long)
        batch_y = batch_y.to(device=device)
        pred_batch = net(batch_x)
        loss = loss_fn(pred_batch,batch_y).item()
        lossTrain+=loss
        num_batches+=1
        loss_fn.zero_grad()

    lossTrain/=num_batches

    print("Computing testing loss")
    lossTest = 0.0
    num_batches = 0
    for batch_x,batch_y in tqdm(test_dataloader):
        batch_x = batch_x.to(device=device).to(torch.long)
        batch_y = batch_y.to(device=device)
        pred_batch = net(batch_x)
        loss = loss_fn(pred_batch,batch_y).item()
        lossTest+=loss
        num_batches+=1
        loss_fn.zero_grad()
    
    lossTest/=num_batches

    print({"train_loss": lossTrain, "test_loss": lossTest})
    # wandb.log({"train_loss": lossTrain, "test_loss": lossTest})

    print("Training model")
    net.train()
    for i, (batch_x, batch_y) in tqdm(enumerate(train_dataloader)):
        batch_x = batch_x.to(device=device).to(torch.long)
        batch_y = batch_y.to(device=device)
        pred = net(batch_x)
        loss = loss_fn(pred,batch_y)
        loss.backward()
        optim.step()
        optim.zero_grad()

Computing training loss


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([7000, 2049, 4, 16]) torch.Size([7000, 2049, 4, 16])





RuntimeError: Invalid buffer size: 437.93 GB

In [11]:
from models.self_attention import SelfAttention

In [13]:
attn_block = nn.Sequential()
attn_block.add_module('attn_1', SelfAttention().to(device='mps'))

In [19]:
attn_block.named_modules

<bound method Module.named_modules of Sequential(
  (attn_1): SelfAttention(
    (query): Linear(in_features=512, out_features=768, bias=True)
    (key): Linear(in_features=512, out_features=768, bias=True)
    (value): Linear(in_features=512, out_features=768, bias=True)
    (softmax): Softmax(dim=2)
    (output): Linear(in_features=768, out_features=512, bias=True)
  )
)>

In [30]:
layer = attn_block.modules().__next__()

In [None]:
layer.

Sequential(
  (attn_1): SelfAttention(
    (query): Linear(in_features=512, out_features=768, bias=True)
    (key): Linear(in_features=512, out_features=768, bias=True)
    (value): Linear(in_features=512, out_features=768, bias=True)
    (softmax): Softmax(dim=2)
    (output): Linear(in_features=768, out_features=512, bias=True)
  )
)

In [38]:
num_layers = 4
attn_blocks = nn.ModuleList()
num_tokens = 2048
d_model = 512
d_q_k_v = 16
num_heads = 4

for i in range(num_layers):
    attn_block = nn.Sequential()
    attn_block.add_module(f'attn_{i}', SelfAttention(num_tokens, d_model, d_q_k_v, num_heads).to(device='mps'))
    attn_block.add_module(f'ln1_{i}', nn.LayerNorm(d_model).to(device='mps'))
    attn_block.add_module(f'fwd_{i}', nn.Linear(in_features=d_model, out_features=d_model).to(device='mps'))
    attn_block.add_module(f'fwd_activation_{i}', nn.Sigmoid().to(device='mps'))
    attn_block.add_module(f'ln2_{i}', nn.LayerNorm(d_model).to(device='mps'))
    attn_blocks.add_module(f'attn_block_{i}', attn_block)

In [47]:
attn_blocks['attn_0']

TypeError: 'str' object cannot be interpreted as an integer

In [52]:
attn_blocks.attn_block_0.attn_0

SelfAttention(
  (query): Linear(in_features=512, out_features=64, bias=True)
  (key): Linear(in_features=512, out_features=64, bias=True)
  (value): Linear(in_features=512, out_features=64, bias=True)
  (softmax): Softmax(dim=2)
  (output): Linear(in_features=64, out_features=512, bias=True)
)