### Setting up a basic training loop - not using attention

In [1]:
''' Changing directories '''
import os 
if 'BERT' not in os.getcwd():
    os.chdir('BERT')
print("Current working dir is {}".format(os.getcwd()))

Current working dir is /juice/scr/scr110/scr/nlp/mtl_bert/unidirectional-NMT/BERT


In [2]:
import pyaml
import onmt
import math
import torch
from dataset import TextDataset, Collator
from encoder import Encoder 
from decoder import Decoder
from discriminator import Discriminator
from lib.huggingface.transformers import RobertaTokenizer, CamembertTokenizer
from torch.utils.data import DataLoader

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
with open(os.path.join(os.path.dirname(os.getcwd()), "config", "config.yml"), "r") as fd:
    config = pyaml.yaml.load(fd, Loader=pyaml.yaml.Loader)

In [5]:
sentence_token_length = config["maxlen"]

In [6]:
tokenizer_en = RobertaTokenizer.from_pretrained('roberta-base')
tokenizer_fr = CamembertTokenizer.from_pretrained('camembert-base')

In [7]:
collator = Collator(maxlen=sentence_token_length)

In [8]:
text_dataset_train = TextDataset("data/europarl-v7/", tokenizer_en, tokenizer_fr, training=True, maxlen=sentence_token_length)
text_dataset_val = TextDataset("data/europarl-v7/",  tokenizer_en, tokenizer_fr, training=False, maxlen=sentence_token_length)

0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
1704 examples with length < 2 removed.
42487 examples with length > 50 removed.
0
36 examples with length < 2 removed.
1037 examples with length > 50 removed.


In [9]:
train_dataloader = DataLoader(text_dataset_train, **config["data_loader"], collate_fn=collator)
val_dataloader = DataLoader(text_dataset_val, **config["data_loader"], collate_fn=collator)

#### Specifying the encoding and decoding models

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if (torch.cuda.is_available()):
    print("Using CUDA!")
else:
    print("Using CPU - Played yourself!")

Using CUDA!


In [11]:
try:
    del encoder_en
    del encoder_fr
except:
    pass 
encoder_en = Encoder("english").to(device=device)
encoder_fr = Encoder("french").to(device=device)

In [12]:
# same
word_padding_idx_en = encoder_en._modules['model'].embeddings.padding_idx
word_padding_idx_fr = encoder_fr._modules['model'].embeddings.padding_idx

# en > fr
word_vocab_size_en = encoder_en._modules['model'].embeddings.word_embeddings.num_embeddings
word_vocab_size_fr = encoder_fr._modules['model'].embeddings.word_embeddings.num_embeddings

# same
word_vec_size_en = encoder_en._modules['model'].embeddings.word_embeddings.embedding_dim
word_vec_size_fr = encoder_fr._modules['model'].embeddings.word_embeddings.embedding_dim

In [13]:
embeddings_en = onmt.modules.embeddings.Embeddings(
    word_vec_size_en, 
    word_vocab_size_en, 
    word_padding_idx_en, 
    position_encoding=True
).to(device=device)

embeddings_fr = onmt.modules.embeddings.Embeddings(
    word_vec_size_fr, 
    word_vocab_size_fr, 
    word_padding_idx_fr, 
    position_encoding=True
).to(device=device)

In [14]:
decoder_en = Decoder(**config["small_transformer"], embeddings=embeddings_en).to(device=device)
decoder_fr = Decoder(**config["small_transformer"], embeddings=embeddings_fr).to(device=device)

In [15]:
# projection: 768 -> 1
discriminator = Discriminator(config["small_transformer"]['d_model'], 1).to(device=device)

Defining loss functions

In [16]:
def loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict):
    '''Standard machine translation cross entropy loss'''
    ce_loss = torch.nn.CrossEntropyLoss(ignore_index = 1) #ignoring padding tokens
    
    predictions_fr = torch.argmax(french_predict, dim=2)
    
    loss_english_to_french = ce_loss(english_predict.transpose(1,2), english_gt)
    loss_french_to_english = ce_loss(french_predict.transpose(1,2), french_gt)
    return (loss_english_to_french + loss_french_to_english, loss_english_to_french, loss_french_to_english)

In [17]:
def loss_fn_single_regularization(english_gt, french_gt, english_predict, french_predict,
                                discriminator_gt, discriminator_predict):
    '''Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs'''
    ce_term, loss_english_to_french, loss_french_to_english = loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict)
    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    regularizing_term = bce_loss(discriminator_predict, discriminator_gt)
    
    return (ce_term + regularizing_term, loss_english_to_french, loss_french_to_english, regularizing_term)


In [18]:
def loss_fn_multi_regularization(english_gt, french_gt, english_predict, french_predict,
                                discriminator_gt_1, discriminator_predict_1,
                                discriminator_gt_2, discriminator_predict_2,):
    '''Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs'''
    ce_term = loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict)
    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    regularizing_term_1 = bce_loss(discriminator_predict_1, discriminator_gt_1)
    regularizing_term_2 = bce_loss(discriminator_predict_2, discriminator_gt_2)
    
    return ce_term + regularizing_term_1 + regularizing_term_2

Importing and defining evaluation functions

In [19]:
def exact_match(prediction, gt):
    '''Evaluate ground percent exact match '''
    mask = gt != 1
    return torch.sum((prediction == gt) * mask).item()/torch.sum(mask).item()

In [20]:
from nltk.translate.bleu_score import sentence_bleu

Defining optimizer and hooks if required

In [21]:
def get_optimizer(regularize="None"):
    params = list(encoder_en.parameters()) + list(encoder_fr.parameters()) +\
         list(decoder_fr.parameters()) + list(decoder_en.parameters())

    if (regularize == "hidden_state"):
        params += list(discriminator.parameters())
    
    if (regularize == "attention"):
        params += list(attn_discriminator.parameters())

    return torch.optim.Adam(params)    

In [22]:
optimizer = get_optimizer(regularize="hidden_state")

Defining primary training loop

In [23]:
from torch.utils.tensorboard import SummaryWriter
from utils import write_to_tensorboard

In [24]:
def train(train_data_iter, val_data_iter, regularize="None"): 
    ''' 
    Train the encoding and decoding models. User needs to pass in a valid iterator over the data,
    and also specify a type of adversarial regularization. regularize = ["hidden_state", "attention"]
    '''
    writer = SummaryWriter("runs/regularize_hidden")
                                                   
    for batch_num, batch in enumerate(train_data_iter):
        optimizer.zero_grad()
        
        # TODO: generalize this to multiple discriminators
        if (regularize != "None"):
            if (batch_num % 2 == 0):
                for param in encoder_en.parameters():
                    param.requires_grad = False
                for param in encoder_fr.parameters():
                    param.requires_grad = False    
                for param in discriminator.parameters():
                    param.requires_grad = True 
            else:
                for param in encoder_en.parameters():
                    param.requires_grad = True
                for param in encoder_fr.parameters():
                    param.requires_grad = True    
                for param in discriminator.parameters():
                    param.requires_grad = False 

        # Reading in input and moving to device

        english_batch, french_batch = batch 
        (english_sentences, english_sentences_no_eos, english_sentences_lengths) = english_batch
        (french_sentences, french_sentences_no_eos, french_sentences_lengths) = french_batch

        english_sentences = english_sentences.to(device=device)
        english_sentences_no_eos = english_sentences_no_eos.to(device=device)
        english_sentences_lengths = english_sentences_lengths.to(device=device)
        french_sentences = french_sentences.to(device=device)
        french_sentences_no_eos = french_sentences_no_eos.to(device=device)
        french_sentences_lengths = french_sentences_lengths.to(device=device)

        # Encoding - Decoding for English -> French
        encoder_outputs_en = encoder_en(english_sentences)
        decoder_fr.init_state(english_sentences.unsqueeze(2).transpose(0,1), None, None)
        dec_outs_fr, _ = decoder_fr(french_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_en[0].transpose(0,1), memory_lengths=english_sentences_lengths)

        # Encoding - Decoding for French -> English
        encoder_outputs_fr = encoder_fr(french_sentences)
        decoder_en.init_state(french_sentences.unsqueeze(2).transpose(0,1), None, None) 
        dec_outs_en, _ = decoder_en(english_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_fr[0].transpose(0,1), memory_lengths=french_sentences_lengths)

        if (regularize == "attention"):
            # extracting the attention scores from the datasets; using 7th attention head
            # as suggested by Clark et al, 2019 
            english_attention = extract_attention_scores(_hooks_english)[6] 
            french_attention = extract_attention_scores(_hooks_french)[6]
            batch_size = english_attention.shape[0]
            english_attention_reshaped = english_attention.view(batch_size, -1)
            french_attention_reshaped = french_attention.view(batch_size, -1)
            
            discriminator_outputs_en = attn_discriminator(english_attention_reshaped)
            discriminator_outputs_fr = attn_discriminator(french_attention_reshaped)
            discriminator_outputs_cat = torch.cat((discriminator_outputs_en, discriminator_outputs_fr))
            if (batch_num % 2 == 0):
                discriminator_labels = torch.tensor([1.0]*batch_size + [0.0]*batch_size)
            else:
                discriminator_labels = torch.tensor([0.0]*batch_size + [1.0]*batch_size)
            discriminator_labels = discriminator_labels.unsqueeze(1).to(device=device)

            
            all_losses = loss_fn_single_regularization(english_sentences[:, 1:],
                                               french_sentences[:, 1:],
                                               dec_outs_en,
                                               dec_outs_fr,
                                               discriminator_labels,
                                               discriminator_outputs_cat,
                                              )
            
            (loss, loss_english_to_french, loss_french_to_english, regularizing_term) = all_losses
            
        elif (regularize == "hidden_state"):
            # using the pooled outputs of the encoders for regularizing 
            discriminator_outputs_en = discriminator(encoder_outputs_en[1])
            discriminator_outputs_fr = discriminator(encoder_outputs_fr[1])
            discriminator_outputs_cat = torch.cat((discriminator_outputs_en, discriminator_outputs_fr))
            batch_size = discriminator_outputs_en.shape[0]
            
            if (batch_num % 2 == 0):
                discriminator_labels = torch.tensor([1.0]*batch_size + [0.0]*batch_size)
            else:
                discriminator_labels = torch.tensor([0.0]*batch_size + [1.0]*batch_size)
            discriminator_labels = discriminator_labels.unsqueeze(1).to(device=device)

            all_losses = loss_fn_single_regularization(english_sentences[:, 1:],
                                               french_sentences[:, 1:],
                                               dec_outs_en,
                                               dec_outs_fr,
                                               discriminator_labels,
                                               discriminator_outputs_cat,
                                              )
            (loss, loss_english_to_french, loss_french_to_english, regularizing_term) = all_losses
        else:
            loss = loss_fn_no_regularization(english_sentences[:, 1:],
                                           french_sentences[:, 1:],
                                           dec_outs_en,
                                           dec_outs_fr,
                                          )
            
        # must be put here to avoid name claim by val loop
        if (batch_num % 10 == 0):
            print("Batch num {}: Loss {}".format(batch_num, loss.item()))
        loss.backward()
        optimizer.step()
                                                   
        write_to_tensorboard("CCE", {"en-fr": loss_english_to_french.item(), "fr-en":loss_french_to_english.item()}, training=True, step=batch_num, writer=writer)
        write_to_tensorboard("BCE", {"attention-regularizing": regularizing_term.item()}, training=True, step=batch_num, writer=writer)

    
        # Running validation script  
        if (batch_num > 0 and batch_num % 100 == 0):
            with torch.no_grad():
                _blue_scores_en_fr = []
                _exact_matches_en_fr = []
                _blue_scores_fr_en = []
                _exact_matches_fr_en = []
                for batch_num, batch in enumerate(val_data_iter):
                    
                    if (batch_num == 25):
                        break
                    
                    # Reading in input and moving to device
                    english_batch, french_batch = batch 
                    (english_sentences, english_sentences_no_eos, english_sentences_lengths) = english_batch
                    (french_sentences, french_sentences_no_eos, french_sentences_lengths) = french_batch

                    english_sentences = english_sentences.to(device=device)
                    english_sentences_no_eos = english_sentences_no_eos.to(device=device)
                    english_sentences_lengths = english_sentences_lengths.to(device=device)
                    french_sentences = french_sentences.to(device=device)
                    french_sentences_no_eos = french_sentences_no_eos.to(device=device)
                    french_sentences_lengths = french_sentences_lengths.to(device=device)

                    # Encoding - Decoding for English -> French
                    encoder_outputs_en = encoder_en(english_sentences)
                    decoder_fr.init_state(english_sentences.unsqueeze(2).transpose(0,1), None, None)
                    dec_outs_fr, _ = decoder_fr(french_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_en[0].transpose(0,1), memory_lengths=english_sentences_lengths)
                    
                    # Encoding - Decoding for French -> English
                    encoder_outputs_fr = encoder_fr(french_sentences)
                    decoder_en.init_state(french_sentences.unsqueeze(2).transpose(0,1), None, None) 
                    dec_outs_en, _ = decoder_en(english_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_fr[0].transpose(0,1), memory_lengths=french_sentences_lengths)

                    
                    # Calculate BLUE Scores, EM and Perplexity
                    predictions_fr = torch.argmax(dec_outs_fr, dim=2)
                    predictions_en = torch.argmax(dec_outs_en, dim=2)
                    
                    for idx in range(french_sentences.shape[0]):
                        detokenized_french_gt = tokenizer_fr.convert_tokens_to_string(french_sentences[idx,1:].tolist())
                        detokenized_french_pred = tokenizer_fr.convert_tokens_to_string(predictions_fr[idx].tolist())
                        
                        _blue_score_en_fr = sentence_bleu(detokenized_french_gt, detokenized_french_pred)
                        _blue_scores_en_fr.append(_blue_score_en_fr)
                        
                    for idx in range(english_sentences.shape[0]):
           
                        detokenized_english_gt = tokenizer_en.decode(english_sentences[idx,1:].tolist())
                        detokenized_english_pred = tokenizer_en.decode(predictions_en[idx].tolist())
                            
                        _blue_score_fr_en = sentence_bleu(detokenized_english_gt, detokenized_english_pred)
                        _blue_scores_fr_en.append(_blue_score_fr_en)
                        
                    _exact_match_en_fr = exact_match(predictions_fr, french_sentences[:,1:])
                    _exact_matches_en_fr.append(_exact_match_en_fr)
                    
                    _exact_match_fr_en = exact_match(predictions_en, english_sentences[:,1:])
                    _exact_matches_fr_en.append(_exact_match_fr_en)
                    
                                                   
                avg_bleu_en_fr = sum(_blue_scores_en_fr)/len(_blue_scores_en_fr)
                avg_bleu_fr_en = sum(_blue_scores_fr_en)/len(_blue_scores_fr_en)
                avg_em_en_fr = sum(_exact_matches_en_fr)/len(_exact_matches_en_fr)
                avg_em_fr_en = sum(_exact_matches_fr_en)/len(_exact_matches_fr_en)
                                                   
                write_to_tensorboard("BLEU", {"en-fr": avg_bleu_en_fr, "fr-en":avg_bleu_fr_en}, training=False, step=batch_num, writer=writer)
                write_to_tensorboard("EM", {"en-fr": avg_em_en_fr, "fr-en":avg_em_fr_en}, training=False, step=batch_num, writer=writer)
            

In [None]:
train(train_dataloader, val_dataloader, regularize="hidden_state")

Batch num 0: Loss 1527.726318359375
Batch num 10: Loss 312.41156005859375
Batch num 20: Loss 168.8173370361328
Batch num 30: Loss 126.66352081298828
Batch num 40: Loss 133.12794494628906
Batch num 50: Loss 106.6199951171875
Batch num 60: Loss 94.66155242919922
Batch num 70: Loss 80.3686294555664
Batch num 80: Loss 61.28822708129883
Batch num 90: Loss 56.556060791015625
Batch num 100: Loss 52.40564727783203


The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


Batch num 110: Loss 56.741024017333984
Batch num 120: Loss 42.57631301879883
Batch num 130: Loss 54.447750091552734
Batch num 140: Loss 55.74707794189453
Batch num 150: Loss 46.81306457519531
Batch num 160: Loss 52.101295471191406
Batch num 170: Loss 58.495052337646484
Batch num 180: Loss 41.351593017578125
Batch num 190: Loss 42.04888916015625
Batch num 200: Loss 39.13063430786133
Batch num 210: Loss 35.858245849609375
Batch num 220: Loss 38.01784133911133
Batch num 230: Loss 32.590301513671875
Batch num 240: Loss 33.58393859863281
Batch num 250: Loss 33.45176315307617
Batch num 260: Loss 35.55852127075195
Batch num 270: Loss 70.44762420654297
Batch num 280: Loss 60.3500862121582
Batch num 290: Loss 60.22654724121094
Batch num 300: Loss 58.03997802734375
Batch num 310: Loss 74.3868408203125
Batch num 320: Loss 48.261207580566406
Batch num 330: Loss 30.621891021728516
Batch num 340: Loss 30.636999130249023
Batch num 350: Loss 33.46067428588867
Batch num 360: Loss 32.2385139465332
Batch

Batch num 2210: Loss 18.396495819091797
Batch num 2220: Loss 43.610572814941406
Batch num 2230: Loss 20.209150314331055
Batch num 2240: Loss 19.795150756835938
Batch num 2250: Loss 13.492600440979004
Batch num 2260: Loss 15.825240135192871
Batch num 2270: Loss 15.930524826049805
Batch num 2280: Loss 18.3903865814209
Batch num 2290: Loss 16.120758056640625
Batch num 2300: Loss 15.938831329345703
Batch num 2310: Loss 15.440675735473633
Batch num 2320: Loss 17.215120315551758
Batch num 2330: Loss 13.7899169921875
Batch num 2340: Loss 14.8385648727417
Batch num 2350: Loss 17.871795654296875
Batch num 2360: Loss 16.957395553588867
Batch num 2370: Loss 16.784107208251953
Batch num 2380: Loss 14.868050575256348
Batch num 2390: Loss 12.899948120117188
Batch num 2400: Loss 14.81849193572998
Batch num 2410: Loss 17.079845428466797
Batch num 2420: Loss 17.25981903076172
Batch num 2430: Loss 14.801520347595215
Batch num 2440: Loss 15.136960983276367
Batch num 2450: Loss 17.773427963256836
Batch nu

Batch num 4280: Loss 12.61088752746582
Batch num 4290: Loss 13.184579849243164
Batch num 4300: Loss 12.108922958374023
Batch num 4310: Loss 14.227204322814941
Batch num 4320: Loss 11.626322746276855
Batch num 4330: Loss 13.085803985595703
Batch num 4340: Loss 19.80148696899414
Batch num 4350: Loss 14.107999801635742
Batch num 4360: Loss 11.632828712463379
Batch num 4370: Loss 14.383805274963379
Batch num 4380: Loss 11.610433578491211
Batch num 4390: Loss 12.432424545288086
Batch num 4400: Loss 8.887800216674805
Batch num 4410: Loss 10.532144546508789
Batch num 4420: Loss 13.726178169250488
Batch num 4430: Loss 11.312039375305176
Batch num 4440: Loss 17.447341918945312
Batch num 4450: Loss 12.232714653015137
Batch num 4460: Loss 12.683088302612305
Batch num 4470: Loss 12.823184967041016
Batch num 4480: Loss 13.608194351196289
Batch num 4490: Loss 13.919692039489746
Batch num 4500: Loss 12.998674392700195
Batch num 4510: Loss 14.62230110168457
Batch num 4520: Loss 12.694602966308594
Batc

Batch num 6340: Loss 13.133964538574219
Batch num 6350: Loss 12.204507827758789
Batch num 6360: Loss 12.166802406311035
Batch num 6370: Loss 11.86599063873291
Batch num 6380: Loss 10.782913208007812
Batch num 6390: Loss 13.821427345275879
Batch num 6400: Loss 12.586196899414062
Batch num 6410: Loss 12.225229263305664
Batch num 6420: Loss 14.119084358215332
Batch num 6430: Loss 11.910850524902344
Batch num 6440: Loss 11.559564590454102
Batch num 6450: Loss 11.296067237854004
Batch num 6460: Loss 13.032496452331543
Batch num 6470: Loss 11.89477825164795
Batch num 6480: Loss 11.508818626403809
Batch num 6490: Loss 11.307814598083496
Batch num 6500: Loss 11.252508163452148
Batch num 6510: Loss 12.159656524658203
Batch num 6520: Loss 12.10630989074707
Batch num 6530: Loss 11.729484558105469
Batch num 6540: Loss 13.03402328491211
Batch num 6550: Loss 12.81981086730957
Batch num 6560: Loss 12.594045639038086
Batch num 6570: Loss 10.844291687011719
Batch num 6580: Loss 11.735206604003906
Batch

Batch num 8400: Loss 12.320418357849121
Batch num 8410: Loss 12.38836669921875
Batch num 8420: Loss 10.007856369018555
Batch num 8430: Loss 13.679346084594727
Batch num 8440: Loss 11.993042945861816
Batch num 8450: Loss 13.118221282958984
Batch num 8460: Loss 11.46019172668457
Batch num 8470: Loss 12.805059432983398
Batch num 8480: Loss 12.113720893859863
Batch num 8490: Loss 11.27938175201416
Batch num 8500: Loss 8.68008041381836
Batch num 8510: Loss 11.838997840881348
Batch num 8520: Loss 12.409638404846191
Batch num 8530: Loss 12.982776641845703
Batch num 8540: Loss 12.772794723510742
Batch num 8550: Loss 12.830733299255371
Batch num 8560: Loss 13.170856475830078
Batch num 8570: Loss 13.268133163452148
Batch num 8580: Loss 11.373501777648926
Batch num 8590: Loss 11.413782119750977
Batch num 8600: Loss 11.44261360168457
Batch num 8610: Loss 14.212586402893066
Batch num 8620: Loss 10.674121856689453
Batch num 8630: Loss 11.238775253295898
Batch num 8640: Loss 12.628458023071289
Batch 

Batch num 10460: Loss 10.858391761779785
Batch num 10470: Loss 12.892328262329102
Batch num 10480: Loss 8.484601020812988
Batch num 10490: Loss 11.072135925292969
Batch num 10500: Loss 11.677855491638184
Batch num 10510: Loss 10.643876075744629
Batch num 10520: Loss 11.120824813842773
Batch num 10530: Loss 11.391757011413574
Batch num 10540: Loss 11.082942962646484
Batch num 10550: Loss 11.673659324645996
Batch num 10560: Loss 10.787030220031738
Batch num 10570: Loss 12.285197257995605
Batch num 10580: Loss 10.833815574645996
Batch num 10590: Loss 12.058860778808594
Batch num 10600: Loss 11.389311790466309
Batch num 10610: Loss 12.09605884552002
Batch num 10620: Loss 11.568236351013184
Batch num 10630: Loss 9.907034873962402
Batch num 10640: Loss 11.140267372131348
Batch num 10650: Loss 11.82919979095459
Batch num 10660: Loss 9.665026664733887
Batch num 10670: Loss 11.914475440979004
Batch num 10680: Loss 13.762864112854004
Batch num 10690: Loss 10.621342658996582
Batch num 10700: Loss

Batch num 12480: Loss 10.077719688415527
Batch num 12490: Loss 11.784712791442871
Batch num 12500: Loss 12.58205509185791
Batch num 12510: Loss 10.40953254699707
Batch num 12520: Loss 10.646117210388184
Batch num 12530: Loss 11.8521089553833
Batch num 12540: Loss 11.339330673217773
Batch num 12550: Loss 12.517513275146484
Batch num 12560: Loss 14.443086624145508
Batch num 12570: Loss 11.316356658935547
Batch num 12580: Loss 9.814388275146484
Batch num 12590: Loss 11.813817024230957
Batch num 12600: Loss 18.65672492980957
Batch num 12610: Loss 15.470144271850586
Batch num 12620: Loss 12.837380409240723
Batch num 12630: Loss 17.190393447875977
Batch num 12640: Loss 97.02857208251953
Batch num 12650: Loss 8.846025466918945
Batch num 12660: Loss 11.241003036499023
Batch num 12670: Loss 10.981300354003906
Batch num 12680: Loss 9.915390014648438
Batch num 12690: Loss 12.138344764709473
Batch num 12700: Loss 11.785970687866211
Batch num 12710: Loss 15.774094581604004
Batch num 12720: Loss 10.

Batch num 14500: Loss 12.376367568969727
Batch num 14510: Loss 10.430978775024414
Batch num 14520: Loss 9.739832878112793
Batch num 14530: Loss 10.955766677856445
Batch num 14540: Loss 13.158984184265137
Batch num 14550: Loss 8.585968017578125
Batch num 14560: Loss 10.772320747375488
Batch num 14570: Loss 11.195357322692871
Batch num 14580: Loss 10.202375411987305
Batch num 14590: Loss 13.0786714553833
Batch num 14600: Loss 10.733305931091309
Batch num 14610: Loss 10.88233757019043
Batch num 14620: Loss 10.749368667602539
Batch num 14630: Loss 10.338922500610352
Batch num 14640: Loss 11.345375061035156
Batch num 14650: Loss 10.935029983520508
Batch num 14660: Loss 11.676665306091309
Batch num 14670: Loss 11.173802375793457
Batch num 14680: Loss 9.076919555664062
Batch num 14690: Loss 10.405904769897461
Batch num 14700: Loss 10.774101257324219
Batch num 14710: Loss 9.752720832824707
Batch num 14720: Loss 10.99216079711914
Batch num 14730: Loss 12.298288345336914
Batch num 14740: Loss 10

Batch num 16520: Loss 12.065289497375488
Batch num 16530: Loss 11.177978515625
Batch num 16540: Loss 10.480552673339844
Batch num 16550: Loss 10.095067977905273
Batch num 16560: Loss 12.663717269897461
Batch num 16570: Loss 11.056676864624023
Batch num 16580: Loss 9.329103469848633
Batch num 16590: Loss 12.764383316040039
Batch num 16600: Loss 9.365935325622559
Batch num 16610: Loss 9.27983283996582
Batch num 16620: Loss 10.032051086425781
Batch num 16630: Loss 10.627384185791016
Batch num 16640: Loss 9.567373275756836
Batch num 16650: Loss 11.837615013122559
Batch num 16660: Loss 13.02377700805664
Batch num 16670: Loss 13.262388229370117
Batch num 16680: Loss 10.321815490722656
Batch num 16690: Loss 10.02649211883545
Batch num 16700: Loss 10.78988265991211
Batch num 16710: Loss 10.469772338867188
Batch num 16720: Loss 11.659978866577148
Batch num 16730: Loss 10.459527969360352
Batch num 16740: Loss 10.90418815612793
Batch num 16750: Loss 11.280086517333984
Batch num 16760: Loss 11.238

Batch num 18550: Loss 11.646556854248047
Batch num 18560: Loss 11.047597885131836
Batch num 18570: Loss 9.894384384155273
Batch num 18580: Loss 11.46529769897461
Batch num 18590: Loss 10.82190227508545
Batch num 18600: Loss 11.083831787109375
Batch num 18610: Loss 12.366808891296387
Batch num 18620: Loss 11.067879676818848
Batch num 18630: Loss 9.013800621032715
Batch num 18640: Loss 10.119670867919922
Batch num 18650: Loss 10.83991813659668
Batch num 18660: Loss 10.47279167175293
Batch num 18670: Loss 11.414790153503418
Batch num 18680: Loss 12.132226943969727
Batch num 18690: Loss 12.242347717285156
Batch num 18700: Loss 10.281196594238281
Batch num 18710: Loss 11.191089630126953
Batch num 18720: Loss 11.255404472351074
Batch num 18730: Loss 10.345362663269043
Batch num 18740: Loss 11.25413703918457
Batch num 18750: Loss 10.909553527832031
Batch num 18760: Loss 10.350628852844238
Batch num 18770: Loss 10.78247356414795
Batch num 18780: Loss 10.935955047607422
Batch num 18790: Loss 10

Batch num 20570: Loss 10.235538482666016
Batch num 20580: Loss 9.939751625061035
Batch num 20590: Loss 11.473068237304688
Batch num 20600: Loss 9.83657169342041
Batch num 20610: Loss 11.21064567565918
Batch num 20620: Loss 9.827032089233398
Batch num 20630: Loss 8.866369247436523
Batch num 20640: Loss 11.359190940856934
Batch num 20650: Loss 11.109756469726562
Batch num 20660: Loss 7.051154613494873
Batch num 20670: Loss 9.132936477661133
Batch num 20680: Loss 10.76158618927002
Batch num 20690: Loss 10.382133483886719
Batch num 20700: Loss 8.862678527832031
Batch num 20710: Loss 10.427435874938965
Batch num 20720: Loss 11.401834487915039
Batch num 20730: Loss 105.22845458984375
Batch num 20740: Loss 25.29604721069336
Batch num 20750: Loss 18.915874481201172
Batch num 20760: Loss 11.599516868591309
Batch num 20770: Loss 9.918065071105957
Batch num 20780: Loss 9.893542289733887
Batch num 20790: Loss 11.208086967468262
Batch num 20800: Loss 10.008210182189941
Batch num 20810: Loss 11.3316

Batch num 22600: Loss 9.754932403564453
Batch num 22610: Loss 10.733416557312012
Batch num 22620: Loss 10.102594375610352
Batch num 22630: Loss 10.877947807312012
Batch num 22640: Loss 11.368189811706543
Batch num 22650: Loss 11.345340728759766
Batch num 22660: Loss 12.42745304107666
Batch num 22670: Loss 9.934901237487793
Batch num 22680: Loss 9.50755500793457
Batch num 22690: Loss 9.56198501586914
Batch num 22700: Loss 8.518766403198242
Batch num 22710: Loss 11.853572845458984
Batch num 22720: Loss 11.648247718811035
Batch num 22730: Loss 9.831513404846191
Batch num 22740: Loss 9.957441329956055
Batch num 22750: Loss 11.908453941345215
Batch num 22760: Loss 10.223410606384277
Batch num 22770: Loss 9.64030933380127
Batch num 22780: Loss 9.595108985900879
Batch num 22790: Loss 11.547274589538574
Batch num 22800: Loss 12.46478271484375
Batch num 22810: Loss 11.403555870056152
Batch num 22820: Loss 9.883271217346191
Batch num 22830: Loss 11.409259796142578
Batch num 22840: Loss 8.4415969

Batch num 24630: Loss 14.763765335083008
Batch num 24640: Loss 12.52927017211914
Batch num 24650: Loss 9.438425064086914
Batch num 24660: Loss 10.161961555480957
Batch num 24670: Loss 10.369175910949707
Batch num 24680: Loss 11.321742057800293
Batch num 24690: Loss 11.657962799072266
Batch num 24700: Loss 9.06144905090332
Batch num 24710: Loss 11.022412300109863
Batch num 24720: Loss 10.698310852050781
Batch num 24730: Loss 10.294012069702148
Batch num 24740: Loss 42.95354080200195
Batch num 24750: Loss 21.0216007232666
Batch num 24760: Loss 10.347230911254883
Batch num 24770: Loss 10.744314193725586
Batch num 24780: Loss 10.012750625610352
Batch num 24790: Loss 9.177306175231934
Batch num 24800: Loss 14.965353012084961
Batch num 24810: Loss 10.588312149047852
Batch num 24820: Loss 12.3358736038208
Batch num 24830: Loss 8.116700172424316
Batch num 24840: Loss 9.065238952636719
Batch num 24850: Loss 45.48961639404297
Batch num 24860: Loss 15.35739517211914
Batch num 24870: Loss 42.08970

Batch num 26660: Loss 10.004644393920898
Batch num 26670: Loss 11.190723419189453
Batch num 26680: Loss 7.2605299949646
Batch num 26690: Loss 9.84959602355957
Batch num 26700: Loss 10.935554504394531
Batch num 26710: Loss 9.514371871948242
Batch num 26720: Loss 8.619270324707031
Batch num 26730: Loss 10.982379913330078
Batch num 26740: Loss 11.472953796386719
Batch num 26750: Loss 10.428322792053223
Batch num 26760: Loss 11.553048133850098
Batch num 26770: Loss 10.759095191955566
Batch num 26780: Loss 7.299527168273926
Batch num 26790: Loss 10.75650405883789
Batch num 26800: Loss 10.261479377746582
Batch num 26810: Loss 10.253299713134766
Batch num 26820: Loss 10.338923454284668
Batch num 26830: Loss 9.867844581604004
Batch num 26840: Loss 10.8430814743042
Batch num 26850: Loss 11.335005760192871
Batch num 26860: Loss 9.695938110351562
Batch num 26870: Loss 12.805227279663086
Batch num 26880: Loss 10.942564010620117
Batch num 26890: Loss 9.597543716430664
Batch num 26900: Loss 11.38465

Batch num 28690: Loss 10.975594520568848
Batch num 28700: Loss 10.613696098327637
Batch num 28710: Loss 10.698972702026367
Batch num 28720: Loss 10.099038124084473
Batch num 28730: Loss 10.048246383666992
Batch num 28740: Loss 9.158878326416016
Batch num 28750: Loss 9.357900619506836
Batch num 28760: Loss 11.324456214904785
Batch num 28770: Loss 10.153430938720703
Batch num 28780: Loss 9.665699005126953
Batch num 28790: Loss 11.559667587280273
Batch num 28800: Loss 9.621159553527832
Batch num 28810: Loss 11.059752464294434
Batch num 28820: Loss 9.363100051879883
Batch num 28830: Loss 9.526044845581055
Batch num 28840: Loss 9.330915451049805
Batch num 28850: Loss 8.345442771911621
Batch num 28860: Loss 9.3341646194458
Batch num 28870: Loss 10.684869766235352
Batch num 28880: Loss 9.202886581420898
Batch num 28890: Loss 8.351156234741211
Batch num 28900: Loss 10.578193664550781
Batch num 28910: Loss 12.204509735107422
Batch num 28920: Loss 10.78753662109375
Batch num 28930: Loss 10.68581

To do for final paper: 
* Add Tensorboard stuff 
* Print out accuracy for the encoder
* Create Perplexity evaluation metric
* Run example with discriminator over both the attention and hidden
* Factorize Code into util and classes