In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import requests

print(torch.backends.mps.is_available())  # Should print True
print(torch.backends.mps.is_built())      # Should print True

device = torch.device('mps')
print('device', device)

True
True
device mps


In [3]:
r = requests.get("https://s3.amazonaws.com/text-datasets/nietzsche.txt")
nietzsche_corpus = r.text

In [4]:
chars = sorted(list(set(nietzsche_corpus)))
vocab_size = len(chars)

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [5]:
data = torch.tensor(encode(nietzsche_corpus), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [6]:
def get_batch(data, block_size, batch_size):
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x, y

In [7]:
class CharacterLSTM(nn.Module):
  def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers=3):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.lstm = nn.LSTM(
      input_size=embedding_dim,
      hidden_size=hidden_size,
      num_layers=num_layers,
      dropout=0.2,
      batch_first=True
    )
    self.fc = nn.Linear(hidden_size, vocab_size)

  def forward(self, idx, hidden=None):
    # idx shape: (batch, sequence)
    embeddings = self.embedding(idx)  # (batch, sequence, embedding_dim)
    output, hidden = self.lstm(embeddings, hidden)
    logits = self.fc(output)  # (batch, sequence, vocab_size)
    return logits, hidden

model = CharacterLSTM(vocab_size=vocab_size, embedding_dim=384, hidden_size=512).to(device)

In [8]:
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
criterion = nn.CrossEntropyLoss()

In [9]:
block_size = 256  # context length
batch_size = 64   # batch size for training
num_epochs = 5000

for epoch in range(num_epochs):
  model.train()

  X, y = get_batch(train_data, block_size, batch_size)
  X, y = X.to(model.embedding.weight.device), y.to(model.embedding.weight.device)

  # Forward pass
  optimizer.zero_grad()
  logits, _ = model(X)

  # Reshape for loss calculation
  B, T, C = logits.shape
  logits = logits.view(B*T, C)
  targets = y.view(B*T)

  # Calculate loss and update
  loss = nn.CrossEntropyLoss()(logits, targets)
  loss.backward()

  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  optimizer.step()

  loss_val = loss.item()
  if (epoch + 1) % 500 == 0:
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss_val:.4f}')

Epoch [500/5000], Loss: 1.3368
Epoch [1000/5000], Loss: 1.2057
Epoch [1500/5000], Loss: 1.1028
Epoch [2000/5000], Loss: 1.0242
Epoch [2500/5000], Loss: 1.0004
Epoch [3000/5000], Loss: 0.9186
Epoch [3500/5000], Loss: 0.9233
Epoch [4000/5000], Loss: 0.9225
Epoch [4500/5000], Loss: 0.8820
Epoch [5000/5000], Loss: 0.8799


In [10]:
@torch.no_grad()
def generate(model, start_sequence, max_new_tokens, temperature=0.8):
  model.eval()

  x = start_sequence.to(model.embedding.weight.device)
  start_text = decode(x.tolist())
  print(start_text, end='', flush=True)

  current_text = start_text

  for _ in range(max_new_tokens):
    logits, _ = model(x.view(1, -1))
    logits = logits[0, -1, :] / temperature

    probs = torch.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    new_char = decode([next_token.item()])

    # Only add newline if it comes after a period and there's no newline already
    if new_char == '\n' and not current_text.endswith('.'):
      new_char = ' '  # Replace newline with space

    # Handle multiple spaces
    if new_char == ' ' and current_text.endswith(' '):
      continue  # Skip consecutive spaces

    print(new_char, end='', flush=True)
    current_text += new_char

    x = torch.cat([x, next_token])

  print()

In [12]:
# # Save the model
# torch.save(model, 'nietzsche_lstm.pth')

# # Load the model
# loaded_model = torch.load('nietzsche_lstm.pth')
# loaded_model.eval()  # Set to evaluation mode

torch.save(model.state_dict(), 'nietzsche_lstm_weights.pth')

In [13]:
saved_model = CharacterLSTM(vocab_size=vocab_size, embedding_dim=384, hidden_size=512).to(device)
saved_model.load_state_dict(torch.load('nietzsche_lstm_weights.pth'))
saved_model.eval()

CharacterLSTM(
  (embedding): Embedding(85, 384)
  (lstm): LSTM(384, 512, num_layers=3, batch_first=True, dropout=0.2)
  (fc): Linear(in_features=512, out_features=85, bias=True)
)

In [15]:
context = torch.tensor(encode("Thus spoke Zarathustra: "), dtype=torch.long).to(device)
generate(saved_model, context, max_new_tokens=500)

Thus spoke Zarathustra: oper it responsibilities on.
 1od in Loding Rolii are Were aloup out in it Betroked Wide Who Wanner anoude's Warking Wire Europeanize, Woo, in Rudance on it.
 2eaniL Will a vowid Whatever Waking in yoowe' Will Who Begain under his virtue Were Will W lited our inward Werebling Wather and injurious Beoli gid Whole onca alone be Bodges on ea how everious either Rorigio more under Bush So Roded) We hone's Beasig Win: Who have Rudain: What is a religiou
