In [1]:
from tokenizers import Tokenizer
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm.notebook import trange, tqdm

from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordPieceTrainer





# Paths

In [None]:
DATA_DIR = Path("data")
TOK_LOC = DATA_DIR / "tokenizer-goethe_schiller_raimund.json"

# Params

In [2]:
VOCAB_SIZE = 512 * 4
BLOCK_SIZE = 10
BATCH_SIZE = 512  # how many independent sequences will we process in parallel?
MAX_ITERS = 5000
EVAL_INTERVAL = 500
LEARNING_RATE = 3e-4
EVAL_ITERS = 200
N_EMBD = 384
N_HEAD = 6
N_LAYER = 6
DROPOUT = 0.2
SPECIAL_TOKENS = ["[UNK]", "[SOS]", "[EOS]"]



In [None]:
tokenizer = Tokenizer(WordPiece())
tokenizer.pre_tokenizer = Whitespace()
trainer = WordPieceTrainer(vocab_size=VOCAB_SIZE, special_tokens=SPECIAL_TOKENS)

tokenizer.train([str(DATA_DIR / "train.txt")], trainer)
tokenizer.save(str(TOK_LOC))


In [3]:
torch.manual_seed(1337)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
tokenizer = Tokenizer.from_file(str(TOK_LOC))
with open(DATA_DIR / "train.txt", "r") as f:
    train_enc = tokenizer.encode(f.read())
with open(DATA_DIR / "val.txt", "r") as f:
    val_enc = tokenizer.encode(f.read())

In [5]:
def build_batch(split):
    data = train_enc.ids if split == 'train' else val_enc.ids
    ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
    #torch.tensor uses ints and torch.Tensor uses float also torch.Tensor is annoying
    x = torch.stack([torch.tensor(data[i:i+BLOCK_SIZE]) for i in ix])
    y = torch.stack([torch.tensor(data[i+1:i+BLOCK_SIZE+1]) for i in ix])
    X, Y = x.to(device), y.to(device)
    return X, Y


In [6]:
class BigramLanguageModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE, VOCAB_SIZE).to(device)
    
    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)
        
        if targets is None:
            loss = None
        else:
        
            B, T, C = logits.shape
        
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx 
        
        


In [7]:
xb, yb = build_batch("train")
m = BigramLanguageModel()
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)
idx = torch.zeros((1,1), dtype=torch.long).to(device)
idx = m.generate(idx, max_new_tokens=100)[0].tolist()
print(tokenizer.decode(idx))

torch.Size([5120, 512])
tensor(6.7330, grad_fn=<NllLossBackward0>)
Und auch ##ot Er # 6 ##he ha ##r auch 7  ##rt  ##? § L ##acht È ihm hat — ##rie E } ##ë r ##û · ##A ##ll ô ; ##4 ##Y im @ dem ##H ##um ##ann ( ##ber Sie Ein ##eh ## ? s ##ck ##it ##à ‘ L aber V O ##ar ##ige ##uß sch X ##-- ##igen ##ls auch ##q ##cht ##hr “ f L ##R ##eit Sch ##Ä wie K ##ers ##nen ##ol ##ie ##X « H ##” ##ill . ” ##d ##anz ##as ##is ##ô ##er ##ber ##lle ##eine dem ##ung


In [8]:
optimizer = torch.optim.AdamW(m.parameters(),lr=1e-3)

In [9]:
for steps in trange(1000):
    
    xb, yb = build_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out