### Setting up a basic training loop - not using attention

In [2]:
''' Changing directories '''
import os 
if 'BERT' not in os.getcwd():
    os.chdir('BERT')
print("Current working dir is {}".format(os.getcwd()))

Current working dir is /juice/scr/scr110/scr/nlp/mtl_bert/unidirectional-NMT/BERT


In [3]:
import pyaml
import onmt
import torch
from dataset import TextDataset, Collator
from encoder import Encoder 
from decoder import Decoder
from discriminator import Discriminator
from lib.huggingface.transformers import RobertaTokenizer, CamembertTokenizer
from torch.utils.data import DataLoader

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
with open(os.path.join(os.path.dirname(os.getcwd()), "config", "config.yml"), "r") as fd:
    config = pyaml.yaml.load(fd, Loader=pyaml.yaml.Loader)

In [6]:
sentence_token_length = config["maxlen"]

In [7]:
tokenizer_en = RobertaTokenizer.from_pretrained('roberta-base')
tokenizer_fr = CamembertTokenizer.from_pretrained('camembert-base')

In [8]:
collator = Collator(maxlen=sentence_token_length)

In [9]:
text_dataset_train = TextDataset("data/europarl-v7/", tokenizer_en, tokenizer_fr, training=True)
text_dataset_val = TextDataset("data/europarl-v7/",  tokenizer_en, tokenizer_fr, training=False)

0
10000
20000
175 examples with length < 2 removed.
10109 examples with length > 50 removed.
0
36 examples with length < 2 removed.
2547 examples with length > 50 removed.


In [10]:
train_dataloader = DataLoader(text_dataset_train, **config["data_loader"], collate_fn=collator)
val_dataloader = DataLoader(text_dataset_val, **config["data_loader"], collate_fn=collator)

#### Specifying the encoding and decoding models

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if (torch.cuda.is_available()):
    print("Using CUDA!")
else:
    print("Using CPU - Played yourself!")

Using CUDA!


In [12]:
try:
    del encoder_en
    del encoder_fr
except:
    pass 
encoder_en = Encoder("english").to(device=device)
encoder_fr = Encoder("french").to(device=device)

In [13]:
# same
word_padding_idx_en = encoder_en._modules['model'].embeddings.padding_idx
word_padding_idx_fr = encoder_fr._modules['model'].embeddings.padding_idx

# en > fr
word_vocab_size_en = encoder_en._modules['model'].embeddings.word_embeddings.num_embeddings
word_vocab_size_fr = encoder_fr._modules['model'].embeddings.word_embeddings.num_embeddings

# same
word_vec_size_en = encoder_en._modules['model'].embeddings.word_embeddings.embedding_dim
word_vec_size_fr = encoder_fr._modules['model'].embeddings.word_embeddings.embedding_dim

In [14]:
embeddings_en = onmt.modules.embeddings.Embeddings(
    word_vec_size_en, 
    word_vocab_size_en, 
    word_padding_idx_en, 
    position_encoding=True
).to(device=device)

embeddings_fr = onmt.modules.embeddings.Embeddings(
    word_vec_size_fr, 
    word_vocab_size_fr, 
    word_padding_idx_fr, 
    position_encoding=True
).to(device=device)

In [15]:
decoder_en = Decoder(**config["small_transformer"], embeddings=embeddings_en).to(device=device)
decoder_fr = Decoder(**config["small_transformer"], embeddings=embeddings_fr).to(device=device)

In [16]:
# projection: standard_sentence_length**2 -> 1
discriminator = Discriminator(sentence_token_length**2, 1).to(device=device)

Defining loss functions

In [17]:
def loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict):
    '''Standard machine translation cross entropy loss'''
    ce_loss = torch.nn.CrossEntropyLoss(ignore_index = 1) #ignoring padding tokens
    
    predictions_fr = torch.argmax(french_predict, dim=2)
    
    loss_english_to_french = ce_loss(english_predict.transpose(1,2), english_gt)
    loss_french_to_english = ce_loss(french_predict.transpose(1,2), french_gt)
    return loss_english_to_french + loss_french_to_english

In [18]:
def loss_fn_single_regularization(english_gt, french_gt, english_predict, french_predict,
                                discriminator_gt, discriminator_predict):
    '''Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs'''
    ce_term = loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict)
    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    regularizing_term = bce_loss(discriminator_predict, discriminator_gt)
    
    return ce_term + regularizing_term

In [19]:
def loss_fn_multi_regularization(english_gt, french_gt, english_predict, french_predict,
                                discriminator_gt_1, discriminator_predict_1,
                                discriminator_gt_2, discriminator_predict_2,):
    '''Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs'''
    ce_term = loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict)
    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    regularizing_term_1 = bce_loss(discriminator_predict_1, discriminator_gt_1)
    regularizing_term_2 = bce_loss(discriminator_predict_2, discriminator_gt_2)
    
    return ce_term + regularizing_term_1 + regularizing_term_2

Importing and defining evaluation functions

In [20]:
def exact_match(prediction, gt):
    '''Evaluate ground percent exact match '''
    mask = gt != 1
    return torch.sum((prediction == gt) * mask).item()/torch.sum(mask).item()

In [21]:
from nltk.translate.bleu_score import sentence_bleu

Defining optimizer and hooks if required

In [22]:
def get_optimizer(regularize="hidden_state"):
    params = list(encoder_en.parameters()) + list(encoder_fr.parameters()) +\
         list(decoder_fr.parameters()) + list(decoder_en.parameters())

    if (regularize == "hidden_state" or regularize =="attention"):
        params += list(discriminator.parameters())

    return torch.optim.Adam(params)    

In [23]:
optimizer = get_optimizer(regularize="attention")

In [24]:
from hook import Hook
from utils import extract_attention_scores

In [25]:
_hooks_english = [Hook(layer[1]) for layer in list(encoder_en.named_modules())]
_hooks_french = [Hook(layer[1]) for layer in list(encoder_fr.named_modules())]

Defining primary training loop

In [42]:
def train(train_data_iter, val_data_iter, regularize="hidden_state"): 
    ''' 
    Train the encoding and decoding models. User needs to pass in a valid iterator over the data,
    and also specify a type of adversarial regularization. regularize = ["hidden_state", "attention"]
    '''
    
    for batch_num, batch in enumerate(train_data_iter):
        optimizer.zero_grad()

        # Reading in input and moving to device

        english_batch, french_batch = batch 
        (english_sentences, english_sentences_no_eos, english_sentences_lengths) = english_batch
        (french_sentences, french_sentences_no_eos, french_sentences_lengths) = french_batch

        english_sentences = english_sentences.to(device=device)
        english_sentences_no_eos = english_sentences_no_eos.to(device=device)
        english_sentences_lengths = english_sentences_lengths.to(device=device)
        french_sentences = french_sentences.to(device=device)
        french_sentences_no_eos = french_sentences_no_eos.to(device=device)
        french_sentences_lengths = french_sentences_lengths.to(device=device)

        # Encoding - Decoding for English -> French
        encoder_outputs_en = encoder_en(english_sentences)
        decoder_fr.init_state(english_sentences.unsqueeze(2).transpose(0,1), None, None)
        dec_outs_fr, _ = decoder_fr(french_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_en[0].transpose(0,1), memory_lengths=english_sentences_lengths)

        # Encoding - Decoding for French -> English
        encoder_outputs_fr = encoder_fr(french_sentences)
        decoder_en.init_state(french_sentences.unsqueeze(2).transpose(0,1), None, None) 
        dec_outs_en, _ = decoder_en(english_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_fr[0].transpose(0,1), memory_lengths=french_sentences_lengths)

        if (regularize == "attention"):
            # extracting the attention scores from the datasets; using 7th attention head
            # as suggested by Clark et al, 2019 
            english_attention = extract_attention_scores(_hooks_english)[6] 
            french_attention = extract_attention_scores(_hooks_french)[6]
            batch_size = english_attention.shape[0]
            english_attention_reshaped = english_attention.view(batch_size, -1)
            french_attention_reshaped = french_attention.view(batch_size, -1)
            
            discriminator_outputs_en = attn_discriminator(english_attention_reshaped)
            discriminator_outputs_fr = attn_discriminator(french_attention_reshaped)
            discriminator_outputs_cat = torch.cat((discriminator_outputs_en, discriminator_outputs_fr))
            discriminator_labels = torch.tensor([1.0]*batch_size + [0.0]*batch_size)
            discriminator_labels = discriminator_labels.unsqueeze(1).to(device=device)

            
            loss = loss_fn_single_regularization(english_sentences[:, 1:],
                                               french_sentences[:, 1:],
                                               dec_outs_en,
                                               dec_outs_fr,
                                               discriminator_labels,
                                               discriminator_outputs_cat,
                                              )
            
        elif (regularize == "hidden_state"):
            # using the pooled outputs of the encoders for regularizing 
            discriminator_outputs_en = discriminator(encoder_outputs_en[1])
            discriminator_outputs_fr = discriminator(encoder_outputs_fr[1])
            discriminator_outputs_cat = torch.cat((discriminator_outputs_en, discriminator_outputs_fr))
            discriminator_labels = torch.tensor([1.0]*discriminator_outputs_en.shape[0] + [0.0]*discriminator_outputs_fr.shape[0])
            discriminator_labels = discriminator_labels.unsqueeze(1).to(device=device)

            loss = loss_fn_single_regularization(english_sentences[:, 1:],
                                               french_sentences[:, 1:],
                                               dec_outs_en,
                                               dec_outs_fr,
                                               discriminator_labels,
                                               discriminator_outputs_cat,
                                              )
        else:
            loss = loss_fn_no_regularization(english_sentences[:, 1:],
                                           french_sentences[:, 1:],
                                           dec_outs_en,
                                           dec_outs_fr,
                                          )
        # must be put here to avoid name claim by val loop
        print("Batch num {}: Loss {}".format(batch_num, loss.item()))
        loss.backward()
        optimizer.step()
    
        # Running validation script  
        if (batch_num > 0 and batch_num % 100 == 0):
            with torch.no_grad():
                _blue_scores = []
                _exact_matches = []
                for batch_num, batch in enumerate(val_data_iter):
                    
                    if (batch_num == 20):
                        break
                    
                    # Reading in input and moving to device
                    english_batch, french_batch = batch 
                    (english_sentences, english_sentences_no_eos, english_sentences_lengths) = english_batch
                    (french_sentences, french_sentences_no_eos, french_sentences_lengths) = french_batch

                    english_sentences = english_sentences.to(device=device)
                    english_sentences_no_eos = english_sentences_no_eos.to(device=device)
                    english_sentences_lengths = english_sentences_lengths.to(device=device)
                    french_sentences = french_sentences.to(device=device)
                    french_sentences_no_eos = french_sentences_no_eos.to(device=device)
                    french_sentences_lengths = french_sentences_lengths.to(device=device)

                    # Encoding - Decoding for English -> French
                    encoder_outputs_en = encoder_en(english_sentences)
                    decoder_fr.init_state(english_sentences.unsqueeze(2).transpose(0,1), None, None)
                    dec_outs_fr, _ = decoder_fr(french_sentences_no_eos.unsqueeze(2).transpose(0,1), encoder_outputs_en[0].transpose(0,1), memory_lengths=english_sentences_lengths)
                    
                    # Calculate BLUE Scores, EM and Perplexity
                    predictions_fr = torch.argmax(dec_outs_fr, dim=2)
                    
                    for idx in range(french_sentences.shape[0]):
                        detokenized_french_gt = tokenizer_fr.convert_tokens_to_string(french_sentences[idx,1:].tolist())
                        detokenized_french_pred = tokenizer_fr.convert_tokens_to_string(predictions_fr[idx].tolist())
                            
                        _blue_score = sentence_bleu(detokenized_french_gt, detokenized_french_pred)
                        _blue_scores.append(_blue_score)

                    _exact_match = exact_match(predictions_fr, french_sentences[:,1:])
                    _exact_matches.append(_exact_match)
                    
                    
                print("BLUE {}, EM: {}".format(sum(_blue_scores)/len(_blue_scores),
                                               sum(_exact_matches)/len(_exact_matches)))
                    
        

In [43]:
train(train_dataloader, val_dataloader, regularize="attention")

torch.Size([2, 50])
torch.Size([2, 49])


RuntimeError: CUDA error: device-side assert triggered

To do: 
* Extract the attention
* Create attention loss 
* Create combination loss of attention and hidden 
* Print out accuracy for the encoder


In [None]:
for batch_num, batch in enumerate(train_dataloader):
        english_batch, french_batch = batch 
        (english_sentences, english_sentences_no_eos, english_sentences_lengths) = english_batch
        (french_sentences, french_sentences_no_eos, french_sentences_lengths) = french_batch
        print(french_sentences)
        break

In [None]:
tokenizer_fr.encode("bonjour <pad>")