In [1]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split

from io import open
import unicodedata
import re

In [2]:
if torch.cuda.is_available():
    device=torch.device(type='cuda', index=0)
else:
    device=torch.device(type='cpu', index=0)

In [3]:
def normalizeString(s): 
    sres=""
    for ch in unicodedata.normalize('NFD', s): 
        if unicodedata.category(ch) != 'Mn':
            sres+=ch
    sres = re.sub(r"([.!?])", r" \1", sres) 
    sres = re.sub(r"[^a-zA-Z!?]+", r" ", sres) 
    return sres.strip()

def createNormalizedPairs():
    initpairs=[]
    for pair in data:
        s1,s2=pair.split('\t')
        s1=normalizeString(s1.lower().strip())
        s2=normalizeString(s2.lower().strip())
        initpairs.append([s1,s2])
    return initpairs

max_length = 10
def filterPairs(initpairs):
    eng_prefixes = (
        "i am ", "i m ",
        "he is", "he s ",
        "she is", "she s ",
        "you are", "you re ",
        "we are", "we re ",
        "they are", "they re "
    )

    pairs=[]
    for pair in initpairs:
        if len(pair[0].split(" ")) < max_length and len(pair[1].split(" ")) < max_length and pair[0].lower().startswith(eng_prefixes):
            pairs.append(pair)

    print("Number of pairs after filtering:", len(pairs))
    return pairs 

In [4]:
class Vocab:
    def __init__(self, name):
        self.name=name
        self.word2index={'SOS':0, 'EOS':1}
        self.index2word={0:'SOS', 1:'EOS'}
        self.word2count={}
        self.nwords=2
    
    def buildVocab(self,s):
        for word in s.split(" "):
            if word not in self.word2index:
                self.word2index[word]=self.nwords
                self.index2word[self.nwords]=word
                self.word2count[word]=1
                self.nwords+=1
            else:
                self.word2count[word]+=1

In [5]:
class Encoder(nn.Module):
    def __init__(self, input_size, embed_size, hidden_size, dropout_p=0.1):
        super().__init__()
        self.e=nn.Embedding(input_size, embed_size)
        self.dropout=nn.Dropout(dropout_p)
        self.gru=nn.GRU(embed_size,hidden_size, batch_first=True)
    
    def forward(self,x):
        x=self.e(x)
        x=self.dropout(x)
        outputs, hidden=self.gru(x)
        return outputs, hidden

In [6]:
class Decoder(nn.Module):
    def __init__(self,output_size,embed_size,hidden_size):
        super().__init__()
        self.e=nn.Embedding(output_size,embed_size)
        self.relu=nn.ReLU()
        self.gru=nn.GRU(embed_size, hidden_size, batch_first=True)
        self.lin=nn.Linear(hidden_size,output_size)
        self.lsoftmax=nn.LogSoftmax(dim=-1)
    
    def forward(self,x,prev_hidden):
        x=self.e(x)
        x=self.relu(x)
        output,hidden=self.gru(x,prev_hidden)
        y=self.lin(output)
        y=self.lsoftmax(y)
        return y, hidden

In [7]:
def get_input_ids(sentence,langobj):
    input_ids=[]
    for word in sentence.split(" "):
        input_ids.append(langobj.word2index[word])
    
    if langobj.name=='fre': 
        input_ids.append(langobj.word2index['EOS'])
    else:
        input_ids.insert(0,langobj.word2index['SOS'])
        input_ids.append(langobj.word2index['EOS'])
    return torch.tensor(input_ids)

In [8]:
class CustomDataset(Dataset):
    def __init__(self):
        super().__init__()
    
    def __len__(self):
        return length
    
    def __getitem__(self,idx):
        t=pairs[idx][0]
        s=pairs[idx][1] 
        s_input_ids=torch.zeros(max_length+1, dtype=torch.int64)
        t_input_ids=torch.zeros(max_length+2, dtype=torch.int64)
        s_input_ids[:len(s.split(" "))+1]=get_input_ids(s,fre) 
        t_input_ids[:len(t.split(" "))+2]=get_input_ids(t,eng)
        
        return s_input_ids, t_input_ids

In [9]:
def train_one_epoch():
    encoder.train()
    decoder.train()
    track_loss=0
    
    for i, (s_ids,t_ids) in enumerate(train_dataloader):
        s_ids=s_ids.to(device)
        t_ids=t_ids.to(device)
        encoder_outputs, encoder_hidden=encoder(s_ids)
        decoder_hidden=encoder_hidden
        yhats, decoder_hidden = decoder(t_ids[:,0:-1],decoder_hidden)
                    
        gt=t_ids[:,1:]
        
        yhats_reshaped=yhats.view(-1,yhats.shape[-1])
        
        gt=gt.reshape(-1)
        
        
        loss=loss_fn(yhats_reshaped,gt)
        track_loss+=loss.item()
        
        opte.zero_grad()
        optd.zero_grad()
        
        loss.backward()
        
        opte.step()
        optd.step()
        
    return track_loss/len(train_dataloader)    

In [10]:
def ids2Sentence(ids,vocab):
    sentence=""
    for id in ids.squeeze():
        if id==0:
            continue
        word=vocab.index2word[id.item()]
        sentence+=word + " "
        if id==1:  
            break
    return sentence

In [11]:
def eval_one_epoch(e,n_epochs):
    encoder.eval()
    decoder.eval()
    track_loss=0
    with torch.no_grad():
        for i, (s_ids,t_ids) in enumerate(test_dataloader):
            s_ids=s_ids.to(device)
            t_ids=t_ids.to(device)
            encoder_outputs, encoder_hidden=encoder(s_ids)
            decoder_hidden=encoder_hidden 
            input_ids=t_ids[:,0]
            yhats=[]
            if e+1==n_epochs:
                pred_sentence=""
            for j in range(1,max_length+2): 
                probs, decoder_hidden = decoder(input_ids.unsqueeze(1),decoder_hidden)
                yhats.append(probs)
                _,input_ids=torch.topk(probs,1,dim=-1)
                input_ids=input_ids.squeeze(1,2) 
                if e+1==n_epochs:
                    word=eng.index2word[input_ids.item()] 
                    pred_sentence+=word + " "
                if input_ids.item() == 1: 
                    break
                                
            if e+1==n_epochs:
                src_sentence=ids2Sentence(s_ids,fre) 
                gt_sentence=ids2Sentence(t_ids[:,1:],eng)

                print("\n-----------------------------------")
                print("Source Sentence:",src_sentence)
                print("GT Sentence:",gt_sentence)
                print("Predicted Sentence:",pred_sentence)
            
            yhats_cat=torch.cat(yhats,dim=1)
            yhats_reshaped=yhats_cat.view(-1,yhats_cat.shape[-1])
            gt=t_ids[:,1:j+1]
            gt=gt.view(-1)
            

            loss=loss_fn(yhats_reshaped,gt)
            track_loss+=loss.item()
            
        if e+1==n_epochs:    
            print("-----------------------------------")
        return track_loss/len(test_dataloader)    

In [12]:
data=open("/kaggle/input/eng-fre-trans/eng-fra.txt").read().strip().split('\n')
print("Total number of pairs:",len(data))
initpairs=createNormalizedPairs() 
pairs=filterPairs(initpairs)
length=len(pairs)
eng=Vocab('eng')
fre=Vocab('fre')
for pair in pairs:
    eng.buildVocab(pair[0])
    fre.buildVocab(pair[1])

print("English Vocab Length:",eng.nwords)
print("French Vocab Length:",fre.nwords)    
    
dataset=CustomDataset()
train_dataset,test_dataset=random_split(dataset,[0.99,0.01])
batch_size=32
train_dataloader=DataLoader(dataset=train_dataset,batch_size=batch_size, shuffle=False)
test_dataloader=DataLoader(dataset=test_dataset,batch_size=1, shuffle=False)
embed_size=300
hidden_size=512
encoder=Encoder(fre.nwords,embed_size,hidden_size).to(device) 
decoder=Decoder(eng.nwords,embed_size,hidden_size).to(device) 

loss_fn=nn.NLLLoss(ignore_index=0).to(device)
lr=0.001
opte=optim.Adam(params=encoder.parameters(), lr=lr, weight_decay=0.001)
optd=optim.Adam(params=decoder.parameters(), lr=lr, weight_decay=0.001)

n_epochs=80

for e in range(n_epochs):
    print("Epoch=",e+1, sep="", end=", ")
    print("Train Loss=", round(train_one_epoch(),4), sep="", end=", ")
    print("Eval Loss=",round(eval_one_epoch(e,n_epochs),4), sep="")

Total number of pairs: 135842

Number of pairs after filtering: 11445

English Vocab Length: 2991

French Vocab Length: 4601

Epoch=1, Train Loss=3.1498, Eval Loss=3.3766

Epoch=2, Train Loss=2.4137, Eval Loss=3.1472

Epoch=3, Train Loss=2.1787, Eval Loss=2.9177

Epoch=4, Train Loss=2.0527, Eval Loss=2.7982

Epoch=5, Train Loss=1.9814, Eval Loss=2.6393

Epoch=6, Train Loss=1.9324, Eval Loss=2.549

Epoch=7, Train Loss=1.8874, Eval Loss=2.5662

Epoch=8, Train Loss=1.842, Eval Loss=2.5203

Epoch=9, Train Loss=1.7934, Eval Loss=2.4571

Epoch=10, Train Loss=1.7474, Eval Loss=2.3836

Epoch=11, Train Loss=1.7079, Eval Loss=2.4303

Epoch=12, Train Loss=1.6659, Eval Loss=2.3827

Epoch=13, Train Loss=1.6279, Eval Loss=2.518

Epoch=14, Train Loss=1.5907, Eval Loss=2.2871

Epoch=15, Train Loss=1.5633, Eval Loss=2.2281

Epoch=16, Train Loss=1.5322, Eval Loss=2.2269

Epoch=17, Train Loss=1.5078, Eval Loss=2.2466

Epoch=18, Train Loss=1.486, Eval Loss=2.193

Epoch=19, Train Loss=1.4643, Eval Loss=2.1