In [None]:
%cd /home/bap/hana/Basic-NLP-RNN/rnn/rnn

In [None]:
import numpy as np
import io
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.utils.data import DataLoader, Dataset

In [None]:
class Config:
    '''
    Config class defines dataset path and hyperparameters.
    '''
    data_train_url = 'dataset/shakespeare_train.txt'
    data_val_url = 'dataset/shakespeare_valid.txt'
    hidden_size = 512   # size of hidden state
    seq_len = 100       # length of LSTM sequence
    num_layers = 3      # num of layers in LSTM layer stack
    epochs = 100        # max number of epochs
    n_seqs = 128
    n_steps = 100
    lr = 0.002          # learning rate
    clip = 5
    num_workers = 2
    op_seq_len = 200    # total num of characters in output test sequence             
    load_chk = False    # load weights from save_path directory to continue training
    save_path = "models/rnn.pth"

In [None]:
class Dataset(Dataset):
    def __init__(self, file_path, seq_len = Config.seq_len):
        self.file_path = file_path
        self.seq_len = seq_len
        self.data = open(file_path, 'r').read()
        self.chars = sorted(list(set(self.data)))
        self.char_to_id = {ch:i for i, ch in enumerate(self.chars) }
        self.id_to_char = {i:ch for i, ch in enumerate(self.chars) }
        self.word_indexes = [self.char_to_id[s] for s in self.chars]

    def __len__(self):
        return len(self.word_indexes) - self.sequence_length
    
    def __getitem__(self,idx):
        return (
            torch.tensor(self.word_indexes[idx: idx + self.seq_len]),
            torch.tensor(self.word_indexes[idx + 1: idx + self.seq_len + 1])
        )

In [None]:
class RNN(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)
        self.lstm = nn.LSTM(vocab_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_seq, hidden_state):
        embedding = self.embedding(input_seq)
        output, hidden_state = self.lstm(embedding, hidden_state)
        output = self.fc(output)
        return output, (hidden_state[0].detach(), hidden_state[1].detach()) 