In [1]:
from sklearn.model_selection import train_test_split
from pipeline import Pipeline
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transforms import pad_token
import numpy as np
from tqdm import tqdm, tqdm_notebook

In [2]:
data, token_idx, idx_token = Pipeline.load("10k_common").data
ratings = Pipeline.load("ratings").data

In [3]:
class RNN(nn.Module):
    def __init__(self, emb_size, hidden_size, num_layers, vocab_size, pad_idx):
        # RNN Accepts the following hyperparams:
        # emb_size: Embedding Size
        # hidden_size: Hidden Size of layer in RNN
        # num_layers: number of layers in RNN
        # vocab_size: vocabulary size
        super(RNN, self).__init__()

        self.num_layers, self.hidden_size = num_layers, hidden_size
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=pad_idx)
        self.rnn = nn.RNN(emb_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, 1)

    def init_hidden(self, batch_size):
        # Function initializes the activation of recurrent neural net at timestep 0
        # Needs to be in format (num_layers, batch_size, hidden_size)
        hidden = torch.randn(self.num_layers, batch_size, self.hidden_size)

        return hidden

    def forward(self, x, lengths):
        # reset hidden state

        batch_size, seq_len = x.size()

        self.hidden = self.init_hidden(batch_size)

        # get embedding of characters
        embed = self.embedding(x)
        # pack padded sequence
        embed = torch.nn.utils.rnn.pack_padded_sequence(embed, lengths.numpy(), batch_first=True)
        # fprop though RNN
        rnn_out, self.hidden = self.rnn(embed, self.hidden)
        # undo packing
        rnn_out, _ = torch.nn.utils.rnn.pad_packed_sequence(rnn_out, batch_first=True)
        # sum hidden activations of RNN across time
        rnn_out = torch.sum(rnn_out, dim=1)

        logits = self.linear(rnn_out)
        return logits

class ScriptsDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, key):
        return (self.data[key], self.labels[key])
    
    def get_loader(self, batch_size = 32):
        return DataLoader(
            dataset = self,
            batch_size = batch_size,
            collate_fn = self.collate,
            shuffle = True
        )
    
    def collate(self, batch):
        data_list = []
        label_list = []
        length_list = []

        for datum in batch:
            label_list.append(datum[1])
            length_list.append(len(datum[0]))
            data_list.append(torch.tensor(datum[0]))
            
        data_list = pad_sequence(data_list, batch_first = True)
        sorted_length_list, sorted_idxs = torch.sort(torch.tensor(length_list), descending = True)
        data_list = data_list[sorted_idxs]
        label_list = torch.tensor(label_list)[sorted_idxs]
        
        return data_list, sorted_length_list, label_list

In [4]:
def test_model(loader, model):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    total = 0
    model.eval()
    for data, lengths, labels in tqdm(loader, desc = "Validation Batches", unit = "batch"):
        data_batch, lengths_batch, label_batch = data, lengths, labels
        predicted = model(data_batch, lengths_batch)

        total += labels.size(0)
        correct += torch.mean((predicted - labels) ** 2)
    return (correct / total)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(data, ratings, test_size=0.15, random_state=42)
X_train, X_test = X_train.reset_index(drop=True), X_test.reset_index(drop=True)

train_loader = ScriptsDataset(X_train, y_train).get_loader(batch_size = 5)
val_loader = ScriptsDataset(X_test, y_test).get_loader(batch_size = 5)

In [None]:
model = RNN(emb_size=100, hidden_size=200, num_layers=1, vocab_size=len(idx_token), pad_idx = token_idx[pad_token])

learning_rate = .1
num_epochs = 2 # number epoch to train

# Criterion and Optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)

In [None]:
losses = []

for epoch in tqdm(range(num_epochs), desc = "Training Epochs", unit = "epoch"):
    for i, (data, lengths, labels) in enumerate(tqdm(train_loader, desc = "Batches", unit = "batch")):
        model.train()
        optimizer.zero_grad()
        # Forward pass
        outputs = model(data, lengths)
        loss = criterion(outputs, labels)

        # Backward and optimize
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        if i > 0 and i % 100 == 0:
            # validate
            val_acc = test_model(val_loader, model)
            print('Epoch: [{}/{}], Step: [{}/{}], Average MSE: {}'.format(
                       epoch+1, num_epochs, i+1, len(train_loader), val_acc))

Training Epochs:   0%|          | 0/2 [00:00<?, ?epoch/s]
Batches:   0%|          | 0/704 [00:00<?, ?batch/s][A
Batches:   0%|          | 1/704 [00:05<1:09:29,  5.93s/batch][A
Batches:   0%|          | 2/704 [00:12<1:12:14,  6.17s/batch][A
Batches:   0%|          | 3/704 [00:17<1:07:34,  5.78s/batch][A
Batches:   1%|          | 4/704 [00:24<1:13:12,  6.27s/batch][A
Batches:   1%|          | 5/704 [00:29<1:06:01,  5.67s/batch][A
Batches:   1%|          | 6/704 [00:38<1:19:52,  6.87s/batch][A
Batches:   1%|          | 7/704 [00:42<1:09:38,  6.00s/batch][A
Batches:   1%|          | 8/704 [00:47<1:06:32,  5.74s/batch][A
Batches:   1%|▏         | 9/704 [00:52<1:01:49,  5.34s/batch][A
Batches:   1%|▏         | 10/704 [00:58<1:04:09,  5.55s/batch][A
Batches:   2%|▏         | 11/704 [01:02<59:28,  5.15s/batch]  [A
Batches:   2%|▏         | 12/704 [01:09<1:06:15,  5.74s/batch][A
Batches:   2%|▏         | 13/704 [01:14<1:03:18,  5.50s/batch][A
Batches:   2%|▏         | 14/704 [01:1

Epoch: [1/2], Step: [101/704], Average MSE: 140868.0



Batches:  14%|█▍        | 102/704 [11:44<2:47:04, 16.65s/batch][A
Batches:  15%|█▍        | 103/704 [11:49<2:12:09, 13.19s/batch][A
Batches:  15%|█▍        | 104/704 [11:54<1:45:52, 10.59s/batch][A
Batches:  15%|█▍        | 105/704 [11:58<1:28:09,  8.83s/batch][A
Batches:  15%|█▌        | 106/704 [12:11<1:39:22,  9.97s/batch][A
Batches:  15%|█▌        | 107/704 [12:17<1:27:24,  8.79s/batch][A
Batches:  15%|█▌        | 108/704 [12:22<1:15:55,  7.64s/batch][A
Batches:  15%|█▌        | 109/704 [12:27<1:06:32,  6.71s/batch][A
Batches:  16%|█▌        | 110/704 [12:37<1:16:38,  7.74s/batch][A
Batches:  16%|█▌        | 111/704 [12:45<1:19:04,  8.00s/batch][A
Batches:  16%|█▌        | 112/704 [12:52<1:15:51,  7.69s/batch][A
Batches:  16%|█▌        | 113/704 [13:01<1:19:15,  8.05s/batch][A
Batches:  16%|█▌        | 114/704 [13:14<1:32:13,  9.38s/batch][A
Batches:  16%|█▋        | 115/704 [13:19<1:19:32,  8.10s/batch][A
Batches:  16%|█▋        | 116/704 [13:23<1:07:28,  6.88s/batc