# Transformer
The goal of this notebook is to create an educational implementation of the transformer with minimal dependencies that's easy to map to the paper.

### References
https://github.com/pytorch/examples/blob/main/word_language_model/

https://github.com/karpathy/nanoGPT/tree/master

In [1]:
import math
import torch
import torch.nn as nn

In [2]:
class Embeddings(nn.Module):
    def __init__(self, positions, vocab_size, dmodel):
        super().__init__()
        self.word_embs = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=dmodel
        )
        p_grid, i_grid = torch.meshgrid(
            torch.arange(positions),
            torch.arange(dmodel),
            indexing='ij'
        )
        self.pos_sin = torch.sin(p_grid / (10000**(2 * i_grid / dmodel)))
        self.pos_cos = torch.cos(p_grid / (10000**(2 * i_grid / dmodel)))

    def forward(self, x):
        e = self.word_embs(x)
        e = (e + self.pos_sin + self.pos_cos)
        return e

In [3]:
# Inputs q,k,v are positions x dk, or positions x dv
# Dot all q with k. One must be tranposed to make this line up
class SDPA(nn.Module):
    def __init__(self, causal, positions, dkv):
        super().__init__()
        self.pos = positions
        self.d = dkv
        self.causal = causal
        if causal:
            self.mask = torch.tril(torch.ones(self.pos, self.pos))
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, q, k, v):
        assert(self.d == k.shape[-1])
        q_dot_k = torch.bmm(q, (k.transpose(-2, -1)))
        scaled = q_dot_k / (math.sqrt(self.d))
        # masked = torch.bmm(scaled, self.mask.expand(q.shape[0], self.pos, self.pos))  # Expand to match the batch dimension
        if self.causal:
            # TODO is this an elementwise or a matmul?
            scaled = scaled * self.mask
        logits = self.softmax(scaled)
        out = torch.bmm(logits, v)
        return out


In [4]:
class Head(nn.Module):
    def __init__(self, causal, positions, dkv, dmodel):
        super().__init__()
        self.v_projection = nn.Linear(dmodel, dkv)
        self.k_projection = nn.Linear(dmodel, dkv)
        self.q_projection = nn.Linear(dmodel, dkv)

        self.sdpa = SDPA(causal, positions, dkv)

    def forward(self, v, k, q):
        vp = self.v_projection(v)
        kp = self.k_projection(k)
        qp = self.q_projection(q)
        return self.sdpa(qp, kp, vp)

# dk = dv = dmodel/nhead
# wq is dmodel x dk
# wk is dmodel x dk
# wv is dmodel x dv
class MHA(nn.Module):
    def __init__(self, causal, positions, nhead, dmodel):
        super().__init__()
        assert(dmodel % nhead == 0)
        self.heads = nn.ModuleList([Head(causal, positions, dmodel//nhead, dmodel) for _ in range(nhead)])
        
        # TODO In the paper it's called WO, and hd_v x dmodel, but b/c d_v = dmodel/h it turns out to be square.
        self.out_projection = nn.Linear(dmodel, dmodel)

    def forward(self, v, k, q):
        head_outputs = [h(v, k, q) for h in self.heads]
        catd = torch.cat(head_outputs, dim=-1)
        
        projection = self.out_projection(catd)
        return projection

In [5]:
class FeedForward(nn.Module):
    def __init__(self, dmodel, ff, dropout=0.1):
        super().__init__()
        self.f0 = nn.Linear(dmodel, ff)
        self.f1 = nn.Linear(ff, dmodel)
        self.dropout = nn.Dropout(p=dropout)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.f0(x)
        x = self.act(x)
        x = self.f1(x)
        x = self.dropout(x)
        return x

# Input/Output are Batch, Positions, then Model Dimension.
class TransformerDecoder(nn.Module):
    def __init__(self, positions, dmodel, ff, nhead, dropout=0.1):
        super().__init__()
        self.mha0 = MHA(True, positions, nhead, dmodel)
        self.layernorm0 = nn.LayerNorm(dmodel)
        self.mha1 = MHA(True, positions, nhead, dmodel)
        self.layernorm1 = nn.LayerNorm(dmodel)
        self.feedforward = FeedForward(dmodel, ff, dropout)
        self.layernorm2 = nn.LayerNorm(dmodel)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, inputs, cross=None): 
        x = self.mha0(inputs, inputs, inputs)
        x = self.layernorm0(x + inputs)

        x_pre_mha1 = x
        if cross is not None:
            x = self.mha1(cross, cross, x)
        else:
            x = self.mha1(x, x, x)
        x = self.layernorm1(x + x_pre_mha1)

        x_pre_ff = x
        x = self.feedforward(x)
        x = self.layernorm2(x + x_pre_ff)
        return x

In [6]:
class OutputProjection(nn.Module):
    def __init__(self, dmodel, vocab_size):
        super().__init__()
        self.proj = nn.Linear(dmodel, vocab_size)
        self.softmax = nn.Softmax(-1)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.proj(x)
        x = self.softmax(x)
        return x

In [7]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, positions, dmodel, ff, nhead, nlayers, dropout=0.1):
        super().__init__()
        self.emb = Embeddings(positions, vocab_size, dmodel)
        self.transformer_layers = nn.ModuleList([TransformerDecoder(positions, dmodel, ff, nhead=nhead, dropout=dropout) for _ in range(nlayers)])
        self.out = OutputProjection(dmodel, vocab_size)

    def forward(self, x):
        x = self.emb(x)
        for layer in self.transformer_layers:
            x = layer(x)

        x = self.out(x)
        return x

In [8]:
dmodel = 64
positions = 512
vocab_size = 26
ff = 4 * dmodel
heads_per_layer = 4
layers = 4

model = Transformer(vocab_size, positions, dmodel, ff, heads_per_layer, layers, 0.1)

with torch.no_grad():
    ids = torch.randint(0, vocab_size, (1, positions))
    z = model(ids)
    print(z)
    print(z.shape)

tensor([[[0.0361, 0.0307, 0.0375,  ..., 0.0521, 0.0266, 0.0183],
         [0.0613, 0.0611, 0.0814,  ..., 0.0227, 0.0402, 0.0676],
         [0.0658, 0.0300, 0.0647,  ..., 0.0811, 0.0406, 0.0489],
         ...,
         [0.0307, 0.0546, 0.0344,  ..., 0.0345, 0.0317, 0.0858],
         [0.0520, 0.0482, 0.0167,  ..., 0.0709, 0.0861, 0.0758],
         [0.0684, 0.0304, 0.0275,  ..., 0.0284, 0.0591, 0.0717]]])
torch.Size([1, 512, 26])


In [9]:
epochs = 100
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)

lossfn = nn.CrossEntropyLoss()

loss_plot = []
for epoch in range(epochs):
    model.train()
    for i, (text, target) in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        images = images.float().to(device)
        targets = target.to(device)

        outs = model(text)

        loss = lossfn(outs, targets)
        loss.backward()
        optimizer.step()

    losses = []
    model.eval()
    for i, (text, target) in enumerate(tqdm(val_loader)):
        with torch.no_grad():
            images = images.float().to(device)
            targets = target.to(device)
            outs = model(text)
            loss = lossfn(outs, targets)
            losses.append(loss)

    epoch_loss = torch.Tensor(losses).mean().item()
    print("Epoch {}, Current loss is {}".format(epoch, epoch_loss))
    print("{}/{} correct, {:.2f}%".format(correct, total, 100*correct/total))
    loss_plot.append(epoch_loss)

plt.plot(loss_plot)

  from .autonotebook import tqdm as notebook_tqdm


NameError: name 'tqdm' is not defined