In [15]:
import numpy as np
import pickle
import torch
import torch.nn.functional as F

import os
import glob

from tqdm import tqdm

from tokenizers import Tokenizer
from tokenizers import decoders
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from src.model import SimpleBigramModel
from src.dataloaders import build_loaders
from hparams import Hparams

In [16]:
hparams = Hparams()

### Tokenization experiments

In [17]:
tokenizer, train_loader, _, _ = build_loaders(hparams)

In [18]:
model = SimpleBigramModel(vocab_size = tokenizer.get_vocab_size())
model = model.to('cuda')

In [19]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

9000000

In [20]:
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.AdamW(model.parameters())

In [21]:
model.train()
for epoch in range(hparams.epochs):
    for x, y in tqdm(train_loader):
        optim.zero_grad()

        x = x.to('cuda')
        y = y.to('cuda')
        
        logits = model(x)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        y = y.view(B*T)

        loss = loss_fn(logits, y)

        loss.backward()
        optim.step()

    print(loss)



  0%|          | 52556/15924451 [02:27<12:21:39, 356.68it/s]


KeyboardInterrupt: 

In [23]:
tokenizer.decoder = decoders.WordPiece()

In [24]:
def inference(input:str, model, tokenizer, out_len:int, determenistic=False):
    model.eval()
    with torch.inference_mode():
        input = torch.tensor(tokenizer.encode(input).ids, dtype=torch.long).to('cuda')
        out = model.generate(input, out_len, deterministic=determenistic)
        out = tokenizer.decode(list(out))
        return out

In [31]:
inference('Hello', model, tokenizer, out_len=100)

"Hello ' s Briscrusbback, and its first prote Swd the CD gameplish cookanelfs acrossingoons. Angelization of 55 hung community ), the 25, detail Gunteer madian, and supposed that Buctically, it, writing for John ( Saki ( Duaded at time treatment with arr coron S"

## Self-Attention Experiments

### Notes:
- Query comes from other sequence (or the sequence itself in self attention) and is a value that converges to represent 