In [1]:
import requests
import torch
import gzip
import shutil


In [2]:
def set_seed(seed:int=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

In [3]:
base_url = "https://github.com/multi30k/dataset/raw/refs/heads/master/data/task1/raw/"

train_url = ("train.de.gz","train.en.gz")
val_url = ("val.de.gz","val.en.gz",)
test_url = ("test_2016_flickr.de.gz","test_2016_flickr.en.gz",)

In [4]:
def download(file_path,url):
    with open(file_path,"wb") as f:
        r = requests.get(url)
        f.write(r.content)

    return file_path

train_paths = [download(url,base_url+url) for url in train_url]
val_paths = [download(url,base_url+url) for url in val_url]
test_paths = [download(url,base_url+url) for url in test_url]

In [5]:
def extract(in_file,out_file):
    with gzip.open(in_file,"rb") as f_in:
        with open(out_file,"wb") as f_out:
            shutil.copyfileobj(f_in,f_out)

    return out_file

train_paths = [extract(file,file[:-3]) for file in train_paths]
val_paths = [extract(file,file[:-3]) for file in val_paths]
test_paths = [extract(file,file[:-3]) for file in test_paths]

In [6]:
from dataclasses import dataclass
@dataclass
class ModelArgs():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 32
    seq_len = 32
    no_of_hidden_units_gru = 32
    embedding_dim = 16
    h_t_size = 32
    en_vocab_size = None
    de_vocab_size = None

    embedding_dim =16
    no_of_layers=4

In [7]:
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m93.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.
Collecting de-core-news-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.8.0/de_core_news_sm-3.8.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m92.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0

In [8]:
import spacy
from collections import deque,Counter,defaultdict
import io

de_tokenizer = spacy.load("de_core_news_sm")
en_tokenizer = spacy.load("en_core_web_sm")

def tokenize(text,tokenizer):
    tokens = tokenizer(text)
    return [token.text.lower for token in tokens if not token.is_space]

def build_vocab(file_path,tokenizer,min_freq=1,special_tokens=["<bos>","<unk>","<pad>","<eos>"]):
    counter = Counter()
    with io.open(file_path,encoding="utf-8") as f:
        for string_ in f:
            tokens = tokenize(string_,tokenizer)
            counter.update(tokens)

    print("Completed extracting tokens...")
    tokens = [tok for tok,freq in counter.items() if freq>=min_freq]
    vocab = {tok:idx for idx,tok in enumerate(tokens+special_tokens)}

    
    unk_idx = vocab["<unk>"]

    vocab = defaultdict(lambda : unk_idx,vocab)

    return vocab

de_vocab = build_vocab(train_paths[0],de_tokenizer)
en_vocab = build_vocab(train_paths[1],en_tokenizer)

ModelArgs.de_vocab_size = len(de_vocab)+1
ModelArgs.en_vocab_size = len(en_vocab)+1

ModelArgs.de_vocab_size,ModelArgs.en_vocab_size            

Completed extracting tokens...
Completed extracting tokens...


(322646, 297118)

In [9]:
def data_process(file_paths):
    raw_de_iter = iter(io.open(file_paths[0],encoding="utf-8"))
    raw_en_iter = iter(io.open(file_paths[1],encoding="utf-8"))

    en_bos_idx = en_vocab["<bos>"]
    en_eos_idx = en_vocab["<eos>"]

    data = []
    
    for raw_de,raw_en in zip(raw_de_iter,raw_en_iter):
        de_tensor = torch.tensor([de_vocab[token] for token in tokenize(raw_de,de_tokenizer)])
        en_tensor = torch.tensor([en_vocab[token] for token in tokenize(raw_en,en_tokenizer)])

        en_tensor = torch.cat([torch.tensor([en_bos_idx]),en_tensor,torch.tensor([en_eos_idx])])

        de_tensor = torch.flip(de_tensor,dims=[0])

        data.append((de_tensor,en_tensor))

    return data

train_data = data_process(train_paths)
val_data = data_process(val_paths)
test_data = data_process(test_paths)

In [10]:
from torch.utils.data import Dataset,DataLoader
class TranslationDataset(Dataset):
    def __init__(self,data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        return self.data[idx]

train_dataset = TranslationDataset(train_data)
val_dataset = TranslationDataset(val_data)
test_dataset = TranslationDataset(test_data)

In [44]:
def collate_fn(batch,seq_len=ModelArgs.seq_len):

    de_batch,en_batch = zip(*batch)
    def pad_or_truncate(sequence,pad_value):
        if len(sequence) > seq_len:
            return sequence[:seq_len]
        else:
            pad_len = seq_len - len(sequence)
            return torch.cat([sequence,torch.full(size=[pad_len],fill_value=pad_value,dtype=sequence.dtype)])
    de_batch = [pad_or_truncate(sample,pad_value=de_vocab["<pad>"]) for sample in de_batch]
    en_batch = [pad_or_truncate(sample,pad_value=en_vocab["<pad>"]) for sample in en_batch]

    de_batch = torch.stack(de_batch)
    en_batch = torch.stack(en_batch)

    return de_batch,en_batch

train_dataloader = DataLoader(dataset=train_dataset,
                             batch_size=ModelArgs.batch_size,
                             shuffle=True,
                             collate_fn=collate_fn,
                             drop_last=True)
val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=ModelArgs.batch_size,
                            shuffle=True,
                            collate_fn=collate_fn,
                           drop_last=True)
test_dataloader = DataLoader(dataset=test_dataset,
                            batch_size=ModelArgs.batch_size,
                            shuffle=True,
                             collate_fn=collate_fn,
                            drop_last=True)

In [12]:
import torch.nn as nn
class ResetGate(nn.Module):
    def __init__(self,h_t_size,embedding_dim):
        super().__init__()
        self.sigma_nn = nn.Sequential(
            nn.Linear(in_features=h_t_size+embedding_dim,out_features=h_t_size),
            nn.Sigmoid()
        )

    def forward(self,X_t,h_t):
        # print("Reset")
        # print(f"X_t : {X_t.shape} h_t : {h_t.shape}")
        combined = torch.cat([h_t,X_t],dim=1)
        # print(f"combined : {combined.shape}")
        # print(self.sigma_nn)
        r_t = self.sigma_nn(combined)

        return r_t


In [13]:
class UpdateGate(nn.Module):
    def __init__(self,h_t_size,embedding_dim):
        super().__init__()
        self.sigma_nn = nn.Sequential(
            nn.Linear(in_features=h_t_size+embedding_dim,out_features=h_t_size),
            nn.Sigmoid()
        )
    def forward(self,X_t,h_t):
        # print("Update")
        # print(f"X_t : {X_t.shape} h_t : {h_t.shape}")
        combined = torch.cat([h_t,X_t],dim=1)
        # print(f"combined : {combined.shape}")
        # print(f"in features : {self.sigma_nn}")
        z_t = self.sigma_nn(combined)

        return z_t

In [14]:
class GRUCell(nn.Module):
    def __init__(self,h_t_size,embedding_dim):
        super().__init__()
        self.reset_gate = ResetGate(h_t_size=h_t_size,embedding_dim=embedding_dim)
        self.update_gate = UpdateGate(h_t_size=h_t_size,embedding_dim=embedding_dim)
        self.tanh_nn = nn.Sequential(
            nn.Linear(in_features=h_t_size+embedding_dim,out_features=h_t_size),
            nn.Tanh()
        )
    def forward(self,X_t,h_t):

        combined = torch.cat([h_t,X_t],dim=1)
        
        r_t = self.reset_gate(X_t,h_t)

        modulated_hidden_state = h_t * r_t

        modulated_hidden_state_X_t = torch.cat([modulated_hidden_state,X_t],dim=1)

        candidate_hidden_state = self.tanh_nn(modulated_hidden_state_X_t)

        z_t = self.update_gate(X_t,h_t)

        h_t_f = (1-z_t) * h_t + z_t * candidate_hidden_state

        return h_t_f

In [15]:
class GRUModel(nn.Module):
    def __init__(self,h_t_size,embedding_dim):
        """
        it take input for a single time step and perform the GRU operation return the hidden state
        """
        super().__init__()
        
        self.gru_cell = GRUCell(h_t_size=h_t_size,embedding_dim=embedding_dim)

    def forward(self,X_i,h_t=None):
        if h_t == None:
            h_t = torch.zeros(size=[ModelArgs.batch_size,ModelArgs.h_t_size],device=ModelArgs.device)

        h_t_f = self.gru_cell(X_i,h_t)

        return h_t

In [16]:
class Embeddings(nn.Module):
    def __init__(self,vocab_size,embedding_dim):
        super().__init__()
        self.embedding_layer = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embedding_dim)
    def forward(self,X):
        return self.embedding_layer(X)

In [17]:
class Encoder(nn.Module):
    def __init__(self,h_t_size,embedding_dim,no_of_layers,vocab_size):
        super().__init__()
        self.embedding_layer = Embeddings(vocab_size=vocab_size,embedding_dim=embedding_dim)
        self.encoder = nn.ModuleList([GRUModel(h_t_size=h_t_size,embedding_dim=embedding_dim)])
        for layer in range(no_of_layers-1):
            self.encoder.append(GRUModel(h_t_size=h_t_size,embedding_dim=h_t_size))

    def forward(self,X_t):
        hidden_states = []
        s_0 = None
        for timestep in range(ModelArgs.seq_len):
            X_t_i = X_t[:,timestep]
            e_i = self.embedding_layer(X_t_i)
            # print(f"e_i : {e_i.shape}")
            for layer in range(len(self.encoder)):
                if timestep == 0:
                    h_t = self.encoder[layer](e_i)
                else:
                    h_t = self.encoder[layer](e_i,h_t)
                e_i = h_t
            hidden_states.append(h_t)
            s_0 = h_t # for now assume that s_0 (s_i-1 for 1st timestep to calculate attention for decoder ) is the hidden state output of last layer of last dimension
        return torch.stack(hidden_states) , s_0
            

In [33]:
class Decoder(nn.Module):
    def __init__(self,h_t_size,embedding_dim,vocab_size,no_of_layers):
        super().__init__()
        self.embedding_layer = Embeddings(vocab_size=vocab_size,embedding_dim=embedding_dim)
        self.decoder = nn.ModuleList([GRUModel(h_t_size=h_t_size,embedding_dim=embedding_dim)])
        for i in range(no_of_layers-1):
            self.decoder.append(GRUModel(h_t_size=h_t_size,embedding_dim=h_t_size))
        self.attention = BhandanauAttention(h_t_size,attn_dim=h_t_size)
        self.classification_head = nn.Linear(in_features=h_t_size,out_features=vocab_size)

    def forward(self,s_0,h_t_all,X_t=None):
        all_logits = []
        for timestep in range(ModelArgs.seq_len):
            if X_t is not None:
                X_t_i = X_t[:,timestep]
                e_i = self.embedding_layer(X_t_i)
            else:
                if timestep == 0:
                    X_t_i = torch.full(size=[ModelArgs.batch_size],fill_value=en_vocab["<bos>"],device=ModelArgs.device)
                    e_i = self.embedding_layer(X_t_i)
                else:
                    # print(f"logits : {logits_curr_timestep.shape} timestep : {timestep}")
                    preds = torch.softmax(logits_curr_timestep,dim=-1) # on vocab dimension
                    # print(f"preds : {preds.shape}")
                    preds = torch.argmax(preds,dim=1) 
                    # print(f"preds 2: {preds.shape}")""
                    
                    e_i = self.embedding_layer(preds)
                    
            c_t = self.attention(s_0,h_t_all)
            h_t = c_t # already we send h_t (encoder) s_i-1 (previous step decoder output) so no need to again send h_t to layer of decoder
            for layer in range(len(self.decoder)):
                h_t = self.decoder[layer](e_i,h_t)
                e_i = h_t
            logits_curr_timestep = self.classification_head(h_t)
            all_logits.append(logits_curr_timestep)
            
        return torch.stack(all_logits)
        
        

e-tj = v_a.T * tanh( W_a * S_i-1 + U_a * h_j )

In [19]:
class BhandanauAttention(nn.Module):
    def __init__(self,h_t_size,attn_dim):
        super().__init__()
        self.W_a = nn.Linear(in_features=h_t_size,out_features=attn_dim)
        self.U_a = nn.Linear(in_features=h_t_size,out_features=attn_dim)
        self.V_a = nn.Linear(in_features=attn_dim,out_features=1)

    def forward(self,s,h):
        # print(f"h :{h.shape}")
        s_proj = self.W_a(s) 
        h_proj = self.U_a(h)
        # print(f"s_proj : {s_proj.shape}")
        # print(f"h_proj : {h_proj.shape}")
        energy = torch.tanh(s_proj + h_proj)
        # print(f"energy : {energy.shape}")
        e_tj = self.V_a(energy)

        # print(f"e_tj : {e_tj.shape}")
        alpha_tj = torch.softmax(e_tj,dim=1).squeeze()
        # print(f"alpha_tj : {alpha_tj.shape}")
        context = torch.bmm(alpha_tj.unsqueeze(1),h).squeeze()
        # print(f"context : {context.shape}")
        return context

In [20]:
class Seq2Seq(nn.Module):
    def __init__(self,h_t_size,embedding_dim,src_vocab_size,dst_vocab_size,no_of_layers):
        super().__init__()
        self.encoder = Encoder(h_t_size=h_t_size,embedding_dim=embedding_dim,no_of_layers=no_of_layers,vocab_size=src_vocab_size)
        self.decoder = Decoder(h_t_size=h_t_size,embedding_dim=embedding_dim,no_of_layers=no_of_layers,vocab_size=dst_vocab_size)
        

    def forward(self,X,y=None):
        h_t_all,s_0 = self.encoder(X)
        # generate c_t for every time step since s_i changes (previous time step output) 
        if y is not None:
            outputs = self.decoder(s_0=s_0,h_t_all=h_t_all,X_t=y)
        else:
            outputs = self.decoder(s_0=s_0,h_t_all=h_t_all)
            
        return outputs

In [21]:
sample_batch = next(iter(train_dataloader))

In [22]:
sample_de = sample_batch[0]
sample_en = sample_batch[1]

In [23]:
sample_de.shape

torch.Size([32, 32])

In [24]:
sample_en.shape

torch.Size([32, 32])

In [34]:
model = Seq2Seq(h_t_size=ModelArgs.h_t_size,
                embedding_dim=ModelArgs.embedding_dim,
                src_vocab_size=ModelArgs.de_vocab_size,
                dst_vocab_size=ModelArgs.en_vocab_size,
                no_of_layers=ModelArgs.no_of_layers)
model = model.to(ModelArgs.device)

In [38]:
all_logits = model(sample_de.to(ModelArgs.device))

In [39]:
all_logits.shape

torch.Size([32, 32, 297118])

In [40]:
all_logits = model(sample_de.to(ModelArgs.device),sample_en.to(ModelArgs.device))
all_logits.shape

torch.Size([32, 32, 297118])

In [45]:
def train(model,model_name,criterion,optimizer,train_dataloader,val_dataloader,epochs,min_val_loss,device):
    from tqdm import tqdm
    model = model.to(device)
    best_val_loss = float('inf')
    for epoch in range(epochs):
        train_loss,correct,total = 0.0,0,0
        model.train()
        train_progress = tqdm(train_dataloader,desc="Training")
        for de_batch,en_batch in train_progress:
            de_batch = de_batch.to(device)
            en_batch = en_batch.to(device)

            optimizer.zero_grad()

            all_logits = model(de_batch,en_batch)
            all_logits = all_logits.view(-1,ModelArgs.en_vocab_size)
            en_batch = en_batch.view(-1)

            loss = criterion(all_logits,en_batch)
            loss.backward()
            optimizer.step()

            preds = torch.softmax(all_logits,dim=-1)
            preds = torch.argmax(preds,dim=-1)
            
            train_loss += loss.item()
            correct += (preds == en_batch).sum()
            total += en_batch.shape[0]

            train_progress.set_postfix({"loss":f"{loss.item():.2f}"})
        train_loss /= len(train_dataloader)
        train_acc = correct/total
        with torch.inference_mode():
            val_loss,correct,total = 0.0,0,0
            val_progress = tqdm(val_dataloader,desc="Evaluation")
            for de_batch,en_batch in val_progress:
                de_batch = de_batch.to(device)
                en_batch = en_batch.to(device)

                all_logits = model(de_batch,en_batch)
                all_logits = all_logits.view(-1,ModelArgs.en_vocab_size)
                en_batch = en_batch.view(-1)
                loss = criterion(all_logits,en_batch)

                preds = torch.softmax(all_logits,dim=-1)
                preds = torch.argmax(preds,dim=-1)

                val_loss += loss.item()
                correct += (preds==en_batch).sum()
                total += en_batch.shape[0]

            val_loss /= len(val_dataloader)
            val_acc = correct/total

        print(f"Epoch : {epoch+1}/{epochs}\n Train Loss : {train_loss:.5f} Train Acc : {train_acc:.4f} \n Val Loss : {val_loss:.5f} Val Acc : {val_acc:.4f} \n\n")
        if best_val_loss > val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(),model_name)
        if val_loss < min_val_loss:
            print("Model trained successfully...")
            break
        
            

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
train(model=model,
      model_name="Bhandanau_Attention_scrath.pth",
      criterion=criterion,
     optimizer=optimizer,
     train_dataloader=train_dataloader,
     val_dataloader=val_dataloader,
     epochs=500,
     min_val_loss=1e-2,
     device=ModelArgs.device)