In [1]:
# %%
from transformers import AutoModel, AutoTokenizer
import pickle
from torchtext.vocab import vocab
import torchtext
import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
import numpy as np
import wandb
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Transformer
max_length = 50


In [2]:
gu_tokenizer = AutoTokenizer.from_pretrained('ai4bharat/indic-bert')
model = AutoModel.from_pretrained('ai4bharat/indic-bert')

# %%
def build_en_vocab():
    counter = Counter()
    for fp in [train_data,val_data,test_data]:
        for i in fp:
            counter.update(i[1])
    return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])


# %%
train_filepaths = ['en-gu/train.en', 'en-gu/train.gu']
val_filepaths = ['en-gu/dev.en', 'en-gu/dev.gu']
test_filepaths = ['en-gu/test.en', 'en-gu/test.gu']



# Loading vocab to embedding converter of indicbert
vocab_to_embedding_convertor = model.get_input_embeddings()

# Tokenizer for english words
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')


Some weights of the model checkpoint at ai4bharat/indic-bert were not used when initializing AlbertModel: ['predictions.bias', 'predictions.dense.bias', 'predictions.decoder.bias', 'predictions.dense.weight', 'predictions.decoder.weight', 'sop_classifier.classifier.bias', 'predictions.LayerNorm.bias', 'sop_classifier.classifier.weight', 'predictions.LayerNorm.weight']
- This IS expected if you are initializing AlbertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
file = open("../machine_translation/train_data.obj",'rb')
train_data = pickle.load(file)
file.close()

file = open("../machine_translation/val_data.obj",'rb')
val_data = pickle.load(file)
file.close()

file = open("../machine_translation/test_data.obj",'rb')
test_data = pickle.load(file)
file.close()

glove_embeddings = torchtext.vocab.GloVe(name='6B', dim=300)
en_vocab = build_en_vocab()
itos = en_vocab.get_itos() 

en_embeddings = []
for i in range(len(itos)):
    en_embeddings.append(glove_embeddings.get_vecs_by_tokens(itos[i], lower_case_backup=True).numpy())

# Convert embeddings to numpy array
en_embeddings = np.array(en_embeddings)

In [7]:
vocab_to_embedding_convertor = model.get_input_embeddings()

#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

In [8]:
from torch import Tensor
import math


# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 128):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 en_embeddings, vocab_to_embedding_convertor,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size_en: int,
                 emb_size_gu: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=300,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       batch_first = True)
        self.generator = nn.Linear(300, tgt_vocab_size)
        self.src_tok_emb = nn.Embedding.from_pretrained(torch.from_numpy(en_embeddings).float().to(device))
        self.tgt_tok_emb = vocab_to_embedding_convertor
        self.positional_encoding_en = PositionalEncoding(emb_size_en, dropout=dropout)
        self.positional_encoding_gu = PositionalEncoding(emb_size_gu, dropout=dropout)
        self.lin_layer = nn.Linear(128, 300)

    def forward(self,
                src,
                trg,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding_en(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding_gu(self.tgt_tok_emb(trg.clone().detach()))
        tgt_emb = self.lin_layer(tgt_emb)
        outs = self.transformer(src_emb, tgt_emb)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)
        

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)

    src_padding_mask = (src == 1).transpose(0, 1)
    tgt_padding_mask = (tgt == 0).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


# %%
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(en_vocab)
TGT_VOCAB_SIZE = gu_tokenizer.vocab_size
EMB_SIZE_EN = 300
EMB_SIZE_GU = 128
NHEAD = 5
FFN_HID_DIM = 512
BATCH_SIZE = 2
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(en_embeddings, vocab_to_embedding_convertor, 
                                 NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE_EN, EMB_SIZE_GU,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(device)


In [None]:
def generate_batch(data_batch):
  en_batch_tokens, gu_tokens = [], []
  max_length = 40
  for (gu_item, en_item, _) in data_batch:
    gu_tokens.append(torch.tensor(gu_item))
    en_tokens = torch.tensor(en_vocab(en_item))
    en_batch_tokens.append(en_tokens)
    
  gu_batch_tokens = pad_sequence(gu_tokens, batch_first=True, padding_value=0)
  en_batch_tokens = pad_sequence(en_batch_tokens,batch_first=True, padding_value=1)
  
  if gu_batch_tokens.shape[1] > max_length:
    gu_batch_tokens = gu_batch_tokens[:, :max_length]

  if en_batch_tokens.shape[1] > max_length:
    en_batch_tokens = en_batch_tokens[:, :max_length]

  return en_batch_tokens, gu_batch_tokens

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)


loss_fn = torch.nn.CrossEntropyLoss(ignore_index=0)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')


import time
wandb.init()

def train(model, iterator, optimizer, criterion):
    model.train()
    losses = 0
    count = 0
    for src, trg in iterator:
    
        src, trg = src.to(device), trg.to(device)

        #trg_input = trg[:-1, :]
        trg_input = trg

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, trg_input)

        logits = model(src, trg_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()
        
        trg_out = trg
        loss = criterion(logits.reshape(-1, logits.shape[-1]), trg_out.reshape(-1))
        loss.backward()
        
        optimizer.step()

        losses += loss.item()
        
        count+=1
        if count%10==0:
          wandb.log({'Train Loss':loss.item()})
    return epoch_loss / len(iterator)


def evaluate(model, iterator, criterion):
    model.eval()
    losses = 0
    with torch.no_grad():
        count=0
        for src, trg in enumerate(iterator):
            src, trg = src.to(device), trg.to(device)

            #trg_input = tgt[:-1, :]
            trg_input = tgt
            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, trg_input)

            logits = model(src, trg_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
            print(logits)
            print()
            

            trg_out = trg[1:, :]
            print(trg_out)
            loss = criterion((logits.reshape(-1, logits.shape[-1]), trg_out.reshape(-1)))
            losses += loss.item()
            count+=1
            if count%10==0:
              wandb.log({'Val Loss':loss.item()})
    return epoch_loss / len(iterator)


def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


N_EPOCHS = 10

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()

    train_loss = train(transformer, train_iter, optimizer, loss_fn)
    valid_loss = evaluate(transformer, valid_iter, loss_fn)

    torch.save(transformer.state_dict(), 'transformer.pt')
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

test_loss = evaluate(transformer, test_iter, loss_fn)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')



The model has 33,443,584 trainable parameters


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnehamjain[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.16 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


  tgt_emb = self.positional_encoding_gu(self.tgt_tok_emb(torch.tensor(trg)))


In [None]:
from torchtext.data.metrics import bleu_score

print(bleu_score(pred_captions, gt_captions, max_n=1, weights=(1, 0, 0, 0)))
print(bleu_score(pred_captions, gt_captions, max_n=2, weights=(0, 1, 0, 0)))
print(bleu_score(pred_captions, gt_captions, max_n=3, weights=(0, 0, 1, 0)))
print(bleu_score(pred_captions, gt_captions, max_n=4, weights=(0, 0, 0, 1)))