In [None]:
# installing modules
!pip install indic-nlp-library
!pip install rouge-score

Collecting indic-nlp-library
  Downloading https://files.pythonhosted.org/packages/2f/51/f4e4542a226055b73a621ad442c16ae2c913d6b497283c99cae7a9661e6c/indic_nlp_library-0.71-py3-none-any.whl
Collecting morfessor
  Downloading https://files.pythonhosted.org/packages/39/e6/7afea30be2ee4d29ce9de0fa53acbb033163615f849515c0b1956ad074ee/Morfessor-2.0.6-py3-none-any.whl
Installing collected packages: morfessor, indic-nlp-library
Successfully installed indic-nlp-library-0.71 morfessor-2.0.6
Collecting rouge-score
  Downloading https://files.pythonhosted.org/packages/1f/56/a81022436c08b9405a5247b71635394d44fe7e1dbedc4b28c740e09c2840/rouge_score-0.0.4-py2.py3-none-any.whl
Installing collected packages: rouge-score
Successfully installed rouge-score-0.0.4


In [None]:
'''
    Code for training transformer for the cross lingual summarization task
'''
import os
import nltk
import math
import torch
import torchtext
import pandas as pd
import numpy as np
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torchtext.vocab import Vectors
from google.colab import drive
from tqdm.notebook import tqdm
from indicnlp.tokenize import indic_tokenize
from rouge_score import rouge_scorer

nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
# setting up the configurations
drive.mount('/content/drive')

# setting the device variable
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

Mounted at /content/drive


In [None]:
# Fixing the fasttext class
class FastText(Vectors):
    new_url_base = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.{}.vec'

    def __init__(self, language="en", **kwargs):
        url = self.new_url_base.format(language)
        name = os.path.basename(url)
        super(FastText, self).__init__(name, url=url, **kwargs)

# downloading the embeddings
glove = torchtext.vocab.GloVe(name='6B', dim=300)
fasttext_hindi = FastText(language='hi')

In [None]:
# tokenizers for both the languages
def english_tokenize(sentence):
    return [word.lower() for word in nltk.tokenize.word_tokenize(str(sentence))]

def hindi_tokenize(sentence):
    return indic_tokenize.trivial_tokenize(str(sentence), lang='hi')

# parsing the dataset using the torchtext utility
def parse_using_torchtext(csv_file_name, batch_size=16, english_vocab_size=10000, hindi_vocab_size=10000):

    # defining the fields
    english_field = torchtext.data.Field(
        sequential=True,
        init_token='<sos>',
        eos_token='<eos>',
        tokenize=english_tokenize,
        batch_first=False
    )
    hindi_field = torchtext.data.Field(
        sequential=True,
        init_token='<sos>',
        eos_token='<eos>',
        tokenize=hindi_tokenize,
        batch_first=False
    )

    # loading the data
    train_data = torchtext.data.TabularDataset.splits(
        path=os.path.dirname(csv_file_name),
        train=os.path.basename(csv_file_name),
        format='csv',
        fields={'english_sentence': ('english_sentence', english_field), 'hindi_sentence': ('hindi_sentence', hindi_field)},
        skip_header=False
    )[0]

    # building the vocabulary
    english_field.build_vocab(train_data, max_size=english_vocab_size, min_freq=2)
    hindi_field.build_vocab(train_data, max_size=hindi_vocab_size, min_freq=2)

    # loading the bucket iterator
    train_iterator = torchtext.data.BucketIterator.splits(
        (train_data,),
        (batch_size,),
        device=device,
        sort_key=lambda x: len(x.english_sentence)
    )[0]

    return english_field, hindi_field, train_data, train_iterator

# construction of bucket iterator using predefined field
def parse_using_field(csv_file_name, english_field, hindi_field, english_col_name, hindi_col_name, batch_size=16):
    
    # loading the data
    train_data = torchtext.data.TabularDataset.splits(
        path=os.path.dirname(csv_file_name),
        train=os.path.basename(csv_file_name),
        format='csv',
        fields={english_col_name: ('english_sentence', english_field), hindi_col_name: ('hindi_sentence', hindi_field)},
        skip_header=False
    )[0]

    # loading the bucket iterator
    train_iterator = torchtext.data.BucketIterator.splits(
        (train_data,),
        (batch_size,),
        device=device,
        sort_key=lambda x: len(x.english_sentence)
    )[0]

    return train_data, train_iterator

# for loading the dataset in an appropriate format
def parse_dataset(csv_file_name, english_col_name, hindi_col_name, max_num=None):

    # to be returned
    english_sentences = []
    hindi_sentences = []

    # reading the csv file
    csv_file_df = pd.read_csv(csv_file_name)

    for index, row in tqdm(csv_file_df.iterrows()):
        if max_num is not None and index == max_num:
            break
        english_sentences.append(english_tokenize(str(row[english_col_name])))
        hindi_sentences.append(hindi_tokenize(str(row[hindi_col_name])))

    return english_sentences, hindi_sentences

In [None]:
# some utility function
def positional_encoding_1d(d_model, length):
    """
    :param d_model: dimension of the model
    :param length: length of positions
    :return: length*d_model position matrix
    """
    if d_model % 2 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dim (got dim={:d})".format(d_model))
    pe = torch.zeros(length, d_model)
    position = torch.arange(0, length).unsqueeze(1)
    div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
                         -(math.log(10000.0) / d_model)))
    pe[:, 0::2] = torch.sin(position.float() * div_term)
    pe[:, 1::2] = torch.cos(position.float() * div_term)

    return pe

# model definition
class CustomTransformer(nn.Module):
    def __init__(
        self,
        english_embedding,
        english_vocab,
        hindi_embedding,
        hindi_vocab,
        num_heads=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dropout=0.1,
        multitask=False
    ):
        super(CustomTransformer, self).__init__()

        # storing the input arguments
        self.english_vocab = english_vocab
        self.hindi_vocab = hindi_vocab
        self.english_embedding = english_embedding
        self.hindi_embedding = hindi_embedding

        # extracting the embeddings weights
        english_embedding_weight = torch.zeros(len(english_vocab.itos), english_embedding.dim, dtype=torch.float)
        hindi_embedding_weight = torch.zeros(len(hindi_vocab.itos), hindi_embedding.dim, dtype=torch.float)
        for word_id in range(len(english_vocab.itos)):
            english_embedding_weight[word_id] = english_embedding[english_vocab.itos[word_id]]
        for word_id in range(len(hindi_vocab.itos)):
            hindi_embedding_weight[word_id] = hindi_embedding[hindi_vocab.itos[word_id]]

        # initializing the embeddings
        self.english_word_embedding = nn.Embedding.from_pretrained(english_embedding_weight, padding_idx=english_vocab.stoi['<pad>'], freeze=False)
        self.hindi_word_embedding = nn.Embedding.from_pretrained(hindi_embedding_weight, padding_idx=hindi_vocab.stoi['<pad>'], freeze=False)

        # initializing the transformer
        self.english_pad_index = english_vocab.stoi['<pad>']
        self.dropout = nn.Dropout(dropout)
        self.fc_out = nn.Linear(hindi_embedding.dim, len(hindi_vocab.stoi.keys()))
        self.transformer = nn.Transformer(
            d_model=english_embedding.dim,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout
        )

    # for creating the english mask
    def make_english_mask(self, english_batch):
        english_mask = english_batch.transpose(0, 1) == self.english_pad_index
        return english_mask

    # forward pass
    def forward(self, english_batch, hindi_batch):
        
        # retrieving the length and checking for correctness
        english_seq_length, batch_size_english = english_batch.shape
        hindi_seq_length, batch_size_hindi = hindi_batch.shape
        assert(batch_size_english == batch_size_hindi)
        batch_size = batch_size_english
        
        # forming the positional embedding
        english_pos_embedding = positional_encoding_1d(self.english_word_embedding.embedding_dim, english_seq_length)
        hindi_pos_embedding = positional_encoding_1d(self.hindi_word_embedding.embedding_dim, hindi_seq_length)
        english_pos_embedding = english_pos_embedding.expand(batch_size, english_seq_length, self.english_word_embedding.embedding_dim).transpose(0, 1)
        hindi_pos_embedding = hindi_pos_embedding.expand(batch_size, hindi_seq_length, self.hindi_word_embedding.embedding_dim).transpose(0, 1)

        # producing the final embedding
        english_final_embedding = self.dropout(self.english_word_embedding(english_batch) + english_pos_embedding.to(device))
        hindi_final_embedding = self.dropout(self.hindi_word_embedding(hindi_batch) + hindi_pos_embedding.to(device))

        # producing the masks
        english_padding_mask = self.make_english_mask(english_batch).to(device)
        hindi_mask = self.transformer.generate_square_subsequent_mask(hindi_seq_length).to(device)

        # getting the output
        output = self.transformer(
            english_final_embedding,
            hindi_final_embedding,
            src_key_padding_mask=english_padding_mask,
            tgt_mask=hindi_mask
        )

        return self.fc_out(output)

# function for training the model
def train_transformer(model, train_iterator, pad_index, num_epoches=1000, learning_rate=1e-4, save_name=None):

    # sentence to be tested
    english_sentence = 'a horse goes under a bridge next to a boat.'

    # initializing the optimizers and loss class
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True)
    criterion = nn.CrossEntropyLoss(ignore_index=pad_index)

    # training begins
    for epoch in tqdm(range(num_epoches)):
        total_loss = 0

        model.train()
        for batch_index, batch in tqdm(enumerate(train_iterator)):
            
            # clearing the gradient buffer
            optimizer.zero_grad()

            # extracting the example
            english_batch = batch.english_sentence.to(device)
            hindi_batch = batch.hindi_sentence.to(device)

            # forward propagation
            predict_logits = model(english_batch, hindi_batch[:-1, :])
            predict_logits = predict_logits.reshape(-1, predict_logits.shape[2])
            actual_hindi_batch = hindi_batch[1:].reshape(-1)

            # backward propagation
            loss = criterion(predict_logits, actual_hindi_batch)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            with torch.no_grad():
                total_loss += loss.detach().item()

            # deleting to save GPU memory
            del english_batch
            del hindi_batch
            del predict_logits
            del actual_hindi_batch
            del loss
            del batch, batch_index

        # making changes to the optimizer
        print('Loss at {}th epoch: {}'.format(epoch, total_loss))
        scheduler.step(total_loss)

        # testing the model
        hindi_tokens = produce_output(model, english_sentence)
        print(english_sentence, '-', ' '.join(hindi_tokens))

        # saving the file
        if save_name is not None:
            torch.save(model.state_dict(), save_name)

# function for producing the output sentence for a given input
def produce_output(model, sentence, max_length=100):

    # tokenization using custom function
    if type(sentence) == str:
        tokens = english_tokenize(sentence)
    else:
        tokens = [token.lower() for token in sentence]

    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, '<sos>')
    tokens.append('<eos>')

    # Go through each english token and convert to an index
    text_to_indices = [model.english_vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [model.hindi_vocab.stoi["<sos>"]]
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            model.eval()
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == model.hindi_vocab.stoi["<eos>"]:
            break

    translated_sentence = [model.hindi_vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence[1:]

# utilities for performance computation
def report_performance(model, english_sentences, hindi_sentences):

    # initializing the ROUGE scorer
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)

    # quantities to be returned
    average_bleu = 0
    average_rouge1 = 0
    average_rougeL = 0

    for (english_sentence, hindi_sentence) in tqdm(zip(english_sentences, hindi_sentences)):
        output_sentence = produce_output(model, english_sentence)

        # bleu score
        average_bleu += nltk.translate.bleu_score.sentence_bleu([hindi_sentence], output_sentence[:-1])

        # rouge_scores
        rouge_obj = scorer.score(' '.join(hindi_sentence), ' '.join(output_sentence[:-1]))
        average_rouge1 += rouge_obj['rouge1'].recall
        average_rougeL += rouge_obj['rougeL'].recall

    # normalizing
    n = len(hindi_sentences)
    return {'bleu_score': average_bleu / n, 'rouge1_score': average_rouge1 / n, 'rougeL_score': average_rougeL / n}

### Pretraining the Baseline Transformer for Machine Translation Task

In [None]:
# for pretraining using hindEnCorp parallel corpus
english_field, hindi_field, train_data, train_iterator = parse_using_torchtext('drive/MyDrive/cs626_dataset/Hindi_English_Truncated_Corpus.csv')

In [None]:
# testing with model
transformer = CustomTransformer(glove, english_field.vocab, fasttext_hindi, hindi_field.vocab, num_heads=6).to(device)

<All keys matched successfully>

In [None]:
# testing the training procedure (80 done)
train_transformer(transformer, train_iterator, english_field.vocab.stoi['<pad>'], num_epoches=20, save_name='/content/drive/MyDrive/cs626_dataset/transformer_6_6_6.pt')

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Loss at 0th epoch: 5427.155629694462
a horse goes under a bridge next to a boat. - जहाज के नीचे भरत की ओर <unk> है . <eos>


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

## Fine tuning of the Baseline Transformer for Cross lingual Summarization

In [None]:
'''
    Training the model for cross lingual summarizaton
'''

# parsing the dataset
train_cls_data, train_cls_iterator = parse_using_field('drive/MyDrive/cs626_dataset/CLS_dataset.csv', english_field, hindi_field, 'text', 'summary')

In [None]:
# training the transformer for cross lingual summarization
train_transformer(transformer, train_cls_iterator, english_field.vocab.stoi['<pad>'], num_epoches=20, save_name='/content/drive/MyDrive/cs626_dataset/transformer_6_6_6_cls.pt')

### Performance of the Baseline Transformer for Cross Lingual Summarization

In [None]:
# loading the dataset in the form of tokens
english_sentences, hindi_sentences = parse_dataset('drive/MyDrive/cs626_dataset/CLS_dataset_test.csv', 'text', 'summary')

# loading the model
transformer = CustomTransformer(glove, english_field.vocab, fasttext_hindi, hindi_field.vocab, num_heads=6).to(device)
transformer.load_state_dict(torch.load('drive/MyDrive/cs626_dataset/transformer_6_6_6_cls.pt', map_location=device))

# obtaining the results
report_performance(transformer, english_sentences, hindi_sentences)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().





{'bleu_score': 0.41357651166166304,
 'rouge1_score': 0.09561944848830092,
 'rougeL_score': 0.0955456041521615}

### Some examples of the Baseline Transformer Model

##### The Bombay High Court on Monday summoned the Maharashtra Women and Child Development Secretary after 42 children went missing over the last three years from a Mumbai remand home. The court criticised the Maharashtra government for lack of 'pro-active action' in the matter. The Bombay High Court is hearing a PIL on the allegations of corruption in the remand home.

In [None]:
' '.join(produce_output(transformer, "The Bombay High Court on Monday summoned the Maharashtra Women and Child Development Secretary after 42 children went missing over the last three years from a Mumbai remand home. The court criticised the Maharashtra government for lack of 'pro-active action' in the matter. The Bombay High Court is hearing a PIL on the allegations of corruption in the remand home.")[:-1])

'42 बच्चों के घर से बाहर निकलने के बाद बॉम्बे <unk> ने सचिव को बुलाया'

#####As many as 76 passengers were rescued from cable cars suspended over a river in German city Cologne after a gondola crashed into a support pillar on Sunday. Passengers were left stranded, and children were seen clinging to parents while dangling as many as 40 metres above the river. The fire department lowered them to safety from the cable cars

In [None]:
' '.join(produce_output(transformer, "As many as 76 passengers were rescued from cable cars suspended over a river in German city Cologne after a gondola crashed into a support pillar on Sunday. Passengers were left stranded, and children were seen clinging to parents while dangling as many as 40 metres above the river. The fire department lowered them to safety from the cable cars")[:-1])

'केबल कार <unk> से उतरने के बाद 76 यात्री निलंबित'

#####An 11-year-old tribal boy allegedly committed suicide on Tuesday by hanging himself near his school, after he was caught stealing ?30 from his classmate in Maharashtra's Mokhada. The boy was reportedly ashamed of his act and had tried to force a classmate to commit suicide with him, but he refused. Police said the boy has a history of criminal activities.

In [None]:
' '.join(produce_output(transformer, "An 11-year-old tribal boy allegedly committed suicide on Tuesday by hanging himself near his school, after he was caught stealing ?30 from his classmate in Maharashtra's Mokhada. The boy was reportedly ashamed of his act and had tried to force a classmate to commit suicide with him, but he refused. Police said the boy has a history of criminal activities.")[:-1])

'19 - वर्षीय आदिवासी लड़के ने <unk> पकड़ा , 30 से 30 पकड़े जाने के बाद आत्म हत्या कर ली'

#####Four labourers on Monday were reportedly injured after a tree branch fell on them at Dombivli station road in Mumbai. They were admitted to hospital with injuries and were later declared out of danger. Reportedly, tree fall cases are on rise in Kalyan-Dombivli. ""Last year fewer cases were reported. We have been getting complaints of tree falls daily,"

In [None]:
' '.join(produce_output(transformer, "Four labourers on Monday were reportedly injured after a tree branch fell on them at Dombivli station road in Mumbai. They were admitted to hospital with injuries and were later declared out of danger. Reportedly, tree fall cases are on rise in Kalyan-Dombivli. Last year fewer cases were reported. We have been getting complaints of tree falls daily")[:-1])

'मुंबई में पेड़ की शाखा गिरने से 4 मजदूर घायल हो गए'

### Multitasking objective transformer

In [None]:
class MultitaskTransformer(nn.Module):
    def __init__(
        self,
        english_embedding,
        english_vocab,
        hindi_embedding,
        hindi_vocab,
        num_heads=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dropout=0.1,
    ):
        super(MultitaskTransformer, self).__init__()

        # storing the input arguments
        self.english_vocab = english_vocab
        self.hindi_vocab = hindi_vocab
        self.english_embedding = english_embedding
        self.hindi_embedding = hindi_embedding

        # extracting the embeddings weights
        english_embedding_weight = torch.zeros(len(english_vocab.itos), english_embedding.dim, dtype=torch.float)
        hindi_embedding_weight = torch.zeros(len(hindi_vocab.itos), hindi_embedding.dim, dtype=torch.float)
        for word_id in range(len(english_vocab.itos)):
            english_embedding_weight[word_id] = english_embedding[english_vocab.itos[word_id]]
        for word_id in range(len(hindi_vocab.itos)):
            hindi_embedding_weight[word_id] = hindi_embedding[hindi_vocab.itos[word_id]]

        # initializing the embeddings
        self.english_word_embedding = nn.Embedding.from_pretrained(english_embedding_weight, padding_idx=english_vocab.stoi['<pad>'], freeze=False)
        self.hindi_word_embedding = nn.Embedding.from_pretrained(hindi_embedding_weight, padding_idx=hindi_vocab.stoi['<pad>'], freeze=False)

        # initializing a single encoder
        encoder_layer = nn.TransformerEncoderLayer(english_embedding.dim, num_heads, dropout=dropout)
        encoder_norm = nn.LayerNorm(english_embedding.dim)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        # machine translation decoder
        mt_decoder_layer = nn.TransformerDecoderLayer(hindi_embedding.dim, num_heads, dropout=dropout)
        mt_decoder_norm = nn.LayerNorm(hindi_embedding.dim)
        self.mt_decoder = nn.TransformerDecoder(mt_decoder_layer, num_decoder_layers, mt_decoder_norm)

        # cross lingual summarization decoder
        cls_decoder_layer = nn.TransformerDecoderLayer(hindi_embedding.dim, num_heads, dropout=dropout)
        cls_decoder_norm = nn.LayerNorm(hindi_embedding.dim)
        self.cls_decoder = nn.TransformerDecoder(cls_decoder_layer, num_decoder_layers, cls_decoder_norm)

        # Feed forward initializations
        self.english_pad_index = english_vocab.stoi['<pad>']
        self.dropout = nn.Dropout(dropout)
        self.mt_fc_out = nn.Linear(hindi_embedding.dim, len(hindi_vocab.stoi.keys()))
        self.cls_fc_out = nn.Linear(hindi_embedding.dim, len(hindi_vocab.stoi.keys()))

    # for creating the english mask
    def make_english_mask(self, english_batch):
        english_mask = english_batch.transpose(0, 1) == self.english_pad_index
        return english_mask

    def generate_square_subsequent_mask(self, sz):
        r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
        """
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    # forward pass
    def forward(self, english_batch, hindi_batch, mode='cls'):
        
        # retrieving the length and checking for correctness
        english_seq_length, batch_size_english = english_batch.shape
        hindi_seq_length, batch_size_hindi = hindi_batch.shape
        assert(batch_size_english == batch_size_hindi)
        batch_size = batch_size_english
        
        # forming the positional embedding
        english_pos_embedding = positional_encoding_1d(self.english_word_embedding.embedding_dim, english_seq_length)
        hindi_pos_embedding = positional_encoding_1d(self.hindi_word_embedding.embedding_dim, hindi_seq_length)
        english_pos_embedding = english_pos_embedding.expand(batch_size, english_seq_length, self.english_word_embedding.embedding_dim).transpose(0, 1)
        hindi_pos_embedding = hindi_pos_embedding.expand(batch_size, hindi_seq_length, self.hindi_word_embedding.embedding_dim).transpose(0, 1)

        # producing the final embedding
        english_final_embedding = self.dropout(self.english_word_embedding(english_batch) + english_pos_embedding.to(device))
        hindi_final_embedding = self.dropout(self.hindi_word_embedding(hindi_batch) + hindi_pos_embedding.to(device))

        # producing the masks
        english_padding_mask = self.make_english_mask(english_batch).to(device)
        hindi_mask = self.generate_square_subsequent_mask(hindi_seq_length).to(device)

        # getting the output
        memory = self.encoder(english_final_embedding, src_key_padding_mask=english_padding_mask)
        if mode == 'cls':
            output = self.cls_decoder(hindi_final_embedding, memory, tgt_mask=hindi_mask)
            return self.cls_fc_out(output)
        elif mode == 'mt':
            output = self.mt_decoder(hindi_final_embedding, memory, tgt_mask=hindi_mask)
            return self.mt_fc_out(output)

# function for training the multi task transformer
def train_multitask_transformer(model, mt_iterator, cls_iterator, pad_index, num_epoches=1000, learning_rate=1e-4, save_name=None):

    # sentence to be tested
    english_sentence = 'a horse goes under a bridge next to a boat.'

    # initializing the optimizers and loss class
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True)
    criterion = nn.CrossEntropyLoss(ignore_index=pad_index)

    # training begins
    for epoch in tqdm(range(num_epoches)):
        total_loss = 0
        total_mt_loss = 0
        total_cls_loss = 0
        model.train()
        for batch_index, mt_batch in tqdm(enumerate(mt_iterator)):
            
            # clearing the gradient buffer
            optimizer.zero_grad()

            # extracting the example
            cls_batch = next(iter(cls_iterator))
            cls_english_batch = cls_batch.english_sentence.to(device)
            cls_hindi_batch = cls_batch.hindi_sentence.to(device)
            mt_english_batch = mt_batch.english_sentence.to(device)
            mt_hindi_batch = mt_batch.hindi_sentence.to(device)

            # forward propagation
            mt_predict_logits = model(mt_english_batch, mt_hindi_batch[:-1, :], mode='mt')
            mt_predict_logits = mt_predict_logits.reshape(-1, mt_predict_logits.shape[2])
            mt_actual_hindi_batch = mt_hindi_batch[1:].reshape(-1)
            cls_predict_logits = model(cls_english_batch, cls_hindi_batch[:-1, :], mode='cls')
            cls_predict_logits = cls_predict_logits.reshape(-1, cls_predict_logits.shape[2])
            cls_actual_hindi_batch = cls_hindi_batch[1:].reshape(-1)

            # backward propagation
            mt_loss = criterion(mt_predict_logits, mt_actual_hindi_batch)
            cls_loss = criterion(cls_predict_logits, cls_actual_hindi_batch) 
            loss = mt_loss + cls_loss
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            with torch.no_grad():
                total_loss += loss.detach().item()
                total_mt_loss += mt_loss.detach().item()
                total_cls_loss += cls_loss.detach().item()

            # deleting to save GPU memory
            del mt_english_batch, cls_english_batch
            del mt_hindi_batch, cls_hindi_batch 
            del mt_predict_logits, cls_predict_logits
            del mt_actual_hindi_batch, cls_actual_hindi_batch
            del loss
            del mt_batch, cls_batch, batch_index

        # making changes to the optimizer
        print('Loss at {}th epoch: {}'.format(epoch, total_loss))
        print('MT: {}'.format(total_mt_loss))
        print('CLS: {}'.format(total_cls_loss))
        scheduler.step(total_loss)

        # testing the model
        hindi_tokens = produce_output(model, english_sentence)
        print(english_sentence, '-', ' '.join(hindi_tokens))

        # saving the file
        if save_name is not None:
            torch.save(model.state_dict(), save_name)

### Training the Multi task transformer


In [None]:
english_field, hindi_field, mt_train_data, mt_train_iterator = parse_using_torchtext('drive/MyDrive/cs626_dataset/Hindi_English_Truncated_Corpus.csv', 16, 20000, 20000)
cls_train_data, cls_train_iterator = parse_using_field('drive/MyDrive/cs626_dataset/CLS_dataset.csv', english_field, hindi_field, 'text', 'summary')


100%|█████████▉| 157992/158016 [00:35<00:00, 9004.34it/s][A

In [None]:
# initializing the multitask transformer
transformer = MultitaskTransformer(glove, english_field.vocab, fasttext_hindi, hindi_field.vocab, num_heads=6).to(device)
transformer.load_state_dict(torch.load('/content/drive/MyDrive/cs626_dataset/multitasktransformer_6_6_6_cls.pt', map_location=device))

<All keys matched successfully>

In [None]:
train_multitask_transformer(transformer, mt_train_iterator, cls_train_iterator, english_field.vocab.stoi['<pad>'], num_epoches=60, save_name='/content/drive/MyDrive/cs626_dataset/multitasktransformer_6_6_6_cls.pt')

In [None]:
# loading the dataset in the form of tokens
english_sentences, hindi_sentences = parse_dataset('drive/MyDrive/cs626_dataset/CLS_dataset_test.csv', 'text', 'summary', max_num=100)

# obtaining the results
report_performance(transformer, english_sentences, hindi_sentences)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().





{'bleu_score': 0.3406240250489478,
 'rouge1_score': 0.015,
 'rougeL_score': 0.015}

### Demo

In [None]:
# for pretraining using hindEnCorp parallel corpus
english_field, hindi_field, train_data, train_iterator = parse_using_torchtext('drive/MyDrive/cs626_dataset/Hindi_English_Truncated_Corpus.csv')

In [None]:
# testing with model
transformer = CustomTransformer(glove, english_field.vocab, fasttext_hindi, hindi_field.vocab, num_heads=6).to(device)

In [None]:
# loading the pretrained model
transformer.load_state_dict(torch.load('drive/MyDrive/cs626_dataset/transformer_6_6_6_cls.pt', map_location=device))

<All keys matched successfully>

In [None]:
# output example
' '.join(produce_output(transformer, 'The Ghaziabad Police has booked 14 people including 3 Congress workers for resorting to violence during protests in front of the Ala Hazrat Haj House on Monday. Nearly 500 protesters had pelted stones at the police during the protest demanding that the facility be opened to pilgrims immediately. The facility has remained sealed since its inauguration last year.')[:-1])

'<unk> हज हाउस के सामने हिंसा भड़काने के लिए <unk> पुलिस ने <unk> की'