In [None]:
import torch
import data_loader
from traineval import train, evaluate
import model as model

import matplotlib.pyplot as plt

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Load data

The following line of code invokes data_loader and will automatically download and extract the dataset if needed.
It instantiates the following variables;
* tokens_vocab - the sentence words vocabulary
* y_vocab - the labels (senses) vocabulary
* datasets - a dictionary with train,dev, and test WSDDataset instances.

Use the optional sentence_count kwarg to limit the number of sentences loaded.

In [None]:
datasets, tokens_vocab, y_vocab = data_loader.load(['train', 'dev'], sentence_count=100)
datasets['train']

In [None]:
datasets['dev']

## Ex2A.1: Implement and train a basic attention model

In [None]:
dropout = 0.25
D = 300

m = model.WSDModel(
    tokens_vocab.size(), 
    y_vocab.size(), 
    D=D, 
    dropout_prob=dropout
).to(device)

In [None]:
lr = 8e-5
batch_size=100
num_epochs=10

optimizer = torch.optim.Adam(m.parameters(), lr=lr)

losses, train_acc, val_acc = train(
    m, optimizer, datasets['train'], datasets['dev'], num_epochs=num_epochs, batch_size=batch_size)

print(f"Validation accuracy: {val_acc[-1]:.3f}, Training accuracy:{train_acc[-1]:.3f}")

### Plot loss and train/val accuracy

You should be getting ~54% validation accuracy after 10 epochs.

In [None]:
fig, axs = plt.subplots(nrows=2, figsize=(15, 6))

axs[0].plot(losses, '-', label='Train Loss');
axs[0].legend()
axs[1].plot(train_acc, '-o', label='Train Acc');
axs[1].plot(val_acc, '-o', label='Val Acc');
axs[1].legend()

plt.tight_layout()

### Inspect attention

Invoke the attention highlight vizualization to get a feel of what attention is doing.

The query token is highlighted green, and the model's attention with pink-blue gradient.
In addition, the loss is given a red gradient.

In [None]:
from traineval import higlight_samples

higlight_samples(m, datasets['dev'], sample_size=5)

Notice how the model gives attention to the padded indices.

## Ex2A.2: Attending Padding

In [None]:
m = model.WSDModel(
    tokens_vocab.size(), 
    y_vocab.size(), 
    D=D, 
    dropout_prob=dropout
).to(device)

losses, train_acc, val_acc = train(
    m, optimizer, datasets['train'], datasets['dev'], num_epochs=num_epochs, batch_size=batch_size)

In [None]:
fig, axs = plt.subplots(nrows=2, figsize=(15, 6))

axs[0].plot(losses, '-', label='Train Loss');
axs[0].legend()
axs[1].plot(train_acc, '-o', label='Train Acc');
axs[1].plot(val_acc, '-o', label='Val Acc');
axs[1].legend()

plt.tight_layout()

In [None]:
len(y_vocab.index)

In [None]:
y_vocab.size()

In [None]:
higlight_samples(m, datasets['dev'], sample_size=5)

#### If you like, feel free to inspect more samples, using the api and pandas as demonstrated below.

In [None]:
import pandas as pd
import numpy as np
from traineval import evaluate_verbose, highlight

pd.set_option('max_columns', 100)

eval_df, attention_df = evaluate_verbose(m, datasets['dev'], iter_lim=100)

#### Show 5 correctly classified samples

In [None]:
idxs = np.where(eval_df['y_true'] != eval_df['y_pred'])
idxs = list(idxs[0][:5])
highlight(eval_df, attention_df, idxs)

#### Show samples of the query word 'left'

In [None]:
idxs = np.where(eval_df['query_token'] == 'left')
highlight(eval_df, attention_df, idxs)

## Ex2A.3: Self Attention

The method below converts the word level WSDDataset instances to sentence level dataset instances WSDSentencesDataset for self attention mode.

Notice how the number of samples now equals number of sentences.

In [None]:
sa_datasets = data_loader.WSDSentencesDataset.from_word_datasets(datasets)
sa_datasets['train']

### Imeplement and train

In [None]:
lr=2e-4
dropout = 0.2
D=300
batch_size=100
num_epochs=5

m = model.WSDModel(
    tokens_vocab.size(), 
    y_vocab.size(), 
    D=D, 
    dropout_prob=dropout
).to(device)

optimizer = torch.optim.Adam(m.parameters(), lr=lr)

losses, train_acc, val_acc = train(
    m, optimizer, sa_datasets['train'], sa_datasets['dev'], num_epochs=num_epochs, batch_size=batch_size)

In [None]:
fig, axs = plt.subplots(nrows=2, figsize=(15, 6))

axs[0].plot(losses, '-', label='Train Loss');
axs[0].legend()
axs[1].plot(train_acc, '-o', label='Train Acc');
axs[1].plot(val_acc, '-o', label='Val Acc');
axs[1].legend()

plt.tight_layout()

# Ex 2B: Position-Sensitive Attention

# Ex 2C: Causal Attention