In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

words = open('data/names.txt', 'r').read().splitlines()
all_chars = ['.'] + sorted(list(set("".join(words))))
itos = {idx: v for idx, v in enumerate(all_chars)}
stoi = {v: k for k, v in itos.items()}

NUM_CHARS = len(all_chars)

In [None]:
block_size = 3

def build_dataset(words):
    X, Y = [], []

    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y

import random
random.seed(314)
random.shuffle(words)
n1 = int(0.7 * len(words))
Xtr, Ytr = build_dataset(words[:n1])
Xte, Yte = build_dataset(words[n1:])

print(Xtr.shape, Ytr.shape)
print(Xte.shape, Yte.shape)

In [None]:
EMBED_SIZE = 10

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.embedding = nn.Linear(NUM_CHARS, EMBED_SIZE, bias=False)
        self.layer1 = nn.Linear(3 * EMBED_SIZE, 32)
        self.layer2 = nn.Linear(32, NUM_CHARS)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, xs):
        x = F.one_hot(xs, num_classes=NUM_CHARS).to(torch.float)
        x = self.embedding(x)
        x = x.view(xs.shape[0], -1)
        x = self.layer1(x).tanh()
        x = self.layer2(x)
        # x = self.softmax(x) ... CrossEntropyLoss already contains softmax
        return x
    
model = Model()
sum([n.nelement() for n in model.parameters()])

In [None]:
xs = Xtr[:5]
xs.shape

In [None]:
embedding = nn.Linear(NUM_CHARS, EMBED_SIZE, bias=False)
embedded = embedding(F.one_hot(xs, num_classes=NUM_CHARS).to(torch.float))
embedded.shape

We can improve the quality of generated model by communicating the generated characters, in order to improve the quality of next generated character. 

A mechanism for such communication is Attention:
- Each character (represented by the embedding) is represented as a Value
- Each character have an associated Key, that describes what it can provide
- Each character have an associated Query, that describes the information it needs

All this transformations are learned while training

In [None]:
HEAD_SIZE = 8
Q = nn.Linear(EMBED_SIZE, HEAD_SIZE)
K = nn.Linear(EMBED_SIZE, HEAD_SIZE)
V = nn.Linear(EMBED_SIZE, HEAD_SIZE)

item = embedding(F.one_hot(Xtr[:10], num_classes=NUM_CHARS).to(torch.float))
q = Q(item)
k = K(item)
v = V(item)

In [None]:
item.shape, q.shape, k.shape, v.shape

In [None]:
# The query is multiplied by the key, in order to build the attention matrix
k.transpose(-2, -1).shape

In [None]:
attentions = (q @ k.transpose(-2, -1))
attentions.shape

In [None]:
weights = F.softmax(attentions, dim=-1)
probs = weights @ v
probs.shape

So, for each element in the batch, and for each character, we have a representation that now includes its relations to other characters ... 

In [None]:
class Model2(nn.Module):
    def __init__(self):
        super(Model2, self).__init__()
        self.embedding = nn.Linear(NUM_CHARS, EMBED_SIZE, bias=False)
        self.layer1 = nn.Linear(3 * HEAD_SIZE, 24)
        self.layer2 = nn.Linear(24, NUM_CHARS)
        self.softmax = nn.Softmax(dim=1)
        self.Q = nn.Linear(EMBED_SIZE, HEAD_SIZE)
        self.K = nn.Linear(EMBED_SIZE, HEAD_SIZE)
        self.V = nn.Linear(EMBED_SIZE, HEAD_SIZE)

    def forward(self, xs):
        x = F.one_hot(xs, num_classes=NUM_CHARS).to(torch.float)
        x = self.embedding(x)
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        attentions = (q @ k.transpose(-2, -1))
        weights = F.softmax(attentions, dim=-1)
        probs = weights @ v        
        x = probs.view(xs.shape[0], -1)
        x = self.layer1(x).tanh()
        x = self.layer2(x)
        return x
    
model2 = Model2()
sum([n.nelement() for n in model2.parameters()])

In [None]:
xs = Xtr[:5]
model2(xs).shape

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
import torch.optim as optim

optimizer = optim.Adam(model2.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

Xtr_dev = Xtr.to(device)
Ytr_dev = Ytr.to(device)
model2_dev = model2.to(device)
# Training loop
num_epochs = 10000
for epoch in range(num_epochs):
    model2_dev.train()

    outputs = model2_dev(Xtr_dev)
    loss = loss_fn(outputs, Ytr_dev)

    optimizer.zero_grad()
    loss.backward()
        
    optimizer.step()

    if epoch % (num_epochs // 10) == 0:
        print(epoch, loss.item())

print("Training complete")