In [44]:
import torch
import torch.nn as nn
import torch.optim as optim

print(torch.backends.mps.is_available())  # Should print True
print(torch.backends.mps.is_built())      # Should print True

device = torch.device('mps')
print('device', device)

True
True
device mps


In [45]:
class SimpleRNN(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(SimpleRNN, self).__init__()
    self.hidden_size = hidden_size
    self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
    self.fc = nn.Linear(hidden_size, output_size)

  def forward(self, x, hidden=None):
    if hidden is None:
      hidden = torch.zeros(1, x.size(0), self.hidden_size, device=x.device)
    out, hidden = self.rnn(x.unsqueeze(-1), hidden)
    out = self.fc(out.squeeze(1))
    return out, hidden

In [46]:
alphabet = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
char_to_idx = {char: idx for idx, char in enumerate(alphabet)}
idx_to_char = {idx: char for idx, char in enumerate(alphabet)}

In [47]:
# Prepare sequences (input: current letter, target: next letter)
X = []  # Input sequences
y = []  # Target sequences

for i in range(len(alphabet) - 1):
  X.append(char_to_idx[alphabet[i]])
  y.append(char_to_idx[alphabet[i + 1]])

# Convert to PyTorch tensors
X = torch.tensor(X).view(-1, 1).float().to(device)
y = torch.tensor(y).long().to(device)

In [49]:
input_size = 1
hidden_size = 16
output_size = len(alphabet)

model = SimpleRNN(input_size, hidden_size, output_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

num_epochs = 10000
for epoch in range(num_epochs):
  model.zero_grad()

  output, hidden = model(X)
  loss = criterion(output, y)

  loss.backward()
  optimizer.step()

  if (epoch + 1) % 1000 == 0:
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

Epoch [1000/10000], Loss: 0.7802
Epoch [2000/10000], Loss: 0.4055
Epoch [3000/10000], Loss: 0.2405
Epoch [4000/10000], Loss: 0.1530
Epoch [5000/10000], Loss: 0.1013
Epoch [6000/10000], Loss: 0.0739
Epoch [7000/10000], Loss: 0.0544
Epoch [8000/10000], Loss: 0.0430
Epoch [9000/10000], Loss: 0.0334
Epoch [10000/10000], Loss: 0.0247


In [50]:
def predict_next_letter(current_letter):
  with torch.no_grad():
    input_idx = torch.tensor([[char_to_idx[current_letter]]]).float().to(device)
    output, _ = model(input_idx)
    predicted_idx = torch.argmax(output).item()
    return idx_to_char[predicted_idx]

In [51]:
test_sequence = ['a', 'f', 'm', 'r', 'z', 'A', 'J', 'K', 'Z']
print("\nPrediction examples:")
for letter in test_sequence:
  next_letter = predict_next_letter(letter)
  print(f"Current letter: {letter}, Predicted next letter: {next_letter}")


Prediction examples:
Current letter: a, Predicted next letter: b
Current letter: f, Predicted next letter: g
Current letter: m, Predicted next letter: n
Current letter: r, Predicted next letter: s
Current letter: z, Predicted next letter: A
Current letter: A, Predicted next letter: B
Current letter: J, Predicted next letter: K
Current letter: K, Predicted next letter: L
Current letter: Z, Predicted next letter: Z


In [52]:
with torch.no_grad():
  current_letter = 'a'
  generated_text = '' + current_letter
  for _ in range(len(alphabet)):
    input_idx = torch.tensor([[char_to_idx[current_letter]]]).float().to(device)
    output, _ = model(input_idx)
    predicted_idx = torch.argmax(output).item()

    current_letter = idx_to_char[predicted_idx]
    generated_text = generated_text + current_letter
    print(generated_text)

ab
abc
abcd
abcde
abcdef
abcdefg
abcdefgh
abcdefghi
abcdefghij
abcdefghijk
abcdefghijkl
abcdefghijklm
abcdefghijklmn
abcdefghijklmno
abcdefghijklmnop
abcdefghijklmnopq
abcdefghijklmnopqr
abcdefghijklmnopqrs
abcdefghijklmnopqrst
abcdefghijklmnopqrstu
abcdefghijklmnopqrstuv
abcdefghijklmnopqrstuvw
abcdefghijklmnopqrstuvwx
abcdefghijklmnopqrstuvwxy
abcdefghijklmnopqrstuvwxyz
abcdefghijklmnopqrstuvwxyzA
abcdefghijklmnopqrstuvwxyzAB
abcdefghijklmnopqrstuvwxyzABC
abcdefghijklmnopqrstuvwxyzABCD
abcdefghijklmnopqrstuvwxyzABCDE
abcdefghijklmnopqrstuvwxyzABCDEF
abcdefghijklmnopqrstuvwxyzABCDEFG
abcdefghijklmnopqrstuvwxyzABCDEFGH
abcdefghijklmnopqrstuvwxyzABCDEFGHI
abcdefghijklmnopqrstuvwxyzABCDEFGHIJ
abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
abcdefghijklmnopqrstuvwxyzABCDEFGHIJKL
abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLM
abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMN
abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNO
abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOP
abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQ
abcdefghijklm