# NMT with Adversarial Regularization

#### 1. Load data

Load 2 Mil. Europarl v7 fr-en sentence pairs

#### 2. Build model

Initialize the encoder, decoder, and discriminator architectures

    experiment parameters:
        - encoder = {Transformer, RoBERTa, CamemBERT}
        - decoder = {Transformer}

#### 3. Define loss/metric functions

Define the sequence cross entropy and adversarial loss functions

    experiment parameters:
        - regularization = {encoder attention, latent variable, both}

#### 4. Define training logic

Define the optimizer and training loop for an arbitrary configuration


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from collections import defaultdict

from torch.cuda import is_available
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter

from modules.lib.huggingface import transformers
from modules.data import TextDataset, Collator
from modules.model import Embeddings, Encoder, Decoder, Discriminator, Hook 
from modules import utils

# TODO replace: 
experiment = "transformer_none"
config = utils.load_config("config/config.yml".format(experiment))

if is_available():
    device = "cuda"
else:
    device = "cpu"
    print("WARNING: CUDA IS NOT AVAILABLE".format(device))
    
ckpt_dir = "experiments/{}/checkpoints".format(experiment)
runs_dir = "experiments/{}/tensorboard".format(experiment)

try: 
    os.makedirs(ckpt_dir)
    os.makedirs(runs_dir)
except FileExistsError: 
    print("File already exists")

## 1. Load data

In [None]:
# Build tokenizer for English and French
tokenizer_en = transformers.RobertaTokenizer.from_pretrained('roberta-base')
tokenizer_fr = transformers.CamembertTokenizer.from_pretrained('camembert-base')

# Build TextDataset for train and valid
data_path = utils.data_path("europarl-v7")
dataset_train = TextDataset(
    data_path, 
    tokenizer_en, 
    tokenizer_fr, 
    training=True, 
    minlen=config["minlen"],
    maxlen=config["maxlen"]
)
dataset_valid = TextDataset(
    data_path, 
    tokenizer_en, 
    tokenizer_fr, 
    training=False, 
    minlen=config["minlen"],
    maxlen=config["maxlen"]
)

# Build DataLoader for train and valid
collator = Collator(maxlen=config["maxlen"])
dataloader_train = DataLoader(dataset_train, **config["data_loader"], collate_fn=collator)
dataloader_valid = DataLoader(dataset_valid, **config["data_loader"], collate_fn=collator)

## 2. Build Model

1. Using the Output Embedding to Improve Language Models - http://arxiv.org/abs/1608.05859

In [None]:
# Init BERT encoder w/pretrained weights
bert_en = Encoder.init_bert("english").to(device=device)
bert_fr = Encoder.init_bert("french").to(device=device)

# Init embeddings w/pretrained weights from BERT encoder
embeddings_en = Embeddings.from_pretrained(bert_en.model.get_input_embeddings()).to(device=device)
embeddings_fr = Embeddings.from_pretrained(bert_fr.model.get_input_embeddings()).to(device=device)

if config["use_bert"]:
    # Use BERT encoder
    encoder_en = bert_en
    encoder_fr = bert_fr
else:
    # Init vanilla Transformer encoder w/pretrained embeddings from BERT encoder
    del bert_en
    del bert_fr
    encoder_en = Encoder.init_vanilla(**config["vanilla_encoder"], embeddings=embeddings_en).to(device=device)
    encoder_fr = Encoder.init_vanilla(**config["vanilla_encoder"], embeddings=embeddings_fr).to(device=device)

# Init Hooks for encoder layers
hooks_en = [Hook(layer[1]) for layer in list(encoder_en.named_modules())]
hooks_fr = [Hook(layer[1]) for layer in list(encoder_fr.named_modules())]
    
# Init vanilla Transformer decoder w/pretrained embeddings from BERT encoder
decoder_en = Decoder(**config["vanilla_decoder"], embeddings=embeddings_en).to(device=device)
decoder_fr = Decoder(**config["vanilla_decoder"], embeddings=embeddings_fr).to(device=device)

# Init Discriminator(s)
discriminators = {}
for regularization in config["regularization"]["type"]:
    discriminators[regularization] = Discriminator(regularization, **config["discriminator"]).to(device=device)

## 3. Define loss/metric functions

In [None]:
def loss_fn(real_en, real_fr, pred_en, pred_fr, real_pred_ys={}, ignore_index=1):
    '''
    Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs
    '''
    cce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
    loss_en2fr = cce_loss(pred_fr.transpose(1,2), real_fr)
    loss_fr2en = cce_loss(pred_en.transpose(1,2), real_en)
    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    reg_losses = defaultdict(lambda: tensor(0.0))
    for regularization in real_pred_ys:
        real_y, pred_y = real_pred_ys[regularization]
        reg_losses[regularization] = bce_loss(pred_y, real_y)
    return cce_loss + sum(reg_losses.values()), loss_en2fr, loss_fr2en, reg_losses

In [None]:
def exact_match(pred, real, ignore_index=1):
    '''
    Evaluate percent exact match between predictions and ground truth
    '''
    mask = real != ignore_index
    return torch.sum((pred == real) * mask).item() / torch.sum(mask).item()

## Defining training logic

1. What Does BERT Look At? An Analysis of BERT's Attention - https://arxiv.org/abs/1906.04341

In [None]:
def get_optimizer(encoder_en, encoder_fr, decoder_en, decoder_fr, discriminators, **kwargs):
    params = (
        list(encoder_en.parameters()) + list(encoder_fr.parameters()) +
        list(decoder_fr.parameters()) + list(decoder_en.parameters())
    )
    for regularization in discriminators:
        if discriminators[regularization] is not None and regularization != "none":
            params += list(discriminators[regularization].parameters())
    return Adam(params, **kwargs)

def switch_trainable(model, step):
    switch = step % 2 == 0
    if len(model["discriminators"]) > 0:
        for module in model:
            if module == "discriminators":
                for regularization in model[module]:
                    for param in model[module][regularization].parameters():
                        param.requires_grad = not switch
            elif module in ("encoder_en, encoder_fr"):
                for param in model[module].parameters():
                    param.requires_grad = switch

def save_weights(model, step):
    for name, module in model.items(): 
        module.save_state_dict(os.path.join(ckpt_dir, "{}.{}.pt".format(step, name)))
        
def forward(model, batch):
        # Unpack batch and move to device
        batch_en, batch_fr = batch
        sents_en, sents_no_eos_en, lengths_en = map(lambda t: t.to(device=device), batch_en)
        sents_fr, sents_no_eos_fr, lengths_fr = map(lambda t: t.to(device=device), batch_fr)

        # Encode English to French
        enc_out_en = encoder_en(sents_en, lengths=lengths_en)
        # Decoder English to French
        decoder_fr.init_state(sents_en)
        dec_out_fr = decoder_fr(sents_no_eos_fr, enc_out_en, memory_lengths=lengths_en)

        # Encoder French to English
        enc_out_fr = encoder_fr(sents_fr, lengths=lengths_fr)
        # Decoder French to English
        decoder_en.init_state(sents_fr)
        dec_out_en = decoder_en(sents_no_eos_en, enc_out_fr, memory_lengths=lengths_fr)
        
        return sents_en, sents_fr, enc_out_en, enc_out_fr, dec_out_en, dec_out_fr

In [None]:
def train(): 
    ''' 
    Train the encoding and decoding models. User needs to pass in a valid iterator over the data,
    and also specify a type of adversarial regularization. regularize = ["hidden", "attention", "both"]
    '''
    model = {
        "encoder_en": encoder_en,
        "encoder_fr": encoder_fr,
        "decoder_en": decoder_en,
        "decoder_fr": decoder_fr,
        "discriminators": discriminators,
    }
    
    optimizer = get_optimizer(**model, **config["adam"])
    
    writer = SummaryWriter(runs_dir)
                                                   
    for batch_i, batch in enumerate(dataloader_train):
        
        # Clear optimizer
        optimizer.zero_grad()
        
        # Alternate trainable for encoder/decoder 
        # and discriminator parameters
        switch_trainable(model, batch_i)

        # Save weights and continue training
        if batch_i > 0 and batch_i % config["checkpoint_frequency"] == 0:
            save_weights(model, batch_i)
        
        # Save weights and terminate training
        if batch_i >= config["max_step_num"]:
            save_weights(model, batch_i)
            break
        
        # Unpack the batch, run the encoders, run the decoders
        sents_en, sents_fr, enc_out_en, enc_out_fr, dec_out_en, dec_out_fr = forward(model, batch)
        
        # Initial default values for regularization 
        real_pred_ys = {}
        switch = step % 2 == 0
        y_real = [float(switch)] * config["batch_size"] + [float(not switch)] * config["batch_size"]
        y_real = torch.tensor(y_real).unsqueeze(-1).to(device=device)
        
        # Gather attention discriminator labels/predictions
        if "attention" in model["discriminators"]:
            # Use the attention scores of the 6th layer
            attention_en = utils.extract_attention_scores(hooks_en)[6].view(config["batch_size"], -1)
            attention_fr = utils.extract_attention_scores(hooks_fr)[6].view(config["batch_size"], -1)
            y_attn_pred_en = model["discriminators"]["attention"](attention_en)
            y_attn_pred_fr = model["discriminators"]["attention"](attention_fr)
            y_attn_pred = torch.cat([y_attn_pred_en, y_attn_pred_fr])
            real_pred_ys["attention"] = y_real, y_attn_pred
            
        # Gather hidden discriminator labels/predictions
        if "attention" in model["discriminators"]:
            # Use the pooled outputs of the encoders for regularization
            y_hddn_pred_en = model["discriminators"]["hidden"](enc_out_en)
            y_hddn_pred_fr = model["discriminators"]["hidden"](enc_out_fr)
            y_hddn_pred = torch.cat([y_hddn_pred_en, y_hddn_pred_fr])
            real_pred_ys["hidden"] = y_real, y_hddn_pred

        loss, loss_en2fr, loss_fr2en, reg_losses = loss_fn(
            sents_en[:, 1:], sents_fr[:, 1:],
            dec_out_en, dec_out_fr, 
            real_pred_ys,
            ignore_index=1
        )
        
        # Optimize trainable parameters
        loss.backward()
        optimizer.step()

        # Write training losses/metrics to stdout and tensorboard
        if batch_i % config["logging_frequency"] == 0:
            print("Batch {}: Loss {}".format(batch_i, loss.item()))
            cce_metrics = {"en-fr": loss_en2fr.item(), "fr-en": loss_fr2en.item()}
            utils.write_to_tensorboard("CCE", cce_metrics, training=True, step=batch_i, writer=writer)
            bce_metrics = {"attn": reg_losses["attention"].item(), "hddn": reg_losses["hidden"].item()}
            utils.write_to_tensorboard("BCE", bce_metrics, training=True, step=batch_i, writer=writer)

        # Running validation script  
        if batch_i > 0 and batch_i % config["val_frequency"] == 0:
            with torch.no_grad():
                
                _loss_en2fr_val = []
                _loss_fr2en_val = []                

                _bleu_en2fr = []
                _bleu_fr2en = []

                _em_en2fr = []
                _em_fr2en = []
                for batch_j, batch in enumerate(dataloader_valid):
                    
                    if (batch_j == config["n_valid"]):
                        break
                    
                    # Unpack the batch, run the encoders, run the decoders
                    sents_en, sents_fr, enc_out_en, enc_out_fr, dec_out_en, dec_out_fr = forward(model, batch)
                    
                    # Calculate BLUE Scores, EM and Perplexity
                    pred_fr = torch.argmax(dec_out_fr, dim=-1)
                    pred_en = torch.argmax(dec_out_en, dim=-1)
                    
                    _, loss_en2fr_val, loss_fr2en_val, _ = loss_fn(
                        sents_en[:, 1:], sents_fr[:, 1:],
                        dec_out_en, dec_out_fr,
                        ignore_index=1
                    )
                    _loss_en2fr_val.append(loss_en2fr_val.item())
                    _loss_fr2en_val.append(loss_fr2en_val.item())
                    
                    for idx in range(batch_size):
                        text_real_fr = tokenizer_fr.convert_tokens_to_string(sents_fr[idx, 1:-1].tolist())
                        text_pred_fr = tokenizer_fr.convert_tokens_to_string(preds_fr[idx, 0:-1].tolist())
                        _bleu_en2fr.append(sentence_bleu(text_real_fr, text_pred_fr))
                        
                        text_real_fr = tokenizer_fr.convert_tokens_to_string(sents_fr[idx, 1:-1].tolist())
                        text_pred_fr = tokenizer_fr.convert_tokens_to_string(preds_fr[idx, 0:-1].tolist())
                        _bleu_en2fr.append(sentence_bleu(text_real_fr, text_pred_fr))
                        
                    _em_en2fr.append(exact_match(preds_en[:, 0:-1], sents_fr[:, 1:-1]))
                    _em_fr2en.append(exact_match(preds_en[:, 0:-1], sents_en[:, 1:-1]))

                avg_em_en2fr = sum(_em_en2fr) / batch_size
                avg_em_fr2en = sum(_em_fr2en) / batch_size
                avg_bleu_en2fr = sum(_bleu_en2fr) / batch_size
                avg_bleu_fr2en = sum(_bleu_fr2en) / batch_size
                avg_loss_en2fr_val = sum(_val_loss_en2fr) / batch_size
                avg_loss_fr2en_val = sum(_val_loss_fr2en) / batch_size
                
                bleu_metrics = {"en-fr": avg_bleu_en2fr, "fr-en": avg_bleu_fr2en}
                write_to_tensorboard("BLEU", bleu_metrics, training=False, step=batch_i, writer=writer)
                
                em_metrics = {"en-fr": avg_em_en2fr, "fr-en": avg_em_fr2en}
                write_to_tensorboard("EM", em_metrics, training=False, step=batch_i, writer=writer)
                
                loss_val_metrics = {"en-fr": avg_loss_en2fr_val, "fr-en": avg_loss_fr2en_val}
                write_to_tensorboard("CCE", loss_val_metrics, training=False, step=batch_i, writer=writer)

In [None]:
train()