In [87]:
import torch
import random
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim

DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [88]:
BATCH_SIZE = 32
CONTEXT_SIZE = 32
MAX_LEN = 128
EMBEDDING_DIMS = 128
NO_HEADS = 4
assert EMBEDDING_DIMS % NO_HEADS == 0
HEAD_DIMS = EMBEDDING_DIMS // NO_HEADS
NO_LAYERS = 5
DROPOUT = 0.2

In [89]:
with open("data.txt") as f:
    data = f.read()
chars = list(set(data))
VOCAB_SIZE = len(chars)
s_to_i = {s: i for i, s in enumerate(chars)}
i_to_s = {v: k for k, v in s_to_i.items()}

def encode(text):
    return [s_to_i[c] for c in text]

def decode(ids):
    return "".join([i_to_s[i] for i in ids])

data = encode(data)
train_data = data[:int(len(data) * 0.9)]
val_data = data[int(len(data) * 0.9):]

def batch_generator(split, indefinite=False):
    idx = 0
    data = train_data if split == "train" else val_data
    per_batch_tokens = BATCH_SIZE * CONTEXT_SIZE + 1
    while True:
        if idx + per_batch_tokens > len(data):
            # TODO: finish this
            if not indefinite:
                return
            b_data = data[idx:]
            b_data += data[0: per_batch_tokens - len(b_data)]
        else:
            b_data = data[idx: idx + per_batch_tokens]
            idx += BATCH_SIZE * CONTEXT_SIZE
        x = [b_data[i * CONTEXT_SIZE: (i+1) * CONTEXT_SIZE] for i in range(BATCH_SIZE)]
        y = [b_data[i * CONTEXT_SIZE + 1: (i+1) * CONTEXT_SIZE + 1] for i in range(BATCH_SIZE)]
        yield torch.tensor(x, device=DEV), torch.tensor(y, device=DEV)

In [90]:
class AttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.Wk = nn.Linear(EMBEDDING_DIMS, HEAD_DIMS)
        self.Wv = nn.Linear(EMBEDDING_DIMS, HEAD_DIMS)
        self.Wq = nn.Linear(EMBEDDING_DIMS, HEAD_DIMS)
        self.ln = nn.LayerNorm(EMBEDDING_DIMS)
        self.drop = nn.Dropout(DROPOUT)
        self.register_buffer("mask", torch.tril(torch.ones((32, 32))))

    def forward(self, k, q, v):
        k, q, v = self.Wk(k), self.Wv(v), self.Wq(q)
        attn = (q @ k.transpose(-1, -2)) * EMBEDDING_DIMS**-0.5
        attn = attn.masked_fill(self.mask[:,:] == 0, float("-inf"))
        attn = F.softmax(attn, dim=-1)
        attn = self.drop(attn)
        return attn @ v

In [91]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(EMBEDDING_DIMS, EMBEDDING_DIMS * 4),
            nn.ReLU(),
            nn.Linear(EMBEDDING_DIMS * 4, EMBEDDING_DIMS),
            nn.Dropout(DROPOUT)
        )

    def forward(self, x):
        return self.layers(x)

In [92]:
class MultiheadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead() for _ in range(NO_HEADS)]
        )
        self.proj = nn.Linear(EMBEDDING_DIMS, EMBEDDING_DIMS)
        self.drop = nn.Dropout(DROPOUT)

    def forward(self, x):
        x = torch.cat([h(k=x, q=x, v=x) for h in self.heads], dim=-1)
        x = self.proj(x)
        return self.drop(x)

In [93]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.self_attn = MultiheadAttention()
        self.ln1 = nn.LayerNorm(EMBEDDING_DIMS)
        self.ffwd = FeedForward()
        self.ln2 = nn.LayerNorm(EMBEDDING_DIMS)
        self.drop = nn.Dropout(DROPOUT)

    def forward(self, x):
        return self.ln2(self.ffwd(x) + x)

In [94]:
class TransformerEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok_emb = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIMS)
        self.pos_enc = torch.randn((MAX_LEN, EMBEDDING_DIMS))
        positions = torch.arange(0, MAX_LEN).unsqueeze(1)
        _2i = torch.arange(0, EMBEDDING_DIMS, 2)
        self.pos_enc[:,::2] = torch.sin(positions / 10000 ** (_2i / EMBEDDING_DIMS))
        self.pos_enc[:, 1::2] = torch.cos(positions / 10000 ** (_2i / EMBEDDING_DIMS))

    def forward(self, tokens):
        return self.tok_emb(tokens) + self.pos_enc[tokens]

In [95]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.embeddings = TransformerEmbedding()
        self.layers = nn.Sequential(*[
            Encoder() for _ in range(NO_LAYERS)
        ])
        self.to_vocab = nn.Linear(EMBEDDING_DIMS, VOCAB_SIZE)

    def forward(self, tokens):
        return self.to_vocab(self.layers(self.embeddings(tokens)))

In [59]:
# training
EPOCHS = 100
LR = 0.1

net = Model()
optimizer = optim.AdamW(net.parameters(), lr=LR)
generator = batch_generator("train")

for e in range(1, EPOCHS + 1):
    x, y = get_batch("train")
    logits = net(x)
    loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), y.view(-1))
    print("%.4f" %(loss.item()))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

4.3093
3.7442
10.5646
14.0874
14.5437
11.7487
9.6533
7.0158
5.7383
5.1075
6.3777
6.1644
5.3338
4.7233
4.4499
4.6398
4.0669
4.0811
4.2670
3.9441
3.8328
3.8894
3.8092
3.5941
3.6241
3.6909
3.4887
3.5374
3.5249
3.4560
3.4633
3.4464
3.4208
3.4478
3.4146
3.4010
3.4250
3.3911
3.3944
3.3645
3.3379
3.3878
3.3717
3.4076
3.3715
3.3639
3.3821
3.4071
3.3146
3.3589
3.3394
3.3602
3.3433
3.3066
3.3420
3.2904
3.3029
3.3305
3.3465
3.3674
3.3764
3.3005
3.3268
3.3278
3.3606
3.3332
3.3406
3.3485
3.3330
3.3066
3.3301
3.3350
3.3323
3.3462
3.3246
3.3032
3.3472
3.3109
3.3300
3.3435
3.3179
3.3359
3.3008
3.3104
3.2983
3.3132
3.3395
3.3212
3.3392
3.3399
3.3126
3.2983
3.3071
3.2984
3.3110
3.3562
3.3388
3.3101
3.3272
3.3311
