In [None]:
import torch
import torch.nn as nn
import numpy as np
from nltk import word_tokenize

<font size=6>
Preprocess the data.
</font>

In [None]:
with open('./data/author_quotes.txt', 'rt', newline='\n') as f:
    quotes = f.readlines()
quotes = [q.rstrip('\n') for q in quotes]

In [None]:
unique_chars = list(set(' '.join(quotes)))
char2id = {ch:i for i, ch in enumerate(unique_chars)}
pad_value = char2id[' ']

In [None]:
#char tokenization
char_tokens = [[' '] + list(q) for q in quotes]
MAX_SENT_LEN = max([len(q) for q in char_tokens])

In [None]:
#vectorize
vect_data = torch.full((len(char_tokens), MAX_SENT_LEN), char2id[' '], dtype=torch.long)
for i, q in enumerate(char_tokens):
    for j, ch in enumerate(q):
        vect_data[i][j] = char2id[char_tokens[i][j]]

<font size=6>
RNN model
</font>

In [None]:
class charRNN(nn.Module):
    def __init__(self, char_vocab_len, emb_size=32, num_layers=3, hid_state_size=64):
        super().__init__()
        self.num_rnn_layers = num_layers
        self.hid_state_size = hid_state_size
        self.emb = nn.Embedding(char_vocab_len, emb_size)
        self.rnn = nn.RNN(emb_size, hid_state_size, num_layers=num_layers, batch_first=True)
        self.out = nn.Linear(hid_state_size, char_vocab_len)
    
    def forward(self, x):
        x = self.emb(x)  # batch_sz x max_sent_len x emb_sz
        #print(f"input emb shape {x.shape}")
        #init h0
        hidden = self.get_init_state(x.shape[0])
        features, hidden = self.rnn(x, hidden) #features of size batch_sz x max_sent_len x hid_state_size
        #print(f"rnn output shape {features.shape}, {hidden.shape}")
        pred = self.out(features)  #batch_sz x max_sent_len x char_vocab_len
        pred = pred.permute(0,2,1)
        #pred = pred.argmax(2)
        #print(f"output shape {pred.shape}")
        return pred, hidden
    def get_init_state(self, batch_size):
        return torch.zeros(self.num_rnn_layers, batch_size, self.hid_state_size)

<font size=6>
Helper functions
</font>

In [None]:
def train(model, vectorized_input, lr=0.01, n_epochs=10, batch_size=128, max_num_batches=0):
    import time
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    #remove the last character from input (nothing to predict for it) and 
    #the first character from output (since the first char we predict is the second character
    #in the sequence)
    dataset = torch.utils.data.TensorDataset(vectorized_input[:, :-1],
                                             vectorized_input[:, 1:])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    opt = torch.optim.Adam(model.parameters(), lr = lr)
    model.to(device)
    model.train()
    for epoch in range(n_epochs):
        t_beg = time.perf_counter()
        mean_loss = 0
        nbatches = 0
        for batch_x, batch_y in dataloader:
            model.zero_grad()
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            print(batch_x[0], batch_y[0])
            pred, _ = model(batch_x)
            print(f"pred shape {pred.shape}")
            loss = nn.functional.cross_entropy(pred, batch_y)
            mean_loss += float(loss)
            nbatches +=1
            loss.backward()
            opt.step()
            if max_num_batches and nbatches == max_num_batches:
                break
        print(f"epoch {epoch} loss {mean_loss/nbatches}, epoch time {time.perf_counter()-t_beg}")
    return model

def predict_char(model, seed_seq, char2id, id2char):
    vect_data = torch.tensor([[char2id[ch] for ch in seed_seq]])
    #print(vect_data.shape)
    #pass the seed sequence first to get the hidden state
    out, _ = model(vect_data)
    #print(out.shape)
    out_char = out.argmax(1).squeeze(0)
    #print(out_char)
    return id2char[out_char[-1]] #take the last char from prediction

def generate_sequence(model, char2id, id2char, seed_seq = "hi", seq_len = MAX_SENT_LEN):
    model.eval()
    chars = list(seed_seq)
    if seq_len <= len(seed_seq):
        return seed_seq 
    
    for _ in range(seq_len - len(seed_seq)):
        next_char = predict_char(model, chars, char2id, id2char)
        chars.append(next_char)
    return ''.join(chars)
    

In [None]:
ch_model = charRNN(len(char2id), num_layers=1)
ch_model.load_state_dict(torch.load('./models/RNN-2022-07-05.model', map_location=torch.device('cpu')))
#ch_model = train(ch_model, vect_data, lr=0.001, n_epochs=10, max_num_batches=100)
#torch.save(ch_model.state_dict(), './models/char_rnn.model')

In [None]:
print(generate_sequence(ch_model, char2id, unique_chars, seq_len=100))