In [1]:
import torch
from dataclasses import dataclass
import torch.nn as nn
import gzip
import matplotlib.pyplot as plt
import requests
import shutil
from torch.utils.tensorboard import SummaryWriter

In [2]:
!pip install tensorboard



In [3]:
@dataclass
class ModelArgs():
    batch_size = 32
    seq_len = 32
    h_t_size = 32
    c_t_size = 32
    no_of_hidden_units_lstm = h_t_size
    
    max_lr = 1e-4
    epochs = 5000
    en_vocab_size = None
    de_vocab_size = None
    device = "cuda" if torch.cuda.is_available() else "cpu"


In [4]:
base_url = "https://github.com/multi30k/dataset/raw/refs/heads/master/data/task1/raw/"

train_url = ("train.de.gz","train.en.gz")
val_url = ("val.de.gz","val.en.gz",)
test_url = ("test_2016_flickr.de.gz","test_2016_flickr.en.gz",)

from time import sleep
def download_data(file_name, url, retries=3):
    for attempt in range(retries):
        try:
            with requests.get(url, stream=True, timeout=10) as r:
                r.raise_for_status()
                with open(file_name, "wb") as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            print(f"Downloaded: {file_name}")
            break
        except (requests.exceptions.RequestException, ConnectionResetError) as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            sleep(2)  # wait before retrying
            if attempt == retries - 1:
                print(f"Failed to download {file_name} after {retries} attempts.")
    return file_name

train_paths = [download_data(path,base_url+path) for path in train_url]
val_paths = [download_data(path,base_url+path) for path in val_url]
test_paths = [download_data(path,base_url+path) for path in test_url]

Attempt 1 failed: ('Connection aborted.', ConnectionResetError(10054, 'An existing connection was forcibly closed by the remote host', None, 10054, None))
Attempt 2 failed: ('Connection aborted.', ConnectionResetError(10054, 'An existing connection was forcibly closed by the remote host', None, 10054, None))
Downloaded: train.de.gz
Downloaded: train.en.gz
Attempt 1 failed: ('Connection aborted.', ConnectionResetError(10054, 'An existing connection was forcibly closed by the remote host', None, 10054, None))
Downloaded: val.de.gz
Downloaded: val.en.gz
Attempt 1 failed: ('Connection aborted.', ConnectionResetError(10054, 'An existing connection was forcibly closed by the remote host', None, 10054, None))
Downloaded: test_2016_flickr.de.gz
Downloaded: test_2016_flickr.en.gz


In [5]:
def extract_data(in_file,out_file):
    with gzip.open(in_file,"rb") as f_in:
        with open(out_file,"wb") as f_out:
            shutil.copyfileobj(f_in,f_out)
    return out_file
train_paths = [extract_data(file,file[:-3]) for file in train_paths]
val_paths = [extract_data(file,file[:-3]) for file in val_paths]
test_paths = [extract_data(file,file[:-3]) for file in test_paths]

In [None]:
from collections import defaultdict,Counter
import spacy
import io

de_tokenizer = spacy.load("de_core_news_sm")
en_tokenizer = spacy.load("en_core_web_sm")

def tokenize(text,tokenizer):
    tokens = tokenizer(text)
    return [token.text.lower for token in tokens if not token.is_space]

def build_vocab(file_name,tokenizer,min_freq=1,special_tokens=["<bos>","<unk>","<pad>","<eos>"]):
    counter = Counter()
    with io.open(file_name,encoding="utf-8") as f:
        for string_ in f:
            tokens = tokenize(string_,tokenizer)
            counter.update(tokens)
            
    tokens = [token for token,freq in counter.items() if freq >= min_freq]
    vocab = {token:idx for idx,token in enumerate(tokens+special_tokens)}
    unk_idx = vocab["<unk>"]
    vocab = defaultdict(lambda:unk_idx,vocab)
    
    return vocab
    
de_vocab = build_vocab(train_paths[0],tokenizer=de_tokenizer)
en_vocab = build_vocab(train_paths[1],en_tokenizer)

ModelArgs.en_vocab_size = len(en_vocab)+1
ModelArgs.de_vocab_size = len(de_vocab)+1

KeyboardInterrupt: 

: 

In [None]:
def data_process(filenames):
    raw_de_iter = iter(io.open(filenames[0],encoding="utf-8"))
    raw_en_iter = iter(io.open(filenames[1],encoding="utf-8"))
    data = []
    for raw_de,raw_en in zip(raw_de_iter,raw_en_iter):
        de_tensor = torch.tensor([de_vocab[token] for token in tokenize(raw_de,de_tokenizer)])
        en_tensor = torch.tensor([en_vocab[token] for token in tokenize(raw_en,en_tokenizer)])
        
        en_tensor = torch.cat([torch.tensor([en_vocab["<bos>"]]),en_tensor,torch.tensor([en_vocab["<eos>"]])])
        
        de_tensor = torch.flip(de_tensor,dims=[0])
        
        data.append((de_tensor,en_tensor))
    
    return data

train_data = data_process(train_paths)
val_data = data_process(val_paths)
test_data = data_process(test_paths)

In [None]:
from torch.utils.data import Dataset,DataLoader
class TranslationDataset(Dataset):
    def __init__(self,data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        return self.data[idx]

train_dataset = TranslationDataset(train_data)
val_dataset = TranslationDataset(val_data)
test_dataset = TranslationDataset(test_data)

In [None]:
def collate_fn(batch,seq_len=ModelArgs.seq_len):
    de_batch,en_batch = zip(*batch)
    
    def pad_or_truncate(sequence,pad_value):
        if len(sequence) >= seq_len:
            return sequence[:seq_len]
        else:
            pad_len = seq_len - len(sequence)
            padding = torch.full([pad_len],fill_value=pad_value,dtype=sequence.dtype)
            # print(f"sequence : {sequence.shape}")
            # print(f"padding : {padding.shape}")
            return torch.cat([sequence,padding])
    de_batch = [pad_or_truncate(sample,pad_value=de_vocab["<pad>"]) for sample in de_batch]
    en_batch = [pad_or_truncate(sample,pad_value=en_vocab["<pad>"]) for sample in en_batch]
    
    de_batch = torch.stack(de_batch)
    en_batch = torch.stack(en_batch)
    
    return de_batch,en_batch

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=ModelArgs.batch_size,
                              shuffle=True,
                              collate_fn=collate_fn,
                              drop_last=True)
val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=ModelArgs.batch_size,
                            shuffle=True,
                            collate_fn=collate_fn,
                            drop_last=True)
test_dataloader = DataLoader(dataset=test_data,
                             batch_size=ModelArgs.batch_size,
                             shuffle=True,
                             collate_fn=collate_fn,
                             drop_last=True)

In [None]:
sample_batch = next(iter(train_dataloader))
sample_de = sample_batch[0]
sample_en = sample_batch[1]

sample_de.shape,sample_en.shape

(torch.Size([32, 32]), torch.Size([32, 32]))

In [None]:
class ForgetGate(nn.Module):
    def __init__(self,h_t_size,embedding_dim):
        super().__init__()
        self.sigma_nn = nn.Sequential(
            nn.Linear(in_features=h_t_size+embedding_dim,out_features=h_t_size),
            nn.Sigmoid()
        )
    def forward(self,X_t,h_t):
        # print(f"h_t : {h_t.shape}")
        # print(f"X_t : {X_t.shape}")
        combined = torch.cat([h_t,X_t],dim=1)
        # print(f"combined : {combined.shape} \n required : {self.sigma_nn}")
        f_t = self.sigma_nn(combined)
        
        return f_t

In [None]:
class InputGate(nn.Module):
    def __init__(self,h_t_size,embedding_dim):
        super().__init__()
        self.sigma_nn = nn.Sequential(
            nn.Linear(in_features=h_t_size+embedding_dim,out_features=h_t_size),
            nn.Sigmoid()
        )
        self.tanh_nn = nn.Sequential(
            nn.Linear(in_features=h_t_size+embedding_dim,out_features=h_t_size),
            nn.Tanh()
        )
    def forward(self,X_t,h_t):
        combined = torch.cat([h_t,X_t],dim=1)
        
        i_t = self.sigma_nn(combined)
        c_t_dash = self.tanh_nn(combined)
        
        candidate_hidden_state = i_t * c_t_dash
        return candidate_hidden_state

In [None]:
class OutputGate(nn.Module):
    def __init__(self,h_t_size,embedding_dim):
        super().__init__()
        self.sigma_nn = nn.Sequential(
            nn.Linear(in_features=h_t_size+embedding_dim,out_features=h_t_size),
            nn.Sigmoid()
        )
    def forward(self,X_t,h_t):
        combined = torch.cat([h_t,X_t],dim=1)
        o_t = self.sigma_nn(combined)
        return o_t

In [None]:
class LSTMCell(nn.Module):
    def __init__(self,h_t_size,embedding_dim):
        super().__init__()
        self.forget_gate = ForgetGate(h_t_size=h_t_size,embedding_dim=embedding_dim)
        self.input_gate = InputGate(h_t_size=h_t_size,embedding_dim=embedding_dim)
        self.output_gate = OutputGate(h_t_size=h_t_size,embedding_dim=embedding_dim)
        
    def forward(self,X_t,h_t,c_t):
        f_t = self.forget_gate(X_t,h_t)
        c_t = c_t * f_t
        candidate_hidden_state = self.input_gate(X_t,h_t)
        c_t = c_t + candidate_hidden_state
        o_t = self.output_gate(X_t,h_t)
        h_t_new = torch.tanh(c_t)*o_t
        return h_t_new,c_t

In [None]:
class LSTMModel(nn.Module):
    def __init__(self,h_t_size,embedding_dim):
        super().__init__()
        self.lstm_cell = LSTMCell(h_t_size=h_t_size,embedding_dim=embedding_dim)
    def forward(self,X_t,h_t=None,c_t=None):
        if h_t is None:
            h_t = torch.zeros(size=[ModelArgs.batch_size,ModelArgs.h_t_size],device=ModelArgs.device)
        if c_t is None:
            c_t = torch.zeros(size=[ModelArgs.batch_size,ModelArgs.c_t_size],device=ModelArgs.device)
        
        h_t,c_t = self.lstm_cell(X_t,h_t,c_t) # for a single time a step
        
        return h_t,c_t
        

In [None]:
class EmbeddingTable(nn.Module):
    def __init__(self,vocab_size,embedding_dim):
        super().__init__()
        self.embedding_layer = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embedding_dim)
    def forward(self,X):
        return self.embedding_layer(X)

In [None]:
class Encoder(nn.Module):
    def __init__(self,h_t_size,embedding_dim,no_of_layers,vocab_size):
        super().__init__()
        self.embedding_layer = EmbeddingTable(vocab_size=vocab_size,embedding_dim=embedding_dim)
        self.encoder = nn.ModuleList([LSTMModel(h_t_size=h_t_size,embedding_dim=embedding_dim)])
        for i in range(no_of_layers-1):
            self.encoder.append(LSTMModel(h_t_size=h_t_size,embedding_dim=h_t_size)) # the input to the upper layers is h_t
    
    def forward(self,X):
        outputs = []
        for timestep in range(ModelArgs.seq_len):
            X_t = X[:,timestep]
            e_i = self.embedding_layer(X_t)
            for layer in range(len(self.encoder)):
                if timestep==0:
                    h_t,c_t = self.encoder[layer](e_i)
                else:
                    h_t,c_t = self.encoder[layer](e_i,h_t,c_t)
                e_i = h_t
            outputs.append(h_t)
        return torch.stack(outputs) , h_t, c_t# src_len * [batch_size,h_t_size] -> [batch_size,src_len,h_t_size]

In [None]:
class LoungAttention(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,encoder_outputs,decoder_output):
        # encoder_outputs -> [batch_size,src_len,h_t_size]
        # decoder_output -> [batch_size,h_t_size]
        decoder_output = decoder_output.unsqueeze(1) # [batch_size,1,h_t_size]
        # print(f"decoder_output : {decoder_output.shape}")
        # print(f"encoder_outputs : {encoder_outputs.shape}")
        scores = torch.bmm(decoder_output,encoder_outputs.transpose(1,2)) # [batch_size,1,src_len]
        attn_weights = torch.softmax(scores,dim=-1) # [batch_size,1,src_len]
        
        context = torch.bmm(attn_weights,encoder_outputs) # [batch_size,1,h_t_size]
        return context.squeeze(),attn_weights.squeeze()

h_t_tilde = tanh(Wc[ct:ht])

In [None]:
class Decoder(nn.Module):
    def __init__(self,h_t_size,embedding_dim,no_of_layers,vocab_size):
        super().__init__()
        self.embedding_layer = EmbeddingTable(vocab_size=vocab_size,embedding_dim=embedding_dim)
        self.decoder = nn.ModuleList([LSTMModel(h_t_size=h_t_size,embedding_dim=embedding_dim+h_t_size)]) # embedding_dim+h_t_size -> for decoder
        for i in range(len(self.decoder)):
            self.decoder.append(LSTMModel(h_t_size=h_t_size,embedding_dim=h_t_size))
        self.attention = LoungAttention()
        self.concat_context = nn.Linear(in_features=h_t_size+h_t_size,out_features=h_t_size)
        self.classification_head = nn.Linear(in_features=h_t_size,out_features=vocab_size)
    
    def forward(self,encoder_outputs,h_t,c_t,X=None):
        outputs = []
        h_t_prev,c_t_prev = h_t,c_t
        h_t_tilde_prev = torch.zeros(size=[ModelArgs.batch_size,ModelArgs.seq_len],device=ModelArgs.device)
        for timestep in range(ModelArgs.seq_len):
            if X is not None:
                X_t = X[:,timestep]
            else:
                if timestep == 0:
                    X_t = torch.full(size=[ModelArgs.batch_size],fill_value=en_vocab["<bos>"],device=ModelArgs.device)
                else:
                    preds = torch.softmax(logits,dim=-1)
                    X_t = torch.argmax(preds,dim=1)
            e_i = self.embedding_layer(X_t)
            input_t = torch.cat([e_i,h_t_tilde_prev],dim=1) # [batch_size,embed_dim+h_t_size]
            for layer in range(len(self.decoder)):
                h_t_prev,c_t_prev = self.decoder[layer](input_t,h_t_prev,c_t_prev)
                input_t = h_t_prev
            context,attn_weights = self.attention(encoder_outputs,h_t_prev) # h_t_prev is the current h_t generated by the decoder which is used by the Loung Attention to compute c_t (context vector)
            concat = torch.cat([h_t_prev,context],dim=1) # h_t_prev is the current h_t generated by the decoder
            h_t_tilde = torch.tanh(self.concat_context(concat))
            logits = self.classification_head(h_t_tilde)
            outputs.append(logits)
            h_t_tilde_prev = h_t_tilde # input feeding approach
        
        return torch.stack(outputs) # [batch_size , seq_len ,vocab_size ]
            

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self,h_t_size,embedding_dim,src_vocab_size,dest_vocab_size,no_of_layers):
        super().__init__()
        self.encoder = Encoder(h_t_size=h_t_size,embedding_dim=embedding_dim,no_of_layers=no_of_layers,vocab_size=src_vocab_size)
        self.decoder = Decoder(h_t_size=h_t_size,embedding_dim=embedding_dim,no_of_layers=no_of_layers,vocab_size=dest_vocab_size)
        
    def forward(self,X,y=None):
        encoder_outputs,h_t,c_t = self.encoder(X)
        logits = self.decoder(encoder_outputs,h_t,c_t,y)
        
        return logits
        

In [None]:
ModelArgs.embedding_dim = 16
ModelArgs.no_of_layers = 4
model = Seq2Seq(h_t_size=ModelArgs.h_t_size,embedding_dim=ModelArgs.embedding_dim,src_vocab_size=ModelArgs.de_vocab_size,dest_vocab_size=ModelArgs.en_vocab_size,no_of_layers=ModelArgs.no_of_layers)
model = model.to(ModelArgs.device)

In [None]:
res = model(sample_de.to(ModelArgs.device))
res.shape

torch.Size([32, 32, 297118])

In [None]:
res = model(sample_de.to(ModelArgs.device),sample_en.to(ModelArgs.device))
res.shape

torch.Size([32, 32, 297118])

In [None]:
def train(model,model_name,criterion,optimizer,train_dataloader,val_dataloader,epochs,min_val_loss,device,writer,lr_scheduler):
    try:
        train_global_step,val_global_step = 1,1
        metrics = {"train_loss":[],"val_loss":[],"train_acc":[],"val_acc":[]}
        from tqdm import tqdm
        best_val_loss = float("inf")
        model = model.to(device)
        for epoch in range(epochs):
            model.train()
            train_loss,correct,total = 0.0,0,0
            train_progress = tqdm(train_dataloader,desc="Training")
            for idx,(de_batch,en_batch) in enumerate(train_progress):
                de_batch = de_batch.to(device)
                en_batch = en_batch.to(device)
                optimizer.zero_grad()
                all_logits = model(de_batch,en_batch)
                # print(all_logits.shape)
                all_logits = all_logits.view(-1,ModelArgs.en_vocab_size)
                en_batch = en_batch.view(-1)
                
                loss = criterion(all_logits,en_batch)
                loss.backward()
                optimizer.step()
                
                
                
                pred_probs = torch.softmax(all_logits,dim=-1)
                preds = torch.argmax(pred_probs,dim=1)
                train_loss += loss.item()
                correct += (preds == en_batch).sum()
                total += en_batch.shape[0]
                
                train_progress.set_postfix({"loss":f"{loss.item():.2f}"})
                
                lr_scheduler.step(train_loss/(idx+1))
                
                metrics["train_loss"].append(train_loss/(idx+1))
                metrics["train_acc"].append(correct/total)
                
                writer.add_scalar("loss/train_iter",train_loss/(idx+1),train_global_step)
                writer.add_scalar("accuracy/train_iter",correct/total,train_global_step)
                train_global_step += 1
                
                
            train_loss /= len(train_dataloader)
            train_acc = correct/total
        
            with torch.inference_mode():
                model.eval()
                val_loss,correct,total = 0.0,0,0
                val_progress = tqdm(val_dataloader,desc="Evaluation")
                for idx,(de_batch,en_batch) in enumerate(val_progress):
                    de_batch = de_batch.to(device)
                    en_batch = en_batch.to(device)
                    
                    all_logits = model(de_batch,en_batch)
                    all_logits = all_logits.view(-1,ModelArgs.en_vocab_size)
                    en_batch = en_batch.view(-1)
                    
                    loss = criterion(all_logits,en_batch)
                    
                    pred_probs = torch.softmax(all_logits,dim=-1)
                    preds = torch.argmax(pred_probs,dim=1)
                    
                    val_loss += loss.item()
                    correct += (preds == en_batch).sum()
                    total += en_batch.shape[0]
                    
                    val_progress.set_postfix({"loss":f"{loss.item():.2f}"})
                    
                    lr_scheduler.step(val_loss/(idx+1))
                    
                    metrics["val_loss"].append(val_loss/(idx+1))
                    metrics["val_acc"].append(correct/total)
                    
                    writer.add_scalar("loss/val_iter",val_loss/(idx+1),val_global_step)
                    writer.add_scalar("accuracy/val_iter",correct/total,val_global_step)
                    val_global_step += 1
                    
                    
                val_loss /= len(val_dataloader)
                val_acc = correct/total
                
            print(f"Epoch : {epoch}/{epochs} \n Train Loss : {train_loss:.4f} , Train Acc : {train_acc:.4f} \n Val Loss : {val_loss:.4f} Val Acc : {val_acc:.4f}")
            
            # metrics["train_loss"].append(train_loss)
            # metrics["val_loss"].append(val_loss)
            # metrics["train_acc"].append(train_acc)
            # metrics["val_acc"].append(val_acc)
            
            writer.add_scalar("loss/train_epoch",train_loss,epoch)
            writer.add_scalar("loss/val_epoch",val_loss,epoch)
            writer.add_scalar("accuracy/train_epoch",train_acc,epoch)
            writer.add_scalar("accuracy/val_epoch",val_acc,epoch)
                    
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(),model_name)
            if val_loss < min_val_loss:
                print("Model trained successfully....")
                break
        return metrics
    except KeyboardInterrupt:
        return metrics

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=ModelArgs.max_lr)
writer = SummaryWriter("runs/loung_attention")
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,mode="min",factor=0.1,patience=3,verbose=True)
metrics = train(model=model,
                model_name="loung_attention.pth",
                criterion=criterion,
                optimizer=optimizer,
                train_dataloader=train_dataloader,
                val_dataloader=val_dataloader,
                epochs=ModelArgs.epochs,
                min_val_loss=1e-3,
                device=ModelArgs.device,
                writer=writer,
                lr_scheduler=lr_scheduler)

Training: 100%|██████████| 906/906 [52:02<00:00,  3.45s/it, loss=3.74] 
Evaluation: 100%|██████████| 31/31 [00:22<00:00,  1.35it/s, loss=3.82]


Epoch : 0/5000 
 Train Loss : 7.5311 , Train Acc : 0.4885 
 Val Loss : 3.7572 Val Acc : 0.5227


Training: 100%|██████████| 906/906 [52:20<00:00,  3.47s/it, loss=1.44]
Evaluation: 100%|██████████| 31/31 [00:59<00:00,  1.91s/it, loss=1.52]


Epoch : 1/5000 
 Train Loss : 2.2047 , Train Acc : 0.5281 
 Val Loss : 1.5142 Val Acc : 0.5237


Training: 100%|██████████| 906/906 [53:59<00:00,  3.58s/it, loss=1.30] 
Evaluation: 100%|██████████| 31/31 [00:23<00:00,  1.33it/s, loss=1.33]


Epoch : 2/5000 
 Train Loss : 1.3770 , Train Acc : 0.5281 
 Val Loss : 1.3076 Val Acc : 0.5238


Training:  11%|█         | 100/906 [17:55<2:24:25, 10.75s/it, loss=1.32] 


In [None]:
import torch
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

In [None]:
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,mode="min",factor=0.1,patience=3,verbose=True)

model.load_state_dict(torch.load("loung_attention.pth",weights_only=True))

metrics = train(model=model,
                model_name="loung_attention.pth",
                criterion=criterion,
                optimizer=optimizer,
                train_dataloader=train_dataloader,
                val_dataloader=val_dataloader,
                epochs=ModelArgs.epochs,
                min_val_loss=1e-3,
                device=ModelArgs.device,
                writer=writer,
                lr_scheduler=lr_scheduler)


Training:  54%|█████▍    | 488/906 [30:47<26:22,  3.79s/it, loss=1.31] 


In [None]:
model

<All keys matched successfully>

In [None]:
len(metrics["val_loss"])

0