# NMT with Adversarial Regularization

#### 1. Load data

Load 2 Mil. Europarl v7 fr-en sentence pairs

#### 2. Build model

Initialize the encoder, decoder, and discriminator architectures

    experiment parameters:
        - encoder = {Transformer, RoBERTa, CamemBERT}
        - decoder = {Transformer}

#### 3. Define loss/metric functions

Define the sequence cross entropy and adversarial loss functions

    experiment parameters:
        - regularization = {encoder attention, latent variable, both}

#### 4. Define training logic

Define the optimizer and training loop for an arbitrary configuration


In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
import os

from torch.cuda import is_available
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter

from modules.lib.huggingface import transformers
from modules.data import TextDataset, Collator
from modules.model import Embeddings, Encoder, Decoder, Discriminator, Hook 
from modules import utils

# TODO replace: 
experiment = "transformer_none"
config = utils.load_config("config/{}.yml".format(experiment))

if is_available():
    device = "cuda"
else:
    device = "cpu"
    print("WARNING: CUDA IS NOT AVAILABLE".format(device))
    
ckpt_dir = "experiments/{}/checkpoints".format(experiment)
runs_dir = "experiments/{}/tensorboard".format(experiment)

try: 
    os.makedirs(ckpt_dir)
    os.makedirs(runs_dir)
except FileExistsError: 
    print("File already exists")

File already exists


## 1. Load data

In [2]:
# Build tokenizer for English and French
tokenizer_en = transformers.RobertaTokenizer.from_pretrained('roberta-base')
tokenizer_fr = transformers.CamembertTokenizer.from_pretrained('camembert-base')

# Build TextDataset for train and valid
data_path = utils.data_path("europarl-v7")
dataset_train = TextDataset(
    data_path, 
    tokenizer_en, 
    tokenizer_fr, 
    training=True, 
    minlen=config["minlen"],
    maxlen=config["maxlen"],
    size=config["n_train"]
)
dataset_valid = TextDataset(
    data_path, 
    tokenizer_en, 
    tokenizer_fr, 
    training=False, 
    minlen=config["minlen"],
    maxlen=config["maxlen"],
    size=config["n_valid"]
)

# Build DataLoader for train and valid
collator = Collator(maxlen=config["maxlen"])
dataloader_train = DataLoader(dataset_train, **config["data_loader"], collate_fn=collator)
dataloader_valid = DataLoader(dataset_valid, **config["data_loader"], collate_fn=collator)

**Loading in pre-saved file: /home/john_kamalu/unidirectional-NMT/data/europarl-v7/data.train.pkl
**Loading in pre-saved file: /home/john_kamalu/unidirectional-NMT/data/europarl-v7/data.val.pkl


## 2. Build Model

1. Using the Output Embedding to Improve Language Models - http://arxiv.org/abs/1608.05859

In [4]:
# Init BERT encoder w/pretrained weights
bert_en = Encoder.init_bert("english").to(device=device)
bert_fr = Encoder.init_bert("french").to(device=device)

# Init embeddings w/pretrained weights from BERT encoder
embeddings_en = Embeddings.from_pretrained(bert_en.model.get_input_embeddings()).to(device=device)
embeddings_fr = Embeddings.from_pretrained(bert_fr.model.get_input_embeddings()).to(device=device)

if config["use_bert"]:
    # Use BERT encoder
    encoder_en = bert_en
    encoder_fr = bert_fr
else:
    # Init vanilla Transformer encoder w/pretrained embeddings from BERT encoder
    del bert_en
    del bert_fr
    encoder_en = Encoder.init_vanilla(**config["vanilla_encoder"], embeddings=embeddings_en).to(device=device)
    encoder_fr = Encoder.init_vanilla(**config["vanilla_encoder"], embeddings=embeddings_fr).to(device=device)

# Init vanilla Transformer decoder w/pretrained embeddings from BERT encoder
decoder_en = Decoder(**config["vanilla_decoder"], embeddings=embeddings_en).to(device=device)
decoder_fr = Decoder(**config["vanilla_decoder"], embeddings=embeddings_fr).to(device=device)

# Init Discriminator
attn_discriminator = None
hidden_discriminator = None
if (config["regularization"]["type"] == "attention"):
    attn_discriminator = Discriminator(config["maxlen"] ** 2, 1, config["regularization"]["n_affine"]).to(device=device)
if (config["regularization"]["type"] == "hidden"): 
    hidden_discriminator = Discriminator(config["d_model"], 1).to(device=device)
if (config["regularization"]["type"] == "both"): 
    attn_discriminator = Discriminator(config["maxlen"] ** 2, 1, config["regularization"]["n_affine"]).to(device=device)
    hidden_discriminator = Discriminator(config["d_model"], 1).to(device=device)

## 3. Define loss/metric functions

In [5]:
def loss_fn_no_regularization(real_en, real_fr, pred_en, pred_fr, ignore_index=1):
    '''
    Standard machine translation cross entropy loss
    '''
    cce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
    loss_en2fr = cce_loss(pred_fr.transpose(1,2), real_fr)
    loss_fr2en = cce_loss(pred_en.transpose(1,2), real_en)
    return loss_en2fr + loss_fr2en, loss_en2fr, loss_fr2en

In [6]:
def loss_fn_single_regularization(real_en, real_fr, pred_en, pred_fr, real_y, pred_y):
    '''
    Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs
    '''
    crossentropy_term, loss_en2fr, loss_fr2en = loss_fn_no_regularization(real_en, real_gt, pred_en, pred_fr)
    bce_loss = torch.nn.BCEWithLogitsLoss()
    regularizing_term = bce_loss(pred_y, real_y)
    return crossentropy_term + regularizing_term, loss_en2fr, loss_fr2en, regularizing_term

In [7]:
def loss_fn_multi_regularization(real_en, real_fr, pred_en, pred_fr, real_ys, pred_ys):
    '''
    Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs
    '''
    crossentropy_term, loss_en2fr, loss_fr2en = loss_fn_no_regularization(real_en, real_gt, pred_en, pred_fr)    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    regularizing_terms = []
    for real_y, pred_y in zip(real_ys, pred_ys):
        regularizing_terms.append(bce_loss(pred_y, real_y))
    return crossentropy_term + sum(regularizing_terms), loss_en2fr, loss_fr2en, regularizing_terms


In [8]:
def exact_match(pred, real, ignore_index=1):
    '''
    Evaluate percent exact match between predictions and ground truth
    '''
    mask = real != ignore_index
    return torch.sum((pred == real) * mask).item() / torch.sum(mask).item()

In [9]:
from nltk.translate.bleu_score import sentence_bleu

## Defining training logic

1. What Does BERT Look At? An Analysis of BERT's Attention - https://arxiv.org/abs/1906.04341

In [10]:
def get_optimizer(encoder_en, encoder_fr, decoder_en, decoder_fr, attn_discriminator=None, hidden_discriminator=None, **kwargs):
    params = (
        list(encoder_en.parameters()) + list(encoder_fr.parameters()) +
        list(decoder_fr.parameters()) + list(decoder_en.parameters())
    )
    if hidden_discriminator is not None:
        params += list(hidden_discriminator.parameters())
    if attn_discriminator is not None: 
        params += list(attn_discriminator.parameters())
    return Adam(params, **kwargs)

In [11]:
optimizer = get_optimizer(
    encoder_en, encoder_fr, 
    decoder_en, decoder_fr,
    attn_discriminator = attn_discriminator,
    hidden_discriminator = hidden_discriminator,
    **config["adam"],)

In [12]:
_hooks_en = [Hook(layer[1]) for layer in list(encoder_en.named_modules())]
_hooks_fr = [Hook(layer[1]) for layer in list(encoder_fr.named_modules())]

In [31]:
model = {
    "encoder_en": encoder_en,
    "encoder_fr": encoder_fr,
    "decoder_en": decoder_en,
    "decoder_fr": decoder_fr,
    "attn_discriminator": attn_discriminator,
    "hidden_discriminator": hidden_discriminator,
}

def train(model, dataloader_train, dataloader_valid, optimizer, regularization): 
    ''' 
    Train the encoding and decoding models. User needs to pass in a valid iterator over the data,
    and also specify a type of adversarial regularization. regularize = ["hidden", "attention", "both"]
    '''
    writer = SummaryWriter(runs_dir)
                                                   
    for batch_i, batch in enumerate(dataloader_train):
        
        if (batch_i == config["checkpoint_frequency"]):   
            for _module_name, _model in model.items(): 
                _model.save_state_dict(os.path.join(ckpt_dir, "{}.{}.pt".format(batch_i,_module_name)))
        
        if (batch_i >= config["max_step_num"]):
            for _module_name, _model in model.items(): 
                _model.save_state_dict(os.path.join(ckpt_dir, "{}.{}.pt".format(batch_i,_module_name)))
            break 
        
        optimizer.zero_grad()
        
        switch = batch_i % 2 == 0
        
        if regularization["type"] != "none":
            print(regularization["type"], type(regularization["type"]))
            for module in model:
                if "discriminator" in module:
                    for param in model[module].parameters():
                        param.requires_grad = not switch
                elif "encoder" in module:
                    for param in model[module].parameters():
                        param.requires_grad = switch

        # Read in input and move to device
        batch_en, batch_fr = batch
        sents_en, sents_no_eos_en, lengths_en = batch_en
        sents_fr, sents_no_eos_fr, lengths_fr = batch_fr
        
        if not config["use_bert"]:
            sents_en = sents_en.unsqueeze(-1)
            sents_fr = sents_fr.unsqueeze(-1)

        sents_en = sents_en.to(device=device)
        sents_no_eos_en = sents_no_eos_en.to(device=device)
        lengths_en = lengths_en.to(device=device)
        
        sents_fr = sents_fr.to(device=device)
        sents_no_eos_fr = sents_no_eos_fr.to(device=device)
        lengths_fr = lengths_fr.to(device=device)

        # Encoding/Decoding for en -> fr
        if not config["use_bert"]:
            enc_out_en = encoder_en(sents_en.transpose(0,1), lengths=lengths_en)
        else:
            enc_out_en = encoder_en(sents_en)
        if not config["use_bert"]:
            enc_out_en[1] = enc_out_en[1].transpose(0,1)
        decoder_fr.init_state(sents_en.unsqueeze(2).transpose(0,1), None, None)
        dec_out_fr, _ = decoder_fr(
            sents_no_eos_fr.unsqueeze(2).transpose(0,1), 
            enc_out_en[1].transpose(0,1), 
            memory_lengths=lengths_en
        )

        # Encoding/Decoding for fr -> en
        if not config["use_bert"]:
            enc_out_fr = encoder_fr(sents_fr.transpose(0,1), lengths=lengths_fr)
        else:
            enc_out_fr = encoder_fr(sents_fr)
        if not config["use_bert"]:
            enc_out_fr[1] = enc_out_fr[1].transpose(0,1)
        decoder_en.init_state(sents_fr.unsqueeze(2).transpose(0,1), None, None) 
        dec_out_en, _ = decoder_en(
            sents_no_eos_en.unsqueeze(2).transpose(0,1), 
            enc_out_fr[1].transpose(0,1), 
            memory_lengths=lengths_fr
        )
        
        # Initial default values for regularization 
        attention_regularization = torch.tensor(0.0)
        hidden_regularization = torch.tensor(0.0)

        if regularization["type"] == "attention":
            
            batch_size = len(sents_en)
            
            attention_en = utils.extract_attention_scores(_hooks_en)[6]
            attention_fr = utils.extract_attention_scores(_hooks_fr)[6]
            attention_en = attention_en.view(batch_size, -1)
            attention_fr = attention_fr.view(batch_size, -1)
            
            discriminator_output_en = attn_discriminator(attention_en)
            discriminator_output_fr = attn_discriminator(attention_fr)
            discriminator_output = torch.cat([discriminator_output_en, discriminator_output_fr])
            
            switch = batch_i % 2 == 0
            discriminator_labels = torch.tensor([float(switch)] * batch_size + [float(not switch)] * batch_size)
            discriminator_labels = discriminator_labels.unsqueeze(1).to(device=device)

            losses = loss_fn_single_regularization(
                sents_en[:, 1:],
                sents_fr[:, 1:],
                dec_outs_en,
                dec_outs_fr,
                discriminator_labels,
                discriminator_output
            )
            
            loss, loss_en2fr, loss_fr2en, attention_regularization = losses
            
        elif regularization["type"] == "hidden":
            # Use the pooled outputs of the encoders for regularization
            discriminator_output_en = hidden_discriminator(enc_out_en[1])
            discriminator_output_fr = hidden_discriminator(enc_out_fr[1])
            discriminator_output = torch.cat((discriminator_output_en, discriminator_output_fr))
            
            switch = batch_i % 2 == 0
            discriminator_labels = torch.tensor([float(switch)] * batch_size + [float(not switch)] * batch_size)
            discriminator_labels = discriminator_labels.unsqueeze(1).to(device=device)

            losses = loss_fn_single_regularization(
                sents_en[:, 1:],
                sents_fr[:, 1:],
                dec_outs_en,
                dec_outs_fr,
                discriminator_labels,
                discriminator_output
            )
            
            loss, loss_en2fr, loss_fr2en, hidden_regularization = losses
        elif regularization["type"] == "both":
            
            batch_size = len(sents_en)
            
            # Applying attention regularization  
            attention_en = utils.extract_attention_scores(_hooks_en)[6]
            attention_fr = utils.extract_attention_scores(_hooks_fr)[6]
            attention_en = attention_en.view(batch_size, -1)
            attention_fr = attention_fr.view(batch_size, -1)
            
            attn_discriminator_output_en = attn_discriminator(attention_en)
            attn_discriminator_output_fr = attn_discriminator(attention_fr)
            attn_discriminator_output = torch.cat([attn_discriminator_output_en, attn_discriminator_output_fr])
            
            # Applying hidden regularization
            hidden_discriminator_output_en = hidden_discriminator(enc_out_en[1])
            hidden_discriminator_output_fr = hidden_discriminator(enc_out_fr[1])
            hidden_discriminator_output = torch.cat((hidden_discriminator_output_en, hidden_discriminator_output_fr))
            
            # Creating labels
            switch = batch_i % 2 == 0
            discriminator_labels = torch.tensor([float(switch)] * batch_size + [float(not switch)] * batch_size)
            discriminator_labels = [discriminator_labels.unsqueeze(1).to(device=device),\
                                    discriminator_labels.unsqueeze(1).to(device=device)]
            
            discriminator_output = [attn_discriminator_output, hidden_discriminator_output]

            losses = loss_fn_multi_regularization(
                sents_en[:, 1:],
                sents_fr[:, 1:],
                dec_outs_en,
                dec_outs_fr,
                discriminator_labels,
                discriminator_output
            )
            loss = losses[0]
            attention_regularization, hidden_regularization = losses[-1]

        else:
            losses = loss_fn_no_regularization(
                sents_en[:, 1:],
                sents_fr[:, 1:],
                dec_outs_en,
                dec_outs_fr
            )
            
            loss, loss_en2fr, loss_fr2en = losses
            
        if (batch_i % 50 == 0):
            print("Batch {}: Loss {}".format(batch_i, loss.item()))

        # Optimizer weights
        loss.backward()
        optimizer.step()
        
        # Write training losses/metrics to tensorboard
        cce_metrics = {"en-fr": loss_en2fr.item(), "fr-en": loss_fr2en.item()}
        utils.write_to_tensorboard("CCE", cce_metrics, training=True, step=batch_i, writer=writer)
        if regularization is not None:
            bce_metrics = {"attention_regularization": attention_regularization.item(),
                          "hidden_regularization": hidden_regularization.item()}
            utils.write_to_tensorboard("BCE", bce_metrics, training=True, step=batch_i, writer=writer)

        # Running validation script  
        if (batch_i > 0 and batch_i % 500 == 0):
            with torch.no_grad():
                _blue_scores_en2fr = []
                _exact_matches_en2fr = []
                _blue_scores_fr2en = []
                _exact_matches_fr2en = []
                _val_loss_en2fr = []
                _val_loss_en2fr = []
                for batch_j, batch in enumerate(dataloader_valid):
                    
                    if (batch_j == 50):
                        break
                    
                    # Read in input and move to device
                    batch_en, batch_fr = batch
                    sents_en, sents_no_eos_en, lengths_en = batch_en
                    sents_fr, sents_no_eos_fr, lengths_fr = batch_fr

                    sents_en = sents_en.to(device=device)
                    sents_no_eos_en = sents_no_eos_en.to(device=device)
                    lengths_en = lengths_en.to(device=device)

                    sents_fr = sents_fr.to(device=device)
                    sents_no_eos_fr = sents_no_eos_fr.to(device=device)
                    lengths_fr = lengths_fr.to(device=device)

                    # Encoding/Decoding for en -> fr
                    enc_out_en = encoder_en(sents_en)
                    decoder_fr.init_state(sents_en.unsqueeze(2).transpose(0,1), None, None)
                    dec_out_fr, _ = decoder_fr(
                        sents_no_eos_fr.unsqueeze(2).transpose(0,1), 
                        enc_out_en[0].transpose(0,1), 
                        memory_lengths=lengths_en
                    )

                    # Encoding/Decoding for fr -> en
                    enc_out_fr = encoder_fr(sents_fr)
                    decoder_en.init_state(sents_fr.unsqueeze(2).transpose(0,1), None, None) 
                    dec_out_en, _ = decoder_en(
                        sents_no_eos_en.unsqueeze(2).transpose(0,1), 
                        enc_out_fr[0].transpose(0,1), 
                        memory_lengths=lengths_fr
                    )
                    
                    # Calculate BLUE Scores, EM and Perplexity
                    preds_fr = torch.argmax(dec_out_fr, dim=2)
                    preds_en = torch.argmax(dec_out_en, dim=2)
                    
                    _, val_loss_en2fr, val_loss_fr2en = loss_fn_no_regularization(sents_en[:, 1:], sents_fr[:, 1:], dec_outs_en, dec_outs_fr)
                    _val_loss_en2fr.append(val_loss_en2fr.item())
                    _val_loss_fr2en.append(val_loss_fr2en.item())
                    
                    for idx in range(batch_size):
                        detokenized_real_fr = tokenizer_fr.convert_tokens_to_string(sents_fr[idx, 1:].tolist())
                        detokenized_pred_fr = tokenizer_fr.convert_tokens_to_string(preds_fr[idx].tolist())
                        _blue_scores_en2fr.append(sentence_bleu(detokenized_real_fr, detokenized_pred_fr))
                        
                        detokenized_real_en = tokenizer_en.convert_tokens_to_string(sents_en[idx, 1:].tolist())
                        detokenized_pred_en = tokenizer_en.convert_tokens_to_string(preds_en[idx].tolist())
                        _blue_scores_fr2en.append(sentence_bleu(detokenized_real_en, detokenized_pred_en))
                        
                    _exact_matches_en2fr.append(exact_match(preds_fr, sents_fr[:, 1:]))
                    _exact_matches_fr2en.append(exact_match(preds_en, sents_en[:, 1:]))
                    
                                                   
                avg_bleu_en2fr = sum(_blue_scores_en2fr) / len(_blue_scores_en2fr)
                avg_bleu_fr2en = sum(_blue_scores_fr2en) / len(_blue_scores_fr2en)
                avg_em_en2fr = sum(_exact_matches_en2fr) / len(_exact_matches_en2fr)
                avg_em_fr2en = sum(_exact_matches_fr2en) / len(_exact_matches_fr2en)
                avg_loss_en2fr = sum(_val_loss_en2fr) / len(_val_loss_en2fr)
                avg_loss_fr2en = sum(_val_loss_fr2en) / len(_val_loss_fr2en)
                
                bleu_metrics = {"en-fr": avg_bleu_en2fr, "fr-en":avg_bleu_fr2en}
                write_to_tensorboard("BLEU", bleu_metrics, training=False, step=batch_i, writer=writer)
                
                exact_match_metrics = {"en-fr": avg_em_en2fr, "fr-en":avg_em_fr2en}
                write_to_tensorboard("EM", exact_match_metrics, training=False, step=batch_i, writer=writer)
                
                val_loss_metrics = {"en-fr": avg_loss_en2fr, "fr-en":avg_loss_fr2en}
                write_to_tensorboard("CE_LOSS", val_loss_metrics, training=False, step=batch_i, writer=writer)

In [32]:
train(model, dataloader_train, dataloader_valid, optimizer, config["regularization"])

tensor([[    0],
        [ 9226],
        [11459],
        [   45],
        [  129],
        [    7],
        [    5],
        [ 7401],
        [    8],
        [13135],
        [    9],
        [15569],
        [   53],
        [   67],
        [    7],
        [   49],
        [ 1318],
        [    4],
        [    2],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1],
        [    1]], device='cuda:0') tensor(19, device='cuda:0')
sents_en.shape = torch.Size([16, 50, 1])
sents_en.transpose(0,1).shape = torch.Size([50, 16, 1])
sents_n

RuntimeError: The size of tensor a (48) must match the size of tensor b (50) at non-singleton dimension 3