In [1]:
# %load_ext autoreload
# %autoreload 2

In [None]:
import statistics
import torch
import torch.nn as nn
from torch.utils import data

import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

## QUESTIONS:
* to(device)?
* Low loss, bad acc 

In [None]:
import data_loader
import evaluation

raw_dataset, tokens_vocab, y_vocab = data_loader.load_raw_data(S=1000)

In [None]:
idx = 1
print(raw_dataset['int_sentences'][idx])
print(raw_dataset['str_sentences'][idx])
print(raw_dataset['int_labels'][idx])
print(raw_dataset['str_labels'][idx])

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

device

In [None]:
torch.manual_seed(0)

class WSDModel(nn.Module):
    def __init__(self, V, Y, D=300):
        super(WSDModel, self).__init__()
        self.E_v = nn.Embedding(V, D)
        self.E_y = nn.Embedding(Y, D)
        
#         self.W_A = nn.Parameter(torch.randn(D, D, requires_grad=True, device=device))
#         self.W_O = nn.Parameter(torch.randn(D, D, requires_grad=True, device=device))
        
        self.W_A = nn.Parameter(torch.Tensor(D, D))
        self.W_O = nn.Parameter(torch.Tensor(D, D))
        
        self.linear_in = nn.Linear(D, D, bias=False)
        self.linear_out = nn.Linear(D * 2, D, bias=False)
        
        self.softmax = torch.nn.Softmax(dim=-1)
        self.tanh = nn.Tanh()
        
    def attention(self, X, Q):
        # X: [B, N, D]
        # Q: [B, 1, D]

        # A: [B, 1, N] 
        A = self.softmax(Q @ self.W_A @ X.transpose(1, 2))
        Q_c = A @ X @ self.W_O
        return Q_c, A
    
    def attention_2(self, X, Q):
        # copied from https://pytorchnlp.readthedocs.io/en/latest/_modules/torchnlp/nn/attention.html
        
        B, output_len, D = Q.size()
        N = X.size(1)

        Q = Q.reshape(B * output_len, D)
#         Q = Q @ self.W_A
        Q = self.linear_in(Q)
        Q = Q.reshape(B, output_len, D)

        A_logits = torch.bmm(Q, X.transpose(1, 2).contiguous())

        # Compute weights across every context sequence
        A_logits = A_logits.view(B * output_len, N)
        A = self.softmax(A_logits)
        A = A.view(B, output_len, N)

        # (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) ->
        # (batch_size, output_len, dimensions)
        mix = torch.bmm(A, X)

        # concat -> (batch_size * output_len, 2*dimensions)
        combined = torch.cat((mix, Q), dim=2)
        combined = combined.view(B * output_len, 2 * D)

        # Apply linear_out on every 2nd dimension of concat
        # output -> (batch_size, output_len, dimensions)
#         output = (combined @ self.W_O_2).view(B, output_len, D)
        output = self.linear_out(combined).view(B, output_len, D)
        output = self.tanh(output)

        return output, A
    
    def forward(self, M_s, v_q):
        # M_s: [B, N]
        # v_q: [B]
        
        X = self.E_v(M_s)
        
        # TODO: https://pytorch.org/docs/stable/torch.html#torch.gather
        Q_idxs = M_s[range(v_q.shape[0]), v_q]
        Q = self.E_v(Q_idxs).unsqueeze(1)

        Q_c, A = self.attention(X, Q)
#         Q_c, A = self.attention_2(X, Q)
        
        H = F.relu(Q_c + Q)
        y_logits = (H @ self.E_y.weight.T).squeeze()
        return y_logits, A.squeeze()

In [None]:
wsd_dataset = data_loader.WSDDataset(raw_dataset, tokens_vocab, y_vocab)
wsd_dataset

In [None]:
V = tokens_vocab.size()
Y = y_vocab.size()
model = WSDModel(V, Y, D=50).to(device)

In [None]:
torch.manual_seed(0)

B = 64

training_generator = data.DataLoader(
    wsd_dataset, 
    batch_size=B, 
    shuffle=True,
    num_workers=4
)

In [None]:
ce_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)

losses = []
train_acc = []
val_acc = []

for epoch in range(10):
    with tqdm(training_generator) as prg_train:
        for M_s, v_q, y_true in prg_train:
            M_s, v_q, y_true = M_s.to(device), v_q.to(device), y_true.to(device)

            ## SHAPES:
            # M_s     --> [B, N]
            # M_q     --> [B]
            # y_true  --> [B]
            
            optimizer.zero_grad()

            y_logits, _ = model(M_s, v_q)
#             print(y_logits.shape)
            loss = ce_loss(y_logits, y_true)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            running_mean_loss = statistics.mean(losses[-min(len(losses), 100):])
            status_str = f'[{epoch}] loss: {running_mean_loss:.3f}'
            prg_train.set_description(status_str)
        
        with torch.set_grad_enabled(False):
            cur_train_acc = evaluation.evaluate(model, training_generator)
            train_acc.append(cur_train_acc)

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(nrows=2, figsize=(15, 5))

axs[0].plot(losses, '-');
axs[1].plot(train_acc, '-o');

In [None]:
train_acc[-1]

In [None]:
import pandas as pd

pd.set_option('max_columns', 100)

g = data.DataLoader(
    wsd_dataset, 
    batch_size=5, 
    shuffle=True,
    num_workers=4
)

acc, eval_df, attention_df = evaluation.evaluate_verbose(model, g, tokens_vocab, y_vocab, iter_lim=4)

In [None]:
ev_styled, att_styled = evaluation.fancy_display(eval_df, attention_df)
ev_styled

In [None]:
att_styled