In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
np.set_printoptions(suppress=True)
from tqdm import tqdm, trange

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
torch.set_printoptions(sci_mode=False)
from torchvision import datasets, transforms

In [2]:
# attempt to autodetect device
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f"using device: {device}")

using device: mps


In [3]:
# Custom dataset for Tiny Shakespeare
class TinyShakespeareDataset(Dataset):
    def __init__(self, text, seq_length):
        self.text = text
        self.seq_length = seq_length
        self.chars = sorted(list(set(text)))
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
        self.data_size = len(self.text)

    def __len__(self):
        return max(0, self.data_size - self.seq_length)

    def __getitem__(self, index):
        x = [self.char_to_idx[c] for c in self.text[index:index+self.seq_length]]
        y = [self.char_to_idx[c] for c in self.text[index+1:index+self.seq_length+1]]
        return torch.tensor(x).float(), torch.tensor(y).float()

In [4]:
import requests
import os

# Download Tiny Shakespeare dataset
def download_tiny_shakespeare():
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    if not os.path.exists("data/tinyshakespeare.txt"):
        data = requests.get(url).text
        with open("data/tinyshakespeare.txt", "w") as f:
            f.write(data)
        
# Download the dataset
download_tiny_shakespeare()

# Read the dataset
with open("data/tinyshakespeare.txt", "r") as f:
    text = f.read()

# Set parameters
seq_length = 25
BS = 128

# Create dataset and split into train and test
dataset = TinyShakespeareDataset(text, seq_length)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create data loaders
loaders = {
    'train': DataLoader(train_dataset, batch_size=BS, shuffle=True),
    'test': DataLoader(test_dataset, batch_size=BS, shuffle=True),
}


In [5]:
input_size = 1
hidden_size = 64
output_size = len(dataset.chars)
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.i2h = nn.Linear(input_size, hidden_size, bias=False)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)
        self.sm = nn.LogSoftmax(dim=1)

    def forward(self, x, hidden_state):
        x = F.relu(self.i2h(x))
        hidden_state = self.h2h(x)
        x = F.relu(x + hidden_state)
        x = F.relu(self.h2o(x))
        x = self.sm(x)
        return x, hidden_state
model = RNNModel(input_size, hidden_size, output_size)
model.to(device)

RNNModel(
  (i2h): Linear(in_features=1, out_features=64, bias=False)
  (h2h): Linear(in_features=64, out_features=64, bias=True)
  (h2o): Linear(in_features=64, out_features=65, bias=True)
  (sm): LogSoftmax(dim=1)
)

In [6]:
def generate_text(model: RNNModel, dataset: TinyShakespeareDataset, prediction_length: int = 100) -> str:
    """
    Generate text up to prediction_length characters
    This function requires the dataset as argument in order to properly
    generate the text and return the output as strings
    """
    model.eval()
    predicted = dataset.vector_to_string([random.randint(0, len(dataset.chars) -1)])
    hidden = model.init_zero_hidden()

    for i in range(prediction_length - 1):
        last_char = torch.Tensor([dataset.char_to_idx[predicted[-1]]])
        X, hidden = last_char.to(device), hidden.to(device)
        out, hidden = model(X, hidden)
        result = torch.multinomial(nn.functional.softmax(out, 1), 1).item()
        #result = out.argmax().item()
        predicted += dataset.idx_to_char[result]

    return predicted

In [7]:
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr = 0.001)

In [8]:
num_epochs = 2
train_losses = {}

for epoch in range(num_epochs):
    epoch_losses = list()
    for i, (X, Y) in enumerate(loaders['train']):
        if X.shape[0] != BS:
            continue
        hidden = torch.zeros(hidden_size, hidden_size, requires_grad=False).float()
        optim.zero_grad()
        X, Y, hidden = X.to(device), Y.to(device), hidden.to(device)
        loss = 0
        for c in range(X.shape[1]):
            out, hidden = model(X[:, c].reshape(X.shape[0],1), hidden)
            l = criterion(out, Y[:, c].long())
            loss += l
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 3)
        optim.step()

        epoch_losses.append(loss.detach().item() / X.shape[1])
        if (i+1) % 50 == 0:
            print('Loss: {:.4f}'.format(loss.detach().item()))
    train_losses[epoch] = torch.tensor(epoch_losses).mean()
    print(f'=> epoch: {epoch + 1}, loss: {train_losses[epoch]}')
    #print(generate_text(model, data.dataset))

Loss: 101.4089
Loss: 100.6088
Loss: 100.5733
Loss: 100.4023
Loss: 100.4382
Loss: 100.2468
Loss: 100.0467
Loss: 100.5013
Loss: 100.4897
Loss: 100.7325
Loss: 100.3112
Loss: 100.6046
Loss: 99.9849
Loss: 100.0906
Loss: 100.6104
Loss: 100.2124
Loss: 100.1337
Loss: 99.4178
Loss: 99.7654
Loss: 99.6583
Loss: 99.4897
Loss: 100.1628
Loss: 100.0264
Loss: 99.8076
Loss: 100.1403
Loss: 100.0202
Loss: 99.8405
Loss: 100.2561
Loss: 99.9041
Loss: 99.7877
Loss: 100.0734
Loss: 100.2575
Loss: 99.7063
Loss: 99.8013
Loss: 99.0286
Loss: 99.7730
Loss: 100.5199
Loss: 99.9648
Loss: 99.4465
Loss: 99.8179
Loss: 99.7983
Loss: 99.8722
Loss: 100.4260
Loss: 99.3983
Loss: 99.8932
Loss: 99.8282
Loss: 99.7684
Loss: 99.7113
Loss: 99.3331
Loss: 99.8361
Loss: 100.3058
Loss: 99.7410
Loss: 99.3862
Loss: 99.2665
Loss: 99.8048
Loss: 100.1091
Loss: 99.9150
Loss: 99.2351
Loss: 99.7043
Loss: 99.7076
Loss: 99.9636
Loss: 99.3525
Loss: 100.0504
Loss: 100.0443
Loss: 99.6168
Loss: 100.1837
Loss: 99.8117
Loss: 100.1734
Loss: 100.1300
Lo