### Setting up a basic training loop - not using attention

In [None]:
''' Changing directories '''
import os 
if 'BERT' not in os.getcwd():
    os.chdir('BERT')
print("Current working dir is {}".format(os.getcwd()))

In [None]:
import pyaml
import onmt
import torch
from dataset import TextDataset
from encoder import Encoder 
from decoder import Decoder
from discriminator import Discriminator
from lib.huggingface.transformers import RobertaTokenizer, CamembertTokenizer
from torch.utils.data import DataLoader

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
with open(os.path.join(os.path.dirname(os.getcwd()), "config", "config.yml"), "r") as fd:
    config = pyaml.yaml.load(fd, Loader=pyaml.yaml.Loader)

In [None]:
text_dataset_train = TextDataset("data/data-30k-default/", is_train=True)
text_dataset_val = TextDataset("data/data-30k-default/", is_train=False)

In [None]:
tokenizer_en = RobertaTokenizer.from_pretrained('roberta-base')
tokenizer_fr = CamembertTokenizer.from_pretrained('camembert-base')

In [None]:
def collate(data): 
    '''Collating function to be passed into the dataloader '''
    input_sentences, output_sentences = zip(*data)
    input_lengths = [len(sentence)+2 for sentence in input_sentences]
    output_lengths = [len(sentence)+2 for sentence in output_sentences]
    
    batch_size = len(input_sentences)
    
    max_input_lengths = max(input_lengths)
    max_output_lengths = max(output_lengths)
    
    max_length = max(max_input_lengths, max_output_lengths)
    
    input_idx_tensor = torch.zeros((batch_size, max_length), dtype=torch.long)
    output_idx_tensor = torch.zeros((batch_size, max_length), dtype=torch.long)
    
    for idx, (sentence_len, input_sentence) in enumerate(zip(input_lengths, input_sentences)): 
        input_idx_tensor[idx, :] = torch.tensor(tokenizer_en.encode(input_sentence) + [1]*(max_length-sentence_len))

    
    for idx, (sentence_len, output_sentence) in enumerate(zip(output_lengths, output_sentences)): 
        output_idx_tensor[idx, :] = torch.tensor(tokenizer_fr.encode(output_sentence) + [1]*(max_length-sentence_len))

    return ((input_idx_tensor, torch.tensor(input_lengths)), (output_idx_tensor, torch.tensor(output_lengths)))

In [None]:
train_dataloader = DataLoader(text_dataset_train, **config["data_loader"], collate_fn=collate)
val_dataloader = DataLoader(text_dataset_val, **config["data_loader"], collate_fn=collate)

#### Specifying the encoding and decoding models

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if (torch.cuda.is_available()):
    print("Using CUDA!")
else:
    print("Using CPU - Played yourself!")

In [None]:
try:
    del encoder_en
    del encoder_fr
except:
    pass 
encoder_en = Encoder("english").to(device=device)
encoder_fr = Encoder("french").to(device=device)

In [None]:
# same
word_padding_idx_en = encoder_en._modules['model'].embeddings.padding_idx
word_padding_idx_fr = encoder_fr._modules['model'].embeddings.padding_idx

# en > fr
word_vocab_size_en = encoder_en._modules['model'].embeddings.word_embeddings.num_embeddings
word_vocab_size_fr = encoder_fr._modules['model'].embeddings.word_embeddings.num_embeddings

# same
word_vec_size_en = encoder_en._modules['model'].embeddings.word_embeddings.embedding_dim
word_vec_size_fr = encoder_fr._modules['model'].embeddings.word_embeddings.embedding_dim

In [None]:
embeddings_en = onmt.modules.embeddings.Embeddings(
    word_vec_size_en, 
    word_vocab_size_en, 
    word_padding_idx_en, 
    position_encoding=True
).to(device=device)

embeddings_fr = onmt.modules.embeddings.Embeddings(
    word_vec_size_fr, 
    word_vocab_size_fr, 
    word_padding_idx_fr, 
    position_encoding=True
).to(device=device)

In [None]:
decoder_en = Decoder(**config["small_transformer"], embeddings=embeddings_en).to(device=device)
decoder_fr = Decoder(**config["small_transformer"], embeddings=embeddings_fr).to(device=device)

In [None]:
# projection: 768 -> 1
discriminator = Discriminator(config["small_transformer"]['d_model'], 1).to(device=device)

Beginning the training loop

In [None]:
def loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict):
    '''Standard machine translation cross entropy loss'''
    ce_loss = torch.nn.CrossEntropyLoss()
    loss_english_to_french = ce_loss(english_predict.transpose(1,2), english_gt)
    loss_french_to_english = ce_loss(french_predict.transpose(1,2), french_gt)
    return loss_english_to_french + loss_french_to_english

In [None]:
def loss_fn_hidden_regularization(english_gt, french_gt, english_predict, french_predict,
                                discriminator_gt, discriminator_predict):
    '''Adversarial Loss: standard loss with binary cross entropy on top of the discriminator outputs'''
    ce_term = loss_fn_no_regularization(english_gt, french_gt, english_predict, french_predict)
    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    regularizing_term = bce_loss(discriminator_predict, discriminator_gt)
    
    return ce_term + regularizing_term

In [None]:
def train(train_data_iter, val_data_iter, regularize="hidden_state"): 
    ''' 
    Train the encoding and decoding models. User needs to pass in a valid iterator over the data,
    and also specify a type of adversarial regularization. regularize = ["hidden_state", "attention"]
    '''
    
    params = list(encoder_en.parameters()) + list(encoder_fr.parameters()) +\
             list(decoder_fr.parameters()) + list(decoder_en.parameters())
    
    if (regularize == "hidden_state"):
        params += list(discriminator.parameters())
        
    optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9)
    
    for batch_num, batch in enumerate(train_data_iter):
            
        optimizer.zero_grad()
        
        # Reading in input and moving to device
        (english_sentences, english_sentences_lengths), (french_sentences, french_sentences_lengths) = batch
        english_sentences = english_sentences.to(device=device)
        english_sentences_lengths = english_sentences_lengths.to(device=device)
        french_sentences = french_sentences.to(device=device)
        french_sentences_lengths = french_sentences_lengths.to(device=device)
    
        # Encoding - Decoding for English -> French
        encoder_outputs_en = encoder_en(english_sentences)
        decoder_fr.init_state(english_sentences.unsqueeze(2).transpose(0,1), None, None)
        dec_outs_fr, _ = decoder_fr(french_sentences.unsqueeze(2).transpose(0,1)[:-1], encoder_outputs_en[0].transpose(0,1), memory_lengths=english_sentences_lengths)
        
        # Encoding - Decoding for French -> English
        encoder_outputs_fr = encoder_fr(french_sentences)
        decoder_en.init_state(french_sentences.unsqueeze(2).transpose(0,1), None, None) 
        dec_outs_en, _ = decoder_en(english_sentences.unsqueeze(2).transpose(0,1)[:-1], encoder_outputs_fr[0].transpose(0,1), memory_lengths=french_sentences_lengths)
        
        if (regularize == "hidden_state"):
            # using the pooled outputs of the encoders for regularizing 
            discriminator_outputs_en = discriminator(encoder_outputs_en[1])
            discriminator_outputs_fr = discriminator(encoder_outputs_fr[1])
            discriminator_outputs_cat = torch.cat((discriminator_outputs_en, discriminator_outputs_fr))
            discriminator_labels = torch.tensor([1.0]*discriminator_outputs_en.shape[0] + [0.0]*discriminator_outputs_fr.shape[0])
            discriminator_labels = discriminator_labels.unsqueeze(1).to(device=device)
            
            loss = loss_fn_hidden_regularization(english_sentences[:, 1:],
                                               french_sentences[:, 1:],
                                               dec_outs_en,
                                               dec_outs_fr,
                                               discriminator_labels,
                                               discriminator_outputs_cat,
                                              )
        else:
            loss = loss_fn_no_regularization(english_sentences,
                                           french_sentences,
                                           dec_outs_en,
                                           dec_outs_fr,
                                          )
        # must be put here to avoid name claim by val loop
        print("Batch num {}: Loss {}".format(batch_num, loss.item()))
        
        if (batch_num % 100 == 0):
            with torch.no_grad():
                for batch_num, batch in enumerate(val_data_iter):
                    # Reading in input and moving to device
                    (english_sentences_val, english_sentences_lengths_val), (french_sentences_val, french_sentences_lengths_val) = batch
                    english_sentences_val = english_sentences_val.to(device=device)
                    english_sentences_lengths_val = english_sentences_lengths_val.to(device=device)
                    french_sentences_val = french_sentences_val.to(device=device)
                    french_sentences_lengths_val = french_sentences_lengths_val.to(device=device)

                    # Encoding - Decoding for English -> French
                    encoder_outputs_en = encoder_en(english_sentences_val)
                    decoder_fr.init_state(english_sentences_val.unsqueeze(2).transpose(0,1), None, None) 
                    dec_outs_fr, _ = decoder_fr(french_sentences_val.unsqueeze(2).transpose(0,1)[:-1], encoder_outputs_en[0].transpose(0,1), memory_lengths=english_sentences_lengths_val)
                    
                    # Calculate BLUE Scores and EM
                    
        loss.backward()
        optimizer.step()
        

In [None]:
train(train_dataloader, val_dataloader)

To do: 
* train without any regularizing terms - train with regularizing terms on the hidden layers, train with regularizing terms on the attention, train with regularizing terms on both 
* Write up validation metrics and print out validation at intervals
