In [9]:
from tqdm import tqdm
import numpy as np
import os
import torch
import urllib.request
import string

## Download the data

The best place to access books that are no longer under Copyright is [Project Gutenberg](https://www.gutenberg.org/). Today we recommend using [Alice’s Adventures in Wonderland by Lewis Carroll](https://www.gutenberg.org/files/11/11-0.txt) for consistency. Of course you can experiment with other books as well.

In [10]:
data_url = 'https://www.gutenberg.org/files/219/219-0.txt'
fname = 'heart_of_darkness.txt'

if fname not in os.listdir():
    urllib.request.urlretrieve(data_url, fname)

## Load data and create character to integer mappings

- Open the text file, read the data then convert it to lowercase letters.
- Map each character to a respective number. Keep 2 dictionaries in order to have more easily access to the mappings both ways around.
- Transform the data from a list of characters to a list of integers

In [11]:
# Load data
with open(fname, 'r') as f:
    data = f.read()

# Preprocess data
table = str.maketrans('\n', ' ')
data = list(data.lower().translate(table))

# Build char-to-int and int-to-char dictionaries
c2i = {x: i for i, x in enumerate(set(data))}
i2c = {i: x for x, i in c2i.items()}

# Transform the data from chars to integers
data = [c2i[c] for c in data]
data[:10], [i2c[i] for i in data][:10]

([22, 43, 41, 13, 25, 2, 39, 45, 29, 13],
 ['\ufeff', 't', 'h', 'e', ' ', 'p', 'r', 'o', 'j', 'e'])

In [12]:
list(c2i.items())[:10], list(i2c.items())[:10]

([('’', 0),
  ('u', 1),
  ('p', 2),
  ('i', 3),
  ('l', 4),
  ('_', 5),
  (',', 6),
  ('4', 7),
  ('“', 8),
  ('z', 9)],
 [(0, '’'),
  (1, 'u'),
  (2, 'p'),
  (3, 'i'),
  (4, 'l'),
  (5, '_'),
  (6, ','),
  (7, '4'),
  (8, '“'),
  (9, 'z')])

## Define the datasets and dataloaders
- We are "thinking" in sequences of 100 characters: 99 characters in the input and 1 in the output.  
E.g. for the sequence *\['h', 'e', 'l', 'l'\]* as input, we will have *\['o'\]* as the expected output.
- Each pair (sample, label) from the training dataset will be composed from a sequence of 99 ints and a single integer label
- We will keep the first 85% sequences as training data and use the remaining for validation

In [13]:
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import typing as t
import string


# Define datasets
class SequenceDataset(data.Dataset):
    def __init__(self, data_url: str, fname: str, seq_len: int=99) -> None:
        super().__init__()

        # Useful props
        self.__data_url = data_url
        self.__seq_len = seq_len
        self.__fname = fname

        # Populated through loading
        self.c2i: t.Dict[str, int]
        self.i2c: t.Dict[int, str]
        self.char: bool = False

        # Load the data
        self.__data = self.__load()

    @property
    def units(self) -> int:
        return len(self.c2i)

    def seq_to_txt(self, seq: t.List[int]) -> str:
        return ''.join([self.i2c[i] for i in seq])

    def __getitem__(self, index: int):
        X = self.__data[index:index + self.__seq_len]
        y = self.__data[index + self.__seq_len]

        if not self.char:
            X = torch.tensor(X)
            y = torch.tensor(y)
            return X, y

        return self.seq_to_txt(X), self.i2c[y]

    def __len__(self) -> int:
        return max(0, len(self.__data) - self.__seq_len)

    def __load(self) -> t.List[int]:
        # Download it if does not exist
        if self.__fname not in os.listdir():
            urllib.request.urlretrieve(self.__data_url, self.__fname)

        # Load data
        with open(self.__fname, 'r') as f:
            data = f.read()

        # Preprocess data
        table = str.maketrans('\n', ' ')
        data = list(data.lower().translate(table))

        # Build char-to-int and int-to-char dictionaries
        self.c2i = {x: i for i, x in enumerate(set(data))}
        self.i2c = {i: x for x, i in self.c2i.items()}

        # Transform the data from chars to integers
        return [self.c2i[c] for c in data]

# Create datasets
dataset = SequenceDataset(
    data_url='https://www.gutenberg.org/cache/epub/65565/pg65565.txt',
    fname='Țara mea.txt'
)

# Split into Train & Validation
gen = torch.Generator('cpu')
train_d, valid_d  = data.random_split(dataset, [0.85, 0.15], generator=gen)

# Specify the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define dataloaders
batch_size = 128
train_dl = data.DataLoader(train_d, batch_size, shuffle=True, generator=gen, num_workers=8, prefetch_factor=2)
valid_dl = data.DataLoader(valid_d, batch_size, shuffle=True, generator=gen, num_workers=8, prefetch_factor=2)

## Define a model with
- An embedding layer with size 32
- Three LSTM layers with a hidden size of 256 and a dropout rate of 20%
- A final linear classification layer

In [14]:
import torch.nn as nn


class RNNModel(nn.Module):
    def __init__(self, num_embeddings: int):
        super().__init__()

        # From int to internal learnable embeddings
        self.embeddings = nn.Embedding(num_embeddings, embedding_dim=32)

        # Define a RNN using three LSTM layers, applied one after another
        self.rnn = nn.LSTM(32, 256, 3, batch_first=True, dropout=0.2)

        # Apply a classifier on the final hidden state
        self.dense = nn.Linear(in_features=256, out_features=num_embeddings, bias=True)

    def forward(self, x: torch.Tensor):
        x = self.embeddings(x)
        o, (h, c) = self.rnn(x)
        x = self.dense(o[:, -1, :])
        return x

## Define the training loop and train the model to predict the next character in the sequence

In [15]:
from torch import Tensor


class Metrics(t.TypedDict):
    accuracy: t.List[float]
    loss: t.List[float]


class TrainHistory(t.TypedDict):
    train: Metrics
    valid: Metrics


def train_validate(model: nn.Module,
                   train_dl: DataLoader,
                   valid_dl: DataLoader,
                   epochs: int,
                   loss_fn: nn.Module,
                   optim: torch.optim.Optimizer) -> TrainHistory:
    # Track history
    history: TrainHistory = {
        'train': {
            'accuracy': [],
            'loss': [],
        },
        'valid': {
            'accuracy': [],
            'loss': [],
        }
    }

    # Do Training & Validation & Testing
    for epoch in range(epochs):
        print('Epoch [%d/%d]' % (epoch + 1, epochs), end=' - ')

        ### Training ###
        model.train(True)

        # Track across a single epoch
        train_loss = []
        train_accuracy = []

        for b, (X, y) in enumerate(train_dl):
            X, y = X.to(device), y.to(device)

            # Prevent grad accumulation
            optim.zero_grad()

            # Forward pass
            logits = model.forward(X)
            loss: Tensor = loss_fn(logits, y)
            y_pred: Tensor = logits.argmax(dim=1).detach()

            # Backward pass
            loss.backward()
            optim.step()

            # Track metrics
            train_loss.append(loss.detach().cpu().item())
            train_accuracy.extend((y_pred == y).detach().cpu().tolist())

        # Aggregate training results
        history['train']['loss'].append(
            torch.mean(torch.tensor(train_loss)).item())
        history['train']['accuracy'].append(
            (torch.sum(torch.tensor(train_accuracy)) / len(train_accuracy)).item())

        ### Validation ###
        model.train(False)

        # Track across a single epoch
        valid_loss = []
        valid_accuracy = []

        for b, (X, y) in enumerate(valid_dl):
            X, y = X.to(device), y.to(device)

            # Forward pass
            with torch.no_grad():
                logits = model.forward(X)
                loss: Tensor = loss_fn(logits, y)
                y_pred: Tensor = logits.argmax(dim=1)

            # Track metrics
            valid_loss.append(loss.detach().cpu().item())
            valid_accuracy.extend((y_pred == y).detach().cpu().tolist())

        # Aggregate training results
        history['valid']['loss'].append(
            torch.mean(torch.tensor(valid_loss)).item())
        history['valid']['accuracy'].append(
            (torch.sum(torch.tensor(valid_accuracy)) / len(valid_accuracy)).item())

        # Inform regarding current metrics
        print('t_loss: %f, t_acc: %f, v_loss: %f, v_acc: %f'
              % (history['train']['loss'][-1], history['train']['accuracy'][-1], history['valid']['loss'][-1], history['valid']['accuracy'][-1]))

    # Output the obtained results so far
    return history

In [16]:
from torch.optim import Adam


# Configure the training settings
model: RNNModel = RNNModel(num_embeddings=dataset.units)
model = model.to(device)
optim = Adam(model.parameters(), lr=8e-4)
loss_fn = nn.CrossEntropyLoss()
epochs = 100

# define the training loop and traing the model
model.requires_grad_(True)
train_validate(model, train_dl, valid_dl, epochs, loss_fn, optim)

Epoch [1/100] - 

t_loss: 2.429425, t_acc: 0.297249, v_loss: 2.089754, v_acc: 0.382244
Epoch [2/100] - t_loss: 1.978714, t_acc: 0.409957, v_loss: 1.848271, v_acc: 0.448214
Epoch [3/100] - t_loss: 1.788922, t_acc: 0.460497, v_loss: 1.690870, v_acc: 0.487920
Epoch [4/100] - t_loss: 1.668142, t_acc: 0.492647, v_loss: 1.614141, v_acc: 0.511128
Epoch [5/100] - t_loss: 1.589415, t_acc: 0.513011, v_loss: 1.557657, v_acc: 0.523108
Epoch [6/100] - t_loss: 1.531417, t_acc: 0.529143, v_loss: 1.522816, v_acc: 0.532011
Epoch [7/100] - t_loss: 1.483238, t_acc: 0.542023, v_loss: 1.498687, v_acc: 0.542175
Epoch [8/100] - t_loss: 1.445887, t_acc: 0.550992, v_loss: 1.475052, v_acc: 0.547933
Epoch [9/100] - t_loss: 1.414433, t_acc: 0.559077, v_loss: 1.458326, v_acc: 0.552273
Epoch [10/100] - t_loss: 1.386680, t_acc: 0.566889, v_loss: 1.444717, v_acc: 0.556060
Epoch [11/100] - 

KeyboardInterrupt: 

In [17]:
torch.save(model.state_dict(), 'weights.pt')

## Evaluate the model by generating text

- Start with 99 characters (potentially chosen from a text)
- Generate a new character using the trained network
- Repeat the process by appending the generated character and making a prediction for a new one

In [18]:
model = RNNModel(num_embeddings=dataset.units)
model.load_state_dict(torch.load('weights.pt'))
model = model.to(device)
model.requires_grad_(False)
model.train(False)

RNNModel(
  (embeddings): Embedding(78, 32)
  (rnn): LSTM(32, 256, num_layers=3, batch_first=True, dropout=0.2)
  (dense): Linear(in_features=256, out_features=78, bias=True)
)

In [24]:
# Generate Text
# 
start_pos = torch.randint(low=0, high=len(dataset), size=(1,), generator=gen)
start, _ = dataset[start_pos]

# Show info from dataset
dataset.char = True
text = t.cast(str, dataset[start_pos][0])
print(start_pos.item(), ':', text)
dataset.char = False

context: Tensor = t.cast(Tensor, start)
gen_count = 100
k = 25

for i in range(gen_count):
    # Infer new character
    logits: Tensor = model(context.to(device).unsqueeze(0)).cpu()

    # Perform Top-K Sampling
    top_k = torch.topk(logits, k=k, sorted=False)
    top_p = top_k.values.squeeze(0)
    top_p = torch.nn.functional.softmax(top_p, 0).cumsum(0)
    top_i = torch.searchsorted(top_p, torch.rand(1))

    # Obtain prediction index
    top_i = int(top_k.indices.squeeze(0)[top_i].item())

    # Shift context and use predicted value
    context = context.roll(-1)
    context[-1] = top_i
    text += dataset.i2c[top_i]

text

291579 : ld be clearly marked as such and sent to the project         gutenberg literary archive foundation 


'ld be clearly marked as such and sent to the project         gutenberg literary archive foundation edecher work. paragraph domp, your with the plote statos efpiand project gutenberg™. is was empoks, '

In [None]:
context, context.roll(-1)

(tensor([27, 36, 21, 27, 37, 23,  9, 27,  3,  9, 18, 13, 27,  9, 34, 34,  9,  8,
          1,  9, 27, 36, 21, 27, 24, 18,  9, 31, 25, 34, 22, 27, 27, 23,  9, 27,
          2, 31, 34, 27, 34, 15, 32,  9,  8, 37, 27, 21, 36, 18, 27, 31, 27,  2,
         23, 15, 32,  9, 27, 27,  4, 27,  8, 36, 27, 15, 37, 27, 15, 34, 27, 15,
         25, 20, 36, 34, 34, 15,  6, 32,  9, 27, 15, 37, 27, 15, 34, 27, 15, 25,
         20, 36, 34, 34, 15,  6, 32,  9, 27]),
 tensor([36, 21, 27, 37, 23,  9, 27,  3,  9, 18, 13, 27,  9, 34, 34,  9,  8,  1,
          9, 27, 36, 21, 27, 24, 18,  9, 31, 25, 34, 22, 27, 27, 23,  9, 27,  2,
         31, 34, 27, 34, 15, 32,  9,  8, 37, 27, 21, 36, 18, 27, 31, 27,  2, 23,
         15, 32,  9, 27, 27,  4, 27,  8, 36, 27, 15, 37, 27, 15, 34, 27, 15, 25,
         20, 36, 34, 34, 15,  6, 32,  9, 27, 15, 37, 27, 15, 34, 27, 15, 25, 20,
         36, 34, 34, 15,  6, 32,  9, 27, 27]))

In [None]:
top_k = torch.topk(torch.tensor([0.8, 0.0, 0.05, 0.15]), k=4).values.squeeze(0).cpu()
top_k = torch.nn.functional.softmax(top_k, 0).cumsum(0)
top_i = torch.searchsorted(top_k, torch.rand(1))
top_i

tensor([1])

In [None]:
top_k

tensor([0.1059, 0.1057, 0.1012, 0.0999, 0.0997, 0.0995, 0.0990, 0.0973, 0.0961,
        0.0958])