In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

print(torch.backends.mps.is_available())  # Should print True
print(torch.backends.mps.is_built())      # Should print True

device = torch.device('mps')
print('device', device)

True
True
device mps


In [2]:
alphabet = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
char_to_idx = {char: idx for idx, char in enumerate(alphabet)}
idx_to_char = {idx: char for idx, char in enumerate(alphabet)}

In [3]:
class ScratchRNN(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(ScratchRNN, self).__init__()

    self.hidden_size = hidden_size
    self.inputToHidden = nn.Linear(input_size + hidden_size, hidden_size)
    self.inputToOutput = nn.Linear(input_size + hidden_size, output_size)
    self.softmax = nn.LogSoftmax(dim=1)

  def forward(self, input_tensor, hidden_tensor):
    combined = torch.cat((input_tensor, hidden_tensor), 1)

    hidden = self.inputToHidden(combined)
    output = self.inputToOutput(combined)
    output = self.softmax(output)

    return output, hidden

  def init_hidden(self):
    return torch.zeros(1, self.hidden_size).to(device)

  def predict(self, char, hidden=None):
    self.eval()

    if hidden is None:
      hidden = self.init_hidden()

    try:
        idx = char_to_idx[char]
    except KeyError:
        raise ValueError(f"Character '{char}' not in training vocabulary")

    input_tensor = torch.tensor([[idx]], dtype=torch.float).to(device)

    with torch.no_grad():
      output, hidden = self(input_tensor, hidden)

    _, predicted_idx = output.max(1)
    predicted_char = idx_to_char[predicted_idx.item()]

    self.train()
    return predicted_char, hidden

input_size = 1
hidden_size = 16
output_size = len(alphabet)

model = ScratchRNN(input_size, hidden_size, output_size).to(device)

In [4]:
X = []
y = []

for i in range(len(alphabet) - 1):
  X.append(char_to_idx[alphabet[i]])
  y.append(char_to_idx[alphabet[i + 1]])

X = torch.tensor(X).view(-1, 1).float().to(device)
y = torch.tensor(y).long().to(device)

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)

num_epochs = 5_000
hidden = torch.zeros(1, hidden_size).to(device)
for epoch in range(num_epochs):
  model.zero_grad()
  total_loss = 0
  hidden = model.init_hidden()

  for idx, input_char in enumerate(X):
    input_tensor = input_char.view(1, 1)
    output, hidden = model.forward(input_tensor, hidden.detach())

    loss = criterion(output, y[idx].view(1))
    total_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  if (epoch + 1) % (num_epochs  / 10) == 0:
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {loss.item()/len(X):.4f}')

Epoch [500/5000], Average Loss: 0.0536
Epoch [1000/5000], Average Loss: 0.0534
Epoch [1500/5000], Average Loss: 0.0534
Epoch [2000/5000], Average Loss: 0.0533
Epoch [2500/5000], Average Loss: 0.0533
Epoch [3000/5000], Average Loss: 0.0532
Epoch [3500/5000], Average Loss: 0.0532
Epoch [4000/5000], Average Loss: 0.0532
Epoch [4500/5000], Average Loss: 0.0532
Epoch [5000/5000], Average Loss: 0.0532


In [9]:
next_char, _ = model.predict('a')
print(f"After 'a' comes '{next_char}'")

hidden = None
char = 'a'
print(f"Starting with: {char}")
for _ in range(36):
  char, hidden = model.predict(char, hidden)
  print(f"Next: {char}")

After 'a' comes 'b'
Starting with: a
Next: b
Next: c
Next: d
Next: e
Next: e
Next: g
Next: f
Next: l
Next: e
Next: l
Next: d
Next: l
Next: d
Next: l
Next: d
Next: l
Next: d
Next: l
Next: d
Next: l
Next: d
Next: l
Next: d
Next: l
Next: d
Next: l
Next: d
Next: l
Next: d
Next: l
Next: d
Next: l
Next: d
Next: l
Next: d
Next: l
