### Setting up a basic training loop - not using attention

In [1]:
''' Changing directories '''
import os 
if 'BERT' not in os.getcwd():
    os.chdir('BERT')
print("Current working dir is {}".format(os.getcwd()))

Current working dir is /juice/scr/scr110/scr/nlp/mtl_bert/unidirectional-NMT/BERT


In [2]:
import pyaml
import onmt
import math
import torch
from dataset import TextDataset, Collator
from encoder import Encoder 
from decoder import Decoder
from discriminator import Discriminator
from lib.huggingface.transformers import RobertaTokenizer, CamembertTokenizer
from torch.utils.data import DataLoader

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
with open(os.path.join(os.path.dirname(os.getcwd()), "config", "config.yml"), "r") as fd:
    config = pyaml.yaml.load(fd, Loader=pyaml.yaml.Loader)

In [5]:
sentence_token_length = config["maxlen"]

In [6]:
tokenizer_en = RobertaTokenizer.from_pretrained('roberta-base')
tokenizer_fr = CamembertTokenizer.from_pretrained('camembert-base')

In [7]:
collator = Collator(maxlen=sentence_token_length)

In [8]:
text_dataset_train = TextDataset("data/europarl-v7/", tokenizer_en, tokenizer_fr, training=True, maxlen=sentence_token_length)
text_dataset_val = TextDataset("data/europarl-v7/",  tokenizer_en, tokenizer_fr, training=False, maxlen=sentence_token_length)

0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
1704 examples with length < 2 removed.
42487 examples with length > 50 removed.
0
36 examples with length < 2 removed.
1037 examples with length > 50 removed.


In [9]:
train_dataloader = DataLoader(text_dataset_train, **config["data_loader"], collate_fn=collator)
val_dataloader = DataLoader(text_dataset_val, **config["data_loader"], collate_fn=collator)

#### Specifying the encoding and decoding models

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if (torch.cuda.is_available()):
    print("Using CUDA!")
else:
    print("Using CPU - Played yourself!")

Using CUDA!


In [11]:
try:
    del encoder_en
    del encoder_fr
except:
    pass 
encoder_en = Encoder("english").to(device=device)
encoder_fr = Encoder("french").to(device=device)

In [12]:
# same
word_padding_idx_en = encoder_en._modules['model'].embeddings.padding_idx
word_padding_idx_fr = encoder_fr._modules['model'].embeddings.padding_idx

# en > fr
word_vocab_size_en = encoder_en._modules['model'].embeddings.word_embeddings.num_embeddings
word_vocab_size_fr = encoder_fr._modules['model'].embeddings.word_embeddings.num_embeddings

# same
word_vec_size_en = encoder_en._modules['model'].embeddings.word_embeddings.embedding_dim
word_vec_size_fr = encoder_fr._modules['model'].embeddings.word_embeddings.embedding_dim

In [13]:
embeddings_en = onmt.modules.embeddings.Embeddings(
    word_vec_size_en, 
    word_vocab_size_en, 
    word_padding_idx_en, 
    position_encoding=True
).to(device=device)

embeddings_fr = onmt.modules.embeddings.Embeddings(
    word_vec_size_fr, 
    word_vocab_size_fr, 
    word_padding_idx_fr, 
    position_encoding=True
).to(device=device)

In [14]:
decoder_en = Decoder(**config["small_transformer"], embeddings=embeddings_en).to(device=device)
decoder_fr = Decoder(**config["small_transformer"], embeddings=embeddings_fr).to(device=device)

Defining loss functions

In [15]:
def loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict):
    '''Standard machine translation cross entropy loss'''
    ce_loss = torch.nn.CrossEntropyLoss(ignore_index = 1) #ignoring padding tokens
    
    predictions_fr = torch.argmax(french_predict, dim=2)
    
    loss_english_to_french = ce_loss(english_predict.transpose(1,2), english_gt)
    loss_french_to_english = ce_loss(french_predict.transpose(1,2), french_gt)
    return (loss_english_to_french + loss_french_to_english, loss_english_to_french, loss_french_to_english)

In [16]:
def loss_fn_single_regularization(english_gt, french_gt, english_predict, french_predict,
                                discriminator_gt, discriminator_predict):
    '''Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs'''
    ce_term, loss_english_to_french, loss_french_to_english = loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict)
    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    regularizing_term = bce_loss(discriminator_predict, discriminator_gt)
    
    return (ce_term + regularizing_term, loss_english_to_french, loss_french_to_english, regularizing_term)


In [17]:
def loss_fn_multi_regularization(english_gt, french_gt, english_predict, french_predict,
                                discriminator_gt_1, discriminator_predict_1,
                                discriminator_gt_2, discriminator_predict_2,):
    '''Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs'''
    ce_term = loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict)
    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    regularizing_term_1 = bce_loss(discriminator_predict_1, discriminator_gt_1)
    regularizing_term_2 = bce_loss(discriminator_predict_2, discriminator_gt_2)
    
    return ce_term + regularizing_term_1 + regularizing_term_2

Importing and defining evaluation functions

In [18]:
def exact_match(prediction, gt):
    '''Evaluate ground percent exact match '''
    mask = gt != 1
    return torch.sum((prediction == gt) * mask).item()/torch.sum(mask).item()

In [19]:
from nltk.translate.bleu_score import sentence_bleu

Defining optimizer and hooks if required

In [20]:
def get_optimizer(regularize="None"):
    params = list(encoder_en.parameters()) + list(encoder_fr.parameters()) +\
         list(decoder_fr.parameters()) + list(decoder_en.parameters())

    if (regularize == "hidden_state"):
        params += list(discriminator.parameters())
    
    if (regularize == "attention"):
        params += list(attn_discriminator.parameters())

    return torch.optim.Adam(params)    

In [21]:
optimizer = get_optimizer(regularize="None")

Defining primary training loop

In [22]:
from torch.utils.tensorboard import SummaryWriter
from utils import write_to_tensorboard

In [23]:
def train(train_data_iter, val_data_iter, regularize="None"): 
    ''' 
    Train the encoding and decoding models. User needs to pass in a valid iterator over the data,
    and also specify a type of adversarial regularization. regularize = ["hidden_state", "attention"]
    '''
    writer = SummaryWriter("runs/without_regularization")
                                                   
    for batch_num, batch in enumerate(train_data_iter):
        optimizer.zero_grad()
        
        # TODO: generalize this to multiple discriminators
        if (regularize != "None"):
            if (batch_num % 2 == 0):
                for param in encoder_en.parameters():
                    param.requires_grad = False
                for param in encoder_fr.parameters():
                    param.requires_grad = False    
                for param in discriminator.parameters():
                    param.requires_grad = True 
            else:
                for param in encoder_en.parameters():
                    param.requires_grad = True
                for param in encoder_fr.parameters():
                    param.requires_grad = True    
                for param in discriminator.parameters():
                    param.requires_grad = False 

        # Reading in input and moving to device

        english_batch, french_batch = batch 
        (english_sentences, english_sentences_no_eos, english_sentences_lengths) = english_batch
        (french_sentences, french_sentences_no_eos, french_sentences_lengths) = french_batch

        english_sentences = english_sentences.to(device=device)
        english_sentences_no_eos = english_sentences_no_eos.to(device=device)
        english_sentences_lengths = english_sentences_lengths.to(device=device)
        french_sentences = french_sentences.to(device=device)
        french_sentences_no_eos = french_sentences_no_eos.to(device=device)
        french_sentences_lengths = french_sentences_lengths.to(device=device)

        # Encoding - Decoding for English -> French
        encoder_outputs_en = encoder_en(english_sentences)
        decoder_fr.init_state(english_sentences.unsqueeze(2).transpose(0,1), None, None)
        dec_outs_fr, _ = decoder_fr(french_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_en[0].transpose(0,1), memory_lengths=english_sentences_lengths)

        # Encoding - Decoding for French -> English
        encoder_outputs_fr = encoder_fr(french_sentences)
        decoder_en.init_state(french_sentences.unsqueeze(2).transpose(0,1), None, None) 
        dec_outs_en, _ = decoder_en(english_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_fr[0].transpose(0,1), memory_lengths=french_sentences_lengths)

        if (regularize == "attention"):
            # extracting the attention scores from the datasets; using 7th attention head
            # as suggested by Clark et al, 2019 
            english_attention = extract_attention_scores(_hooks_english)[6] 
            french_attention = extract_attention_scores(_hooks_french)[6]
            batch_size = english_attention.shape[0]
            english_attention_reshaped = english_attention.view(batch_size, -1)
            french_attention_reshaped = french_attention.view(batch_size, -1)
            
            discriminator_outputs_en = attn_discriminator(english_attention_reshaped)
            discriminator_outputs_fr = attn_discriminator(french_attention_reshaped)
            discriminator_outputs_cat = torch.cat((discriminator_outputs_en, discriminator_outputs_fr))
            if (batch_num % 2 == 0):
                discriminator_labels = torch.tensor([1.0]*batch_size + [0.0]*batch_size)
            else:
                discriminator_labels = torch.tensor([0.0]*batch_size + [1.0]*batch_size)
            discriminator_labels = discriminator_labels.unsqueeze(1).to(device=device)

            
            all_losses = loss_fn_single_regularization(english_sentences[:, 1:],
                                               french_sentences[:, 1:],
                                               dec_outs_en,
                                               dec_outs_fr,
                                               discriminator_labels,
                                               discriminator_outputs_cat,
                                              )
            
            (loss, loss_english_to_french, loss_french_to_english, regularizing_term) = all_losses
            
        elif (regularize == "hidden_state"):
            # using the pooled outputs of the encoders for regularizing 
            discriminator_outputs_en = discriminator(encoder_outputs_en[1])
            discriminator_outputs_fr = discriminator(encoder_outputs_fr[1])
            discriminator_outputs_cat = torch.cat((discriminator_outputs_en, discriminator_outputs_fr))
            batch_size = discriminator_outputs_en.shape[0]
            
            if (batch_num % 2 == 0):
                discriminator_labels = torch.tensor([1.0]*batch_size + [0.0]*batch_size)
            else:
                discriminator_labels = torch.tensor([0.0]*batch_size + [1.0]*batch_size)
            discriminator_labels = discriminator_labels.unsqueeze(1).to(device=device)

            all_losses = loss_fn_single_regularization(english_sentences[:, 1:],
                                               french_sentences[:, 1:],
                                               dec_outs_en,
                                               dec_outs_fr,
                                               discriminator_labels,
                                               discriminator_outputs_cat,
                                              )
            (loss, loss_english_to_french, loss_french_to_english, regularizing_term) = all_losses
        else:
            all_losses = loss_fn_no_regularization(english_sentences[:, 1:],
                                           french_sentences[:, 1:],
                                           dec_outs_en,
                                           dec_outs_fr,
                                          )
            (loss, loss_english_to_french, loss_french_to_english) = all_losses
            
        # must be put here to avoid name claim by val loop
        if (batch_num % 10 == 0):
            print("Batch num {}: Loss {}".format(batch_num, loss.item()))
        loss.backward()
        optimizer.step()
                                                   
        write_to_tensorboard("CCE", {"en-fr": loss_english_to_french.item(), "fr-en":loss_french_to_english.item()}, training=True, step=batch_num, writer=writer)
    
        # Running validation script  
        if (batch_num > 0 and batch_num % 100 == 0):
            with torch.no_grad():
                _blue_scores_en_fr = []
                _exact_matches_en_fr = []
                _blue_scores_fr_en = []
                _exact_matches_fr_en = []
                for batch_num, batch in enumerate(val_data_iter):
                    
                    if (batch_num == 25):
                        break
                    
                    # Reading in input and moving to device
                    english_batch, french_batch = batch 
                    (english_sentences, english_sentences_no_eos, english_sentences_lengths) = english_batch
                    (french_sentences, french_sentences_no_eos, french_sentences_lengths) = french_batch

                    english_sentences = english_sentences.to(device=device)
                    english_sentences_no_eos = english_sentences_no_eos.to(device=device)
                    english_sentences_lengths = english_sentences_lengths.to(device=device)
                    french_sentences = french_sentences.to(device=device)
                    french_sentences_no_eos = french_sentences_no_eos.to(device=device)
                    french_sentences_lengths = french_sentences_lengths.to(device=device)

                    # Encoding - Decoding for English -> French
                    encoder_outputs_en = encoder_en(english_sentences)
                    decoder_fr.init_state(english_sentences.unsqueeze(2).transpose(0,1), None, None)
                    dec_outs_fr, _ = decoder_fr(french_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_en[0].transpose(0,1), memory_lengths=english_sentences_lengths)
                    
                    # Encoding - Decoding for French -> English
                    encoder_outputs_fr = encoder_fr(french_sentences)
                    decoder_en.init_state(french_sentences.unsqueeze(2).transpose(0,1), None, None) 
                    dec_outs_en, _ = decoder_en(english_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_fr[0].transpose(0,1), memory_lengths=french_sentences_lengths)

                    
                    # Calculate BLUE Scores, EM and Perplexity
                    predictions_fr = torch.argmax(dec_outs_fr, dim=2)
                    predictions_en = torch.argmax(dec_outs_en, dim=2)
                    
                    for idx in range(french_sentences.shape[0]):
                        detokenized_french_gt = tokenizer_fr.convert_tokens_to_string(french_sentences[idx,1:].tolist())
                        detokenized_french_pred = tokenizer_fr.convert_tokens_to_string(predictions_fr[idx].tolist())
                        
                        _blue_score_en_fr = sentence_bleu(detokenized_french_gt, detokenized_french_pred)
                        _blue_scores_en_fr.append(_blue_score_en_fr)
                        
                    for idx in range(english_sentences.shape[0]):
           
                        detokenized_english_gt = tokenizer_en.decode(english_sentences[idx,1:].tolist())
                        detokenized_english_pred = tokenizer_en.decode(predictions_en[idx].tolist())
                            
                        _blue_score_fr_en = sentence_bleu(detokenized_english_gt, detokenized_english_pred)
                        _blue_scores_fr_en.append(_blue_score_fr_en)
                        
                    _exact_match_en_fr = exact_match(predictions_fr, french_sentences[:,1:])
                    _exact_matches_en_fr.append(_exact_match_en_fr)
                    
                    _exact_match_fr_en = exact_match(predictions_en, english_sentences[:,1:])
                    _exact_matches_fr_en.append(_exact_match_fr_en)
                    
                                                   
                avg_bleu_en_fr = sum(_blue_scores_en_fr)/len(_blue_scores_en_fr)
                avg_bleu_fr_en = sum(_blue_scores_fr_en)/len(_blue_scores_fr_en)
                avg_em_en_fr = sum(_exact_matches_en_fr)/len(_exact_matches_en_fr)
                avg_em_fr_en = sum(_exact_matches_fr_en)/len(_exact_matches_fr_en)
                                                   
                write_to_tensorboard("BLEU", {"en-fr": avg_bleu_en_fr, "fr-en":avg_bleu_fr_en}, training=False, step=batch_num, writer=writer)
                write_to_tensorboard("EM", {"en-fr": avg_em_en_fr, "fr-en":avg_em_fr_en}, training=False, step=batch_num, writer=writer)
            

In [None]:
train(train_dataloader, val_dataloader, regularize="None")

Batch num 0: Loss 1538.1806640625
Batch num 10: Loss 290.81146240234375
Batch num 20: Loss 165.44488525390625
Batch num 30: Loss 149.76431274414062
Batch num 40: Loss 133.50181579589844
Batch num 50: Loss 95.65170288085938
Batch num 60: Loss 83.58039855957031
Batch num 70: Loss 76.00181579589844
Batch num 80: Loss 66.68729400634766
Batch num 90: Loss 68.71720123291016
Batch num 100: Loss 58.46533203125


The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


Batch num 110: Loss 61.50366973876953
Batch num 120: Loss 42.91539001464844
Batch num 130: Loss 45.67382049560547
Batch num 140: Loss 43.91678237915039
Batch num 150: Loss 48.81451416015625
Batch num 160: Loss 48.088043212890625
Batch num 170: Loss 43.744911193847656
Batch num 180: Loss 42.92486572265625
Batch num 190: Loss 33.52617645263672
Batch num 200: Loss 36.62435531616211
Batch num 210: Loss 32.46125411987305
Batch num 220: Loss 39.31957244873047
Batch num 230: Loss 28.299510955810547
Batch num 240: Loss 30.110687255859375
Batch num 250: Loss 38.564945220947266
Batch num 260: Loss 39.400508880615234
Batch num 270: Loss 32.4063720703125
Batch num 280: Loss 30.811477661132812
Batch num 290: Loss 27.167308807373047
Batch num 300: Loss 33.816566467285156
Batch num 310: Loss 26.51189422607422
Batch num 320: Loss 27.344032287597656
Batch num 330: Loss 31.231231689453125
Batch num 340: Loss 22.12495994567871
Batch num 350: Loss 26.32811737060547
Batch num 360: Loss 24.62741470336914
Ba

Batch num 2210: Loss 20.53413200378418
Batch num 2220: Loss 15.628986358642578
Batch num 2230: Loss 16.13553810119629
Batch num 2240: Loss 18.90988540649414
Batch num 2250: Loss 16.247264862060547
Batch num 2260: Loss 14.947830200195312
Batch num 2270: Loss 16.749595642089844
Batch num 2280: Loss 15.873306274414062
Batch num 2290: Loss 19.189470291137695
Batch num 2300: Loss 16.1698055267334
Batch num 2310: Loss 15.115942001342773
Batch num 2320: Loss 15.474618911743164
Batch num 2330: Loss 17.020959854125977
Batch num 2340: Loss 16.400415420532227
Batch num 2350: Loss 14.388948440551758
Batch num 2360: Loss 14.64815902709961
Batch num 2370: Loss 16.385149002075195
Batch num 2380: Loss 15.149892807006836
Batch num 2390: Loss 18.040380477905273
Batch num 2400: Loss 15.09393310546875
Batch num 2410: Loss 19.349336624145508
Batch num 2420: Loss 17.89707374572754
Batch num 2430: Loss 14.088546752929688
Batch num 2440: Loss 16.871841430664062
Batch num 2450: Loss 14.468758583068848
Batch nu

Batch num 4280: Loss 14.837448120117188
Batch num 4290: Loss 14.07795238494873
Batch num 4300: Loss 13.6301908493042
Batch num 4310: Loss 11.659866333007812
Batch num 4320: Loss 12.39797306060791
Batch num 4330: Loss 12.003469467163086
Batch num 4340: Loss 13.75466537475586
Batch num 4350: Loss 12.183684349060059
Batch num 4360: Loss 12.728195190429688
Batch num 4370: Loss 12.197675704956055
Batch num 4380: Loss 13.98775863647461
Batch num 4390: Loss 12.355978012084961
Batch num 4400: Loss 12.436922073364258
Batch num 4410: Loss 12.443350791931152
Batch num 4420: Loss 12.746835708618164
Batch num 4430: Loss 11.20602798461914
Batch num 4440: Loss 13.297197341918945
Batch num 4450: Loss 13.128600120544434
Batch num 4460: Loss 12.414178848266602
Batch num 4470: Loss 14.248371124267578
Batch num 4480: Loss 13.470954895019531
Batch num 4490: Loss 15.86652946472168
Batch num 4500: Loss 15.757064819335938
Batch num 4510: Loss 11.87980842590332
Batch num 4520: Loss 14.84808349609375
Batch num 

Batch num 6340: Loss 12.538497924804688
Batch num 6350: Loss 11.171161651611328
Batch num 6360: Loss 11.874031066894531
Batch num 6370: Loss 10.31435775756836
Batch num 6380: Loss 12.271791458129883
Batch num 6390: Loss 12.640090942382812
Batch num 6400: Loss 11.508548736572266
Batch num 6410: Loss 12.406991958618164
Batch num 6420: Loss 12.513989448547363
Batch num 6430: Loss 11.909392356872559
Batch num 6440: Loss 11.9193754196167
Batch num 6450: Loss 12.210372924804688
Batch num 6460: Loss 12.908041000366211
Batch num 6470: Loss 14.682632446289062
Batch num 6480: Loss 12.518199920654297
Batch num 6490: Loss 11.82234001159668
Batch num 6500: Loss 11.53640365600586
Batch num 6510: Loss 13.381515502929688
Batch num 6520: Loss 12.781628608703613
Batch num 6530: Loss 12.057126998901367
Batch num 6540: Loss 12.989946365356445
Batch num 6550: Loss 14.086227416992188
Batch num 6560: Loss 10.890249252319336
Batch num 6570: Loss 11.354394912719727
Batch num 6580: Loss 12.269386291503906
Batch

Batch num 8400: Loss 12.0958251953125
Batch num 8410: Loss 9.964010238647461
Batch num 8420: Loss 12.56396484375
Batch num 8430: Loss 12.472752571105957
Batch num 8440: Loss 11.968271255493164
Batch num 8450: Loss 12.842056274414062
Batch num 8460: Loss 14.150684356689453
Batch num 8470: Loss 12.351936340332031
Batch num 8480: Loss 13.734625816345215
Batch num 8490: Loss 12.33475399017334
Batch num 8500: Loss 10.920330047607422
Batch num 8510: Loss 10.494182586669922
Batch num 8520: Loss 12.374107360839844
Batch num 8530: Loss 13.06254768371582
Batch num 8540: Loss 11.628233909606934
Batch num 8550: Loss 9.647408485412598
Batch num 8560: Loss 12.325315475463867
Batch num 8570: Loss 11.383443832397461
Batch num 8580: Loss 12.634618759155273
Batch num 8590: Loss 11.958566665649414
Batch num 8600: Loss 12.466405868530273
Batch num 8610: Loss 11.249820709228516
Batch num 8620: Loss 11.21460247039795
Batch num 8630: Loss 12.28781509399414
Batch num 8640: Loss 10.765907287597656
Batch num 86

Batch num 10460: Loss 12.073395729064941
Batch num 10470: Loss 12.153554916381836
Batch num 10480: Loss 11.529979705810547
Batch num 10490: Loss 12.735363006591797
Batch num 10500: Loss 9.886653900146484
Batch num 10510: Loss 10.285593032836914
Batch num 10520: Loss 11.181818962097168
Batch num 10530: Loss 12.831205368041992
Batch num 10540: Loss 10.808584213256836
Batch num 10550: Loss 10.475578308105469
Batch num 10560: Loss 11.30087661743164
Batch num 10570: Loss 10.983662605285645
Batch num 10580: Loss 11.271774291992188
Batch num 10590: Loss 11.35125732421875
Batch num 10600: Loss 12.12234115600586
Batch num 10610: Loss 10.910028457641602
Batch num 10620: Loss 10.427988052368164
Batch num 10630: Loss 11.918657302856445
Batch num 10640: Loss 11.722230911254883
Batch num 10650: Loss 10.75849437713623
Batch num 10660: Loss 10.7138671875
Batch num 10670: Loss 12.755048751831055
Batch num 10680: Loss 9.921062469482422
Batch num 10690: Loss 10.595979690551758
Batch num 10700: Loss 13.75

Batch num 12480: Loss 12.044109344482422
Batch num 12490: Loss 9.679555892944336
Batch num 12500: Loss 10.499136924743652
Batch num 12510: Loss 10.075279235839844
Batch num 12520: Loss 9.989789962768555
Batch num 12530: Loss 11.120414733886719
Batch num 12540: Loss 13.10830020904541
Batch num 12550: Loss 10.332477569580078
Batch num 12560: Loss 9.501270294189453
Batch num 12570: Loss 9.79444408416748
Batch num 12580: Loss 10.072305679321289
Batch num 12590: Loss 13.048086166381836
Batch num 12600: Loss 13.01939582824707
Batch num 12610: Loss 11.56692886352539
Batch num 12620: Loss 11.57791805267334
Batch num 12630: Loss 11.758646011352539
Batch num 12640: Loss 11.011640548706055
Batch num 12650: Loss 11.860480308532715
Batch num 12660: Loss 10.122438430786133
Batch num 12670: Loss 11.436981201171875
Batch num 12680: Loss 9.903020858764648
Batch num 12690: Loss 9.1346435546875
Batch num 12700: Loss 10.555830001831055
Batch num 12710: Loss 11.594884872436523
Batch num 12720: Loss 11.1636

Batch num 14500: Loss 10.531059265136719
Batch num 14510: Loss 12.52972412109375
Batch num 14520: Loss 10.58328914642334
Batch num 14530: Loss 12.168708801269531
Batch num 14540: Loss 10.61160659790039
Batch num 14550: Loss 9.130583763122559
Batch num 14560: Loss 11.298576354980469
Batch num 14570: Loss 10.495771408081055
Batch num 14580: Loss 12.02033805847168
Batch num 14590: Loss 9.761374473571777
Batch num 14600: Loss 10.262524604797363
Batch num 14610: Loss 12.392995834350586
Batch num 14620: Loss 12.135869979858398
Batch num 14630: Loss 12.516550064086914
Batch num 14640: Loss 11.138627052307129
Batch num 14650: Loss 10.737394332885742
Batch num 14660: Loss 11.963736534118652
Batch num 14670: Loss 10.230920791625977
Batch num 14680: Loss 12.182991027832031
Batch num 14690: Loss 9.43471908569336
Batch num 14700: Loss 10.30388069152832
Batch num 14710: Loss 9.68006706237793
Batch num 14720: Loss 10.843214988708496
Batch num 14730: Loss 10.35739803314209
Batch num 14740: Loss 12.085

To do for final paper: 
* Add Tensorboard stuff 
* Print out accuracy for the encoder
* Create Perplexity evaluation metric
* Run example with discriminator over both the attention and hidden
* Factorize Code into util and classes