In [1]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
# step 1 is to load the text data
with open('shakespeare.txt', 'r') as f:
  text = f.read()

In [4]:
# step 2 is to create a char-to-index mapping so that the model can understand our text
chars = sorted(list(set(text))) # unique chars in dataset
char2idx = {char: idx for idx, char in enumerate(chars)} # map chars to indicies
idx2char = {idx: char for idx, char in enumerate(chars)} # map indicies back to chars

# convert the entire text into a sequence or integers
data = [char2idx[char] for char in text]

print(f'Number of unique chars: {len(chars)}') # how many unique chars do we have
print(f'Example mapping: {char2idx}') # peek into the mapping

Number of unique chars: 91
Example mapping: {'\n': 0, ' ': 1, '!': 2, '"': 3, '#': 4, '$': 5, '%': 6, '&': 7, "'": 8, '(': 9, ')': 10, '*': 11, ',': 12, '-': 13, '.': 14, '/': 15, '0': 16, '1': 17, '2': 18, '3': 19, '4': 20, '5': 21, '6': 22, '7': 23, '8': 24, '9': 25, ':': 26, ';': 27, '<': 28, '>': 29, '?': 30, '@': 31, 'A': 32, 'B': 33, 'C': 34, 'D': 35, 'E': 36, 'F': 37, 'G': 38, 'H': 39, 'I': 40, 'J': 41, 'K': 42, 'L': 43, 'M': 44, 'N': 45, 'O': 46, 'P': 47, 'Q': 48, 'R': 49, 'S': 50, 'T': 51, 'U': 52, 'V': 53, 'W': 54, 'X': 55, 'Y': 56, 'Z': 57, '[': 58, ']': 59, '_': 60, '`': 61, 'a': 62, 'b': 63, 'c': 64, 'd': 65, 'e': 66, 'f': 67, 'g': 68, 'h': 69, 'i': 70, 'j': 71, 'k': 72, 'l': 73, 'm': 74, 'n': 75, 'o': 76, 'p': 77, 'q': 78, 'r': 79, 's': 80, 't': 81, 'u': 82, 'v': 83, 'w': 84, 'x': 85, 'y': 86, 'z': 87, '|': 88, '}': 89, '~': 90}


In [5]:
# step 3 is to define a RNN model for text generation
class RNN(nn.Module):
  def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
    super(RNN, self).__init__()
    self.embed = nn.Embedding(vocab_size, embed_size) # embedding layer to convert indices to vectors
    self.rnn = nn.RNN(embed_size, hidden_size, num_layers, batch_first=True) # vanilla RNN
    self.fc = nn.Linear(hidden_size, vocab_size) # output layer to predict the next character

  def forward(self, x, hidden):
    x = self.embed(x) # convert input to embeddings
    out, hidden = self.rnn(x, hidden) # pass through the RNN
    out = self.fc(out) # generate predictions
    return out, hidden


In [6]:
# step 4 is to make a helper function to create mini-batches of the data for training
seq_length = 100 # sequence length for training
def get_batches(data, batch_size):
  n_batches = len(data) // (batch_size * seq_length)
  data = data[:n_batches * batch_size * seq_length]
  x = np.array(data)
  y = np.roll(x, -1) # shift targets one position, like if you have an input of present character, output is the next character
  x = x.reshape((batch_size, -1))
  y = y.reshape((batch_size, -1))
  return x, y

In [7]:
batch_size = 64 # number of sequencies in a batch
x, y = get_batches(data, batch_size)
print(f'Input shape: {x.shape}, target shape: {y.shape}') # to show the data shapes

Input shape: (64, 85300), target shape: (64, 85300)


In [8]:
# step 5 is to train this RNN model
def train(model, data, epochs, batch_size, seq_length, vocab_size, lr=0.001):
  criterion = nn.CrossEntropyLoss() # loss function
  optimizer = torch.optim.Adam(model.parameters(), lr=lr) # optimizer

  for epoch in range(epochs):
    x, y = get_batches(data, batch_size) # get a batch of data
    hidden = None # initialize hidden state
    for i in range(0, x.shape[1], seq_length):
      inputs = torch.tensor(x[:, i:i+seq_length], dtype=torch.long)
      targets = torch.tensor(y[:, i:i+seq_length], dtype=torch.long) # the target outputs

      optimizer.zero_grad() # reset gradients

      # detach hidden state to prevent graph build up
      if hidden is not None:
        hidden = hidden.detach()

      # forward pass
      output, hidden = model(inputs, hidden)

      # compute loss
      loss = criterion(output.view(-1, vocab_size), targets.view(-1))

      # backprop and update
      loss.backward()
      optimizer.step()

    print(f'Epoch {epoch + 1}/{epochs}, loss: {loss.item():.5f}') # progress update



In [9]:
# hyperparameters for RNN
vocab_size = len(chars)
embed_size = 128
hidden_size = 256
num_layers = 2

In [11]:
# initialize and train the model
model = RNN(vocab_size, embed_size, hidden_size, num_layers)
train(model,data, epochs=10, batch_size=64, seq_length=100, vocab_size=vocab_size)

Epoch 1/10, loss: 1.44987
Epoch 2/10, loss: 1.35407
Epoch 3/10, loss: 1.31232
Epoch 4/10, loss: 1.28601
Epoch 5/10, loss: 1.26744
Epoch 6/10, loss: 1.25524
Epoch 7/10, loss: 1.24523
Epoch 8/10, loss: 1.23815
Epoch 9/10, loss: 1.23252
Epoch 10/10, loss: 1.22719


In [16]:
 # next generate text by the model
def generate_text(model, start_char, length, hidden):
  model.eval() # switch to evaluation mode
  input = torch.tensor([[char2idx[start_char]]], dtype=torch.long)
  generated = start_char # start the text with the start character

  for _ in range(length):
    output, hidden = model(input, hidden)
    prob = nn.functional.softmax(output.squeeze(), dim=-1).data # predict probabilities
    char_idx = torch.multinomial(prob, 1).item() # sample from probabilities
    generated += idx2char[char_idx] # append the character to the output
    input = torch.tensor([[char_idx]], dtype=torch.long) # update the input

  return generated

In [17]:
# Generate 1000 characters of text starting with 'B'
hidden = None # reset hidden state
print(generate_text(model, start_char='B', length=1000, hidden=hidden))

BOLAMANA, and HERMIONE and unthinks again,
    Speak it.
  BEROWNE. No, if I warrant you; an honest Suffolk; well!
  CONSTABLE. And so Gallingrable-on the moon-caked her monument,
    Speak "
  WARG HAMWEND MACHINE RICHARD, SALESS thereon, my does
    By hole wife, that no more to leave, and all the high modor? I'll use your degned devil!
  CAIUS. Stay, which you made but carried to Octiodwallow to th' ear so import whilst I know
    Had
    then; holly some cheticles trunk to the midderly in
    the bow of wine,
  And, in the Christian
    provese be advis'd sort, LIRION youtless the world from wedax'd
    The gross of worthless; so
    Could make it, look doth be here.

  I"-
  MESSENGER. Whitwer, thou can wrongs and milthal eye.
  HEBREY. What spoke this and false certain another!
  Fal. Display perfect:
    For that if I love him, you was,
    The gods, I faccutaties
    Romeligious and Lover, sir?
  SHELONS

  CLOWN. Gentle, Northumble,
    With the starbing of love, and most the 