In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import requests
import math
from pathlib import Path

print(torch.backends.mps.is_available())  # Should print True
print(torch.backends.mps.is_built())      # Should print True

device = torch.device('mps')
print('device', device)

True
True
device mps


In [17]:
r = requests.get("https://s3.amazonaws.com/text-datasets/nietzsche.txt")
corpus = r.text

In [18]:
chars = sorted(list(set(corpus)))
vocab_size = len(chars)

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [None]:
data = torch.tensor(encode(corpus), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [20]:
def get_batch(data, block_size, batch_size):
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x, y

In [21]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len=5000):
    super().__init__()
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    self.register_buffer('pe', pe)

  def forward(self, x):
    return x + self.pe[:, :x.size(1)]

class TransformerBlock(nn.Module):
  def __init__(self, d_model, n_heads, dropout=0.1):
    super().__init__()
    self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.ff = nn.Sequential(
      nn.Linear(d_model, 4 * d_model),
      nn.ReLU(),
      nn.Linear(4 * d_model, d_model),
      nn.Dropout(dropout)
    )

  def forward(self, x):
    # Self attention
    attended, _ = self.attention(x, x, x)
    x = self.norm1(x + attended)

    # Feedforward
    fed_forward = self.ff(x)
    x = self.norm2(x + fed_forward)
    return x

class SimpleTransformer(nn.Module):
  def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=6, dropout=0.1):
    super().__init__()
    self.d_model = d_model

    # Token embedding and positional encoding
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoder = PositionalEncoding(d_model)
    self.dropout = nn.Dropout(dropout)

    # Transformer blocks
    self.transformer_blocks = nn.ModuleList([
      TransformerBlock(d_model, n_heads, dropout)
      for _ in range(n_layers)
    ])

    # Output layer
    self.output = nn.Linear(d_model, vocab_size)

  def forward(self, x):
    # Input embedding and positional encoding
    x = self.embedding(x) * math.sqrt(self.d_model)
    x = self.pos_encoder(x)
    x = self.dropout(x)

    # Process through transformer blocks
    for block in self.transformer_blocks:
      x = block(x)

    # Output projection
    x = self.output(x)
    return x

  def save_checkpoint(self, epoch, optimizer, loss, filename):
    checkpoint = {
      'epoch': epoch,
      'model_state_dict': self.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss,
      'vocab_size': self.output.out_features,
      'd_model': self.d_model,
      'n_heads': self.transformer_blocks[0].attention.num_heads,
      'n_layers': len(self.transformer_blocks)
    }
    torch.save(checkpoint, filename)

  @classmethod
  def load_checkpoint(cls, filename, device):
    checkpoint = torch.load(filename, map_location=device)
    model = cls(
      vocab_size=checkpoint['vocab_size'],
      d_model=checkpoint['d_model'],
      n_heads=checkpoint['n_heads'],
      n_layers=checkpoint['n_layers']
    ).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model, checkpoint

In [28]:
def train_model(model, train_data, val_data=None, epochs=5000, batch_size=64, block_size=256, lr=3e-4, save_every=1000):
  optimizer = optim.AdamW(model.parameters(), lr=lr)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

  best_loss = float('inf')
  Path('checkpoints').mkdir(exist_ok=True)

  for epoch in range(epochs):  # Now epochs is used directly
    model.train()
    x, y = get_batch(train_data, block_size, batch_size)
    x, y = x.to(device), y.to(device)

    logits = model(x)
    B, T, C = logits.shape
    loss = F.cross_entropy(logits.view(B*T, C), y.view(B*T))

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()

    if epoch % 100 == 0:
      print(f'Epoch {epoch}: Loss {loss.item():.4f}')

    if epoch % save_every == 0:
      model.save_checkpoint(
        epoch, optimizer, loss.item(),
        f'checkpoints/transformer_epoch_{epoch}.pt'
      )

    if loss.item() < best_loss:
      best_loss = loss.item()
      model.save_checkpoint(
        epoch, optimizer, loss.item(),
        'checkpoints/transformer_best.pt'
      )

In [29]:
model = SimpleTransformer(vocab_size=vocab_size, d_model=256, n_heads=8, n_layers=6).to(device)
train_model(model, train_data, val_data)

Epoch 0: Loss 4.5719
Epoch 100: Loss 2.5312
Epoch 200: Loss 2.4845
Epoch 300: Loss 2.4810
Epoch 400: Loss 2.4867
Epoch 500: Loss 2.4894
Epoch 600: Loss 2.4869
Epoch 700: Loss 2.5114
Epoch 800: Loss 2.4637
Epoch 900: Loss 2.4861
Epoch 1000: Loss 2.4772
Epoch 1100: Loss 2.4796
Epoch 1200: Loss 2.4833
Epoch 1300: Loss 2.4584
Epoch 1400: Loss 2.4778
Epoch 1500: Loss 2.4755
Epoch 1600: Loss 2.4703
Epoch 1700: Loss 2.4659
Epoch 1800: Loss 2.4736
Epoch 1900: Loss 2.4761
Epoch 2000: Loss 2.4762
Epoch 2100: Loss 2.5033
Epoch 2200: Loss 2.4570
Epoch 2300: Loss 2.4569
Epoch 2400: Loss 2.4721
Epoch 2500: Loss 2.4862
Epoch 2600: Loss 2.4676
Epoch 2700: Loss 2.4941
Epoch 2800: Loss 2.4937
Epoch 2900: Loss 2.4661
Epoch 3000: Loss 2.4615
Epoch 3100: Loss 2.4681
Epoch 3200: Loss 2.4795
Epoch 3300: Loss 2.4576
Epoch 3400: Loss 2.4738
Epoch 3500: Loss 2.4741
Epoch 3600: Loss 2.4608
Epoch 3700: Loss 2.4517
Epoch 3800: Loss 2.4460
Epoch 3900: Loss 2.4784
Epoch 4000: Loss 2.4811
Epoch 4100: Loss 2.4766
Epoc

In [30]:
@torch.no_grad()
def generate(model, start_sequence, max_new_tokens, temperature=0.8):
  model.eval()
  x = start_sequence.to(model.embedding.weight.device)

  print(decode(x.tolist()), end='', flush=True)
  for _ in range(max_new_tokens):
    logits = model(x.view(1, -1))
    logits = logits[0, -1, :] / temperature

    probs = F.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)

    print(decode([next_token.item()]), end='', flush=True)

    x = torch.cat([x, next_token])

  print()

In [31]:
torch.save(model.state_dict(), 'nietzsche_simple_transformer_weights.pth')

In [33]:
saved_model = SimpleTransformer(vocab_size=vocab_size, d_model=256, n_heads=8, n_layers=6).to(device)
saved_model.load_state_dict(torch.load('nietzsche_simple_transformer_weights.pth'))
saved_model.eval()

SimpleTransformer(
  (embedding): Embedding(85, 256)
  (pos_encoder): PositionalEncoding()
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer_blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Linear(in_features=1024, out_features=256, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output): Linear(in_features=256, out_features=85, bias=True)
)

In [34]:
context = torch.tensor(encode("Thus spoke Zarathustra: "), dtype=torch.long).to(device)
generate(saved_model, context, max_new_tokens=500)

Thus spoke Zarathustra: this t evon t ast treas a onthersomas monthe me and toLal
s m t osereshexpsere therare ty" thesiste s me ars are t athace ictt tomiteallyes ve tontout alatrenerd fes serencare omasss thathen  outhe be pof ts in f th atit h te one amio t s ulliopoff aland ind oburer othe as st-tilwh ler acthi

KeyboardInterrupt: 