### Setting up a basic training loop - not using attention

In [2]:
''' Changing directories '''
import os 
if 'BERT' not in os.getcwd():
    os.chdir('BERT')
print("Current working dir is {}".format(os.getcwd()))

Current working dir is /juice/scr/scr110/scr/nlp/mtl_bert/unidirectional-NMT/BERT


In [3]:
import pyaml
import onmt
import math
import torch
from dataset import TextDataset, Collator
from encoder import Encoder 
from decoder import Decoder
from discriminator import Discriminator
from lib.huggingface.transformers import RobertaTokenizer, CamembertTokenizer
from torch.utils.data import DataLoader

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
with open(os.path.join(os.path.dirname(os.getcwd()), "config", "config.yml"), "r") as fd:
    config = pyaml.yaml.load(fd, Loader=pyaml.yaml.Loader)

In [5]:
sentence_token_length = config["maxlen"]

In [6]:
tokenizer_en = RobertaTokenizer.from_pretrained('roberta-base')
tokenizer_fr = CamembertTokenizer.from_pretrained('camembert-base')

In [7]:
collator = Collator(maxlen=sentence_token_length)

In [9]:
text_dataset_train = TextDataset("data/europarl-v7/", tokenizer_en, tokenizer_fr, training=True, maxlen=sentence_token_length)
text_dataset_val = TextDataset("data/europarl-v7/",  tokenizer_en, tokenizer_fr, training=False, maxlen=sentence_token_length)

0
10000
20000
175 examples with length < 2 removed.
4191 examples with length > 50 removed.
0
36 examples with length < 2 removed.
1037 examples with length > 50 removed.


In [10]:
train_dataloader = DataLoader(text_dataset_train, **config["data_loader"], collate_fn=collator)
val_dataloader = DataLoader(text_dataset_val, **config["data_loader"], collate_fn=collator)

#### Specifying the encoding and decoding models

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if (torch.cuda.is_available()):
    print("Using CUDA!")
else:
    print("Using CPU - Played yourself!")

Using CUDA!


In [12]:
try:
    del encoder_en
    del encoder_fr
except:
    pass 
encoder_en = Encoder("english").to(device=device)
encoder_fr = Encoder("french").to(device=device)

In [13]:
# same
word_padding_idx_en = encoder_en._modules['model'].embeddings.padding_idx
word_padding_idx_fr = encoder_fr._modules['model'].embeddings.padding_idx

# en > fr
word_vocab_size_en = encoder_en._modules['model'].embeddings.word_embeddings.num_embeddings
word_vocab_size_fr = encoder_fr._modules['model'].embeddings.word_embeddings.num_embeddings

# same
word_vec_size_en = encoder_en._modules['model'].embeddings.word_embeddings.embedding_dim
word_vec_size_fr = encoder_fr._modules['model'].embeddings.word_embeddings.embedding_dim

In [14]:
embeddings_en = onmt.modules.embeddings.Embeddings(
    word_vec_size_en, 
    word_vocab_size_en, 
    word_padding_idx_en, 
    position_encoding=True
).to(device=device)

embeddings_fr = onmt.modules.embeddings.Embeddings(
    word_vec_size_fr, 
    word_vocab_size_fr, 
    word_padding_idx_fr, 
    position_encoding=True
).to(device=device)

In [15]:
decoder_en = Decoder(**config["small_transformer"], embeddings=embeddings_en).to(device=device)
decoder_fr = Decoder(**config["small_transformer"], embeddings=embeddings_fr).to(device=device)

In [16]:
# projection: standard_sentence_length**2 -> 1
attn_discriminator = Discriminator(int(sentence_token_length**2), 1).to(device=device)

Defining loss functions

In [17]:
def loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict):
    '''Standard machine translation cross entropy loss'''
    ce_loss = torch.nn.CrossEntropyLoss(ignore_index = 1) #ignoring padding tokens
    
    predictions_fr = torch.argmax(french_predict, dim=2)
    
    loss_english_to_french = ce_loss(english_predict.transpose(1,2), english_gt)
    loss_french_to_english = ce_loss(french_predict.transpose(1,2), french_gt)
    return (loss_english_to_french + loss_french_to_english, loss_english_to_french, loss_french_to_english)

In [18]:
def loss_fn_single_regularization(english_gt, french_gt, english_predict, french_predict,
                                discriminator_gt, discriminator_predict):
    '''Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs'''
    ce_term, loss_english_to_french, loss_french_to_english = loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict)
    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    regularizing_term = bce_loss(discriminator_predict, discriminator_gt)
    
    return (ce_term + regularizing_term, loss_english_to_french, loss_french_to_english, regularizing_term)


In [19]:
def loss_fn_multi_regularization(english_gt, french_gt, english_predict, french_predict,
                                discriminator_gt_1, discriminator_predict_1,
                                discriminator_gt_2, discriminator_predict_2,):
    '''Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs'''
    ce_term = loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict)
    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    regularizing_term_1 = bce_loss(discriminator_predict_1, discriminator_gt_1)
    regularizing_term_2 = bce_loss(discriminator_predict_2, discriminator_gt_2)
    
    return ce_term + regularizing_term_1 + regularizing_term_2

Importing and defining evaluation functions

In [20]:
def exact_match(prediction, gt):
    '''Evaluate ground percent exact match '''
    mask = gt != 1
    return torch.sum((prediction == gt) * mask).item()/torch.sum(mask).item()

In [21]:
from nltk.translate.bleu_score import sentence_bleu

Defining optimizer and hooks if required

In [22]:
def get_optimizer(regularize="None"):
    params = list(encoder_en.parameters()) + list(encoder_fr.parameters()) +\
         list(decoder_fr.parameters()) + list(decoder_en.parameters())

    if (regularize == "hidden_state"):
        params += list(discriminator.parameters())
    
    if (regularize == "attention"):
        params += list(attn_discriminator.parameters())

    return torch.optim.Adam(params)    

In [23]:
optimizer = get_optimizer(regularize="attention")

Defining primary training loop

In [52]:
from torch.utils.tensorboard import SummaryWriter
from utils import write_to_tensorboard

In [77]:
def train(train_data_iter, val_data_iter, regularize="None"): 
    ''' 
    Train the encoding and decoding models. User needs to pass in a valid iterator over the data,
    and also specify a type of adversarial regularization. regularize = ["hidden_state", "attention"]
    '''
    writer = SummaryWriter("runs/regularize_attention")
                                                   
    for batch_num, batch in enumerate(train_data_iter):
        optimizer.zero_grad()
        
        # TODO: generalize this to multiple discriminators
        if (regularize != "None"):
            if (batch_num % 2 == 0):
                for param in encoder_en.parameters():
                    param.requires_grad = False
                for param in encoder_fr.parameters():
                    param.requires_grad = False    
                for param in attn_discriminator.parameters():
                    param.requires_grad = True 
            else:
                for param in encoder_en.parameters():
                    param.requires_grad = True
                for param in encoder_fr.parameters():
                    param.requires_grad = True    
                for param in attn_discriminator.parameters():
                    param.requires_grad = False 

        # Reading in input and moving to device

        english_batch, french_batch = batch 
        (english_sentences, english_sentences_no_eos, english_sentences_lengths) = english_batch
        (french_sentences, french_sentences_no_eos, french_sentences_lengths) = french_batch

        english_sentences = english_sentences.to(device=device)
        english_sentences_no_eos = english_sentences_no_eos.to(device=device)
        english_sentences_lengths = english_sentences_lengths.to(device=device)
        french_sentences = french_sentences.to(device=device)
        french_sentences_no_eos = french_sentences_no_eos.to(device=device)
        french_sentences_lengths = french_sentences_lengths.to(device=device)

        # Encoding - Decoding for English -> French
        encoder_outputs_en = encoder_en(english_sentences)
        decoder_fr.init_state(english_sentences.unsqueeze(2).transpose(0,1), None, None)
        dec_outs_fr, _ = decoder_fr(french_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_en[0].transpose(0,1), memory_lengths=english_sentences_lengths)

        # Encoding - Decoding for French -> English
        encoder_outputs_fr = encoder_fr(french_sentences)
        decoder_en.init_state(french_sentences.unsqueeze(2).transpose(0,1), None, None) 
        dec_outs_en, _ = decoder_en(english_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_fr[0].transpose(0,1), memory_lengths=french_sentences_lengths)

        if (regularize == "attention"):
            # extracting the attention scores from the datasets; using 7th attention head
            # as suggested by Clark et al, 2019 
            english_attention = extract_attention_scores(_hooks_english)[6] 
            french_attention = extract_attention_scores(_hooks_french)[6]
            batch_size = english_attention.shape[0]
            english_attention_reshaped = english_attention.view(batch_size, -1)
            french_attention_reshaped = french_attention.view(batch_size, -1)
            
            discriminator_outputs_en = attn_discriminator(english_attention_reshaped)
            discriminator_outputs_fr = attn_discriminator(french_attention_reshaped)
            discriminator_outputs_cat = torch.cat((discriminator_outputs_en, discriminator_outputs_fr))
            if (batch_num % 2 == 0):
                discriminator_labels = torch.tensor([1.0]*batch_size + [0.0]*batch_size)
            else:
                discriminator_labels = torch.tensor([0.0]*batch_size + [1.0]*batch_size)
            discriminator_labels = discriminator_labels.unsqueeze(1).to(device=device)

            
            all_losses = loss_fn_single_regularization(english_sentences[:, 1:],
                                               french_sentences[:, 1:],
                                               dec_outs_en,
                                               dec_outs_fr,
                                               discriminator_labels,
                                               discriminator_outputs_cat,
                                              )
            
            (loss, loss_english_to_french, loss_french_to_english, regularizing_term) = all_losses
            
        elif (regularize == "hidden_state"):
            # using the pooled outputs of the encoders for regularizing 
            discriminator_outputs_en = discriminator(encoder_outputs_en[1])
            discriminator_outputs_fr = discriminator(encoder_outputs_fr[1])
            discriminator_outputs_cat = torch.cat((discriminator_outputs_en, discriminator_outputs_fr))
            
            if (batch_num % 2 == 0):
                discriminator_labels = torch.tensor([1.0]*batch_size + [0.0]*batch_size)
            else:
                discriminator_labels = torch.tensor([0.0]*batch_size + [1.0]*batch_size)
            discriminator_labels = discriminator_labels.unsqueeze(1).to(device=device)

            all_losses = loss_fn_single_regularization(english_sentences[:, 1:],
                                               french_sentences[:, 1:],
                                               dec_outs_en,
                                               dec_outs_fr,
                                               discriminator_labels,
                                               discriminator_outputs_cat,
                                              )
            (loss, loss_english_to_french, loss_french_to_english, regularizing_term) = all_losses
        else:
            loss = loss_fn_no_regularization(english_sentences[:, 1:],
                                           french_sentences[:, 1:],
                                           dec_outs_en,
                                           dec_outs_fr,
                                          )
            
        # must be put here to avoid name claim by val loop
        if (batch_num % 10 == 0):
            print("Batch num {}: Loss {}".format(batch_num, loss.item()))
        loss.backward()
        optimizer.step()
                                                   
        write_to_tensorboard("CCE", {"en-fr": loss_english_to_french.item(), "fr-en":loss_french_to_english.item()}, training=True, step=batch_num, writer=writer)
        write_to_tensorboard("BCE", {"attention-regularizing": regularizing_term.item()}, training=True, step=batch_num, writer=writer)

    
        # Running validation script  
        if (batch_num > 0 and batch_num % 10 == 0):
            with torch.no_grad():
                _blue_scores_en_fr = []
                _exact_matches_en_fr = []
                _blue_scores_fr_en = []
                _exact_matches_fr_en = []
                for batch_num, batch in enumerate(val_data_iter):
                    
                    if (batch_num == 50):
                        break
                    
                    # Reading in input and moving to device
                    english_batch, french_batch = batch 
                    (english_sentences, english_sentences_no_eos, english_sentences_lengths) = english_batch
                    (french_sentences, french_sentences_no_eos, french_sentences_lengths) = french_batch

                    english_sentences = english_sentences.to(device=device)
                    english_sentences_no_eos = english_sentences_no_eos.to(device=device)
                    english_sentences_lengths = english_sentences_lengths.to(device=device)
                    french_sentences = french_sentences.to(device=device)
                    french_sentences_no_eos = french_sentences_no_eos.to(device=device)
                    french_sentences_lengths = french_sentences_lengths.to(device=device)

                    # Encoding - Decoding for English -> French
                    encoder_outputs_en = encoder_en(english_sentences)
                    decoder_fr.init_state(english_sentences.unsqueeze(2).transpose(0,1), None, None)
                    dec_outs_fr, _ = decoder_fr(french_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_en[0].transpose(0,1), memory_lengths=english_sentences_lengths)
                    
                    # Encoding - Decoding for French -> English
                    encoder_outputs_fr = encoder_fr(french_sentences)
                    decoder_en.init_state(french_sentences.unsqueeze(2).transpose(0,1), None, None) 
                    dec_outs_en, _ = decoder_en(english_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_fr[0].transpose(0,1), memory_lengths=french_sentences_lengths)

                    
                    # Calculate BLUE Scores, EM and Perplexity
                    predictions_fr = torch.argmax(dec_outs_fr, dim=2)
                    predictions_en = torch.argmax(dec_outs_en, dim=2)
                    
                    for idx in range(french_sentences.shape[0]):
                        detokenized_french_gt = tokenizer_fr.convert_tokens_to_string(french_sentences[idx,1:].tolist())
                        detokenized_french_pred = tokenizer_fr.convert_tokens_to_string(predictions_fr[idx].tolist())
                        
                        _blue_score_en_fr = sentence_bleu(detokenized_french_gt, detokenized_french_pred)
                        _blue_scores_en_fr.append(_blue_score_en_fr)
                        
                    for idx in range(english_sentences.shape[0]):
           
                        detokenized_english_gt = tokenizer_en.decode(english_sentences[idx,1:].tolist())
                        detokenized_english_pred = tokenizer_en.decode(predictions_en[idx].tolist())
                            
                        _blue_score_fr_en = sentence_bleu(detokenized_english_gt, detokenized_english_pred)
                        _blue_scores_fr_en.append(_blue_score_fr_en)
                        
                    _exact_match_en_fr = exact_match(predictions_fr, french_sentences[:,1:])
                    _exact_matches_en_fr.append(_exact_match_en_fr)
                    
                    _exact_match_fr_en = exact_match(predictions_en, english_sentences[:,1:])
                    _exact_matches_fr_en.append(_exact_match_fr_en)
                    
                                                   
                avg_bleu_en_fr = sum(_blue_scores_en_fr)/len(_blue_scores_en_fr)
                avg_bleu_fr_en = sum(_blue_scores_fr_en)/len(_blue_scores_fr_en)
                avg_em_en_fr = sum(_exact_matches_en_fr)/len(_exact_matches_en_fr)
                avg_em_fr_en = sum(_exact_matches_fr_en)/len(_exact_matches_fr_en)
                                                   
                write_to_tensorboard("BLEU", {"en-fr": avg_bleu_en_fr, "fr-en":avg_bleu_fr_en}, training=False, step=batch_num, writer=writer)
                write_to_tensorboard("EM", {"en-fr": avg_em_en_fr, "fr-en":avg_em_fr_en}, training=False, step=batch_num, writer=writer)
            

In [None]:
train(train_dataloader, val_dataloader, regularize="attention")

Batch num 0: Loss 26.897085189819336
Batch num 10: Loss 25.889673233032227
Batch num 20: Loss 22.817777633666992
Batch num 30: Loss 32.06048583984375
Batch num 40: Loss 23.33237648010254
Batch num 50: Loss 22.062509536743164
Batch num 60: Loss 27.385663986206055
Batch num 70: Loss 30.629655838012695
Batch num 80: Loss 26.126462936401367
Batch num 90: Loss 20.1998233795166
Batch num 100: Loss 26.90410041809082
Batch num 110: Loss 28.324705123901367
Batch num 120: Loss 22.771947860717773
Batch num 130: Loss 21.53239631652832
Batch num 140: Loss 24.382516860961914
Batch num 150: Loss 25.296218872070312
Batch num 160: Loss 29.16266632080078
Batch num 170: Loss 20.89948844909668
Batch num 180: Loss 22.258615493774414
Batch num 190: Loss 25.992115020751953
Batch num 200: Loss 20.665569305419922
Batch num 210: Loss 20.555999755859375
Batch num 220: Loss 25.68578338623047
Batch num 230: Loss 26.146568298339844
Batch num 240: Loss 24.552053451538086
Batch num 250: Loss 25.202714920043945
Batch 

To do for final paper: 
* Add Tensorboard stuff 
* Print out accuracy for the encoder
* Create Perplexity evaluation metric
* Run example with discriminator over both the attention and hidden
* Factorize Code into util and classes