In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
np.set_printoptions(suppress=True)
from tqdm import tqdm, trange

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
torch.set_printoptions(sci_mode=False)
from torchvision import datasets, transforms
import math

In [2]:
# attempt to autodetect device
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f"using device: {device}")

using device: mps


In [3]:
# Custom dataset for Tiny Shakespeare
class TinyShakespeareDataset(Dataset):
    def __init__(self, text, seq_length):
        self.text = text
        self.seq_length = seq_length
        self.chars = sorted(list(set(text)))
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
        self.data_size = len(self.text)

    def __len__(self):
        return max(0, self.data_size - self.seq_length)

    def __getitem__(self, index):
        x = [self.char_to_idx[c] for c in self.text[index:index+self.seq_length]]
        y = [self.char_to_idx[c] for c in self.text[index+1:index+self.seq_length+1]]
        return torch.tensor(x), torch.tensor(y)

In [4]:
import requests
import os

# Download Tiny Shakespeare dataset
def download_tiny_shakespeare():
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    if not os.path.exists("data/tinyshakespeare.txt"):
        data = requests.get(url).text
        with open("data/tinyshakespeare.txt", "w") as f:
            f.write(data)
        
# Download the dataset
download_tiny_shakespeare()

# Read the dataset
with open("data/tinyshakespeare.txt", "r") as f:
    text = f.read()

# Set parameters
seq_length = 512
BS = 32

# Create dataset and split into train and test
dataset = TinyShakespeareDataset(text, seq_length)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create data loaders
loaders = {
    'train': DataLoader(train_dataset, batch_size=BS, shuffle=True),
    'test': DataLoader(test_dataset, batch_size=BS, shuffle=True),
}


In [5]:
vocab_size = len(dataset.chars)
emb_size = 512
n_layers = 16
n_heads = 16
class MLP(nn.Module):
    def __init__(self, emb_length):
        super().__init__()
        self.i2h = nn.Linear(emb_length, emb_length * 4)
        self.h2o = nn.Linear(emb_length * 4, emb_length)

    def forward(self, x):
        x = F.relu(self.i2h(x))
        x = self.h2o(x)
        return x

class ScaledDotProductAttention(nn.Module):
    def __init__(self, emb_length, n_heads):
        super().__init__()
        self.head_size = emb_length // n_heads
        self.key = nn.Linear(self.head_size, self.head_size)
        self.query = nn.Linear(self.head_size, self.head_size)
        self.value = nn.Linear(self.head_size, self.head_size)
        self.n_heads = n_heads
        self.ln1 = nn.LayerNorm(self.head_size)
        self.ln2 = nn.LayerNorm(self.head_size)

    def forward(self, x):
        xq = self.query(x)
        xk = self.key(x)
        xv = self.value(x)
        att = (xq @ xk.transpose(-2,-1)) * (1.0 / math.sqrt(self.n_heads))
        # TODO sean mask here
        x = F.softmax(att, dim=-1)
        x = x @ xv
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.heads = nn.ModuleList([ScaledDotProductAttention(emb_size, n_heads) for i in range(n_heads)])
        self.feed_forward = MLP(emb_size)

    def forward(self, x):
        x = torch.split(x, emb_size // n_heads, 2)
        attention_heads = []
        for i in range(self.n_heads):
            head_x = self.heads[i](x[i])
            attention_heads.append(head_x)
        x = torch.cat(attention_heads, 2)
        x = self.feed_forward(x)
        return x

class Transformer(nn.Module):
    def __init__(self, seq_length, emb_length, n_layers, n_heads, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_length)
        self.positional_encoding = nn.Embedding(seq_length, emb_length)
        self.layers = nn.ModuleList([MultiHeadAttention(n_heads) for i in range(n_layers)])
        self.n_layers = n_layers
        self.seq_length = seq_length
        self.drop = nn.Dropout(p=0.2)
        self.ln = nn.LayerNorm(emb_length)
        self.ll = nn.Linear(emb_length, vocab_size)
        self.embedding.weight = self.ll.weight


    def forward(self, x):
        b, t = x.size()
        tok_emb = self.embedding(x)
        assert t <= self.seq_length, f"Cannot forward sequence of length {t}, block size is only {self.seq_length}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
        pos_emb = self.positional_encoding(pos) # position embeddings of shape (t, n_embd)
        x = self.drop(tok_emb + pos_emb)
        for transformer_block in self.layers:
            x = transformer_block(x)
        x = self.ln(x)
        x = self.ll(x)
        x = F.softmax(x, dim=-1)
        return x
        
model = Transformer(seq_length, emb_size, n_layers, n_heads, vocab_size)
model.to(device)

Transformer(
  (embedding): Embedding(65, 512)
  (positional_encoding): Embedding(512, 512)
  (layers): ModuleList(
    (0-15): 16 x MultiHeadAttention(
      (heads): ModuleList(
        (0-15): 16 x ScaledDotProductAttention(
          (key): Linear(in_features=32, out_features=32, bias=True)
          (query): Linear(in_features=32, out_features=32, bias=True)
          (value): Linear(in_features=32, out_features=32, bias=True)
          (ln1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (ln2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
      )
      (feed_forward): MLP(
        (i2h): Linear(in_features=512, out_features=2048, bias=True)
        (h2o): Linear(in_features=2048, out_features=512, bias=True)
      )
    )
  )
  (drop): Dropout(p=0.2, inplace=False)
  (ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (ll): Linear(in_features=512, out_features=65, bias=True)
)

In [8]:
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr = 6e-4)

In [9]:
num_epochs = 50
train_losses = {}

for epoch in range(num_epochs):
    epoch_losses = list()
    for i, (X, Y) in enumerate(loaders['train']):
        if X.shape[0] != BS:
            continue
        optim.zero_grad()
        X, Y, = X.to(device), Y.to(device)
        out = model(X)
        out = out.transpose(1, 2)  # batch_size x vocab_size
        loss = criterion(out, Y.long())
        loss.backward()
        optim.step()

        epoch_losses.append(loss.detach().item() / X.shape[1])
        if (i+1) % 250 == 0:
            print('Loss: {:.4f}'.format(loss.detach().item()))
    train_losses[epoch] = torch.tensor(epoch_losses).mean()
    print(f'=> epoch: {epoch + 1}, loss: {train_losses[epoch]}')

Loss: 4.0494
Loss: 4.0480
Loss: 4.0471
Loss: 4.0484
Loss: 4.0478
Loss: 4.0476
Loss: 4.0485
Loss: 4.0462
Loss: 4.0454
Loss: 4.0476
Loss: 4.0442
Loss: 4.0462
Loss: 4.0457
Loss: 4.0471
Loss: 4.0450
Loss: 4.0461
Loss: 4.0472
Loss: 4.0493
Loss: 4.0484
Loss: 4.0503
Loss: 4.0501
Loss: 4.0475
Loss: 4.0470
Loss: 4.0530
Loss: 4.0468
Loss: 4.0480
Loss: 4.0499
Loss: 4.0488
Loss: 4.0441
Loss: 4.0443
Loss: 4.0467
Loss: 4.0476
Loss: 4.0484
Loss: 4.0452
Loss: 4.0469


KeyboardInterrupt: 