In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import string
import random
import sys
import unidecode
from tqdm.auto import tqdm

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

all_characters = string.printable
n_characters = len(all_characters)

file = unidecode.unidecode(open('shakespeare_larger.txt').read())

In [3]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embed = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden, cell):
        out = self.embed(x)
        out, (hidden, cell) = self.rnn(out.unsqueeze(1), (hidden, cell))
        out = self.fc(out.reshape(out.shape[0], -1))

        return out, hidden, cell
    
    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)

        return hidden, cell
    
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.chunk_len = 250
        self.num_epochs = 10000
        self.batch_size = 1
        self.print_every = 50
        self.hidden_size = 256
        self.num_layers = 2
        self.learning_rate = 0.003

    def char_tensor(self, string):
        tensor = torch.zeros(len(string)).long()
        for c in range(len(string)):
            tensor[c] = all_characters.index(string[c])

        return tensor

    def get_random_batch(self):
        start_idx = random.randint(0, len(file) - self.chunk_len)
        end_idx = start_idx + self.chunk_len + 1
        text_string = file[start_idx:end_idx]
        text_input = torch.zeros(self.batch_size, self.chunk_len)
        text_target = torch.zeros(self.batch_size, self.chunk_len)

        for i in range(self.batch_size):
            text_input[i, :] = self.char_tensor(text_string[:-1])
            text_target[i, :] = self.char_tensor(text_string[1:])

        return text_input.long(), text_target.long()

    def generate(self, initial_string='A', prediction_len=100, temperature=0.85):
        hidden, cell = self.rnn.init_hidden(self.batch_size) #########
        initial_input = self.char_tensor(initial_string)
        predicted = initial_string

        for p in range(len(initial_string) - 1):
            out, hidden, cell = self.rnn(initial_input[p].view(1).to(device), hidden, cell) #######

        last_char = initial_input[-1]

        for p in range(prediction_len):
            out, hidden, cell = self.rnn(last_char.view(1).to(device), hidden, cell) #########
            output_dist = out.data.view(-1).div(temperature).exp()
            top_char = torch.multinomial(output_dist, 1)[0]
            predicted_char = all_characters[top_char]
            predicted += predicted_char
            last_char = self.char_tensor(predicted_char)

        return predicted

    def train(self):
        self.rnn = RNN(n_characters, self.hidden_size, self.num_layers, n_characters).to(device)
        optimizer = torch.optim.Adam(self.rnn.parameters(), self.learning_rate)
        loss_fn = nn.CrossEntropyLoss()

        for epoch in tqdm(range(self.num_epochs)):
            input, target = self.get_random_batch()
            input, target = input.to(device), target.to(device)

            hidden, cell = self.rnn.init_hidden(self.batch_size)

            optimizer.zero_grad()
            loss = 0

            for c in range(self.chunk_len):
                output, hidden, cell = self.rnn(input[:, c], hidden, cell)
                loss += loss_fn(output, target[:, c])

            loss.backward()
            optimizer.step()
            loss = loss.item() / self.chunk_len

            if epoch % self.print_every == 0:
                print(f"loss: {loss} \n")

In [6]:
gen = Generator().to(device)

In [8]:
gen.load_state_dict(torch.load('models/model.pth'))

RuntimeError: Error(s) in loading state_dict for Generator:
	Unexpected key(s) in state_dict: "rnn.embed.weight", "rnn.rnn.weight_ih_l0", "rnn.rnn.weight_hh_l0", "rnn.rnn.bias_ih_l0", "rnn.rnn.bias_hh_l0", "rnn.rnn.weight_ih_l1", "rnn.rnn.weight_hh_l1", "rnn.rnn.bias_ih_l1", "rnn.rnn.bias_hh_l1", "rnn.fc.weight", "rnn.fc.bias". 

In [9]:
gen.train()

  0%|          | 0/10000 [00:00<?, ?it/s]

loss: 4.61071337890625 

loss: 2.510514892578125 

loss: 2.3075009765625 

loss: 2.18582763671875 

loss: 2.241069091796875 

loss: 1.9728653564453125 

loss: 2.106625732421875 

loss: 1.97143359375 

loss: 2.166918212890625 

loss: 1.9970152587890626 

loss: 2.058250244140625 

loss: 1.81194970703125 

loss: 1.98783642578125 

loss: 1.925274658203125 

loss: 2.01877978515625 

loss: 1.816458984375 

loss: 1.693320068359375 

loss: 1.7893072509765624 

loss: 1.8161129150390625 

loss: 1.92575537109375 

loss: 1.9623277587890624 

loss: 1.9824852294921875 

loss: 1.7685836181640624 

loss: 1.5866041259765624 

loss: 1.703200927734375 

loss: 1.86522509765625 

loss: 1.8189678955078126 

loss: 1.6810771484375 

loss: 1.7042664794921876 

loss: 1.789553466796875 

loss: 1.537436279296875 

loss: 1.9293896484375 

loss: 2.051091064453125 

loss: 1.65183349609375 

loss: 1.7150777587890624 

loss: 1.78080615234375 

loss: 1.679027099609375 

loss: 1.7594903564453126 

loss: 1.4643525390625 

In [15]:
gen.generate('gau')

'gaure:\nAnd where he life worship on her banmory\nTo it awe\nAs thou fear, night well; a for oness speech\n'

In [19]:
from pathlib import Path

In [49]:
model_path = Path('models')
model_path.mkdir(parents=True, exist_ok=True)

model_name = 'model.pth'
model_path_save = model_path/model_name

torch.save(obj=gen.state_dict(), f=model_path_save)