<a href="https://colab.research.google.com/github/judem-21/Seq2Seq-Model/blob/main/Seq2Seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Seq2Seq Model

###Importing Libraries

In [1]:
import torch,torchvision
from torch.utils.data import Dataset,DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from torchvision import transforms as transforms
#from torch.torchmetrics.text.bleu import BLEUScore

In [6]:
import os
import pandas as pd
from skimage import io
import spacy
from PIL import Image
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [3]:
!python -m spacy download en
!python -m spacy download de_core_news_sm
spacy_eng = spacy.load("en_core_web_sm")
spacy_ger = spacy.load("de_core_news_sm")

[38;5;3m⚠ As of spaCy v3.0, shortcuts like 'en' are deprecated. Please use the
full pipeline package name 'en_core_web_sm' instead.[0m
Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.
Collecting de-core-news-sm==3.7.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.7.0/de_core_news_sm-3.7.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━

###Dataset (Loading and Testing)

####Dataset Loading

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [13]:
class Vocabulary:
    def __init__(self, freq_threshold=3):
        self.itos_source = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi_source = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}

        self.itos_target = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi_target = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}

        self.punctuation_marks = [
    '.', ',', ';', ':', '!', '?', '-', '—', '(', ')', '[', ']', '{', '}',
    "'", '"', '...', '“', '”', '‘', '’', '/', '\\', '|', '@', '#', '$', '%',
    '^', '&', '*', '_', '=', '+', '<', '>', '`', '~'
]

        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer(text,key):
      if key=='en':
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
      else:
        return [tok.text.lower() for tok in spacy_ger.tokenizer(text)]

    def build_vocabulary(self, sentence_list, key):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
          if key=='source': lookup=self.tokenizer(sentence,'de')
          else: lookup=self.tokenizer(sentence,'en')

          for word in lookup:
            if word=='\n' or word in self.punctuation_marks: continue

            if word not in frequencies:
              frequencies[word] = 1

            else:
              frequencies[word] += 1

            if frequencies[word] == self.freq_threshold:
              if key=='source':
                self.stoi_source[word] = idx
                self.itos_source[idx] = word
              elif key=='target':
                self.stoi_target[word] = idx
                self.itos_target[idx] = word
              idx += 1

    def numericalize(self, text,key):
        if key=='source':
          tokenized_text = self.tokenizer(text,'de')
          return [self.stoi_source[token] if token in self.stoi_source else self.stoi_source["<UNK>"] for token in tokenized_text]
        elif key=='target':
          tokenized_text = self.tokenizer(text,'en')
          return [self.stoi_target[token] if token in self.stoi_target else self.stoi_target["<UNK>"] for token in tokenized_text]

In [51]:
class seq2seq_dataset(Dataset):
    def __init__(self,dataset_path, num_samples=29000, freq_threshold=3):
        self.df = pd.read_csv(dataset_path,delimiter='\t',names=['English','German']).values
        self.num_samples=num_samples

        self.idx_sentences_source= {}
        self.idx_sentences_target= {}

        self.punctuation_marks = [
    '.', ',', ';', ':', '!', '?', '-', '—', '(', ')', '[', ']', '{', '}',
    "'", '"', '...', '“', '”', '‘', '’', '/', '\\', '|', '@', '#', '$', '%',
    '^', '&', '*', '_', '=', '+', '<', '>', '`', '~'
]

        idx=0
        for row in self.df:
            source_sentence = row[1]
            #if 'Tom' in source_sentence: continue

            target_sentence = row[0]

            if source_sentence[-1] in self.punctuation_marks: source_sentence=source_sentence[:-1]
            if target_sentence[-1] in self.punctuation_marks: target_sentence=target_sentence[:-1]

            self.idx_sentences_source[idx]=source_sentence
            self.idx_sentences_target[idx]=target_sentence

            idx+=1
            if idx==self.num_samples: break


        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.idx_sentences_source.values(),'source')
        self.vocab.build_vocabulary(self.idx_sentences_target.values(),'target')

    def __len__(self):
        return len(self.idx_sentences_source)

    def __getitem__(self, index):
        source_sentence = self.idx_sentences_source[index]
        target_sentence = self.idx_sentences_target[index]

        #numericalised source sentence
        numericalized_caption_source= [self.vocab.stoi_source["<SOS>"]] + self.vocab.numericalize(source_sentence,key='source') + [self.vocab.stoi_source["<EOS>"]]

        #numericalised target sentence
        numericalized_caption_target= [self.vocab.stoi_target["<SOS>"]] + self.vocab.numericalize(target_sentence,key='target') + [self.vocab.stoi_target["<EOS>"]]

        return torch.tensor(numericalized_caption_source), torch.tensor(numericalized_caption_target)

In [16]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        sources=[item[0] for item in batch]
        sources=pad_sequence(sources, batch_first=False, padding_value=self.pad_idx)

        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

        return sources, targets

In [17]:
def get_loader(
    dataset_path,num_samples=28000,
    freq_threshold=2,
    batch_size=32,
    num_workers=8,
    shuffle=True,
    pin_memory=True,
):
    dataset = seq2seq_dataset(dataset_path=dataset_path, num_samples=num_samples,freq_threshold=freq_threshold)

    pad_idx = dataset.vocab.stoi_source["<PAD>"]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )

    return loader, dataset

####Dataset Testing

In [43]:
train_loader, dataset = get_loader(dataset_path='/content/drive/MyDrive/Seq2SeqModel/dataset.txt',
                                   num_samples=99968,freq_threshold=2,batch_size=64,num_workers=2)

In [44]:
len(dataset)

99968

In [45]:
for batch_idx,(sources,targets) in enumerate(train_loader):
  break

  self.pid = os.fork()
  self.pid = os.fork()


In [46]:
batch_chk_idx=5
for batch_idx,(sources,targets) in enumerate(train_loader):
  if batch_idx==batch_chk_idx:break
print(f'Batch no.: {batch_idx+1}:-')

source_sentence=''
num_words_source=0
chk_idx=8

print(f'Sample {chk_idx+1} of Batch {batch_idx+1}')
print(f'Source Shape: {sources.shape}')
print(f'Target Shape: {targets.shape}')

for source_idx in sources[:,chk_idx]:
  num_words_source+=1
  source_sentence+=dataset.vocab.itos_source[source_idx.item()]+' '

target_sentence=''
num_words_target=0
for target_idx in targets[:,chk_idx]:
  num_words_target+=1
  target_sentence+=dataset.vocab.itos_target[target_idx.item()]+' '
source_sentence=source_sentence[:-1]
target_sentence=target_sentence[:-1]

print(f'Source Sentence: "{source_sentence}" with length: {num_words_source}')
print(f'Target Sentence: "{target_sentence}" with length: {num_words_target}')

Batch no.: 6:-
Sample 9 of Batch 6
Source Shape: torch.Size([11, 64])
Target Shape: torch.Size([11, 64])
Source Sentence: "<SOS> wie alt schätzt ihr sie <EOS> <PAD> <PAD> <PAD> <PAD>" with length: 11
Target Sentence: "<SOS> how old do you think she is <EOS> <PAD> <PAD>" with length: 11


In [None]:
'\n' in dataset.vocab.itos_target

False

###Model

In [37]:
class EncoderRNN(nn.Module):
  def __init__(self,input_size,hidden_size,vocab_size,num_layers=1):
    super(EncoderRNN,self).__init__()
    self.embed=nn.Embedding(num_embeddings=vocab_size,embedding_dim=input_size)
    self.lstm=nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,dropout=0.5)
    self.dropout=nn.Dropout(p=0.5)


  def forward(self,source_sentence):
    #source_sentence dim: (seq_len,batch)

    embeddings=self.dropout(self.embed(source_sentence))
    #embeddings dim: (seq_len,batch,input_size)

    outputs,(hidden_state,cell_state)=self.lstm(embeddings)
    #output dim: (seq_len,batch,hidden_size)
    #hidden_state/cell_state dim: (num_layers,batch,hidden_size)

    return hidden_state,cell_state

In [36]:
class DecoderRNN(nn.Module):
  def __init__(self,input_size,hidden_size,vocab_size,num_layers=1):
    super(DecoderRNN,self).__init__()
    self.embed=nn.Embedding(num_embeddings=vocab_size,embedding_dim=input_size)
    self.lstm=nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,dropout=0.5)
    self.linear=nn.Linear(in_features=hidden_size,out_features=vocab_size)
    self.dropout=nn.Dropout(p=0.5)


  def forward(self,word_idx,hidden_state,cell_state):
    #word_vector dim: (batch,)
    #hidden_state/cell_state dim: (num_layers,batch,hidden_size)

    embedding=self.dropout(self.embed(word_idx.unsqueeze(0)))
    #embeddingdim: (1,batch,input_size)

    output,(hidden_state,cell_state)=self.lstm(embedding,(hidden_state,cell_state))
    #output dim: (1,batch,hidden_size)
    #hidden_state/cell_state dim: (num_layers,batch,hidden_size)

    predicted_word=self.linear(output).squeeze(0)
    #predicted_word dim: (batch,vocab_size)

    return predicted_word,hidden_state,cell_state

In [35]:
class Seq2Seq(nn.Module):
  def __init__(self,input_size,hidden_size,vocab_size_source,vocab_size_target,num_layers=1):
    super(Seq2Seq,self).__init__()
    self.encoder=EncoderRNN(input_size=input_size,hidden_size=hidden_size,vocab_size=vocab_size_source,num_layers=num_layers)
    self.decoder=DecoderRNN(input_size=input_size,hidden_size=hidden_size,vocab_size=vocab_size_target,num_layers=num_layers)
    self.vocab_size_source=vocab_size_source
    self.vocab_size_target=vocab_size_target

  def forward(self,source_sentences,target_sentences,teacher_forcer_ratio=0.5):
    #source/target sentence dim: (seq_len,batch)

    hidden_state,cell_state=self.encoder(source_sentences)
    #hidden_state/cell_state dim: (num_layers,batch,hidden_size)

    target_len,batch_size=target_sentences.shape
    outputs=torch.zeros(size=(target_len,batch_size,self.vocab_size_target))
    input_word=target_sentences[0]
    #input_word dim: (batch, )

    for idx in range(1,target_len):
      prediction,hidden_state,cell_state=self.decoder(input_word,hidden_state,cell_state)
      outputs[idx]=prediction

      if random.random()<teacher_forcer_ratio:
        input_word=target_sentences[idx]
      else:
        input_word=prediction.argmax(dim=1)

    return outputs

  def translate(self,source_sentence,vocab_target,device,max_len=50):
    #source_sentence dim: (seq_len,)

    hidden_state,cell_state=self.encoder(source_sentence.unsqueeze(1).to(device))
    #hidden_state/cell_state dim: (num_layers,batch,hidden_size)

    translated='<SOS> '
    input_word=torch.tensor([source_sentence[0].item(),])
    #input_word dim: (1,)

    for pred in range(max_len):
      prediction,hidden_state,cell_state=self.decoder(input_word.to(device),hidden_state.to(device),cell_state.to(device))
      #prediction dim: (1,vocab_size)

      word_idx=prediction.argmax(dim=1)
      #word_idx dim: (1,)

      translated+=vocab_target[word_idx.item()]+' '

      if vocab_target[word_idx.item()]=='<EOS>': break

      input_word=torch.tensor([word_idx.item(),])

    return translated[:-1]


###Training

Initialisation

In [52]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#hyper-parameters
num_epochs=40
lr=1e-3
batch_size=64
input_size,hidden_size=256,1024
num_samples=99968
num_layers=1
freq_threshold=2
dataset_path='/content/drive/MyDrive/Seq2SeqModel/dataset.txt'

#dataset
'''train_loader, dataset = get_loader(source_file=source_file,
                                   target_file=target_file,
                                   num_samples=num_samples,freq_threshold=freq_threshold,batch_size=batch_size,num_workers=2)'''
train_loader,dataset=get_loader(dataset_path=dataset_path,num_samples=num_samples,freq_threshold=freq_threshold,batch_size=batch_size)

#model
#min_loss=np.Inf
save_path='/content/drive/MyDrive/Seq2SeqModel/model'
model=Seq2Seq(input_size=input_size,hidden_size=hidden_size,vocab_size_source=len(dataset.vocab.itos_source),vocab_size_target=len(dataset.vocab.itos_target),num_layers=num_layers).to(device)
min_loss=torch.load('/content/drive/MyDrive/Seq2SeqModel/model.pth',map_location=device)['loss']
model.load_state_dict(torch.load('/content/drive/MyDrive/Seq2SeqModel/model.pth',map_location=device)['model_state_dict'])
optimizer=torch.optim.Adam(model.parameters(),lr=lr)
criterion=nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi_target['<PAD>'])



In [32]:
min_loss

0.6902420901038735

In [24]:
#validation sentence
chk_idx=25000
source,target=dataset[chk_idx]

validation_sentence_source=''
for word_idx in source:
  validation_sentence_source+=dataset.vocab.itos_source[word_idx.item()]+' '
print(f'Source sentence: {validation_sentence_source[:-1]}')

validation_sentence_target=''
for word_idx in target:
  validation_sentence_target+=dataset.vocab.itos_target[word_idx.item()]+' '
print(f'Target sentence: {validation_sentence_target[:-1]}')

Source sentence: <SOS> bitte setzen sie sich <EOS>
Target sentence: <SOS> please have a seat <EOS>


Model Checkpoint

In [25]:
def save_checkpoint(epoch,model,loss,optimiser,path):
  save_path=path+'.pth'
  torch.save({
      'epoch':epoch,
      'model_state_dict':model.state_dict(),
      'optimizer_state_dict':optimiser.state_dict(),
      'loss':loss
  },save_path)

Training Block

In [None]:
print(f'Actual sentence: {validation_sentence_target[:-1]}')
epoch_losses=[]
model.train()
num_batches=len(train_loader)
for epoch in range(1,num_epochs+1):
  batch_losses=[]
  print(f'Epoch {epoch} begins:-\n')
  for batch_idx,(source_sentences,target_sentences) in enumerate(train_loader):
    inputs=source_sentences.to(device)
    targets=target_sentences.to(device)

    outputs=model(inputs,targets).to(device)


    optimizer.zero_grad()
    loss=criterion(outputs[1:].reshape(-1,outputs.shape[2]),targets[1:].reshape(-1))

    torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=1)
    loss.backward()
    optimizer.step()

    batch_losses.append(loss.item())

    if (batch_idx+1)%100==0:
      print(f'Epoch {epoch}/{num_epochs}, Batch {batch_idx+1}/{num_batches}, Batch Loss: {batch_losses[-1]:.4f}')

      translated_sentence=model.translate(source.to(device),dataset.vocab.itos_target,device)
      print(f'Translated sentence: {translated_sentence}\n')

  epoch_losses.append(np.mean(batch_losses))
  print(f'\nEpoch {epoch}/{num_epochs}, Epoch Loss: {epoch_losses[-1]:.4f}')
  current_epoch_loss=epoch_losses[-1]
  if current_epoch_loss<min_loss:
    print(f'Epoch Loss improved from {min_loss:.4f} to {current_epoch_loss:.4f}')
    min_loss=current_epoch_loss
    save_checkpoint(epoch,model,current_epoch_loss,optimizer,save_path)
    print(f'Improved Model saved at "{save_path}"\n')


  print(f'Epoch {epoch} ends!!\n\n')

Actual sentence: <SOS> please have a seat <EOS>
Epoch 1 begins:-



  self.pid = os.fork()


Epoch 1/40, Batch 100/1562, Batch Loss: 4.8075
Translated sentence: <SOS>i 's n't you <EOS>

Epoch 1/40, Batch 200/1562, Batch Loss: 4.3914
Translated sentence: <SOS>tom is a <EOS>

Epoch 1/40, Batch 300/1562, Batch Loss: 4.1464
Translated sentence: <SOS>you 're <EOS>

Epoch 1/40, Batch 400/1562, Batch Loss: 4.0596
Translated sentence: <SOS>the <UNK> <EOS>

Epoch 1/40, Batch 500/1562, Batch Loss: 4.0774
Translated sentence: <SOS>please me <EOS>

Epoch 1/40, Batch 600/1562, Batch Loss: 3.4407
Translated sentence: <SOS>please <UNK> <EOS>

Epoch 1/40, Batch 700/1562, Batch Loss: 3.3699
Translated sentence: <SOS>please you <EOS>

Epoch 1/40, Batch 800/1562, Batch Loss: 3.5245
Translated sentence: <SOS>please you <EOS>

Epoch 1/40, Batch 900/1562, Batch Loss: 3.1944
Translated sentence: <SOS>please please <EOS>

Epoch 1/40, Batch 1000/1562, Batch Loss: 2.9826
Translated sentence: <SOS>please please <EOS>

Epoch 1/40, Batch 1100/1562, Batch Loss: 2.7033
Translated sentence: <SOS>please pleas

###Testing Arena

In [95]:
print(f'Vocab size of source is: {len(dataset.vocab.itos_source)} and that of target is: {len(dataset.vocab.itos_target)}')

Vocab size of source is: 9918 and that of target is: 6525


In [53]:
num_chk_samples=100
idx=0
for sample in range(1,num_chk_samples+1):
  #idx=random.randint(0,num_samples)
  print(f'Sample no. {sample}, Dataset Index {idx}:-')
  source,target=dataset[idx]
  idx+=1
  validation_sentence_source=''
  for word_idx in source:
    validation_sentence_source+=dataset.vocab.itos_source[word_idx.item()]+' '
  validation_sentence_source=validation_sentence_source[:-1]
  validation_sentence_target=''
  validation_sentence_output=model.translate(source,dataset.vocab.itos_target,device)
  for word in target:
    validation_sentence_target+=dataset.vocab.itos_target[word.item()]+' '
  validation_sentence_target=validation_sentence_target[:-1]

  print(f'Source Sentence: {validation_sentence_source}')
  print(f'Target Sentence: {validation_sentence_target}')
  print(f'Translated Sentence: {validation_sentence_output}\n\n')

Sample no. 1, Dataset Index 0:-
Source Sentence: <SOS> hallo <EOS>
Target Sentence: <SOS> hi <EOS>
Translated Sentence: <SOS> <EOS>


Sample no. 2, Dataset Index 1:-
Source Sentence: <SOS> grüß gott <EOS>
Target Sentence: <SOS> hi <EOS>
Translated Sentence: <SOS> is some <EOS>


Sample no. 3, Dataset Index 2:-
Source Sentence: <SOS> lauf <EOS>
Target Sentence: <SOS> run <EOS>
Translated Sentence: <SOS> <EOS>


Sample no. 4, Dataset Index 3:-
Source Sentence: <SOS> potzdonner <EOS>
Target Sentence: <SOS> wow <EOS>
Translated Sentence: <SOS> <EOS>


Sample no. 5, Dataset Index 4:-
Source Sentence: <SOS> <UNK> <EOS>
Target Sentence: <SOS> wow <EOS>
Translated Sentence: <SOS> <UNK> <EOS>


Sample no. 6, Dataset Index 5:-
Source Sentence: <SOS> feuer <EOS>
Target Sentence: <SOS> fire <EOS>
Translated Sentence: <SOS> fire <EOS>


Sample no. 7, Dataset Index 6:-
Source Sentence: <SOS> hilfe <EOS>
Target Sentence: <SOS> help <EOS>
Translated Sentence: <SOS> <EOS>


Sample no. 8, Dataset Index 

In [None]:
plt.plot([i for i in range(1,len(epoch_losses)+1)],epoch_losses)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss vs Epochs')
plt.show()

In [94]:
source_inp='Wie geht es dir'
source_word_idxs=[dataset.vocab.stoi_source['<SOS>'],]
for word in [tok.text.lower() for tok in spacy_ger.tokenizer(source_inp)]:
  source_word_idxs.append(dataset.vocab.stoi_source[word])
source_word_idxs.append(dataset.vocab.stoi_source['<EOS>'])
source_word_idxs=torch.tensor(source_word_idxs)

translated_sentence=model.translate(source_word_idxs,dataset.vocab.itos_target,device)

target_sentence='<SOS> '+'how do you do'+' <EOS>'
print(f'Target sentence: {target_sentence}')
print(f'Translated sentence: {translated_sentence}')

Target sentence: <SOS> how do you do <EOS>
Translated sentence: <SOS> how are you <EOS>
