In [None]:
import torch
from pytorch_transformers import AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from tqdm import tqdm, trange, tqdm_notebook
import pandas as pd
import numpy as np

import sys
sys.path.append("../source/dataloaders/")
sys.path.append("../source/models/")

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

try:
    from tensorboardX import SummaryWriter
except ImportError:
    raise RuntimeError("No tensorboardX package is found. Please install with the command: \
                        git clone https://github.com/lanpa/tensorboardX && cd tensorboardX && python setup.py install")

In [None]:
import data_format_utils as dfu
from dataloaders import TrainValDataloader
from bert import BertForWSD


In [None]:

from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Loss, RunningAverage, Precision, Recall
from ignite.handlers import ModelCheckpoint, EarlyStopping
from ignite.contrib.handlers import ProgressBar

## Get and process data
#### Helper function to process dataset
#### Sample Dataset has been preprocessed in order to match context in corpus (Semcor3) with proper gloss in Wordnet

In [None]:
# Helper function to process dataset with proper tokens, and embeddings.

def gen_dataloader(_datapath,sample_size=100, batch_size=32, filter_bad_rows=True):
    _df = pd.read_csv(_datapath)
    
    _smpldf = _df
    if sample_size:
        _smpldf = _df.sample(sample_size)
    
    dfu.tokenize_and_index(_smpldf)
    dfu.gen_sentence_indexes(_smpldf)
    dfu.find_index_of_target_token(_smpldf)
    
    if filter_bad_rows: # rows where the target word index exceeds tensor size 
        _smpldf = _smpldf[_smpldf.target_token_idx.apply(lambda x: x[0] <  dfu.MAX_LEN)]

    _dl = TrainValDataloader(_smpldf,batch_size)
    return _dl

In [None]:
dl = gen_dataloader('../data/processed/sample_data.csv',sample_size=4000, batch_size=16)


## Load model

### Declare optimizer classes and loss criterion

In [None]:
model = BertForWSD() 

#optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']

optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.0}]

optimizer = AdamW(optimizer_grouped_parameters,
                  lr=2e-5)  # To reproduce BertAdam specific behavior set correct_bias=False


criterion = torch.nn.CrossEntropyLoss()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    torch.cuda.empty_cache()
    model.to(device)

### Declare processing functions

In [None]:
def process_function(engine, batch):
    model.train()
    optimizer.zero_grad()
    batch = (tens.to(device) for tens in batch)
    b_tokens_tensor, b_sentence_tensor, b_target_token_tensor, y = batch
    y_pred = model(b_tokens_tensor, b_sentence_tensor, b_target_token_tensor)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

def eval_function(engine, batch):
    model.eval()
    with torch.no_grad():
        batch = (tens.to(device) for tens in batch)
        b_tokens_tensor, b_sentence_tensor, b_target_token_tensor, y = batch
        y_pred = model(b_tokens_tensor, b_sentence_tensor, b_target_token_tensor)
        return y_pred, y

In [None]:

trainer = Engine(process_function)
train_evaluator = Engine(eval_function)
validation_evaluator = Engine(eval_function)



## Declare Metrics

In [None]:
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

In [None]:
def thresholded_output_transform(output):
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y

### Training Accuracy

In [None]:

Accuracy(output_transform=thresholded_output_transform).attach(train_evaluator, 'accuracy')
Loss(criterion).attach(train_evaluator, 'bce')

### Validation Metrics

In [None]:
Accuracy(output_transform=thresholded_output_transform).attach(validation_evaluator, 'accuracy')
Loss(criterion).attach(validation_evaluator, 'bce')

precision = Precision(output_transform=thresholded_output_transform,average=True)
recall = Recall(output_transform=thresholded_output_transform,average=True)


precision.attach(validation_evaluator, 'Precision')
recall.attach(validation_evaluator, 'Recall')
F1 = (precision * recall * 2 / (precision + recall))
F1.attach(validation_evaluator, 'F1')

In [None]:

pbar = ProgressBar(persist=True, bar_format="")
pbar.attach(trainer, ['loss'])
#pbar.attach(trainer, ['accuracy'])

In [None]:
def score_function(engine):
    val_loss = engine.state.metrics['bce']
    return -val_loss

handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)
validation_evaluator.add_event_handler(Events.COMPLETED, handler)

In [None]:
def create_summary_writer(model, data_loader, log_dir):
    writer = SummaryWriter(logdir=log_dir)
    data_loader_iter = iter(data_loader)
    batch = next(data_loader_iter)
    batch = tuple(b.to(device) for b in batch)[:-1]
    try:
        writer.add_graph(model, batch)
    except Exception as e:
        print("Failed to save model graph: {}".format(e))
    return writer

In [None]:
log_dir = './logs'
writer = create_summary_writer(model, dl.train_dataloader, log_dir)

## Declare Result logs

In [None]:

log_interval = 10

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    iter = (engine.state.iteration - 1) % len(dl.train_dataloader) + 1
    if iter % log_interval == 0:
        #print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
        #      "".format(engine.state.epoch, iter, len(dl.train_dataloader), engine.state.output))
        writer.add_scalar("training/loss", engine.state.output, engine.state.iteration)

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    train_evaluator.run(dl.train_dataloader)
    metrics = train_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_bce = metrics['bce']
    pbar.log_message(
        "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
        .format(engine.state.epoch, avg_accuracy, avg_bce))
    writer.add_scalar("training/avg_loss", avg_accuracy, engine.state.epoch)
    writer.add_scalar("training/avg_accuracy", avg_bce, engine.state.epoch)
    
def log_validation_results(engine):
    validation_evaluator.run(dl.val_dataloader)
    metrics = validation_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_bce = metrics['bce']
    avg_precision = metrics['Precision']
    avg_recall = metrics['Recall']
    avg_F1 = metrics['F1']
    pbar.log_message(
        "Validation Results - Epoch: {} Averages: Acc: {:.3f} Loss: {:.3f} Precision: {:.3f} Recall: {:.3f} F1: {:.3f}"
        .format(engine.state.epoch, avg_accuracy, avg_bce, avg_precision, avg_recall, avg_F1))
    pbar.n = pbar.last_print_n = 0
    writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch)
    writer.add_scalar("valdation/avg_loss", avg_bce, engine.state.epoch)
    writer.add_scalar("valdation/avg_F1", avg_F1, engine.state.epoch)
    writer.add_scalar("valdation/avg_precision", avg_precision, engine.state.epoch)
    writer.add_scalar("valdation/avg_recall", avg_recall, engine.state.epoch)

trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)

# Checkpoints

In [None]:
checkpointer = ModelCheckpoint('./model_checkpoints/models', 'bertWSD', save_interval=1, n_saved=2, 
                               create_dir=True, save_as_state_dict=True,require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'bertWSD': model})

## Train Model

In [None]:
trainer.run(dl.train_dataloader, max_epochs=3)
writer.close()

In [None]:
# Function to calculate the accuracy of our predictions vs labels using scikits learn
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

def accuracy_precision_recall_fscore(confusion_matrix):
    #TP,FP,FN,TN = confusion_matrix.ravel()
    TN, FP, FN, TP = confusion_matrix.ravel()
    accuracy = (TP+TN)/(TP+FP+FN+TN)
    precision = TP/(TP+FP) 
    recall = TP/(TP+FN) 
    F1 = 2*precision*recall/(precision+recall)
    return accuracy,precision,recall,F1