In [1]:
from sklearn.model_selection import train_test_split
from pipeline import Pipeline
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transforms import pad_token
import numpy as np
from tqdm import tqdm_notebook, tqdm
import pickle
import random
from dataset import TBTTScriptsDataset
from model import GRU

In [2]:
data, token_idx, idx_token = Pipeline.load("10k_common").data
ratings = Pipeline.load("ratings").data

In [3]:
def test_model(loader, model):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    total = 0
    model.eval()
    for data, labels in tqdm_notebook(loader, desc = "Validation Batches", unit = "batch", leave = False):
        batch_size, seq_len = data.shape
        hidden = model.init_hidden(batch_size)
        
        data_truncated = torch.chunk(data, int(seq_len / T), dim = 1)
        for truncated_slice in data_truncated:
            predicted, hidden = model(truncated_slice, hidden)
            
        total += labels.size(0)
        correct += torch.mean((predicted - labels) ** 2)
    return (correct / total)

In [None]:
# with open("10k_common.pkl", "rb") as f:
#   data, token_idx, idx_token = pickle.load(f).data
#   data = data.apply(lambda x: x[:500])
  
# with open("rating.pkl", "rb") as f:
#   ratings = pickle.load(f)

data, token_idx, idx_token = Pipeline.load("10k_common").data
ratings = Pipeline.load("ratings").data

X_train, X_test, y_train, y_test = train_test_split(data, ratings, test_size=0.15, random_state=42)
X_train, X_test = X_train.reset_index(drop=True), X_test.reset_index(drop=True)

# train_loader = ScriptsDataset(X_train, y_train).get_loader(batch_size = 5)
# val_loader = ScriptsDataset(X_test, y_test).get_loader(batch_size = 10)
train_loader = TBTTScriptsDataset(X_train, y_train).get_loader()
val_loader = TBTTScriptsDataset(X_test, y_test).get_loader()

In [None]:
# model = RNN(emb_size = 128, hidden_size = 256, num_layers = 1, vocab_size = len(idx_token), pad_idx = token_idx[pad_token])
model = GRU(emb_size = 100, hidden_size = 128, num_layers = 1, vocab_size = len(idx_token), pad_idx = token_idx[pad_token])

learning_rate = .001
num_epochs = 2

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

total_step = len(train_loader)
T = 50

In [None]:
losses = []
interval_loss = []
for epoch in tqdm_notebook(range(num_epochs), desc = "Training Epochs", unit = "epoch"):
    for i, (data, labels) in enumerate(tqdm_notebook(train_loader, desc = "Batches", unit = "batch")):
        model.train()
        optimizer.zero_grad()
        
        batch_size, seq_len = data.shape
        hidden = model.init_hidden(batch_size)
        
        data_truncated = torch.chunk(data, int(seq_len / T), dim = 1)
        for truncated_slice in data_truncated:
            outputs, hidden = model(truncated_slice, hidden)

        loss = criterion(outputs, labels)

        # Backward and optimize
        loss.backward()
        
        enc_grads = torch.nn.utils.clip_grad_norm_(model.parameters(), 40)
        
        optimizer.step()
                
        losses.append(loss.item())
        interval_loss.append(loss.item())
        if i > 0 and i % 10 == 0:
            # validate
            val_acc = test_model(val_loader, model)
            avg_intval_loss = sum(interval_loss) / len(interval_loss)
            tqdm.write('Epoch: [{}/{}], Step: [{}/{}], Average MSE: {:.4f}, Avg Loss: {:.4f}'.format(
                       epoch+1, num_epochs, i+1, total_step, val_acc, avg_intval_loss))
            interval_loss = []

HBox(children=(IntProgress(value=0, description='Training Epochs', max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Batches', max=3519), HTML(value='')))