In [None]:
import torch

from tqdm import tqdm

In [None]:
embedding_size = 100
sequence_length = 128
lstm_size = 256
bidirectional = True
n_layer = 2
dropout = 0.5
epochs = 10000
batch_size = 256
vocab_size = 50560
lr = 1e-3
num_workers = 1

checkpoint_interval = 100
save_path = "./"
train_data_path = "./data/train.txt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Define stuff

So notebook can be run on colab/kaggle

In [None]:
import torchdata.dataloader2 as dl2
import torchdata.datapipes as dp


class StoryDataset:
    def __init__(
            self,
            root,
            batch_size=1,
            num_workers=1,
            shuffle=True,
            drop_last=False,
            sequence_size=32,
            pad_idx=2,
    ):
        self.sequence_size = sequence_size
        self.pad_idx = pad_idx

        datapipe = dp.iter.FileLister(root, recursive=True).filter(
            filter_fn=self.filter_fn
        )
        datapipe = dp.iter.FileOpener(datapipe, mode="rt")
        datapipe = dp.iter.StreamReader(datapipe)
        datapipe = dp.iter.Mapper(datapipe, fn=self.map_fn)
        datapipe = (
            dp.iter.FlatMapper(datapipe, fn=self.batch_fn).shuffle().sharding_filter()
        )
        datapipe = dp.iter.Batcher(datapipe, batch_size=batch_size, drop_last=drop_last)

        self.dloader2 = dl2.DataLoader2(
            datapipe,
            reading_service=dl2.MultiProcessingReadingService(num_workers=num_workers),
            datapipe_adapter_fn=dl2.adapter.Shuffle(shuffle),
        )

    def __iter__(self):
        return self.dloader2.__iter__()

    def map_fn(self, x):
        return (self.sequence_size - 1) * [self.pad_idx] + [
            int(y) for y in x[1].split(",")
        ]

    def batch_fn(self, x):
        return [
            x[i: i + self.sequence_size + 1]
            for i in range(0, len(x) - self.sequence_size)
        ]

    @staticmethod
    def filter_fn(name):
        return name.endswith(".txt")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# from ..utils.generation_utils import (
#     top_k_top_p_filtering,
#     greedy_search,
#     multinomial_sampling,
#     temperature_softmax,
# )


class LSTMTextGenerator(nn.Module):
    def __init__(
            self,
            vocab_size,
            emb_size,
            lstm_size,
            lstm_layers=1,
            lstm_bidirectional=False,
            lstm_dropout=0.2,
            use_gru=False,
            pad_idx=2,
            seq_len=20,
    ):
        super(LSTMTextGenerator, self).__init__()

        self.pad_idx = pad_idx

        self.seq_len = seq_len

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=emb_size, padding_idx=pad_idx
        )

        self.lstm = (
            nn.LSTM(
                input_size=emb_size,
                hidden_size=lstm_size,
                num_layers=lstm_layers,
                batch_first=True,
                dropout=lstm_dropout,
                bidirectional=lstm_bidirectional,
            )
            if not use_gru
            else nn.GRU(
                input_size=emb_size,
                hidden_size=lstm_size,
                num_layers=lstm_layers,
                batch_first=True,
                dropout=lstm_dropout,
                bidirectional=lstm_bidirectional,
            )
        )

        self.dropout = nn.Dropout(lstm_dropout)

        self.fc1 = nn.Linear(
            in_features=lstm_size * 2 if lstm_bidirectional else lstm_size,
            out_features=vocab_size,
        )

        # self.softmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, inputs):  # add typing later
        embedded = self.embedding(inputs)

        lstm_output, _ = self.lstm(embedded)

        dropped = self.dropout(lstm_output[:, -1, :])
        output = self.fc1(dropped)

        return output

    def generate(
            self,
            start_text,
            length=100,
            temperature=1.0,
            strategy="top_k_top_p",
            top_k=0,
            top_p=1.0,
            n_samples=1,
    ):
        assert not self.training
        starter = start_text

        with torch.no_grad():
            for i in range(length):
                inp = torch.LongTensor([starter[-self.seq_len:]])
                pred = self.forward(inp)

                logits = pred / temperature
                if strategy == "greedy":
                    out = greedy_search(logits)
                elif strategy == "multinomial":
                    out = torch.nn.functional.softmax(logits, dim=1)
                    out = multinomial_sampling(out, n_samples=n_samples)
                elif strategy == "top_k_top_p":
                    out = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
                    out = torch.nn.functional.softmax(out, dim=1)
                    out = multinomial_sampling(out, n_samples=n_samples)

                starter.append(out.item())
                if out == 1:
                    break

        return starter


# Initialize classes

In [None]:
lstm = LSTMTextGenerator(
    vocab_size=vocab_size,
    emb_size=embedding_size,
    lstm_size=lstm_size,
    lstm_layers=n_layer,
    lstm_bidirectional=bidirectional,
    lstm_dropout=dropout,
    pad_idx=2,
    seq_len=sequence_length,
)
lstm.to(device)

In [None]:
data_loader = StoryDataset(
    train_data_path,
    batch_size=batch_size,
    num_workers=num_workers,
    sequence_size=sequence_length,
)

In [None]:
optimizer = torch.optim.RAdam(lstm.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

# Training

In [None]:
lstm.train()

for epoch in tqdm(range(1, epochs + 1)):
    running_loss = 0.0

    for batch in tqdm(data_loader):
        batch_tensor = torch.LongTensor(batch).to(device)

        x_batch = batch_tensor[:, :sequence_length]
        y_batch = batch_tensor[:, sequence_length]

        lstm.zero_grad()

        output = lstm(x_batch)

        loss = criterion(output, y_batch)

        running_loss += loss.item()

        loss.backward()

        nn.utils.clip_grad_norm_(lstm.parameters(), 0.5)  # 

        optimizer.step()

    print(f"Epoch {epoch}, loss: {running_loss}")

    if epoch % checkpoint_interval == 0:
        save_name = (
            f"embedding_size_{embedding_size}_"
            f"sequence_length_{sequence_length}_"
            f"lstm_size_{lstm_size}_"
            f"bidirectional_{bidirectional}_"
            f"n_layer_{n_layer}_"
            f"dropout_{dropout}_"
            f"epoch_{epoch}_"
            f"class_{lstm.__class__.__name__}.pth"
        )
        torch.save(lstm.state_dict(), save_path + save_name)