In [1]:
import itertools

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

In [2]:
class Head(nn.Module):
    def __init__(self, context_len, embedding_dim, head_dim):
        super().__init__()
        self.contex_len = context_len
        self.embedding_dim = embedding_dim
        self.head_dim = head_dim
        assert embedding_dim % head_dim == 0

        self.q = nn.Linear(embedding_dim, head_dim, bias=False)
        self.k = nn.Linear(embedding_dim, head_dim, bias=False)
        self.v = nn.Linear(embedding_dim, head_dim, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(context_len, context_len)))

    def forward(self, x):
        # x.shape B, T, E
        q = self.q(x)  # B, T, H
        k = self.k(x)  # B, T, H
        v = self.v(x)  # B, T, H
        w = q @ k.transpose(-2, -1) * self.head_dim**-0.5  # B, T, T
        w = w.masked_fill(self.tril == 0, float("-inf"))
        w = torch.softmax(w, dim=-1)
        out = w @ v
        return out  # B, T, H (head_size)

In [3]:
class MHH(nn.Module):
    def __init__(self, context_len, embedding_dim, head_dim):
        super().__init__()
        self.context_len = context_len
        self.embedding_dim = embedding_dim
        self.head_dim = head_dim
        assert embedding_dim % head_dim == 0
        self.num_heads = embedding_dim // head_dim
        self.heads = nn.ModuleList(
            [Head(context_len, embedding_dim, head_dim) for _ in range(self.num_heads)]
        )
        self.projection_layer = nn.Linear(self.embedding_dim, self.embedding_dim)

    def forward(self, x):
        results = [head(x) for head in self.heads]
        results = torch.cat(results, dim=-1)
        return self.projection_layer(results)


class FeedForwardNetwork(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
        )

    def forward(self, x):
        return self.ffn(x)


class TransformerBlock(nn.Module):
    def __init__(self, context_len, embedding_dim, head_dim):
        super().__init__()
        self.context_len = context_len
        self.embedding_dim = embedding_dim
        self.head_dim = head_dim
        self.mhh = MHH(context_len, embedding_dim, head_dim)
        self.ffn = FeedForwardNetwork(embedding_dim)
        self.layer_norm_1 = nn.LayerNorm(embedding_dim)
        self.layer_norm_2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        x = x + self.mhh(self.layer_norm_1(x))
        x = x + self.ffn(self.layer_norm_2(x))
        return x

In [4]:
class TransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, context_len, embedding_dim, head_dim):
        super().__init__()
        self.vocab_size = vocab_size
        self.head_dim = head_dim
        self.context_len = context_len
        self.embedding_sim = embedding_dim

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.positional_embeddings = nn.Embedding(context_len, embedding_dim)

        self.transformers = nn.Sequential(
            TransformerBlock(context_len, embedding_dim, head_dim),
            TransformerBlock(context_len, embedding_dim, head_dim),
            TransformerBlock(context_len, embedding_dim, head_dim),
            nn.LayerNorm(embedding_dim),
        )
        self.output_layer = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x, targets=None):
        x = self.embedding(x) + self.positional_embeddings(
            torch.arange(x.size(1), device=x.device)
        )
        out = self.output_layer(self.transformers(x))
        if targets is not None:
            # out dim (B, T, E)
            out = out.flatten(0, 1)
            targets = targets.flatten()
            loss = nn.functional.cross_entropy(out, targets)
        else:
            loss = None
        return out, loss

    def generate(self, idx, horizon_limit):
        for _ in range(horizon_limit):
            idx_trimmed = idx[:, -self.context_len :]
            logits, _ = self(idx_trimmed)
            activations = logits[:, -1, :]
            probs = torch.softmax(activations, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=-1)
        return idx

In [5]:
with open("shake.txt") as f:
    text = f.read()

In [6]:
vocab = sorted(set(text))
vocab_size = len(vocab)
batch_size = 4
context_len = 8
embedding_dim = 32
head_dim = 16
train_split_pct = 0.9
train_split_idx = int(train_split_pct * len(text))
train_text = text[:train_split_idx]
test_text = text[train_split_idx:]
max_steps = 10000
val_check = 500
val_steps = 50

char2idx = {char: idx for idx, char in enumerate(vocab)}
idx2char = {idx: char for char, idx in char2idx.items()}


def encode(x):
    return [char2idx[char] for char in x]


def decode(idxs):
    return "".join([idx2char[idx] for idx in idxs])

In [7]:
class ShakeDataset(Dataset):
    def __init__(self, text: str, context_len: int):
        self.text = text
        self.context_len = context_len
        self.encoded_text = torch.tensor(encode(text))

    def __getitem__(self, idx: int):
        x = self.encoded_text[idx : idx + context_len]
        y = self.encoded_text[idx + 1 : idx + context_len + 1]
        return x, y

    def __len__(self):
        return len(self.encoded_text) - self.context_len

In [8]:
train_dataset = ShakeDataset(train_text, context_len)
test_dataset = ShakeDataset(test_text, context_len)

train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True, drop_last=True)

In [9]:
# Set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

model = TransformerLanguageModel(vocab_size, context_len, embedding_dim, head_dim)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

Using device: mps


In [10]:
steps = 0


@torch.no_grad()
def evaluate(model, dataloader, device):
    model.eval()
    losses = []
    for x, y in itertools.islice(dataloader, val_steps):
        x, y = x.to(device), y.to(device)
        _, loss = model(x, y)
        losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses) if losses else float("inf")


model.train()
while True:
    for batch in train_dataloader:
        x, y = batch
        x, y = x.to(device), y.to(device)
        logits, loss = model(x, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if steps >= max_steps:
            break
        if steps % val_check == 0:
            print(f"Step {steps} Train loss: ", evaluate(model, train_dataloader, device))
            print(f"Step {steps] Test loss: ", evaluate(model, test_dataloader, device))
        steps += 1
    else:
        continue
    break

Train loss:  4.312373113632202
Test loss:  4.346557111740112
Train loss:  2.610311346054077
Test loss:  2.628956799507141
Train loss:  2.5060336208343506
Test loss:  2.5220720601081847
Train loss:  2.4326512694358824
Test loss:  2.4677705144882203
Train loss:  2.37625497341156
Test loss:  2.349957518577576
Train loss:  2.374982142448425
Test loss:  2.3944062995910644
Train loss:  2.2466498684883116
Test loss:  2.2545091271400453
Train loss:  2.289176394939423
Test loss:  2.303149197101593
Train loss:  2.25638267993927
Test loss:  2.235528516769409
Train loss:  2.221574754714966
Test loss:  2.3426869988441466
Train loss:  2.2510094237327576
Test loss:  2.1914745283126833
Train loss:  2.219421772956848
Test loss:  2.254181044101715
Train loss:  2.2224785184860227
Test loss:  2.2982584738731386
Train loss:  2.175561077594757
Test loss:  2.2401065540313723
Train loss:  2.2030136251449584
Test loss:  2.2291086173057555
Train loss:  2.149005799293518
Test loss:  2.1746397399902344
Train loss

In [13]:
train_dataset[500][0].unsqueeze(0).shape


torch.Size([1, 8])

In [16]:
seq = model.generate(train_dataset[500][0].unsqueeze(0).to("mps"), 500)
seq = seq.flatten().tolist()

In [17]:
decode(seq)

" citizen\nTo ad And to stank.\nSo roffir couc not as  ane sold sels\nGown besmparrsefuld mingervince :\nO.\n\nCach sard'y?\n\nYour'll not Go by meappirnts have gend in forese that Loud unnow'd the ose dises!nows, artard is a cend bise hard cands mestll a he shoun Ted tconk the\nke spatind me cearel!\n\nKIRITES Io Car Rorf die her athe for discatuped is not had badbrudd lold, ar om? on that grup hanct will mars thit the to tust, swi warr and have they ow his dion beath:\nLoth, him my salpes move!\nHis steedecr; of an"