In [None]:
import urllib3
import collections
import re
shakespeare = 'http://www.gutenberg.org/files/100/100-0.txt'
http = urllib3.PoolManager()
text = http.request('GET', shakespeare).data.decode('utf-8')
raw_dataset = ' '.join(re.sub('[^A-Za-z]+', ' ', text).lower().split())
print('number of characters: ', len(raw_dataset))
print(raw_dataset[0:70])

idx_to_char = list(set(raw_dataset))
char_to_idx = dict([(char, i) for i, char in enumerate(idx_to_char)])
vocab_size = len(char_to_idx)
corpus_indices = [char_to_idx[char] for char in raw_dataset]
sample = corpus_indices[:20]
print('chars:', ''.join([idx_to_char[idx] for idx in sample]))
print('indices:', sample)
train_indices = corpus_indices[:-100000]
test_indices = corpus_indices[-100000:]

number of characters:  5058009
the project gutenberg ebook of the complete works of william shakespea
chars: the project gutenber
indices: [12, 22, 0, 10, 17, 14, 21, 5, 0, 13, 12, 10, 16, 6, 12, 0, 26, 7, 0, 14]


In [None]:
print(idx_to_char)

['e', 'f', 'x', 'd', 'k', 'j', 'u', 'b', 'y', 'q', ' ', 's', 't', 'c', 'r', 'v', 'g', 'p', 'w', 'a', 'i', 'o', 'h', 'l', 'z', 'm', 'n']


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
import time

In [None]:
class RNN(nn.Module):
    def __init__(self, vocab_size):
        super(RNN, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = 256
        self.rnn = nn.RNN(vocab_size, self.hidden_size)
        self.linear = nn.Linear(self.hidden_size, self.vocab_size)

    def forward(self, inputs, state):

        X = F.one_hot(torch.transpose(inputs,0,1), self.vocab_size)
        Y, state = self.rnn(X,state)
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output


In [None]:
def initialize_model(m):
  if isinstance(m, nn.RNN):
        nn.init.xavier_uniform_(m.weight_ih_l0.data)
        nn.init.xavier_uniform_(m.weight_hh_l0.data)
        if m.bias is not None:
            nn.init.constant_(m.bias_ih_l0.data, 0)
            nn.init.constant_(m.bias_hh_l0.data, 0)
  if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)


In [None]:
def predict_rnn(prefix, num_chars, model, device, vocab_size, idx_to_char, char_to_idx):
  output = [char_to_idx[prefix[0]]]
  for t in range(num_chars + len(prefix) -1):
    X = torch.tensor([output[-1]],device=device)
    Y = model(X, state)
    if t < len(prefix) - 1:
      output.appebd(char_to_idx[prefix[t + 1]])
    else:
      output.append(int(Y.argmax(axis=1).item()))
    return ''.join([idx_to_char[i] for i in output])


In [None]:
def train(model, data, epochs, tbptt_step, pred_period, prefixes):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters())
    # TODO one hot encode and gradient clip in here
    for epoch in range(epochs):
        l_sum, n, start = 0.0, 0, time.time()
        for i, (x, y) in enumerate(data):
            outputs = model(x)
            loss = criterion(outputs, y)

            if (i + 1) % tbptt_step == 0:
                optimizer.zero_grad()
                torch.mean(loss.backward())
                optimizer.step()
            l_sum += loss.item()
            n += tbptt_step
        if (epoch + 1) % pred_period == 0:
          print('epoch %d, perplexity %f, time %.2f sec' %(
              epoch + 1, math.exp(l_sum / n), time.time() - start))
          for prefix in prefixes:
            print(' -', predict_rnn())

In [None]:
batch_size = 32
num_epochs = 1
net = RNN(vocab_size, 256, vocab_size)
initialize_model(net)
train(net, raw_dataset, num_epochs, batch_size)