In [1]:
import numpy as np
import pickle
import torch
import torch.nn.functional as F

import os
import glob

from tqdm import tqdm

from tokenizers import Tokenizer
from tokenizers import decoders
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from model.LM import SimpleBigramModel, AttentionLM
from src.dataloaders import build_loaders
from hparams import Hparams

In [2]:
hparams = Hparams()

### Tokenization experiments

In [3]:
tokenizer, train_loader, _, _ = build_loaders(hparams)

In [4]:
model = AttentionLM(
    hparams, vocab_size=tokenizer.get_vocab_size(), att_func_type="full"
).to("cuda")
model.compile()

In [5]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

2127800

In [6]:
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.AdamW(model.parameters())

In [7]:
loss_buffer = 100

model.train()
for epoch in range(hparams.epochs):
    windowed_loss = np.zeros(loss_buffer, dtype=np.float32)
    bar = tqdm(total=len(train_loader))

    for idx, (x, y) in enumerate(train_loader):
        optim.zero_grad()

        x = x.to('cuda')
        y = y.to('cuda')
        
        logits = model(x)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        y = y.view(B*T)

        loss = loss_fn(logits, y)

        loss.backward()
        optim.step()

        windowed_loss[idx%loss_buffer] = loss

        bar.set_description(f"Loss: {windowed_loss.mean():.5f}")
        bar.update()
        
    print(loss)
    



  0%|          | 0/1492918 [00:00<?, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (8192x128 and 512x64)

In [None]:
tokenizer.decoder = decoders.WordPiece()

In [None]:
def inference(input:str, model, tokenizer, out_len:int, determenistic=False):
    model.eval()
    with torch.inference_mode():
        input = torch.tensor(tokenizer.encode(input).ids, dtype=torch.long).to('cuda')
        out = model.generate_batch(input, out_len, deterministic=determenistic)
        out = [tokenizer.decode(list(t)) for t in out]
        return out

In [None]:
inference('Hello', model, tokenizer, out_len=100)

['##ueuueeeueueueeoueuueueueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee']

## Self-Attention Experiments

### Notes:
- Query comes from other sequence (or the sequence itself in self attention) and is a value that converges to represent 