### Transformer Seq2Seq (Neural Signals to Bi-grams)

This notebook contains code for running a 2-word seq2seq Transformer where the neural signals are sent through the encoder while the corresponding bi-grams are sent through the decoder.

Set the seed for reproducibility. For more info read https://pytorch.org/docs/stable/notes/randomness.html and https://discuss.pytorch.org/t/random-seed-initialization/7854/18

In [1]:
import json
import math
import os
import random
import sys
import time
import warnings
from collections import Counter
from datetime import datetime
from pprint import pprint

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as data
from transformers import AdamW

from arg_parser import arg_parser
from build_matrices import (build_design_matrices_classification,
                            build_design_matrices_seq2seq)
from config import build_config
from dl_utils import Brain2enDataset, MyCollator
from models import PITOM, ConvNet10, MeNTAL, MeNTALmini
from train_eval import plot_training, train, valid
from eval_utils import evaluate_roc, evaluate_topk
from vocab_builder import get_sp_vocab, get_std_vocab, get_vocab

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset

In [3]:
results_folder = '20200531-ipynb'

In [4]:
args = arg_parser(['--subjects', '625'])

In [5]:
# Model objectives
MODEL_OBJ = {
    "ConvNet10": "classifier",
    "PITOM": "classifier",
    "MeNTALmini": "classifier",
    "MeNTAL": "seq2seq"
}

# GPUs
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args.gpus = min(args.gpus, torch.cuda.device_count())

# Fix random seed
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

args.model = args.model.split("_")[0]
classify = False if (args.model in MODEL_OBJ
                     and MODEL_OBJ[args.model] == "seq2seq") else True

In [6]:
CONFIG = build_config(args, results_folder)

Subject: 625
Training Data:: Number of Conversations is: 63
Validation Data:: Number of Conversations is: 13


In [None]:
word2freq, word_list, n_classes, vocab, i2w = get_std_vocab(
    CONFIG, comprehension=False, classify=classify)

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
import matplotlib.ticker as mtick

plt.rc('text', usetex=False)
plt.rc('font', family='serif')


def figure1(SAVE_DIR, word2freq):
    '''Plotting histogram of word frequency'''
    try:
        k = list(word2freq.values())
    except:
        k = word2freq
    k = list(filter((-1).__ne__, k))
    fig, ax = plt.subplots()
    ax.plot(range(len(k)), sorted(k))
    ax.xaxis.set_major_formatter(mtick.PercentFormatter(len(word2freq)))
    plt.title('Frequency of words (sorted)', fontsize=16)
    plt.xlabel('Percentage of Words', fontsize=16)
    plt.ylabel('Word Frequency', fontsize=16)
    plt.yscale('log')
    plt.grid(True, which='both')
    plt.savefig(os.path.join(SAVE_DIR, 'word_frequency.svg'))
    plt.show(fig)


def figure2(SAVE_DIR, word2freq):
    bins = [0, 5, 10, 20, 30, 40, 50, 100, 250, 500, 750, 1000, 5000]
    try:
        k = list(word2freq.values())
    except:
        k = word2freq
    k = list(filter((-1).__ne__, k))
    categories = pd.cut(k, bins)
    price_binned = pd.value_counts(categories).reindex(categories.categories)
    
    fig, ax = plt.subplots()
    ax.bar(range(0, len(bins)-1), price_binned, width=1, align='edge')
    plt.xticks(range(len(bins)), labels=bins, rotation='45')

    for i, v in enumerate(price_binned.values):
        ax.text(i + 0.25, v + 5, str(v), color='blue', fontweight='bold')
    
    plt.title('Distribution of Words Frequency', fontsize=16)
    plt.xlabel('Word Frequency', fontsize=14)
    plt.ylabel('Count', fontsize=14)
    plt.savefig(os.path.join(SAVE_DIR, 'word_frequency_dist.svg'))
    plt.show()

In [None]:
figure1(CONFIG["SAVE_DIR"], word2freq)
figure2(CONFIG["SAVE_DIR"], word2freq)

In [None]:
args = arg_parser(['--subjects', '625',
                   '--max-electrodes', '55',
                   '--vocab-min-freq', '10',
                   '--vocab-max-freq', '250',
                  '--epochs', '5'])
CONFIG = build_config(args, results_folder)
args.gpus = min(args.gpus, torch.cuda.device_count())

In [None]:
word2freq, word_list, n_classes, vocab, i2w = get_std_vocab(
    CONFIG, comprehension=False, classify=classify)

In [None]:
figure1(CONFIG["SAVE_DIR"], word2freq)
figure2(CONFIG["SAVE_DIR"], word2freq)

In [None]:
print("Loading training data")
x_train, y_train = build_design_matrices_seq2seq(
    'train', CONFIG, vocab, delimiter=" ", aug_shift_ms=[-1000, -500])

print("Loading validation data")
x_valid, y_valid = build_design_matrices_seq2seq(
    'valid', CONFIG, vocab, delimiter=" ", aug_shift_ms=[], remove_unks=False)

### Some insights about the bigrams in the training set

In [None]:
def seq_len_stats(x_train, x_valid):
    train_seq_lengths = [sample.shape[0] for sample in x_train]
    valid_seq_lengths = [sample.shape[0] for sample in x_valid]
    print("Training Seq Lengths::")
    print(f"\tMin: {min(train_seq_lengths)}") 
    print(f"\tMax: {max(train_seq_lengths)}")
    print(f"\tMean: {np.mean(train_seq_lengths):.2f}")
    print(f"\tMedian: {np.median(train_seq_lengths):.2f}")
    print(f"\tStd: {np.std(train_seq_lengths):.2f}")

    print("Validation Seq Lengths::")
    print(f"\tMin: {min(valid_seq_lengths)}") 
    print(f"\tMax: {max(valid_seq_lengths)}")
    print(f"\tMean: {np.mean(valid_seq_lengths):.2f}")
    print(f"\tMedian: {np.median(valid_seq_lengths):.2f}")
    print(f"\tStd: {np.std(valid_seq_lengths):.2f}")
    
    return train_seq_lengths, valid_seq_lengths
    
def figure4(SAVE_DIR, lengths, string):
    '''Plotting histogram of Training Signal Lengths'''
    fig, ax = plt.subplots()
    ax.plot(range(len(lengths)), sorted(lengths))
    ax.xaxis.set_major_formatter(mtick.PercentFormatter(len(lengths)))
    plt.title(string + ' set Seq lengths (sorted)', fontsize=16)
    plt.xlabel('Percentage of Samples', fontsize=14)
    plt.ylabel('Sequence Length', fontsize=14)
    plt.yscale('log')
    plt.grid(True, which='both')
    plt.savefig(os.path.join(SAVE_DIR, string + '_signal_len.svg'))
    plt.show(fig)
    

def figure5(SAVE_DIR, lengths, string):
    bins = [0, 25, 50, 75, 100, 250, 500, 1000, 2500, 5000, 7500, 10000]

    categories = pd.cut(lengths, bins)
    price_binned = pd.value_counts(categories).reindex(categories.categories)
    
    fig, ax = plt.subplots()
    ax.bar(range(0, len(bins)-1), price_binned, width=1, align='edge')
    plt.xticks(range(len(bins)), labels=bins)

    for i, v in enumerate(price_binned.values):
        ax.text(i + 0.25, v + 5, str(v), color='blue', fontweight='bold')
    
    plt.title(f'Distribution of Seq lengths ({string})', fontsize=14)
    plt.xlabel('Sequence Length', fontsize=14)
    plt.ylabel('Count', fontsize=14)
    plt.savefig(os.path.join(SAVE_DIR, string + '_signal_len_dist.svg'))
    plt.show()
    

def figure6(SAVE_DIR, lengths, string):
    plt.hist(lengths, bins=1000)
    plt.xlim([0, 100])
    plt.title(f'Distribution of Seq lengths ({string})', fontsize=14)
    plt.xlabel('Sequence Length', fontsize=14)
    plt.ylabel('Count', fontsize=14)
    plt.savefig(os.path.join(SAVE_DIR, string + '_signal_len_dist_zoom.svg'))
    plt.show()

In [None]:
train_seq_lengths, valid_seq_lengths = seq_len_stats(x_train, x_valid)
    
figure4(CONFIG["SAVE_DIR"], train_seq_lengths, 'Training')
figure4(CONFIG["SAVE_DIR"], valid_seq_lengths, 'Validation')

figure5(CONFIG["SAVE_DIR"], train_seq_lengths, 'Training')
figure5(CONFIG["SAVE_DIR"], valid_seq_lengths, 'Validation')

figure6(CONFIG["SAVE_DIR"], train_seq_lengths, 'Training')
figure6(CONFIG["SAVE_DIR"], valid_seq_lengths, 'Validation')

In [None]:
print("Loading training data")
x_train, y_train = build_design_matrices_seq2seq(
    'train', CONFIG, vocab, delimiter=" ", aug_shift_ms=[-1000, -500], max_num_bins=60)

# print("Loading validation data")
# x_valid, y_valid = build_design_matrices_seq2seq(
#     'valid', CONFIG, vocab, delimiter=" ", aug_shift_ms=[], max_num_bins=60, remove_unks=False)

print("Loading validation data")
x_valid, y_valid = build_design_matrices_seq2seq(
    'valid', CONFIG, vocab, delimiter=" ", aug_shift_ms=[], max_num_bins=60, remove_unks=True)

In [None]:
train_seq_lengths, valid_seq_lengths = seq_len_stats(x_train, x_valid)

In [None]:
def replace_words(data):
    df_y_train = pd.DataFrame(data)
    df_y_train[1].replace(i2w, inplace=True)
    df_y_train[2].replace(i2w, inplace=True)

    return df_y_train


def bigram_freq_excel(data, word2freq, i2w, filename, ref_data=None):
    valid_df = replace_words(data)
    valid_df = valid_df.groupby([1, 2]).size().reset_index(name='Count')
    valid_df['BF1'] = valid_df[1].replace(dict(valid_df[1].value_counts()))
    valid_df['BF2'] = valid_df[2].replace(dict(valid_df[2].value_counts()))
    valid_df['VF1'] = valid_df[1].replace(word2freq)
    valid_df['VF2'] = valid_df[2].replace(word2freq)

    if ref_data is not None:
        valid_df = valid_df.merge(ref_data, on=[1, 2], suffixes=('_valid', '_train'), how='left') 
        
    valid_df.to_excel(os.path.join(CONFIG["SAVE_DIR"], filename), index=False)
        
    print(len(valid_df[1].unique()))
    print(len(valid_df[2].unique()))

    print(set(word2freq.keys()) - set(valid_df[1].unique()))
    print(set(word2freq.keys()) - set(valid_df[2].unique()))
    
    return valid_df


raw_train_df = bigram_freq_excel(y_train, word2freq, i2w, "625_bi-gram-freq-train.xlsx")
_ = bigram_freq_excel(y_valid, word2freq, i2w, "625_bi-gram-freq-valid.xlsx", ref_data=raw_train_df)

In [None]:
# def figure6(SAVE_DIR, df, word2freq, string):
#     sorted_w2f = sorted(word2freq.items())
#     l = [a[1] for a in sorted_w2f if a[1] != -1]
#     plt.plot(df[1].value_counts().sort_index(), marker='.', markersize = 2.5, linewidth=0.25)
#     plt.plot(df[2].value_counts().sort_index(), marker='.', markersize = 2.5, linewidth=0.25)
#     plt.plot(l, marker='.', markersize = 2.5, linewidth=0.25, color='k')
#     plt.xticks(list(range(0, len(vocab), 50)), list(range(0, len(vocab), 50)))
#     plt.legend(['First word', 'Second Word', 'Actual'])
#     plt.xlabel('Word Index', fontsize=14)
#     plt.ylabel('Frequency', fontsize=14)
#     plt.yscale('log')
#     plt.title(f'Frequency of each word in the bigram ({string})', fontsize=14)
#     plt.savefig(os.path.join(SAVE_DIR, string + '_bigram-Freq.svg'))
#     plt.show()
    
# figure6(SAVE_DIR, train_df, word2freq, 'Training')
# figure6(SAVE_DIR, train_df, word2freq, 'Validation')

#### Converting train and validation data to Loader objects

In [None]:
class Brain2enDataset(Dataset):
    """Brainwave-to-English Dataset.
       Pytorch Dataset wrapper
    """
    def __init__(self, signals, labels):
        """
        Args:
            signals (list): brainwave examples.
            labels (list): english examples.
        """
        # global oov_token, vocab

        assert (len(signals) == len(labels))
        indices = [(i, len(signals[i]), len(labels[i]))
                   for i in range(len(signals))]
        indices.sort(key=lambda x: (x[1], x[2], x[0]))
        self.examples = []
        self.max_seq_len = 0
        self.max_sent_len = 0
        self.train_freq = Counter()
        c = 0
        for i in indices:
            if i[1] > 384 or i[2] < 4 or i[2] > 128:
                c += 1
                continue
            lab = labels[i[0]]
            self.train_freq.update(lab)
            lab = torch.tensor(lab).long()
            self.examples.append(
                (torch.from_numpy(signals[i[0]]).float(), lab))
            self.max_seq_len = max(self.max_seq_len, i[1])
            self.max_sent_len = max(self.max_sent_len, len(lab))
        print("Skipped", c, "examples")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

In [None]:
train_ds = Brain2enDataset(x_train, y_train)
print("Number of training signals: ", len(train_ds))
valid_ds = Brain2enDataset(x_valid, y_valid)
print("Number of validation signals: ", len(valid_ds))

In [None]:
class MyCollator(object):
    def __init__(self, CONFIG, vocabulary):
        self.CONFIG = CONFIG
        self.vocabulary = vocabulary
        self.pad_token = CONFIG["pad_token"]

    def __call__(self, batch):
        # do something with batch and self.params
        src = pad_sequence([batch[i][0] for i in range(len(batch))],
                           batch_first=True,
                           padding_value=0.)
        labels = pad_sequence([batch[i][1] for i in range(len(batch))],
                              batch_first=True,
                              padding_value=self.vocabulary[self.pad_token])
        trg = torch.zeros(labels.size(0), labels.size(1),
                          len(self.vocabulary)).scatter_(
                              2, labels.unsqueeze(-1), 1)
        trg, trg_y = trg[:, :-1, :], labels[:, 1:]
        pos_mask, pad_mask = self.masks(trg_y)
        return src, trg, trg_y, pos_mask, pad_mask

    def masks(self, labels):
        pos_mask = (torch.triu(torch.ones(labels.size(1),
                                          labels.size(1))) == 1).transpose(
                                              0, 1).unsqueeze(0)
        pos_mask = pos_mask.float().masked_fill(pos_mask == 0,
                                                float('-inf')).masked_fill(
                                                    pos_mask == 1, float(0.0))
        pad_mask = labels == self.vocabulary[self.pad_token]
        return pos_mask, pad_mask

In [None]:
batch = train_ds[:128]

In [None]:
CONFIG["pad_token"]

In [None]:
src = pad_sequence([batch[i][0] for i in range(len(batch))],
                           batch_first=True,
                           padding_value=0.)
labels = pad_sequence([batch[i][1] for i in range(len(batch))],
                              batch_first=True,
                              padding_value=vocab[CONFIG["pad_token"]])

In [None]:
trg = torch.zeros(labels.size(0), labels.size(1),
                  len(vocab)).scatter_(
                      2, labels.unsqueeze(-1), 1)

In [None]:
trg.shape

In [None]:
trg, trg_y = trg[:, :-1, :], labels[:, 1:]

In [None]:
trg_y

In [None]:
labels = trg_y

In [None]:
pos_mask = (torch.triu(torch.ones(labels.size(1),
                                  labels.size(1))) == 1).transpose(
                                      0, 1).unsqueeze(0)

In [None]:
pos_mask

In [None]:
pos_mask = pos_mask.float().masked_fill(pos_mask == 0,
                                        float('-inf')).masked_fill(
                                            pos_mask == 1, float(0.0))

In [None]:
pos_mask

In [None]:
pad_mask = labels == vocab[CONFIG["pad_token"]]

In [None]:
pad_mask

In [None]:
my_collator = MyCollator(CONFIG, vocab)
train_dl = data.DataLoader(train_ds,
                           batch_size=args.batch_size,
                           shuffle=True,
                           num_workers=CONFIG["num_cpus"],
                           collate_fn=my_collator)
valid_dl = data.DataLoader(valid_ds,
                           batch_size=args.batch_size,
                           num_workers=CONFIG["num_cpus"],
                           collate_fn=my_collator)

#### Creating a Model

In [None]:
DEFAULT_MODELS = {
    "ConvNet10": (len(vocab), ),
    "PITOM": (len(vocab), sum(args.max_electrodes)),
    "MeNTALmini":
    (sum(args.max_electrodes), len(vocab), args.tf_dmodel, args.tf_nhead,
     args.tf_nlayer, args.tf_dff, args.tf_dropout),
    "MeNTAL": (sum(args.max_electrodes), len(vocab), args.tf_dmodel,
               args.tf_nhead, args.tf_nlayer, args.tf_dff, args.tf_dropout)
}

# Create model
if args.init_model is None:
    if args.model in DEFAULT_MODELS:
        print("Building default model: %s" % args.model, end="")
        model_class = globals()[args.model]
        model = model_class(*(DEFAULT_MODELS[args.model]))
    else:
        print("Building custom model: %s" % args.model, end="")
        sys.exit(1)
else:
    model_name = "%s%s.pt" % (SAVE_DIR, args.model)
    if os.path.isfile(model_name):
        model = torch.load(model_name)
        model = model.module if hasattr(model, 'module') else model
        print("Loaded initial model: %s " % args.model)
    else:
        print("No models found in: ", SAVE_DIR)
        sys.exit(1)
print(" with %d trainable parameters" %
      sum([p.numel() for p in model.parameters() if p.requires_grad]))
sys.stdout.flush()

In [None]:
criterion = nn.CrossEntropyLoss()
step_size = int(math.ceil(len(train_ds) / args.batch_size))
optimizer = AdamW(model.parameters(),
                  lr=args.lr,
                  weight_decay=args.weight_decay)
scheduler = None

In [None]:
print("Training on %d GPU(s) with batch_size %d for %d epochs" %
      (args.gpus, args.batch_size, args.epochs))
print("=" * CONFIG["print_pad"])
sys.stdout.flush()

In [None]:
best_val_loss = float("inf")
best_model = model
history = {
    'train_loss': [],
    'train_acc': [],
    'valid_loss': [],
    'valid_acc': []
}
""" train_loss_compute = SimpleLossCompute(criterion,
                                       opt=optimizer,
                                       scheduler=scheduler)
valid_loss_compute = SimpleLossCompute(criterion, opt=None, scheduler=None)
"""
epoch = 0
model_name = "%s%s.pt" % (CONFIG["SAVE_DIR"], args.model)

In [None]:
CONFIG["SAVE_DIR"]

In [None]:
model_name

In [None]:
# training
data_iter = train_dl
device = DEVICE
opt = optimizer
scheduler = scheduler
seq2seq = True
pad_idx = vocab[CONFIG["pad_token"]]

In [None]:
batch = iter(data_iter)
batch = next(batch)

In [None]:
type(batch)

In [None]:
len(batch)

In [None]:
batch[0].shape

In [None]:
print('hello')

In [None]:
model.train()
start_time = time.time()
total_loss = 0.
total_acc = 0.
count, batch_count = 0, 0
CLIP_NORM = 1.0

In [None]:
# Prevent gradient accumulation
model.zero_grad()
src = batch[0].to(device)
trg = batch[1].long().to(device)

trg_y = batch[2].long().to(device)
trg_pos_mask, trg_pad_mask = batch[3].to(device), batch[4].to(
    device)

In [None]:
# Perform loss computation during forward pass for parallelism
out, trg_y, loss = model.forward(src, trg, trg_pos_mask,
                                 trg_pad_mask, trg_y, criterion)

In [None]:
out.shape

In [None]:
trg_y.shape

In [None]:
trg_y

In [None]:
idx = (trg_y != pad_idx).nonzero(as_tuple=True)

In [None]:
idx

In [None]:
total_loss += loss.data.item()

In [None]:
out1 = out[idx]
trg_y1 = trg_y[idx]

In [None]:
out.shape

In [None]:
out1.shape

In [None]:
trg_y1.shape

In [None]:
trg_y1

In [None]:
print("\nTraining on %d GPU(s) with batch_size %d for %d epochs" %
      (args.gpus, args.batch_size, args.epochs))
sys.stdout.flush()

best_val_loss = float("inf")
best_model = model
history = {
    'train_loss': [],
    'train_acc': [],
    'valid_loss': [],
    'valid_acc': []
}
""" train_loss_compute = SimpleLossCompute(criterion,
                                       opt=optimizer,
                                       scheduler=scheduler)
valid_loss_compute = SimpleLossCompute(criterion, opt=None, scheduler=None)
"""
epoch = 0
model_name = "%s%s.pt" % (CONFIG["SAVE_DIR"], args.model)
""" totalfreq = float(sum(train_ds.train_freq.values()))
print(
    sorted(((i2w[l], f / totalfreq)
            for l, f in train_ds.train_freq.most_common()),
           key=lambda x: -x[1]))
"""
# Run training and validation for args.epochs epochs
lr = args.lr
for epoch in range(1, args.epochs + 1):
    epoch_start_time = time.time()
    print(f'Epoch: {epoch:02}')
    print('\tTrain: ', end='')
    train_loss, train_acc = train(
        train_dl,
        model,
        criterion,
        list(range(args.gpus)),
        DEVICE,
        optimizer,
        scheduler=scheduler,
        seq2seq=not classify,
        pad_idx=vocab[CONFIG["pad_token"]] if not classify else -1)
    for param_group in optimizer.param_groups:
        if 'lr' in param_group:
            print(' | lr {:1.2E}'.format(param_group['lr']))
            break
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    print('\tValid: ', end='')
    with torch.no_grad():
        valid_loss, valid_acc = valid(
            valid_dl,
            model,
            criterion,
            DEVICE,
            temperature=args.temp,
            seq2seq=not classify,
            pad_idx=vocab[CONFIG["pad_token"]] if not classify else -1)
    history['valid_loss'].append(valid_loss)
    history['valid_acc'].append(valid_acc)

    # Store best model so far
    if valid_loss < best_val_loss:
        best_model, best_val_loss = model, valid_loss
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            model_to_save = best_model.module\
                if hasattr(best_model, 'module') else best_model
            torch.save(model_to_save, model_name)
        sys.stdout.flush()

#         # Additional Info when using cuda
#         if DEVICE.type == 'cuda':
#             print('Memory Usage:')
#             for i in range(args.gpus):
#                 max_alloc = round(
#                     torch.cuda.max_memory_allocated(i) / 1024**3, 1)
#                 cached = round(torch.cuda.memory_cached(i) / 1024**3, 1)
#                 print(f'GPU: {i} Allocated: {max_alloc}G Cached: {cached}G')

#         # if epoch > 10 and valid_loss > max(history['valid_loss'][-3:]):
#         #     lr /= 2.
#         #     for param_group in optimizer.param_groups:
#         #         param_group['lr'] = lr

In [None]:
# Plot loss,accuracy vs. time and save figures
plot_training(history, CONFIG["SAVE_DIR"], title="%s_lr%s" % (args.model, args.lr))

#### Post-processing

In [None]:
device = DEVICE
print("Evaluating predictions on test set")
# Load best model
model = torch.load(model_name)
if args.gpus:
    model.to(device)

softmax = nn.Softmax(dim=1)

In [None]:
vocab[CONFIG["begin_token"]]

In [None]:
len(vocab)

In [None]:
trg_y.size(0)

In [None]:
y.shape

In [None]:
train_bi_preds.shape

In [None]:
all_preds, categorical, all_labs = [], [], []

train_bi_preds = torch.zeros(len(train_ds), trg_y.shape[1], len(vocab)) 
valid_bi_preds = torch.zeros(len(valid_ds), trg_y.shape[1], len(vocab))

# Calculate all predictions on test set
with torch.no_grad():
    model.eval()
    
    for enum, batch in enumerate(valid_dl):
        
        src = batch[0].to(device) 
        trg_y = batch[2].long().to(device)
        trg_pos_mask= batch[3].to(device).squeeze() 
        trg_pad_mask = batch[4].to(device)
        
        memory = model.encode(src)
        y = torch.zeros(src.size(0), 1, len(vocab)).long().to(device)
        y[:, :, vocab[CONFIG["begin_token"]]] = 1

        bi_out = torch.zeros(len(batch[0]), trg_y.shape[1], len(vocab))
        for i in range(trg_y.size(1)):
            out = model.decode(memory, y,
                               trg_pos_mask[:y.size(1), :y.size(1)],
                               trg_pad_mask[:, :y.size(1)])[:, -1, :]
            out = softmax(out / args.temp)
            bi_out[:, i, :] = out
            temp = torch.zeros(src.size(0), len(vocab)).long().to(device)
            temp = temp.scatter_(1,
                                 torch.argmax(out, dim=1).unsqueeze(-1), 1)
            y = torch.cat([y, temp.unsqueeze(1)], dim=1)
        
        y = y[:, 1:, :]
        valid_bi_preds[enum*args.batch_size:(enum+1)*args.batch_size, :, :] = bi_out
            
        idx = (trg_y != vocab[CONFIG["pad_token"]]).nonzero(as_tuple=True)
        lab = trg_y[idx]
        cat = torch.zeros((lab.size(0), len(vocab)),
                          dtype=torch.long).to(lab.device)
        cat = cat.scatter_(1, lab.unsqueeze(-1), 1)
#         all_preds.extend(y[idx].cpu().numpy())
#         categorical.extend(cat.cpu().numpy())
#         all_labs.extend(lab.cpu().numpy())

# all_preds = np.array(all_preds)
# categorical = np.array(categorical)
# all_labs = np.array(all_labs)
# print("Calculated predictions")

# train_freq = train_ds.train_freq
# if CONFIG["vocabulary"] == 'spm':
#     i2w = {i: vocab.IdToPiece(i) for i in range(len(vocab))}
# markers = [
#     CONFIG["begin_token"], CONFIG["end_token"], CONFIG["oov_token"],
#     CONFIG["pad_token"]
# ]

In [None]:
valid_bi_preds

In [None]:
print(all_preds.shape)
print(categorical.shape)
print(all_labs.shape)

In [None]:
# Evaluate top-k
print("Evaluating top-k")
sys.stdout.flush()
res = evaluate_topk(all_preds,
                    all_labs,
                    i2w,
                    train_freq,
                    CONFIG["SAVE_DIR"],
                    suffix='-val',
                    min_train=args.vocab_min_freq,
                    tokens_to_remove=markers)

In [None]:
# Evaluate ROC-AUC
print("Evaluating ROC-AUC")
sys.stdout.flush()
res.update(
    evaluate_roc(all_preds,
                 categorical,
                 i2w,
                 train_freq,
                 CONFIG["SAVE_DIR"],
                 do_plot=not args.no_plot,
                 min_train=args.vocab_min_freq,
                 tokens_to_remove=markers))
pprint(res.items())
print("Saving results")
with open(CONFIG["SAVE_DIR"] + "results.json", "w") as fp:
    json.dump(res, fp, indent=4)