In [None]:
from torch import nn, optim
import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import numpy as np

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('CUDA')

In [None]:
path = '/kaggle/input/tales-dataset/train.txt'
input_text = open(path, 'r', encoding='utf-8').read()
input_text = input_text.lower()

In [None]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 300
        self.embedding_dim = 200
        self.num_layers = 3
        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(num_embeddings=n_vocab, embedding_dim=self.embedding_dim,)
        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)
    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state
    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, text, seq_len):
        self.seq_len = seq_len
        self.text = text
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()
        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
        self.words_indexes = [self.word_to_index[w] for w in self.words]
    def load_words(self):
        return self.text.split(' ')
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)
    def __len__(self):
        return len(self.words_indexes) - self.seq_len
    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.seq_len]),
            torch.tensor(self.words_indexes[index+1:index+self.seq_len+1]),
        )

In [None]:
seq_len = 5
max_epochs = 15
def train(dataset, model, device = device):
    model.train()
    dataloader = DataLoader(dataset, batch_size=2048)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(seq_len)
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            y_pred, (state_h, state_c) = model(x.to(device), (state_h.to(device), state_c.to(device)))
            loss = criterion(y_pred.transpose(1, 2), y.to(device))
            state_h = state_h.detach()
            state_c = state_c.detach()
            loss.backward()
            optimizer.step()
        print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [None]:
def predict(dataset, model, text, next_words=300, temperature=1.0):
    model.eval()
    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits/temperature, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])
        text= text + ' ' + dataset.index_to_word[word_index]
    return text

In [None]:
dataset = Dataset(input_text, seq_len)
model = Model(dataset)
model.to(device)
train(dataset, model)

In [None]:
def save_checkpoint(path, model):
    state = {'state_dict': model.state_dict()}
    torch.save(state, path)
    print('model saved to %s' % path)

In [None]:
save_checkpoint("LSTM30epoch.pth", model)

In [None]:
model.to('cpu')
print(predict(dataset, model, text='once upon a', temperature = 1.0))

In [None]:
print(predict(dataset, model, text='once upon a', temperature = 0.8))

In [None]:
model.to(device)
train(dataset, model)

In [None]:
model.to('cpu')
print(predict(dataset, model, text='once upon a', temperature = 1.0))

In [None]:
model.to('cpu')
print(predict(dataset, model, text='once upon a', temperature = 0.8))

In [None]:
print(predict(dataset, model, text='the young man', temperature = 0.8).replace("<| end of text |>", "\n"))