In [1]:
import json
import os
import re
import random
import time
import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.utils.data as data
from torch.nn.utils import clip_grad_norm_
from tensorboardX import SummaryWriter
import torch.utils
import torchtext
from torchtext.data import Example

import numpy as np
from PIL import Image
import h5py

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
SRC_SG_FILE = "/raid6/home/ramraj/SG-VQA/data/scene_graphs.json"
SRC_QA_FILE = "/raid6/home/ramraj/SG-VQA/data/question_answers.json"
SRC_IMG_DIR = "/raid6/home/ramraj/SG-VQA/data/VG_100K_ONE/"

IMG_FEATURES_FILE = "/raid6/home/ramraj/SG-VQA/intermediate-data/one_img_features.h5"
SRC_SPLIT_PATH = "/raid6/home/ramraj/SG-VQA/data/vg_split.json"

RESULTS_DIR = "./results_debug_dir/"
SAVE_DIR = "./saved_debug_dir/"

In [3]:
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(SAVE_DIR, exist_ok=True)

In [4]:
SOS_WORD = '<SOS>'
EOS_WORD = '<EOS>'
PAD_WORD = '<PAD>'

MAX_Q_LEN = 27
MAX_A_LEN = 24

# ========================= model
question_token_embedding_dim = 300
attn_model = 'concat' #attn_model = 'general' #attn_model = 'concat' 'dot'
hidden_size = 200
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1

# ========================= optimizer
learning_rate = 0.0001
decoder_learning_ratio = 5.0
step_size = 1
gamma = 0.001
clip = 50.0
teacher_forcing_ratio = 1.0

# ========================= training
batch_size=32
data_worker=1
shuffle = True
resume_epoch = 0 # saved_model = "./saved/...."
num_epochs = 1
save_step = 1
print_every = 50
save_every = 1

In [5]:
sg_data = json.load(open(SRC_SG_FILE, 'r'))
qas_data = json.load(open(SRC_QA_FILE, 'r'))
len(sg_data), len(qas_data)

(108077, 108077)

# Vocab`

In [7]:
r"""
generated
1. question fixed length (including EOS) = MAX_Q_LEN + 1
2. answer fixed length (no EOS) = MAX_A_LEN
"""

# 1. Define the fields
QUE_TEXT = torchtext.data.Field(sequential=True, 
                                tokenize=lambda x: x.split(),
                                # init_token=SOS_WORD,
                                eos_token=EOS_WORD,
                                pad_token=PAD_WORD,
                                include_lengths=True,
                                batch_first=True,
                                fix_length=MAX_Q_LEN + 1, # ['who', 'are', 'you', 'EOS'] ==> +1 for 'EOS'
                                lower=True) # todo: do i need init_token, eos_token, preprocessing ?

ANS_TEXT = torchtext.data.Field(sequential=True, 
                                tokenize=lambda x: x.split(),                                
                                # eos_token=EOS_WORD, # init_token=SOS_WORD,
                                pad_token=PAD_WORD,
                                batch_first=True,
                                fix_length=MAX_A_LEN, # ['SOS', i', 'am', 'a', 'pilot']
                                lower=True) # todo: do i need init_token, eos_token, preprocessing ?


FIELDS = [('ans_text', ANS_TEXT), ('que_text', QUE_TEXT)]

example_texts = []

for idx, qas_ins in enumerate(qas_data):

    for qa_idx, qa in enumerate(qas_ins['qas']):
        a = qa['answer']
        a = re.sub(r'[^\w\s]', ' ', a).lower().strip()        
        
        q = qa['question']
        q = re.sub(r'[^\w\s]', ' ', q).lower().strip()        
        
        example_texts.append( Example.fromlist([a, q] , FIELDS ) )

    if idx % 10000 == 0:
        print("Finished ", idx)
        
torchtext_dataset = torchtext.data.Dataset(example_texts, fields=FIELDS)

ANS_TEXT.build_vocab(torchtext_dataset, vectors='glove.6B.300d', vectors_cache='../cache')
QUE_TEXT.build_vocab(torchtext_dataset, vectors='glove.6B.300d', vectors_cache='../cache')

print("======= Question ======")
print(ANS_TEXT.vocab.vectors.size())

print("======= Answer ========")
print(QUE_TEXT.vocab.vectors.size())

PAD_token = QUE_TEXT.vocab.stoi[PAD_WORD]
SOS_token = QUE_TEXT.vocab.stoi[SOS_WORD]
EOS_token = QUE_TEXT.vocab.stoi[EOS_WORD]



Finished  0
Finished  10000
Finished  20000
Finished  30000
Finished  40000
Finished  50000
Finished  60000
Finished  70000
Finished  80000
Finished  90000
Finished  100000


../cache/glove.6B.zip: 862MB [06:30, 2.21MB/s]                               
100%|█████████▉| 399999/400000 [00:59<00:00, 6684.93it/s]


torch.Size([23335, 300])
torch.Size([18046, 300])


# Data Loading`

In [8]:
def decouple_q_and_a(qas_data, mode_id_list):
    structured_qas = {"qas_id": [], "qa_id": [], "questions": [], "answers": []}
    for qas_idx, qas in enumerate(qas_data):
        if not qas['id'] in mode_id_list:
            continue
        for qa in qas['qas']:
            structured_qas['qas_id'].append(qas['id'])
            structured_qas['qa_id'].append(qa['qa_id'])
            structured_qas['questions'].append(qa['question'])
            structured_qas['answers'].append(qa['answer'])
    
    return structured_qas


with open(SRC_SPLIT_PATH, 'r') as split_df:
    data_ids = json.load(split_df)
    train_data_ids = data_ids['train']
    valid_data_ids = data_ids['val']

train_structured_qas = decouple_q_and_a(qas_data, train_data_ids)
print("Finished loading train_structured_qas ...")
valid_structured_qas = decouple_q_and_a(qas_data, valid_data_ids)
print("Finished loading valid_structured_qas ...")

Finished loading train_structured_qas ...
Finished loading valid_structured_qas ...


In [9]:
PAD_token = QUE_TEXT.vocab.stoi[PAD_WORD]
SOS_token = QUE_TEXT.vocab.stoi[SOS_WORD]
EOS_token = QUE_TEXT.vocab.stoi[EOS_WORD]

def prepare_questions(questions_list):
    for q in questions_list:
        q = re.sub(r'[^\w\s]', ' ', q).lower().strip()
        yield q.split()

def prepare_answers(answers_list):
    for a in answers_list:
        a = re.sub(r'[^\w\s]', ' ', a).lower().strip()
        yield a.split()
        
class VQA(torch.utils.data.Dataset):
    """ VQA dataset, open-ended """
    def __init__(self,
                image_features_path,
                structured_qas,
                question_field,
                answer_field,
                data_mode="train"):
        super(VQA, self).__init__()

        print("Step1 : Data loading")
        print("....... {} mode number of data samples : {}".format(data_mode, len(structured_qas['qas_id'])))

        # ========================= Load Answer Vocab =================
        self.answers_vocab = answer_field.vocab
        self.answers = list(prepare_answers( structured_qas['answers'] ))
        self.answers = answer_field.pad(self.answers)
        # print(self.answers[:2])
        self.answers = answer_field.numericalize(self.answers)
        # print(self.answers[:2])
        print("....... VG Answer data have been PREPARED & ENCODED ...")

        # ========================= Load Question Vocab =================
        self.questions_vocab = question_field.vocab
        self.questions = list(prepare_questions( structured_qas['questions'] ))
        self.questions = question_field.pad(self.questions)
        # print(self.questions[0][:2])
        # print(self.questions[1][:2])
        self.questions = question_field.numericalize(self.questions)
        # print(self.questions[0][:2])
        # print(self.questions[1][:2])
        print("....... VG Question data have been PREPARED & ENCODED ...")
    
        self.image_features_path = image_features_path
        self.features_file = h5py.File(self.image_features_path, 'r') 
        self.vg_ids = structured_qas['qas_id']
        print("....... Done Data loading ...")
    
    def __len__(self):
        assert len(self.answers) != len(self.questions), "mismatched questions & answer sample size"
        return len(self.answers)
    
    def __getitem__(self, item):
        q = self.questions[0][item]
        q_length = self.questions[1][item]
        a = self.answers[item]
        
        img_id = self.vg_ids[item]
        v = self._load_image(img_id)

        mask = self._binaryMatrix(q)        
        mask = torch.BoolTensor(mask)
        
        return v, q, a, mask, q_length, item
    
    def _binaryMatrix(self, seq, value=PAD_token):
        m = []
        for token in seq:
            if token == PAD_token:
                m.append(0)
            else:
                m.append(1)
        return m

    def _encode_answers(self, answer):
        answer_return = [self.answers_vocab.stoi[SOS_WORD]] # SOS token
        for a in answer:
            answer_return.append(self.answers_vocab.stoi[a])
        answer_return.append(self.answers_vocab.stoi[EOS_WORD]) # EOS token
        return answer_return
    
    def _encode_questions(self, question):
        question_return = [self.questions_vocab.stoi[SOS_WORD]] # SOS token
        for q in question:
            question_return.append(self.questions_vocab.stoi[q])
        question_return.append(self.questions_vocab.stoi[EOS_WORD]) # EOS token
        return question_return, len(question) + 2
    
    @property
    def max_que_length(self):
        if not hasattr(self, '_ques_max_length'):
            self._ques_max_length = max(map(len, self.questions))
        return self._ques_max_length
    @property
    def max_ans_length(self):
        if not hasattr(self, '_ans_max_length'):
            self._ans_max_length = max(map(len, self.answers))
        return self._ans_max_length


    def _load_image(self, img_id):
        if not hasattr(self, 'features_file'):
            self.features_file = h5py.File(self.image_features_path, 'r')
        img = np.array(self.features_file.get( str(img_id)) )
        return torch.from_numpy(img).float()

In [10]:
train_vqa_dataset = VQA(IMG_FEATURES_FILE,
                        train_structured_qas,
                        QUE_TEXT,
                        ANS_TEXT,
                        data_mode='train')

valid_vqa_dataset = VQA(IMG_FEATURES_FILE,
                        train_structured_qas,
                        QUE_TEXT,
                        ANS_TEXT,
                        data_mode='val')

Step1 : Data loading
....... train mode number of data samples : 1156933
....... VG Answer data have been PREPARED & ENCODED ...
....... VG Question data have been PREPARED & ENCODED ...
....... Done Data loading ...
Step1 : Data loading
....... val mode number of data samples : 1156933
....... VG Answer data have been PREPARED & ENCODED ...
....... VG Question data have been PREPARED & ENCODED ...
....... Done Data loading ...


In [11]:
def collate_fn(batch):
    # put question lengths in descending order so that we can use packed sequences later
    batch.sort(key=lambda x: x[-2], reverse=True)
    return data.dataloader.default_collate(batch)



train_data_loader = torch.utils.data.DataLoader(dataset=train_vqa_dataset,
                                                batch_size=batch_size,
                                                shuffle=shuffle,
                                                collate_fn=collate_fn,
                                                drop_last=True,
                                                num_workers=data_worker)
valid_data_loader = torch.utils.data.DataLoader(dataset=valid_vqa_dataset,
                                                batch_size=batch_size,
                                                shuffle=shuffle,
                                                collate_fn=collate_fn,
                                                drop_last=True,
                                                num_workers=data_worker)

print(len(train_data_loader), len(valid_data_loader))
print(train_vqa_dataset.__len__(), valid_vqa_dataset.__len__())

n_train_batches = len(train_data_loader)
n_val_batches = len(valid_data_loader)

36154 36154
1156933 1156933


# Model`

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class EncoderRNN(nn.Module):
    def __init__(self, q_word_emb_dim, hidden_size, q_embedding, question_length, padding_idx, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        
        self.q_len = question_length
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = q_embedding
        self.gru = nn.GRU(q_word_emb_dim, hidden_size, n_layers,  # todo: what is input size ?
                          dropout=(0 if n_layers == 1 else dropout), 
                          bidirectional=True)
        self.padding_idx = padding_idx
        self.img_encoder_lin = nn.Linear(512, hidden_size, bias=True) # TODO: replace with constant

    def forward(self, img, input_seq, input_lengths, hidden=None): 
        
        embedded = self.embedding(input_seq)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True)        
        outputs, hidden = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir)        
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs,
                                                            batch_first=False, 
                                                            padding_value=self.padding_idx,
                                                            total_length=self.q_len)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs (1, batch, hidden)
        
        hidden = torch.mul(hidden, self.img_encoder_lin(img))
        return outputs, hidden


class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [13]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, q_word_emb_dim, attn_model, a_embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = a_embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(q_word_emb_dim, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        # Choose attention model
        if attn_model != 'none':
            self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_seq, last_hidden, encoder_outputs):
        # Note: we run this one step at a time

        # Get the embedding of the current input word (last output word)
        embedded = self.embedding(input_seq)
        embedded = self.embedding_dropout(embedded) #[1, 64, 512]
        if(embedded.size(0) != 1):
            raise ValueError('Decoder input sequence length should be 1')

        # Get current hidden state from input word and last hidden state
        rnn_output, hidden = self.gru(embedded, last_hidden)

        # Calculate attention from current RNN state and all encoder outputs;
        # apply to encoder outputs to get weighted average
        attn_weights = self.attn(rnn_output, encoder_outputs) #[64, 1, 14]
        # encoder_outputs [14, 64, 512]
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) #[64, 1, 512]

        # Attentional vector using the RNN hidden state and context vector
        # concatenated together (Luong eq. 5)
        rnn_output = rnn_output.squeeze(0) #[64, 512]
        context = context.squeeze(1) #[64, 512]
        concat_input = torch.cat((rnn_output, context), 1) #[64, 1024]
        concat_output = torch.tanh(self.concat(concat_input)) #[64, 512]

        # # Finally predict next token (Luong eq. 6, without softmax)
        # output = self.out(concat_output) #[64, output_size]

        # # Return final output, hidden state, and attention weights (for visualization)
        # return output, hidden, attn_weights

        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden

In [15]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

def format_time(elapsed_time):
    """
    Takes a time in seconds and returns a string hh:mm:ss
    """
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed_time)))
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [16]:
question_embedding = nn.Embedding.from_pretrained(QUE_TEXT.vocab.vectors, padding_idx=PAD_token, freeze=False)
answer_embedding = nn.Embedding.from_pretrained(ANS_TEXT.vocab.vectors, padding_idx=PAD_token, freeze=False)

encoder = EncoderRNN(q_word_emb_dim=question_token_embedding_dim,
                     hidden_size=hidden_size,
                     q_embedding=question_embedding,
                     question_length=MAX_Q_LEN + 1, # FIXED LEN OF QUESTION
                     padding_idx=PAD_token,
                     n_layers=encoder_n_layers,
                     dropout=dropout)

decoder = LuongAttnDecoderRNN(q_word_emb_dim=question_token_embedding_dim,
                              attn_model=attn_model,
                              a_embedding=answer_embedding,
                              hidden_size=hidden_size,
                              output_size=len(ANS_TEXT.vocab),
                              n_layers=decoder_n_layers,
                              dropout=dropout)

encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Models built and ready to go!


In [18]:
# lr_default = 1e-3 # if eval_loader is not None else 7e-4
# lr_decay_step = 2
# lr_decay_rate = .25
# lr_decay_epochs = range(10,20,lr_decay_step) # if eval_loader is not None else range(10,20,lr_decay_step)
# gradual_warmup_steps = [0.5 * lr_default, 1.0 * lr_default, 1.5 * lr_default, 2.0 * lr_default]
# # saving_epoch = 3
# grad_clip = .25

In [19]:
lr_default = 0.0001

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

print('Building optimizers ...') # Initialize optimizers
# encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate * decoder_learning_ratio) # remove '* decoder_learning_ratio'
# decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
encoder_optimizer = optim.Adamax(filter(lambda p: p.requires_grad, encoder.parameters()), lr=lr_default)
decoder_optimizer = optim.Adamax(filter(lambda p: p.requires_grad, decoder.parameters()), lr=lr_default)
    
# If you have cuda, configure cuda to call
for state in encoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

for state in decoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

Building optimizers ...


In [20]:
def train(v, q, a, mask, q_len,
          max_target_len, max_question_len,
          encoder, decoder,
          encoder_optimizer, decoder_optimizer, batch_size, clip):
    
    target_variable = a.view(max_target_len, -1)
    mask = mask.view(max_question_len, -1)

    # Zero gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    v, q, target_variable, mask = v.to(device), q.to(device), target_variable.to(device), mask.to(device)
    # Lengths for rnn packing should always be on the cpu
    q_len = q_len.to("cpu")

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    # Forward pass through encoder
    # print(q_len)
    encoder_outputs, encoder_hidden = encoder(v, q, q_len)

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

In [None]:
import logger

def instance_bce_with_logits(logits, labels, reduction='mean'):
    assert logits.dim() == 2

    loss = nn.functional.binary_cross_entropy_with_logits(logits, labels, reduction=reduction)
    if reduction == 'mean':
        loss *= labels.size(1)
    return loss


def trainIters(model_name, encoder, decoder, encoder_optimizer, decoder_optimizer, 
               encoder_n_layers, decoder_n_layers, save_dir, resume_epoch, num_epochs, batch_size, 
               print_every, save_every, clip, loadFilename):

    print('Initializing ...')
    start_iteration = 1
    print_loss = 0

    print("Training...")
    total_t0 = time.time()
    for epoch in range(resume_epoch, num_epochs):
        
        total_loss = 0
        train_score = 0
        total_norm = 0
        count_norm = 0
        
#         if epoch < len(gradual_warmup_steps):
#             optim.param_groups[0]['lr'] = gradual_warmup_steps[epoch]
#             logger.write('gradual warmup lr: %.4f' % optim.param_groups[0]['lr'])
#         elif epoch in lr_decay_epochs:
#             optim.param_groups[0]['lr'] *= lr_decay_rate
#             logger.write('decreased lr: %.4f' % optim.param_groups[0]['lr'])
#         else:
#             logger.write('lr: %.4f' % optim.param_groups[0]['lr'])
            
        
        running_loss = 0.0
        running_corr_exp1 = 0
        t0 = time.time()
        for batch_idx, batch in enumerate(train_data_loader):
            v, q, a, mask, q_len, _ = batch
        
            # Run a training iteration with batch
            loss = train(v, q, a, mask, q_len,
                         MAX_A_LEN, MAX_Q_LEN + 1, # FIXED Q & A LEN
                         encoder, decoder,
                         encoder_optimizer, decoder_optimizer, batch_size, clip)
            print_loss += loss
        

            # Print progress
            if batch_idx % print_every == 0:
                print_loss_avg = print_loss / print_every
                elapsed = format_time(time.time() - t0)
                # print("Batch: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(batch_idx, batch_idx / n_train_batches * 100, print_loss_avg))
                print('| TRAIN SET | Epoch [{:02d}/{:02d}], Step [{:04d}/{:04d}], Loss: {:.4f} | Elapsed: {:}'
                          .format(epoch+1, num_epochs, batch_idx, int(n_train_batches), print_loss_avg, elapsed))
                print_loss = 0
                
            # Log the loss and accuracy in an epoch.
            with open(os.path.join(SAVE_DIR, 'train-log-epoch.txt'), 'a') as f:
                f.write(str(epoch+1) + '\t'
                        + str(print_loss) + '\t'
                        + str(print_loss_avg))
                
            # break # TODO:

        # Save checkpoint
        if (epoch % save_every == 0):
            directory = os.path.join(save_dir, model_name, 'final-{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            # TODO: change above name 'remove final term'
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'epoch': epoch,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'question_embedding': question_embedding.state_dict(),
                'answer_embedding': answer_embedding.state_dict(),
            }, os.path.join(directory, '{}_{}.tar'.format(epoch, 'checkpoint')))
            
        training_epoch_time = format_time(time.time() - t0)
        print(training_epoch_time)
        
    print("Total training took {:} (h:mm:ss)".format(format_time(time.time() - total_t0)))
            


            
# Run training iterations
print("Starting Training!")

model_name = 'cb_model'
loadFilename = False # todo

trainIters(model_name, encoder, decoder, encoder_optimizer, decoder_optimizer,
           encoder_n_layers, decoder_n_layers, SAVE_DIR, resume_epoch, num_epochs, batch_size,
           print_every, save_every, clip, loadFilename)

Starting Training!
Initializing ...
Training...
| TRAIN SET | Epoch [01/01], Step [0000/36154], Loss: 0.1996 | Elapsed: 0:00:01
| TRAIN SET | Epoch [01/01], Step [0050/36154], Loss: 9.9691 | Elapsed: 0:00:13
| TRAIN SET | Epoch [01/01], Step [0100/36154], Loss: 9.9692 | Elapsed: 0:00:25
| TRAIN SET | Epoch [01/01], Step [0150/36154], Loss: 9.9696 | Elapsed: 0:00:38
| TRAIN SET | Epoch [01/01], Step [0200/36154], Loss: 9.9709 | Elapsed: 0:00:50
| TRAIN SET | Epoch [01/01], Step [0250/36154], Loss: 9.9709 | Elapsed: 0:01:03
| TRAIN SET | Epoch [01/01], Step [0300/36154], Loss: 9.9679 | Elapsed: 0:01:15
| TRAIN SET | Epoch [01/01], Step [0350/36154], Loss: 9.9677 | Elapsed: 0:01:27
| TRAIN SET | Epoch [01/01], Step [0400/36154], Loss: 9.9698 | Elapsed: 0:01:40
| TRAIN SET | Epoch [01/01], Step [0450/36154], Loss: 9.9679 | Elapsed: 0:01:52
| TRAIN SET | Epoch [01/01], Step [0500/36154], Loss: 9.9686 | Elapsed: 0:02:04
| TRAIN SET | Epoch [01/01], Step [0550/36154], Loss: 9.9713 | Elapsed: 

# Eval`

In [None]:
# def correct_tokens(pred, true_tokens, padding_idx):
#     pred = pred.view(-1)
#     # true_tokens = true_tokens[:, 1:].contiguous()
#     non_padding = true_tokens.view(-1).ne(padding_idx)
#     num_correct = pred.eq(true_tokens.view(-1)).masked_select(non_padding).sum().item()
#     num_non_padding = non_padding.sum().item()
#     return num_non_padding, num_correct


# def eval(v, q, a, mask, q_len,
#           max_target_len, max_question_len,
#           encoder, decoder, batch_size):

#     target_variable = a.view(max_target_len, -1)
#     mask = mask.view(max_question_len, -1)

#     # Set device options
#     v, q, target_variable, mask = v.to(device), q.to(device), target_variable.to(device), mask.to(device)
#     # Lengths for rnn packing should always be on the cpu
#     q_len = q_len.to("cpu")

#     # Initialize variables
#     loss = 0
#     print_losses = []
#     n_totals = 0
    
#     all_tokens = torch.zeros([0], device=device, dtype=torch.long) # TODO: # Initialize tensors to append decoded words to
#     all_scores = torch.zeros([0], device=device)

#     # Forward pass through encoder
#     encoder_outputs, encoder_hidden = encoder(v, q, q_len)

#     # Create initial decoder input (start with SOS tokens for each sentence)
#     decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
#     decoder_input = decoder_input.to(device)


#     # Set initial decoder hidden state to the encoder's final hidden state
#     decoder_hidden = encoder_hidden[:decoder.n_layers]

#     for t in range(max_target_len):
#         decoder_output, decoder_hidden = decoder(
#             decoder_input, decoder_hidden, encoder_outputs
#         )
#         # No teacher forcing: next input is decoder's own current output
#         # _, topi = decoder_output.topk(1)
#         # decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
#         # decoder_input = decoder_input.to(device)
        
#         # other code block
#         decoder_scores, decoder_input = torch.max(decoder_output, dim=1) # Obtain most likely word token and its softmax score
#         # print(decoder_scores.shape, decoder_input.shape)
#         decoder_input = torch.unsqueeze(decoder_input, 0) # Prepare current token to be next decoder input (add a dimension)
#         decoder_scores = torch.unsqueeze(decoder_scores, 0)
#         all_tokens = torch.cat((all_tokens, decoder_input), dim=0) # Record token and score
#         all_scores = torch.cat((all_scores, decoder_scores), dim=0)

#         # Calculate and accumulate loss
#         mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
#         loss += mask_loss
#         print_losses.append(mask_loss.item() * nTotal)
#         n_totals += nTotal
    
#     all_tokens = all_tokens.view(-1, max_target_len)
#     all_scores = all_scores.view(-1, max_target_len)
#     # print(all_tokens.shape, all_scores.shape)
    
#     # accuracy
#     num_non_padding, num_correct = correct_tokens(all_tokens, target_variable, PAD_token)
#     # acc = np.round(100 * (num_correct / num_non_padding), 2)
 
#     return sum(print_losses) / n_totals, all_tokens, all_scores, num_non_padding, num_correct


# from utils import MetricReporter
# mc = MetricReporter(verbose=False)
# mc.eval()
# t0 = time.time()
# with torch.no_grad():
#     for batch_idx, batch in enumerate(valid_data_loader):
#         v, q, a, mask, q_len, _ = batch

#         out_loss, out_all_tokens, out_all_scores, out_num_non_padding, out_num_correct = eval(v, q, a, mask, q_len,
#              MAX_A_LEN, MAX_Q_LEN + 1,
#              encoder, decoder,
#              batch_size)
#         mc.update_metrics(out_loss, out_num_non_padding, out_num_correct)
#         elapsed = format_time(time.time() - t0)

#         # break
#         mc.report_metrics()
#         print('| VAL SET | Step [{:04d}/{:04d}],  Loss: {:.4f}, Accuracy: {:.4f} | Elapsed: {:}'
#                               .format(batch_idx, int(n_val_batches), mc.list_valid_loss[-1], mc.list_valid_accuracy[-1], elapsed))