### Post Processing and Visualization
---

In [11]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from torchtext.data.metrics import bleu_score

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
SOS_TOKEN = '<sos>'
EOS_TOKEN = '<eos>'

### Load test performances
---

In [156]:
def reverse_vocab(vocab):
    return dict([(v, k) for k,v in vocab.items()])

def validate_vocab(d):
    print(f"one to one mapping: {bool(len(d) == len(set([v for _, v in d.items()])))}")
    print(len(d), len(set([v for _, v in d.items()])))
    
def set_util_maps(itos):
    itos[0] = '<unk>'
    itos[1] = '<pad>'
    itos[2] = '<sos>'
    itos[3] = '<eos>'
    return itos

def batched_data_to_list(src, tgt, pred_tgts):
    """ src, tgt, pred_tgts have format:
    [tensor(src_len, batch size), ...]
    
    return [[int0, int1, int2,...],...]
    """
    src_, tgt_, pred_tgts_ = [], [], []
    
    for batch in src:
        src_.extend([datum.tolist() for datum in batch.transpose(1,0)])
    
    for batch in tgt:
        tgt_.extend([datum.tolist() for datum in batch.transpose(1,0)])
        
    for batch in pred_tgts:
        pred_tgts_.extend([datum.tolist() for datum in batch.argmax(2).transpose(1,0)])
        
    return src_, tgt_, pred_tgts_


def ints_to_sentences(sent, vocab):
    """sent is a sequence of ints"""
    return [vocab[w] if w in vocab else "<unk>" for w in sent]

def clip_sentence(sent):
    """ sent: list[str] """
    # remove initial <sos>
    sent = sent[1:]
    
    # stop when first "<eos>" generated
    sent_ = []
    for w in sent:
        if w == '<eos>':
            return sent_
        sent_.append(w)
    return sent_

def postprocess(src, tgt, pred_tgts, src_itos, tgt_itos):
    src, tgt, pred_tgts = batched_data_to_list(src, tgt, pred_tgts)
    src = list(map(lambda sent: ints_to_sentences(sent, src_itos), src))
    tgt = list(map(lambda sent: ints_to_sentences(sent, tgt_itos), tgt))
    pred_tgts = list(map(lambda sent: ints_to_sentences(sent, tgt_itos), pred_tgts))
    
    src = list(map(clip_sentence, src))
    tgt = list(map(clip_sentence, tgt))
    pred_tgts = list(map(clip_sentence, pred_tgts))
    return src, tgt, pred_tgts

def bleu_score_wrapper(tgt, pred_tgts):
    """ convert format of tgt to appropriate format """
    tgt = [[t] for t in tgt]
    return bleu_score(pred_tgts, tgt)

def write_translations(sents, file):
    with open(file, 'w') as f:
        for sent in sents:
            f.write(" ".join(sent))

In [157]:
# Multi30k-gru**2-5
payload = torch.load("for-analysis/Multi30k-gru**2-5/payload.pt")
print(payload.keys())
pred_tgts = payload['pred_tgts']
src = payload['src']
tgt = payload['tgt']
src_vocab = payload['SRC_vocab']  # stoi
tgt_vocab = payload['TGT_vocab']  # stoi
src_itos = set_util_maps(reverse_vocab(src_vocab))
tgt_itos = set_util_maps(reverse_vocab(tgt_vocab))

# compute bleu score
src, tgt, pred_tgts = postprocess(src, tgt, pred_tgts, src_itos, tgt_itos)
write_translations(pred_tgts, "for-analysis/output_translations/gru**2.eng")
print(f"Bleu score: {bleu_score_wrapper(pred_tgts, tgt)*100}")

dict_keys(['pred_tgts', 'src', 'tgt', 'SRC_vocab', 'TGT_vocab'])
Bleu score: 2.373911067843437


In [158]:
# Multi30k-gcn_gru-1
payload = torch.load("for-analysis/Multi30k-gcn_gru-1/payload.pt")
print(payload.keys())
pred_tgts = payload['pred_tgts']
src = payload['src']
tgt = payload['tgt']
src_vocab = payload['SRC_vocab']  # stoi
tgt_vocab = payload['TGT_vocab']  # stoi
src_itos = set_util_maps(reverse_vocab(src_vocab))
tgt_itos = set_util_maps(reverse_vocab(tgt_vocab))

# compute bleu score
src, tgt, pred_tgts = postprocess(src, tgt, pred_tgts, src_itos, tgt_itos)
write_translations(pred_tgts, "for-analysis/output_translations/gcn_gru.eng")
print(f"Bleu score: {bleu_score_wrapper(pred_tgts, tgt)*100}")

dict_keys(['pred_tgts', 'src', 'tgt', 'SRC_vocab', 'TGT_vocab'])
Bleu score: 7.8253477811813354


In [159]:
# Multi30k-gcngru_gru-1
payload = torch.load("for-analysis/Multi30k-gcngru_gru-1/payload.pt")
print(payload.keys())
pred_tgts = payload['pred_tgts']
src = payload['src']
tgt = payload['tgt']
src_vocab = payload['SRC_vocab']  # stoi
tgt_vocab = payload['TGT_vocab']  # stoi
src_itos = set_util_maps(reverse_vocab(src_vocab))
tgt_itos = set_util_maps(reverse_vocab(tgt_vocab))

# compute bleu score
src, tgt, pred_tgts = postprocess(src, tgt, pred_tgts, src_itos, tgt_itos)
write_translations(pred_tgts, "for-analysis/output_translations/gcngru_gru.eng")
print(f"Bleu score: {bleu_score_wrapper(pred_tgts, tgt)*100}")

dict_keys(['pred_tgts', 'src', 'tgt', 'SRC_vocab', 'TGT_vocab'])
Bleu score: 22.964130342006683


In [160]:
# Multi30k-gcnattn_gru-1
payload = torch.load("for-analysis/Multi30k-gcnattn_gru-1/payload.pt")
print(payload.keys())
pred_tgts = payload['pred_tgts']
src = payload['src']
tgt = payload['tgt']
src_vocab = payload['SRC_vocab']  # stoi
tgt_vocab = payload['TGT_vocab']  # stoi
src_itos = set_util_maps(reverse_vocab(src_vocab))
tgt_itos = set_util_maps(reverse_vocab(tgt_vocab))

# compute bleu score
src, tgt, pred_tgts = postprocess(src, tgt, pred_tgts, src_itos, tgt_itos)
write_translations(pred_tgts, "for-analysis/output_translations/gcnattn_gru.eng")
print(f"Bleu score: {bleu_score_wrapper(pred_tgts, tgt)*100}")

dict_keys(['pred_tgts', 'attns', 'src', 'tgt', 'SRC_vocab', 'TGT_vocab'])
Bleu score: 16.638437879309034


In [161]:
# gru_attn**2
payload = torch.load("for-analysis/Multi30k-gru_attn**2-3/payload.pt")
print(payload.keys())
pred_tgts = payload['pred_tgts']
attns = payload['attns']
src = payload['src']
tgt = payload['tgt']
src_vocab = payload['SRC_vocab']  # stoi
tgt_vocab = payload['TGT_vocab']  # stoi
src_itos = set_util_maps(reverse_vocab(src_vocab))
tgt_itos = set_util_maps(reverse_vocab(tgt_vocab))

# compute bleu score
src, tgt, pred_tgts = postprocess(src, tgt, pred_tgts, src_itos, tgt_itos)
write_translations(pred_tgts, "for-analysis/output_translations/gru_attn.eng")
print(f"Bleu score: {bleu_score_wrapper(pred_tgts, tgt)*100}")

dict_keys(['pred_tgts', 'attns', 'src', 'tgt', 'SRC_vocab', 'TGT_vocab'])
Bleu score: 29.867154359817505


In [162]:
# Multi30k-gcngruattn_gru-1
payload = torch.load("for-analysis/Multi30k-gcngruattn_gru-1/payload.pt")
print(payload.keys())
pred_tgts = payload['pred_tgts']
attns = payload['attns']
src = payload['src']
tgt = payload['tgt']
src_vocab = payload['SRC_vocab']  # stoi
tgt_vocab = payload['TGT_vocab']  # stoi
src_itos = set_util_maps(reverse_vocab(src_vocab))
tgt_itos = set_util_maps(reverse_vocab(tgt_vocab))

# compute bleu score
src, tgt, pred_tgts = postprocess(src, tgt, pred_tgts, src_itos, tgt_itos)
write_translations(pred_tgts, "for-analysis/output_translations/gcngruattn_gru.eng")
print(f"Bleu score: {bleu_score_wrapper(pred_tgts, tgt)*100}")

dict_keys(['pred_tgts', 'attns', 'src', 'tgt', 'SRC_vocab', 'TGT_vocab'])
Bleu score: 29.064443707466125


In [163]:
write_translations(tgt, "for-analysis/output_translations/ref.eng")

#### Separately deal with transformer

In [129]:
# transformer
payload = torch.load("for-analysis/Multi30k-transformer-100/payload.pt", map_location=torch.device('cpu'))
print(payload.keys())
pred_tgts = payload['pred_tgts']
attns = payload['attns']
src = payload['src']
tgt = payload['tgt']
src_vocab = payload['SRC_vocab']
tgt_vocab = payload['TGT_vocab']

dict_keys(['pred_tgts', 'attns', 'src', 'tgt', 'SRC_vocab', 'TGT_vocab'])


In [131]:
print(f"Bleu score: {bleu_score(pred_tgts, tgt)*100}")

Bleu score: 36.49856947437553


We have 3 types of models:
* seq2seq -  don't need to worry about it because teacher forceing was turned off during testing
* seq2seq with attention - don't need to worry either...teacher forcing was turned off during testing
* transformer - worry about it.

Each would need a version of translate operation. So we need to load and run all of them to generate test performance.

In [1]:
import spacy, random, math, time, yaml, sys, os
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchtext.datasets import TranslationDataset, Multi30k, IWSLT
from torchtext.data import Field, BucketIterator, RawField, Dataset
from moduleloader import ModuleLoader
from models.gcn import GCNLayer
from src.utils import set_seed, tokenize_de, tokenize_en, batch_graph, \
get_sentence_lengths, counter2array, ensure_path_exist, \
print_status, learning_rate_decay, count_parameters
from src.early_stopping import EarlyStopping
from src.logging import Logger

#### Load Models

In [3]:
# transformer
# config_file = 'for-analysis/Multi30k-transformer-100/config.yaml'
# loader = ModuleLoader(config_file)
# config = loader.config
# train_data = loader.train_data
# valid_data = loader.valid_data
# test_data = loader.test_data
# train_iterator = loader.train_iterator
# valid_iterator = loader.valid_iterator
# test_iterator = loader.test_iterator
# SRC = loader.SRC
# TGT = loader.TGT
# GRH = loader.GRH
# SRC_PAD_IDX = loader.SRC_PAD_IDX
# TGT_PAD_IDX = loader.TGT_PAD_IDX
# enc = loader.enc
# dec = loader.dec
# model = loader.model
# train_epoch = loader.train_epoch
# evaluate = loader.evaluate
# criterion = loader.criterion
# device = loader.device