In [1]:
import os
import torch 
from io import open
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F
from torch import optim,nn
import time
import math
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns
from matplotlib import font_manager
import wandb
import gc

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# Special tokens
PAD_TOKEN = "<PAD>"
EOS_TOKEN = "<EOS>"
SOS_TOKEN = "<SOS>"
UNK_TOKEN = "<UNK>"

# Update the function to create mappings to include the special tokens
def create_mappings(vocab):
    vocab = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN] + sorted(vocab)
    word2int = {word: i for i, word in enumerate(vocab)}
    int2word = {i: word for word, i in word2int.items()}
    return word2int, int2word

def wordEncoder(words,encodelist):
    n_letters = len(encodelist)
    tensor = torch.zeros(len(words), n_letters)
    for i,word in enumerate(words):
        tensor[i][encodelist[word]] = 1
    return tensor
    
def tokenise(word, wordMap):
    return torch.tensor([wordMap[SOS_TOKEN]] + [wordMap[letter] for letter in word] + [wordMap[EOS_TOKEN]], dtype=torch.long)

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [4]:
types = {"train":'hi.translit.sampled.train.tsv',"val":'hi.translit.sampled.dev.tsv',"test":"hi.translit.sampled.test.tsv"}
with open(os.path.join("lexicons/",types["train"]), "r", encoding="utf-8") as f:
    lines = f.readlines()
train_data = np.array([[text.split("\t")[0],text.split("\t")[1][:-1]] for text in lines if not text.split("\t")[0] == '</s>'])
with open(os.path.join("lexicons/",types["val"]), "r", encoding="utf-8") as f:
    lines = f.readlines()
val_data = np.array([[text.split("\t")[0],text.split("\t")[1][:-1]] for text in lines if not text.split("\t")[0] == '</s>'])
with open(os.path.join("lexicons/",types["test"]), "r", encoding="utf-8") as f:
    lines = f.readlines()
test_data = np.array([[text.split("\t")[0],text.split("\t")[1][:-1]] for text in lines if not text.split("\t")[0] == '</s>'])
test_data_point = np.array([["अनुज","anuj"],["निर्णयप्रक्रियेत","nirnayaprakriyet"]])
merged_data = np.concatenate((train_data,val_data))
len(merged_data)  

48562

In [5]:
test_data_point

array([['अनुज', 'anuj'],
       ['निर्णयप्रक्रियेत', 'nirnayaprakriyet']], dtype='<U16')

In [6]:
devnagri2int,latinList2int = {letter: idx for idx, letter in enumerate(set("".join(merged_data[:, 0])))},{letter: idx for idx, letter in enumerate(set("".join(merged_data[:, 1])))}
int2devnagri,int2latinList = {idx: letter for letter, idx in devnagri2int.items()},{idx: letter for letter, idx in latinList2int.items()}

In [7]:
data = np.array([[text.split("\t")[0],text.split("\t")[1][:-1]] for text in lines if not text.split("\t")[0] == '</s>'])

In [8]:
# Update the vocabularies
devnagri2int, int2devnagri = create_mappings(set("".join(merged_data[:, 0])))
latin2int, int2latin = create_mappings(set("".join(merged_data[:, 1])))

In [9]:
class LangDataset(Dataset):
    def __init__(self,type:str):
        types = {"train":train_data,"val":val_data,"test":test_data, "test_ponit":test_data_point}
        data = types[type]
        self.X,self.Y,self.X_encoded,self.Y_encoded = [],[],[],[]
        for word in data:
            self.X.append(word[1])
            self.Y.append(word[0])
            self.X_encoded.append(tokenise(word[1],latin2int))
            self.Y_encoded.append(tokenise(word[0],devnagri2int))
        
    def __getitem__(self, idx):
        latin_word= self.X[idx]
        devnagri_word = self.Y[idx]
        latin_tensor = self.X_encoded[idx]
        devnagri_tensor = self.Y_encoded[idx]

        return latin_word, devnagri_word, latin_tensor, devnagri_tensor

    def __len__(self):
        return len(self.X)

In [10]:
class EncoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, nonlinearity="tanh", dropout_p=0.1, layer="rnn"):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.layer = layer
        self.embedding = nn.Embedding(vocab_size, embed_size)
        
        if layer == "rnn":
            self.cell = nn.RNN(embed_size, hidden_size, num_layers, nonlinearity, batch_first=True) 
        elif layer == "gru":   
            self.cell = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True)
        elif layer == "lstm":
            self.cell = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input, input_lengths, hidden=None):
        embedded = self.dropout(self.embedding(input))
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, input_lengths, batch_first=True, enforce_sorted=True
        )
        
        if self.layer == "lstm":
            output, (hidden, cell) = self.cell(packed, hidden)
        else:
            output, hidden = self.cell(packed, hidden)
            cell = None
            
        # Unpack sequence
        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
        
        return output, hidden, cell

In [11]:
# class BeamSearchNode:
#     def __init__(self, hidden_state, cell_state, prev_node, token_id, log_prob, length):
#         self.hidden = hidden_state
#         self.cell = cell_state
#         self.prev = prev_node
#         self.token = token_id
#         self.logp = log_prob
#         self.length = length

#     def get_score(self, length_normalize=True):
#         if length_normalize:
#             return self.logp / float(self.length + 1e-6)
#         return self.logp

# class DecoderRNN(nn.Module):
#     def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, nonlinearity="tanh", layer="rnn"):
#         super().__init__()
#         self.hidden_size = hidden_size
#         self.num_layers = num_layers
#         self.embedding = nn.Embedding(vocab_size, embed_size)
#         self.layer = layer
        
#         if layer == "rnn":
#             self.cell = nn.RNN(embed_size, hidden_size, num_layers, nonlinearity, batch_first=True) 
#         elif layer == "gru":   
#             self.cell = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True)
#         elif layer == "lstm":
#             self.cell = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
#         self.fc = nn.Linear(hidden_size, vocab_size)

#     def forward(self, encoder_outputs, encoder_hidden, encoder_cell, 
#                 target_tensor=None, MAX_LENGTH=None):
#         batch_size = encoder_outputs.size(0)
        
#         # Initialize decoder states
#         if self.layer == "lstm":
#             decoder_hidden = encoder_hidden
#             decoder_cell = encoder_cell
#         else:
#             decoder_hidden = encoder_hidden
#             decoder_cell = None
#         if target_tensor is not None:
#             # Teacher forcing: Process all timesteps at once
#             MAX_LENGTH = target_tensor.size(1)
            
#             # Create shifted sequences for teacher forcing
#             decoder_input = torch.cat([
#                 torch.full((batch_size, 1), devnagri2int[SOS_TOKEN], device=device),
#                 target_tensor[:, :-1]
#             ], dim=1)
            
#             # Process entire sequence
#             embedded = self.embedding(decoder_input)
            
#             # Run through RNN
#             if self.layer == "lstm":
#                 decoder_outputs, (decoder_hidden, decoder_cell) = self.cell(embedded, (decoder_hidden, decoder_cell))
#             else:
#                 decoder_outputs, decoder_hidden = self.cell(embedded, decoder_hidden)
#             # Apply output projection
#             decoder_outputs = self.fc(decoder_outputs)
#             decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
#             return decoder_outputs, decoder_hidden, decoder_cell, None 
#         else:
#             MAX_LENGTH = MAX_LENGTH or 30
#             sos = devnagri2int[SOS_TOKEN]
#             eos = devnagri2int[EOS_TOKEN]
#             # start with all SOS tokens
#             input_token = torch.full((batch_size,1), sos,
#                                      device=device, dtype=torch.long)
#             preds = []
#             hidden, cell = decoder_hidden, decoder_cell

#             for _ in range(MAX_LENGTH):
#                 emb = self.embedding(input_token)    # (B,1,E)
#                 if self.layer == "lstm":
#                     out, (hidden, cell) = self.cell(emb, (hidden, cell))
#                 else:
#                     out, hidden = self.cell(emb, hidden)
#                 logits = self.fc(out.squeeze(1))     # (B, V)
#                 next_tok = logits.argmax(dim=-1, keepdim=True)  # (B,1)
#                 preds.append(next_tok)
#                 input_token = next_tok
#                 # once all batches have produced EOS, we can break early
#                 if (next_tok == eos).all():
#                     break

#             # concatenate predictions into (B, T)
#             predicted_seqs = torch.cat(preds, dim=1)
#             return predicted_seqs, None, None, None    
        
#     def forward_step(self, x, hidden):
#         out = self.embedding(x)
        
#         if self.layer == "lstm":
#             if hidden is None:
#                 h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
#                 c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
#                 hidden = (h0, c0)
#             output, (hidden, cell) = self.cell(out, hidden)
#             output = self.fc(output)
#             return output, hidden, cell
#         else:
#             if hidden is None:
#                 hidden = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
#                 output, hidden = self.cell(out, hidden)
#             else:
#                 output, hidden = self.cell(out, hidden[0])
#             output = self.fc(output)
#             return output, hidden, None

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

class BeamSearchNode:
    def __init__(self, hidden_state, cell_state, prev_node, token_id, log_prob, length):
        self.hidden = hidden_state
        self.cell   = cell_state
        self.prev   = prev_node
        self.token  = token_id
        self.logp   = log_prob
        self.length = length

    def get_score(self, length_normalize=True):
        return (self.logp / (self.length + 1e-6)) if length_normalize else self.logp

    def __lt__(self, other):
        return self.get_score() < other.get_score()


class DecoderRNN(nn.Module):
    def __init__(self,vocab_size,embed_size,hidden_size,num_layers=1,nonlinearity="tanh",layer="rnn",pad_token_id=0):
        super().__init__()
        self.hidden_size  = hidden_size
        self.num_layers   = num_layers
        self.layer        = layer
        self.embedding    = nn.Embedding(vocab_size, embed_size, padding_idx=pad_token_id)
        self.vocab_size   = vocab_size
        self.pad_token_id = pad_token_id

        if layer == "rnn":
            self.cell = nn.RNN(embed_size, hidden_size, num_layers,
                               nonlinearity, batch_first=True)
        elif layer == "gru":
            self.cell = nn.GRU(embed_size, hidden_size, num_layers,
                               batch_first=True)
        elif layer == "lstm":
            self.cell = nn.LSTM(embed_size, hidden_size, num_layers,
                                batch_first=True)

        self.fc = nn.Linear(hidden_size, vocab_size)


    def forward(self,encoder_outputs,encoder_hidden,encoder_cell, target_tensor=None,MAX_LENGTH=None,teacher_forcing_prob=0.5,beam_width=5):
        B = encoder_outputs.size(0)
        device = encoder_outputs.device

        # 1) init states
        if self.layer == "lstm":
            decoder_hidden = encoder_hidden.contiguous()
            decoder_cell   = encoder_cell.contiguous()
        else:
            decoder_hidden = encoder_hidden.contiguous()
            decoder_cell   = None

        # 2) training branch
        if target_tensor is not None:
            T = target_tensor.size(1)
            input_tok = torch.full((B, 1),
                                   devnagri2int[SOS_TOKEN],
                                   dtype=torch.long,
                                   device=device)
            outputs = []
            for t in range(T):
                emb = self.embedding(input_tok)
                if self.layer == "lstm":
                    h = decoder_hidden.contiguous()
                    c = decoder_cell.contiguous()
                    out, (decoder_hidden, decoder_cell) = self.cell(emb, (h, c))
                else:
                    hx = decoder_hidden.contiguous()
                    out, decoder_hidden = self.cell(emb, hx)

                logits = self.fc(out.squeeze(1))
                logp   = F.log_softmax(logits, dim=1)
                outputs.append(logp.unsqueeze(1))

                if random.random() < teacher_forcing_prob:
                    input_tok = target_tensor[:, t].unsqueeze(1)
                else:
                    input_tok = logp.argmax(1).unsqueeze(1)

            return torch.cat(outputs, dim=1), decoder_hidden, decoder_cell, None

        # 3) inference with **batched** beam search
        else:
            K = beam_width
            V = self.vocab_size
            max_len = MAX_LENGTH or 30
            sos = devnagri2int[SOS_TOKEN]
            eos = devnagri2int[EOS_TOKEN]

            # a) expand hidden/cell: (layers, B, H) → (layers, B*K, H)
            if self.layer == "lstm":
                h0 = encoder_hidden.contiguous().unsqueeze(2).repeat(1, 1, K, 1)
                c0 = encoder_cell.contiguous().unsqueeze(2).repeat(1, 1, K, 1)
                hidden = h0.view(self.num_layers, B*K, self.hidden_size)
                cell   = c0.view(self.num_layers, B*K, self.hidden_size)
            else:
                h0 = encoder_hidden.contiguous().unsqueeze(2).repeat(1, 1, K, 1)
                hidden = h0.view(self.num_layers, B*K, self.hidden_size)
                cell   = None

            # b) init scores & sequences
            scores = torch.zeros(B, K, device=device)
            scores[:,1:] = -1e9
            seqs = torch.full((B, K, max_len),
                              self.pad_token_id,
                              dtype=torch.long,
                              device=device)
            seqs[:,:,0] = sos
            input_tok = torch.full((B*K,1), sos, dtype=torch.long, device=device)

            # c) step time
            for t in range(1, max_len):
                emb = self.embedding(input_tok)  # (B*K,1,E)
                if self.layer == "lstm":
                    h_in, c_in = hidden.contiguous(), cell.contiguous()
                    out, (h_out, c_out) = self.cell(emb, (h_in, c_in))
                else:
                    h_in = hidden.contiguous()
                    out, h_out = self.cell(emb, h_in)
                    c_out = None

                logits   = self.fc(out.squeeze(1))            # (B*K, V)
                logp_all = F.log_softmax(logits, dim=-1).view(B, K, V)

                total_scores = scores.unsqueeze(2) + logp_all  # (B, K, V)
                flat = total_scores.view(B, -1)               # (B, K*V)
                top_scores, top_idx = flat.topk(K, dim=-1)    # (B, K)

                beam_idx  = top_idx // V                      # (B, K)
                token_idx = top_idx %  V                      # (B, K)

                # reorder hidden
                h_beams = h_out.view(self.num_layers, B, K, self.hidden_size)
                hidden  = h_beams.gather(
                    2,
                    beam_idx.unsqueeze(0).unsqueeze(-1)
                            .expand(self.num_layers, B, K, self.hidden_size)
                ).view(self.num_layers, B*K, self.hidden_size)

                if self.layer == "lstm":
                    c_beams = c_out.view(self.num_layers, B, K, self.hidden_size)
                    cell    = c_beams.gather(
                        2,
                        beam_idx.unsqueeze(0).unsqueeze(-1)
                                .expand(self.num_layers, B, K, self.hidden_size)
                    ).view(self.num_layers, B*K, self.hidden_size)

                scores = top_scores  # update

                # reorder & append seqs
                seqs = seqs.gather(
                    1,
                    beam_idx.unsqueeze(-1)
                            .expand(B, K, max_len)
                )
                seqs[:,:,t] = token_idx
                input_tok = token_idx.view(B*K,1)

                if (token_idx == eos).all():
                    break

            # d) select best beam
            best = scores.argmax(dim=-1)  # (B,)
            preds = seqs[torch.arange(B, device=device), best]  # (B, max_len)

            return preds, None, None, None


In [13]:
# class BahdanauAttention(nn.Module):
#     def __init__(self, hidden_size):
#         super().__init__()
#         self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)
#         self.Ua = nn.Linear(hidden_size, hidden_size, bias=False)
#         self.Va = nn.Linear(hidden_size, 1)

#     def forward(self, decoder_hidden, encoder_outputs):
        
#         query = decoder_hidden[-1].unsqueeze(1)  # (batch, 1, hidden)
        
#         # Proper additive attention
#         energy = torch.tanh(self.Wa(query) + self.Ua(encoder_outputs))  # (batch, seq_len, hidden)
#         scores = self.Va(energy).squeeze(-1)  # (batch, seq_len)
        
#         weights = F.softmax(scores, dim=1).unsqueeze(1)  # (batch, 1, seq_len)
#         context = torch.bmm(weights, encoder_outputs)  # (batch, 1, hidden)
        
#         return context, weights

# class AttnDecoderRNN(nn.Module):
#     def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, nonlinearity="tanh", layer="lstm"):
#         super().__init__()
#         self.hidden_size = hidden_size
#         self.embed_size = embed_size
#         self.num_layers = num_layers
#         self.layer = layer
        
#         self.embedding = nn.Embedding(vocab_size, embed_size)
#         self.attention = BahdanauAttention(hidden_size)
        
#         # Input size is embed_size + hidden_size (for attention context)
#         if layer == "lstm":
#             self.rnn = nn.LSTM(embed_size + hidden_size, hidden_size, 
#                               num_layers, batch_first=True)
#         elif layer == "gru":
#             self.rnn = nn.GRU(embed_size + hidden_size, hidden_size,
#                              num_layers, batch_first=True)
#         else:  # rnn
#             self.rnn = nn.RNN(embed_size + hidden_size, hidden_size,
#                             num_layers, nonlinearity, batch_first=True)
            
#         self.fc = nn.Linear(hidden_size, vocab_size)
        
#     def forward(self, encoder_outputs, encoder_hidden, encoder_cell, 
#                 target_tensor=None, MAX_LENGTH=None):
#         batch_size = encoder_outputs.size(0)
        
#         # Initialize decoder states
#         if self.layer == "lstm":
#             decoder_hidden = encoder_hidden
#             decoder_cell = encoder_cell
#         else:
#             decoder_hidden = encoder_hidden
#             decoder_cell = None
        
#         # Determine sequence length
#         if target_tensor is not None:
#             MAX_LENGTH = target_tensor.size(1)
#         else:
#             MAX_LENGTH = MAX_LENGTH or 30
            
#         # Create initial input tensor with SOS tokens
#         decoder_input = torch.full((batch_size, 1), 
#                                  devnagri2int[SOS_TOKEN], 
#                                  device=device)
        
#         if target_tensor is not None:
#             # Teacher forcing: Process all timesteps at once
#             # Shift target tensor right by 1 to include SOS token at start
#             decoder_input = torch.cat([
#                 decoder_input,
#                 target_tensor[:, :-1]
#             ], dim=1)
            
#             # Process entire sequence
#             embedded = self.embedding(decoder_input)  # (batch, seq_len, embed_size)
            
#             # Calculate attention for all timesteps
#             context_vectors = []
#             attentions = []
            
#             # Process each timestep (can be parallelized further with einsum)
#             for t in range(MAX_LENGTH):
#                 context, attn = self.attention(decoder_hidden, encoder_outputs)
#                 context_vectors.append(context)
#                 attentions.append(attn)
            
#             # Stack contexts and attentions
#             context_vectors = torch.cat(context_vectors, dim=1)
#             attentions = torch.cat(attentions, dim=1)
            
#             # Combine embeddings with context vectors
#             rnn_input = torch.cat([embedded, context_vectors], dim=2)
            
#             # Process entire sequence through RNN
#             if self.layer == "lstm":
#                 outputs, (decoder_hidden, decoder_cell) = self.rnn(rnn_input, (decoder_hidden, decoder_cell))
#             else:
#                 outputs, decoder_hidden = self.rnn(rnn_input, decoder_hidden)
            
#             # Apply output projection
#             decoder_outputs = self.fc(outputs)
            
#         else:
#             # Inference mode: Generate one token at a time
#             decoder_outputs = []
#             attentions = []
            
#             for t in range(MAX_LENGTH):
#                 # Single timestep processing
#                 decoder_output, decoder_hidden, decoder_cell, attn = self.forward_step(
#                     decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
                
#                 decoder_outputs.append(decoder_output)
#                 attentions.append(attn)
                
#                 # Get next input token
#                 decoder_input = decoder_output.argmax(dim=-1)
                
#             decoder_outputs = torch.cat(decoder_outputs, dim=1)
#             attentions = torch.cat(attentions, dim=1)
            
#         # Apply log softmax
#         decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        
#         return decoder_outputs, decoder_hidden, decoder_cell, attentions
    
#     def forward_step(self, input, hidden, cell, encoder_outputs):
#         embedded = self.embedding(input)  # (batch, 1, embed)
        
#         # Calculate attention
#         context, attn_weights = self.attention(hidden, encoder_outputs)
        
#         # Combine input with context
#         rnn_input = torch.cat((embedded, context), dim=2)
        
#         # Add dropout
#         rnn_input = F.dropout(rnn_input, p=0.3, training=self.training)
        
#         # RNN step
#         if self.layer == "lstm":
#             output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
#         else:
#             output, hidden = self.rnn(rnn_input, hidden)
        
#         output = self.fc(output)
        
#         return output, hidden, cell, attn_weights


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

class BeamSearchNode:
    def __init__(self, hidden_state, cell_state, prev_node, token_id, log_prob, length):
        self.hidden = hidden_state
        self.cell   = cell_state
        self.prev   = prev_node
        self.token  = token_id
        self.logp   = log_prob
        self.length = length

    def get_score(self, length_normalize=True):
        return (self.logp / (self.length + 1e-6)) if length_normalize else self.logp

    def __lt__(self, other):
        return self.get_score() < other.get_score()


class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Ua = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, decoder_hidden, encoder_outputs):
        h = decoder_hidden[0] if isinstance(decoder_hidden, tuple) else decoder_hidden
        query = h[-1].unsqueeze(1)  # (B, 1, H)
        energy = torch.tanh(self.Wa(query) + self.Ua(encoder_outputs))  # (B, T, H)
        scores = self.Va(energy).squeeze(-1)  # (B, T)
        weights = F.softmax(scores, dim=1).unsqueeze(1)  # (B,1,T)
        context = torch.bmm(weights, encoder_outputs)  # (B,1,H)
        return context, weights


class AttnDecoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size,
                 num_layers=1, nonlinearity="tanh", layer="lstm", pad_token_id=0):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.layer = layer
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=pad_token_id)
        self.attention = BahdanauAttention(hidden_size)
        self.pad_token_id = pad_token_id
        self.vocab_size = vocab_size
        rnn_input_dim = embed_size + hidden_size
        cell_cls = {
            'lstm': nn.LSTM,
            'gru': nn.GRU,
            'rnn': lambda *args, **kwargs: nn.RNN(*args, nonlinearity=nonlinearity, **kwargs)
        }[layer]
        self.rnn = cell_cls(rnn_input_dim, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, encoder_outputs, encoder_hidden, encoder_cell=None,
                target_tensor=None, MAX_LENGTH=None, teacher_forcing_prob=0.5,
                beam_width=5):
        B, T_enc, _ = encoder_outputs.size()
        device = encoder_outputs.device
        # init decoder state
        if self.layer == 'lstm':
            dec_hidden = encoder_hidden.contiguous()
            dec_cell = encoder_cell.contiguous()
        else:
            dec_hidden = encoder_hidden.contiguous()
            dec_cell = None

        if target_tensor is not None:
            # TRAINING with partial teacher forcing
            T = target_tensor.size(1)
            input_tok = torch.full((B,), devnagri2int[SOS_TOKEN], dtype=torch.long, device=device)
            outputs = []
            for t in range(T):
                emb = self.embedding(input_tok).unsqueeze(1)  # (B,1,E)
                context, _ = self.attention(
                    dec_hidden, encoder_outputs)
                rnn_input = torch.cat([emb, context], dim=2)

                if self.layer == 'lstm':
                    h, c = dec_hidden.contiguous(), dec_cell.contiguous()
                    out, (dec_hidden, dec_cell) = self.rnn(rnn_input, (h, c))
                else:
                    h = dec_hidden.contiguous()
                    out, dec_hidden = self.rnn(rnn_input, h)
                logits = self.fc(out.squeeze(1))
                logp = F.log_softmax(logits, dim=1)
                outputs.append(logp.unsqueeze(1))
                teacher = random.random() < teacher_forcing_prob
                top1 = logp.argmax(1)
                input_tok = target_tensor[:, t] if teacher else top1
            return torch.cat(outputs, dim=1), dec_hidden, dec_cell, None

        # INFERENCE with batched beam search
        else:
            K = beam_width
            max_len = MAX_LENGTH or 30
            sos = devnagri2int[SOS_TOKEN]; eos = devnagri2int[EOS_TOKEN]
            # expand states: (layers,B,H)->(layers,B*K,H)
            def expand(x): return x.unsqueeze(2).repeat(1,1,K,1)
            if self.layer=='lstm':
                h0, c0 = expand(dec_hidden), expand(dec_cell)
                hidden = h0.view(self.num_layers, B*K, self.hidden_size)
                cell = c0.view(self.num_layers, B*K, self.hidden_size)
            else:
                h0 = expand(dec_hidden)
                hidden, cell = h0.view(self.num_layers,B*K,self.hidden_size), None
            # beam data
            scores = torch.zeros(B, K, device=device);
            scores[:,1:] = -1e9
            seqs = torch.full((B,K,max_len), self.pad_token_id, device=device, dtype=torch.long)
            seqs[:,:,0] = sos
            input_tok = torch.full((B*K,), sos, dtype=torch.long, device=device)
            # time loop
            for t in range(1, max_len):
                emb = self.embedding(input_tok).unsqueeze(1)  # (B*K,1,E)
                # attention per beam
                h_layer = hidden.view(self.num_layers,B,K,self.hidden_size)[-1]
                h_flat = h_layer.view(B*K,self.hidden_size).unsqueeze(0)
                enc_flat = encoder_outputs.unsqueeze(1).repeat(1,K,1,1).view(B*K,T_enc,self.hidden_size)
                context, _ = self.attention(h_flat, enc_flat)
                rnn_in = torch.cat([emb, context.view(B*K,1,self.hidden_size)], dim=2)
                # RNN step
                if self.layer=='lstm':
                    out,(h_new,c_new)=self.rnn(rnn_in,(hidden.contiguous(),cell.contiguous()))
                else:
                    out,h_new=self.rnn(rnn_in,hidden.contiguous()); c_new=None
                # scores
                logp = F.log_softmax(self.fc(out.squeeze(1)),dim=1).view(B,K,self.vocab_size)
                total = scores.unsqueeze(2) + logp
                flat = total.view(B,-1)
                top_scores, top_idx = flat.topk(K,dim=-1)
                beam_idx, token_idx = top_idx//self.vocab_size, top_idx%self.vocab_size
                # reorder hidden/cell
                def gather_beams(x):
                    xb = x.view(self.num_layers,B,K,self.hidden_size)
                    return xb.gather(2,beam_idx.unsqueeze(0).unsqueeze(-1)
                                    .expand(self.num_layers,B,K,self.hidden_size))
                hidden = gather_beams(h_new).view(self.num_layers,B*K,self.hidden_size)
                if cell is not None:
                    cell = gather_beams(c_new).view(self.num_layers,B*K,self.hidden_size)
                scores = top_scores
                seqs = seqs.gather(1,beam_idx.unsqueeze(-1).expand(B,K,max_len))
                seqs[:,:,t] = token_idx
                input_tok = token_idx.view(B*K)
                if (token_idx==eos).all(): break
            best = scores.argmax(dim=-1)
            preds = seqs[torch.arange(B,device=device),best]
            return preds, None, None, None

In [15]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer):
    total_loss = 0
    for data in dataloader:
        _, _, input_tensor, target_tensor, input_lengths, target_lengths = data
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        # Encoder forward
        encoder_outputs, encoder_hidden, encoder_cell = encoder(input_tensor, input_lengths)
        # Decoder forward (use target tensor without last token)
        decoder_outputs, _, _, attention = decoder(
            encoder_outputs, encoder_hidden, encoder_cell,
            target_tensor=target_tensor[:, :-1] if target_tensor is not None else None
        )
        
        # Calculate loss with masking
        loss = masked_cross_entropy(
            decoder_outputs, 
            target_tensor[:, 1:],  # Shift targets
            devnagri2int[PAD_TOKEN]
        )
        
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1.0)
        
        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def masked_cross_entropy(logits, target, pad_idx):
    # logits: (batch_size, seq_len, vocab_size)
    # target: (batch_size, seq_len)
    mask = (target != pad_idx).float()
    logits_flat = logits.view(-1, logits.size(-1))
    target_flat = target.reshape(-1)
    loss = F.nll_loss(logits_flat, target_flat, reduction='none')
    total_non_pad = mask.sum()
    loss = (loss * mask.view(-1)).sum() / (total_non_pad + 1e-6)
    return loss

def train(train_dataloader, val_dataloader,encoder, decoder, n_epochs, teacher_forcing_prob,beam_width,learning_rate=0.001,print_every=1, plot_every=100,iswandb=False):
    encoder.train() 
    decoder.train()
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer)
        print_loss_total += loss
        plot_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            # print(f"Word Validation Accuracy {evaluate_model(encoder,decoder,val_dataloader,int2devnagri,device,False)}")
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))
        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
        
        if iswandb:
            wandb.log({"train_loss": print_loss_avg,
                    #    "train_accuracy":evaluate_model(encoder=encoder,decoder=decoder,dataloader=train_dataloader,int2devnagri=int2devnagri,device=device,show_confusion=False,iswandb=iswandb,teacher_forcing_prob=teacher_forcing_prob,beam_width=beam_width),
                       "val_accuracy":evaluate_model(encoder=encoder,decoder=decoder,dataloader=val_dataloader,int2devnagri=int2devnagri,device=device,show_confusion=True,iswandb=iswandb,teacher_forcing_prob=teacher_forcing_prob,beam_width=beam_width), 
                       "epoch": epoch}
            )
    showPlot(plot_losses)

# def evaluate_model(encoder, decoder, dataloader, int2devnagri, device, show_confusion=True):
#     # Set up plotting with proper font handling
#     font_path = 'C:/Users/aksha/Downloads/Noto_Sans_Devanagari/NotoSansDevanagari-VariableFont_wdth,wght.ttf'  # Adjust path if needed
#     font_manager.fontManager.addfont(font_path)
#     plt.rcParams['font.family'] = 'Noto Sans Devanagari'
    
#     encoder.eval()
#     decoder.eval()

#     correct_words = 0
#     total_words = 0
#     y_true = []
#     y_pred = []

#     with torch.no_grad():
#         for _, _, inputs, targets, input_lengths, _ in dataloader:
#             inputs = inputs.to(device)
#             targets = targets.to(device)

#             encoder_outputs, encoder_hidden, encoder_cell = encoder(inputs, input_lengths)
#             encoder_hidden = tuple(h.to(device) for h in encoder_hidden) if isinstance(encoder_hidden, tuple) \
#                             else encoder_hidden.to(device)
#             encoder_cell = encoder_cell.to(device) if encoder_cell is not None else None
#             decoder_outputs, _, _, _ = decoder(encoder_outputs, encoder_hidden, encoder_cell, beam_width=4)

#             predicted_indices = decoder_outputs.argmax(dim=-1)

#             for pred_seq, true_seq in zip(predicted_indices, targets):
#                 pred_list = [i.item() for i in pred_seq if i.item() != 0]
#                 true_list = [i.item() for i in true_seq if i.item() != 0]

#                 pred_str = ''.join([int2devnagri[i] for i in pred_list])
#                 true_str = ''.join([int2devnagri[i] for i in true_list])
#                 if pred_str == true_str:
#                     correct_words += 1
#                 total_words += 1

#                 min_len = min(len(pred_list), len(true_list))
#                 y_true.extend(true_list[:min_len])
#                 y_pred.extend(pred_list[:min_len])

#     word_accuracy = correct_words / total_words if total_words > 0 else 0.0
    
#     fig = plt.figure(figsize=(18, 8), constrained_layout=True)
    
#     # Word Accuracy Plot (Left)
#     ax1 = fig.add_subplot(1, 2, 1)
#     bars = ax1.bar(['Correct', 'Incorrect'], 
#                  [correct_words, total_words - correct_words],
#                  color=['#4CAF50', '#F44336'])
    
#     # Add value labels on bars
#     for bar in bars:
#         height = bar.get_height()
#         ax1.text(bar.get_x() + bar.get_width()/2., height,
#                 f'{height:,}\n({height/total_words:.1%})',
#                 ha='center', va='bottom')
    
#     ax1.set_title(f'Word Accuracy: {word_accuracy:.2%}\nTotal Words: {total_words:,}', pad=20)
#     ax1.set_ylabel('Count', labelpad=10)
#     ax1.grid(axis='y', linestyle='--', alpha=0.7)
    
#     # Confusion Matrix (Right)
#     if show_confusion and y_true and y_pred:
#         ax2 = fig.add_subplot(1, 2, 2)
        
#         labels = sorted(list(set(y_true + y_pred)))
#         cm = confusion_matrix(y_true, y_pred, labels=labels)
#         cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
#         # Plot every nth character to reduce crowding
#         step = max(1, len(labels)//20)  # Show ~20 labels max
#         display_labels = [int2devnagri[label] if i%step==0 else '' 
#                          for i, label in enumerate(labels)]
        
#         sns.heatmap(cm_normalized, ax=ax2,
#                    cmap='YlOrRd',
#                    cbar_kws={'label': 'Accuracy Percentage', 'shrink': 0.7},
#                    xticklabels=display_labels,
#                    yticklabels=display_labels,
#                    annot=False,
#                    square=True)
        
#         ax2.set_title('Character Prediction Patterns', pad=20)
#         ax2.set_xlabel('Predicted Characters', labelpad=10)
#         ax2.set_ylabel('True Characters', labelpad=10)
        
#         # Rotate labels and adjust spacing
#         plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')
#         plt.setp(ax2.get_yticklabels(), rotation=0)
        
#         # Add divider lines
#         for _, spine in ax2.spines.items():
#             spine.set_visible(True)
#             spine.set_color('gray')
    
#     plt.show()
#     plt.savefig("acc.png")
#     return word_accuracy

def evaluate_model(encoder, decoder, dataloader, int2devnagri, device, teacher_forcing_prob,beam_width,show_confusion=True, iswandb=False):
    font_path = 'C:/Users/aksha/Downloads/Noto_Sans_Devanagari/NotoSansDevanagari-VariableFont_wdth,wght.ttf'
    font_manager.fontManager.addfont(font_path)
    plt.rcParams['font.family'] = 'Noto Sans Devanagari'
    
    encoder.eval()
    decoder.eval()

    correct_words = 0
    total_words = 0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for _, _, inputs, targets, input_lengths, _ in dataloader:
            # _, _, input_tensor, target_tensor, input_lengths, target_lengths = data
            inputs = inputs.to(device)
            targets = targets.to(device)

            encoder_outputs, encoder_hidden, encoder_cell = encoder(inputs, input_lengths)
            # encoder_hidden = tuple(h.to(device) for h in encoder_hidden) if isinstance(encoder_hidden, tuple) \
                            # else encoder_hidden.to(device)
            # encoder_cell = encoder_cell.to(device) if encoder_cell is not None else None
            decoder_outputs, _, _, _ = decoder(encoder_outputs, encoder_hidden, encoder_cell,teacher_forcing_prob=teacher_forcing_prob,beam_width=beam_width)
            # print(decoder_outputs)
            # print(f"decoder_outputs shape: {decoder_outputs.shape}")
            if decoder_outputs.dtype == torch.long:
    # beam search: output is token IDs already
                predicted_indices = decoder_outputs
            else:
                # teacher-forcing: output is (batch, seq_len, vocab_size) logits
                predicted_indices = decoder_outputs.argmax(dim=-1)


            for i, (pred_seq, true_seq) in enumerate(zip(predicted_indices, targets)):
                # print(f"Batch {i}: pred_seq shape: {pred_seq.shape if hasattr(pred_seq, 'shape') else 'scalar'}, value: {pred_seq}")
                if pred_seq.dim() == 0:
                    pred_list = [pred_seq.item()] if pred_seq.item() != devnagri2int[PAD_TOKEN] else []
                else:
                    pred_list = [i.item() for i in pred_seq if i.item() != devnagri2int[PAD_TOKEN]]
                true_list = [i.item() for i in true_seq if i.item() != devnagri2int[PAD_TOKEN]]

                pred_str,true_str = [],[]
                for char in pred_list:
                    if char is devnagri2int[EOS_TOKEN]:
                        break
                    if char is not devnagri2int[SOS_TOKEN]:
                        pred_str.append(int2devnagri[char])
                for char in true_list:
                    if char is devnagri2int[EOS_TOKEN]:
                        break
                    if char is not devnagri2int[SOS_TOKEN]:
                        true_str.append(int2devnagri[char])
                pred_str = ''.join(pred_str)
                true_str = ''.join(true_str)
                # print(pred_str)
                # print(true_str)
                if pred_str == true_str:
                    correct_words += 1
                total_words += 1

                min_len = min(len(pred_list), len(true_list))
                y_true.extend(true_list[:min_len])
                y_pred.extend(pred_list[:min_len])

    word_accuracy = correct_words / total_words if total_words > 0 else 0.0
    
    fig = plt.figure(figsize=(18, 8), constrained_layout=True)
    
    ax1 = fig.add_subplot(1, 2, 1)
    bars = ax1.bar(['Correct', 'Incorrect'], 
                  [correct_words, total_words - correct_words],
                  color=['#4CAF50', '#F44336'])
    
    for bar in bars:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:,}\n({height/total_words:.1%})',
                ha='center', va='bottom')
    
    ax1.set_title(f'Word Accuracy: {word_accuracy:.2%}\nTotal Words: {total_words:,}', pad=20)
    ax1.set_ylabel('Count', labelpad=10)
    ax1.grid(axis='y', linestyle='--', alpha=0.7)
    
    if show_confusion and y_true and y_pred:
        ax2 = fig.add_subplot(1, 2, 2)
        labels = sorted(list(set(y_true + y_pred)))
        cm = confusion_matrix(y_true, y_pred, labels=labels)
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
        step = max(1, len(labels)//20)
        display_labels = [int2devnagri[label] if i%step==0 else '' 
                         for i, label in enumerate(labels)]
        
        sns.heatmap(cm_normalized, ax=ax2,
                   cmap='YlOrRd',
                   cbar_kws={'label': 'Accuracy Percentage', 'shrink': 0.7},
                   xticklabels=display_labels,
                   yticklabels=display_labels,
                   annot=False,
                   square=True)
        
        ax2.set_title('Character Prediction Patterns', pad=20)
        ax2.set_xlabel('Predicted Characters', labelpad=10)
        ax2.set_ylabel('True Characters', labelpad=10)
        
        plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')
        plt.setp(ax2.get_yticklabels(), rotation=0)
        
        for _, spine in ax2.spines.items():
            spine.set_visible(True)
            spine.set_color('gray')
        if iswandb:
            wandb.log({"confusion_matrix": wandb.Image(fig)})
    plt.tight_layout()
    
    plt.show()
    plt.savefig("acc.png")

    encoder.train()
    decoder.train()
    return word_accuracy

In [16]:
train_dataset = LangDataset("train")
val_dataset = LangDataset("val")
test_dataset = LangDataset("test")

In [17]:
def collate_fn(batch):
    # Sort by input sequence length (descending)
    batch.sort(key=lambda x: len(x[2]), reverse=True)
    
    latin_words, devnagri_words, latin_tensors, devnagri_tensors = zip(*batch)
    
    # Get sequence lengths
    input_lengths = [len(seq) for seq in latin_tensors]
    target_lengths = [len(seq) for seq in devnagri_tensors]
    
    # Pad sequences
    latin_tensors = nn.utils.rnn.pad_sequence(latin_tensors, batch_first=True, padding_value=latin2int[PAD_TOKEN])
    devnagri_tensors = nn.utils.rnn.pad_sequence(devnagri_tensors, batch_first=True, padding_value=devnagri2int[PAD_TOKEN])
    
    return (latin_words, devnagri_words, 
            latin_tensors.to(device), devnagri_tensors.to(device),
            input_lengths, target_lengths)

# Update DataLoader initialization to use the collate function
train_dataloader = DataLoader(train_dataset, 
                            batch_size=64, 
                            shuffle=True, 
                            collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, 
                          batch_size=64, 
                          shuffle=True, 
                          collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, 
                           batch_size=64, 
                           shuffle=True, 
                           collate_fn=collate_fn)

In [18]:
# import gc
# torch.cuda.empty_cache()
# torch.cuda.synchronize()
# gc.collect()
# # Initialize with proper parameters
# # Model Architecture
# encoder = EncoderRNN(
#     vocab_size=len(latin2int),
#     embed_size=256,
#     hidden_size=512,
#     num_layers=3,
#     layer="lstm",
#     # dropout_p=0.3
# ).to(device)

# decoder = AttnDecoderRNN(
#     vocab_size=len(devnagri2int),
#     embed_size=256,
#     hidden_size=512,
#     num_layers=3,
#     layer="lstm"
# ).to(device)


# # Training Schedule
# train(train_dataloader, val_dataloader, encoder, decoder,
#       n_epochs=3,  # Increased epochs
#       learning_rate=0.001,
#       print_every=1,
#       plot_every=10)

In [19]:
# train(train_dataloader, val_dataloader,encoder, decoder, n_epochs=2, learning_rate=0.001)

In [20]:
# evaluate_model(encoder,decoder,val_dataloader,int2devnagri,device,True)

In [21]:
test_data_1point = LangDataset(type="test_ponit")
test_data_point_loader = DataLoader(test_data_1point, 
                           batch_size=64, 
                           shuffle=True, 
                           collate_fn=collate_fn)

In [22]:
with torch.no_grad():
    y_true, y_pred, total_words, correct_words = [], [], 0, 0
    for lw, dw, inputs, targets, input_lengths, _ in test_data_point_loader:
        print("Input:", lw)
        print("Target:", dw)
        
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # Encoder forward
        encoder_outputs, encoder_hidden, encoder_cell = encoder(inputs, input_lengths)
        # encoder_hidden = tuple(h.to(device) for h in encoder_hidden) if isinstance(encoder_hidden, tuple) \
                            # else encoder_hidden.to(device)
        # Decoder forward with beam search
        # print(type(encoder_outputs),type(encoder_hidden), type(encoder_cell) )
        beam_outputs, _, _, _ = decoder(
            encoder_outputs, 
            encoder_hidden, 
            encoder_cell, 
        )
        # Process beam search results
        for batch_idx in range(beam_outputs.size(0)):
            # Get top sequence for this batch item
            top_sequence = beam_outputs[batch_idx]
            
            # If top_sequence contains logits, get predicted indices
            if top_sequence.dim() > 1:
                pred_indices = top_sequence.argmax(dim=-1)
            else:
                pred_indices = top_sequence

            # Convert indices to characters
            pred_chars = []
            for token_id in pred_indices:
                token_id = token_id.item() if isinstance(token_id, torch.Tensor) else token_id
                if token_id == devnagri2int[EOS_TOKEN]:
                    break
                if token_id != devnagri2int[PAD_TOKEN] and token_id != devnagri2int[SOS_TOKEN]:
                    pred_chars.append(int2devnagri[token_id])
            
            # Get true translation
            true_chars = []
            for token_id in targets[batch_idx]:
                if token_id == devnagri2int[EOS_TOKEN]:
                    break
                if token_id != devnagri2int[PAD_TOKEN] and token_id != devnagri2int[SOS_TOKEN]:
                    true_chars.append(int2devnagri[token_id.item()])
            
            # Convert to strings
            pred_str = ''.join(pred_chars)
            true_str = ''.join(true_chars)
            
            print(f"Predicted: {pred_str}")
            print(f"Expected: {true_str}\n")
            
            # Update metrics
            if pred_str == true_str:
                correct_words += 1
            total_words += 1
            
            # For confusion matrix
            y_true.extend(true_chars)
            y_pred.extend(pred_chars[:len(true_chars)])

word_accuracy = correct_words / total_words if total_words > 0 else 0.0
print(f"Word Accuracy: {word_accuracy:.2%}")

Input: ('nirnayaprakriyet', 'anuj')
Target: ('निर्णयप्रक्रियेत', 'अनुज')


NameError: name 'encoder' is not defined

In [23]:
def sweep_config(best_config=False):
    """Define the configuration for hyperparameter sweep"""
    base_params = {
        "embed_size": {"values": [128, 256, 512]},
        "num_layers": {"values": [2, 3, 4]},
        "layer": {"values": ["lstm", "gru"]},
        "hidden_size": {"values": [128, 256, 512]},
        "batch_size": {"values": [32, 64]},
        "learning_rate": {"values": [1e-4, 1e-3, 5*1e-3]},
        "dropout_p": {"values": [0.1,0.3, 0.4]},
        "activation": {"values": ["tanh"]},
        "teacher_forcing_prob": {"values": [0.8, 0.9,0.99]},
        "beam_width": {"values": [1, 2, 4]},
        "num_epochs": {"values": [6]}
    }

    if not best_config:
        return {
            "method": "bayes",
            "metric": {"name": "val_accuracy", "goal": "maximize"},
            "parameters": base_params
        }
    else:
        # Fix to best-known values
        fixed = {k: {"values": [v["values"][-1]]} for k, v in base_params.items()}
        return {
            "method": "bayes",
            "metric": {"name": "val_accuracy", "goal": "maximize"},
            "parameters": fixed
        }


def wandb_train():
    """Main training function for a wandb run"""
    # Initialize wandb
    run = wandb.init()
    config = run.config
    run.name = f"Layer-{config.layer}-Batch-{config.batch_size}-LR-{config.learning_rate}-Dropout-{config.dropout_p}-Layers-{config.num_layers}-LayerType-{config.layer}-BeamWidth-{config.beam_width}"
    run.save()
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare data loaders using config.batch_size
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                              shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size,
                            shuffle=False, collate_fn=collate_fn)

    # Instantiate models
    encoder = EncoderRNN(
        vocab_size=len(latin2int),
        embed_size=config.embed_size,
        hidden_size=config.hidden_size,
        num_layers=config.num_layers,
        layer=config.layer,
        dropout_p=config.dropout_p
    ).to(device)

    decoder = DecoderRNN(
        vocab_size=len(devnagri2int),
        embed_size=config.embed_size,
        hidden_size=config.hidden_size,
        num_layers=config.num_layers,
        layer=config.layer
    ).to(device)

    # Call your training loop
    train(train_loader, val_loader,encoder, decoder, n_epochs=config.num_epochs, learning_rate=config.learning_rate,teacher_forcing_prob=config.teacher_forcing_prob,beam_width=config.beam_width,print_every=1, plot_every=10,iswandb=True)
    # train(train_dataloader, val_dataloader,encoder, decoder, n_epochs=2, learning_rate=0.001)
    # Finish wandb run
    run.finish()


def run_sweep(sweep_id=None, best_config=False):
    """Create or run a wandb sweep."""
    if sweep_id is None:
        sweep_id = wandb.sweep(sweep_config(best_config), project="transliteration-sweep")
    wandb.agent(sweep_id, function=wandb_train, count=1 if best_config else 15)

In [None]:
run_sweep(sweep_id=None, best_config=False)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Create sweep with ID: kob5w4rc
Sweep URL: https://wandb.ai/me21b172-indian-institute-of-technology-madras/transliteration-sweep/sweeps/kob5w4rc


[34m[1mwandb[0m: Agent Starting Run: 5tk4d5mi with config:
[34m[1mwandb[0m: 	activation: tanh
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	beam_width: 1
[34m[1mwandb[0m: 	dropout_p: 0.3
[34m[1mwandb[0m: 	embed_size: 512
[34m[1mwandb[0m: 	hidden_size: 512
[34m[1mwandb[0m: 	layer: gru
[34m[1mwandb[0m: 	learning_rate: 0.005
[34m[1mwandb[0m: 	num_epochs: 6
[34m[1mwandb[0m: 	num_layers: 2
[34m[1mwandb[0m: 	teacher_forcing_prob: 0.99
[34m[1mwandb[0m: Currently logged in as: [33mme21b172[0m ([33mme21b172-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




0m 45s (- 3m 46s) (1 16%) 1.7440


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


1m 34s (- 3m 8s) (2 33%) 1.5345


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


2m 24s (- 2m 24s) (3 50%) 1.4999


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


3m 10s (- 1m 35s) (4 66%) 1.4758


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


3m 55s (- 0m 47s) (5 83%) 1.4784


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


4m 44s (- 0m 0s) (6 100%) 1.4601


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


0,1
epoch,▁▂▄▅▇█
train_loss,█▃▂▁▁▁
val_accuracy,▇▂▇▆▁█

0,1
epoch,6.0
train_loss,1.4601
val_accuracy,0.01744


[34m[1mwandb[0m: Agent Starting Run: pjllujdk with config:
[34m[1mwandb[0m: 	activation: tanh
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	beam_width: 1
[34m[1mwandb[0m: 	dropout_p: 0.4
[34m[1mwandb[0m: 	embed_size: 256
[34m[1mwandb[0m: 	hidden_size: 128
[34m[1mwandb[0m: 	layer: lstm
[34m[1mwandb[0m: 	learning_rate: 0.0001
[34m[1mwandb[0m: 	num_epochs: 6
[34m[1mwandb[0m: 	num_layers: 2
[34m[1mwandb[0m: 	teacher_forcing_prob: 0.9


0m 35s (- 2m 56s) (1 16%) 3.0918


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


1m 34s (- 3m 8s) (2 33%) 2.5131


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


2m 22s (- 2m 22s) (3 50%) 2.2390


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


3m 11s (- 1m 35s) (4 66%) 2.0464


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


3m 59s (- 0m 47s) (5 83%) 1.9091


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


4m 48s (- 0m 0s) (6 100%) 1.8092


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


0,1
epoch,▁▂▄▅▇█
train_loss,█▅▃▂▂▁
val_accuracy,▁▁▁▃▆█

0,1
epoch,6.0
train_loss,1.80918
val_accuracy,0.00665


[34m[1mwandb[0m: Agent Starting Run: 9wdtqj2x with config:
[34m[1mwandb[0m: 	activation: tanh
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	beam_width: 2
[34m[1mwandb[0m: 	dropout_p: 0.1
[34m[1mwandb[0m: 	embed_size: 256
[34m[1mwandb[0m: 	hidden_size: 256
[34m[1mwandb[0m: 	layer: lstm
[34m[1mwandb[0m: 	learning_rate: 0.0001
[34m[1mwandb[0m: 	num_epochs: 6
[34m[1mwandb[0m: 	num_layers: 4
[34m[1mwandb[0m: 	teacher_forcing_prob: 0.8


1m 48s (- 9m 3s) (1 16%) 2.7693


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


3m 34s (- 7m 8s) (2 33%) 1.8833


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


5m 19s (- 5m 19s) (3 50%) 1.5412


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


7m 12s (- 3m 36s) (4 66%) 1.3553


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


9m 9s (- 1m 49s) (5 83%) 1.2296


  fig = plt.figure(figsize=(18, 8), constrained_layout=True)
  plt.tight_layout()
  plt.tight_layout()
  plt.show()


11m 4s (- 0m 0s) (6 100%) 1.1384


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


0,1
epoch,▁▂▄▅▇█
train_loss,█▄▃▂▁▁
val_accuracy,▁▂▄▆▇█

0,1
epoch,6.0
train_loss,1.13836
val_accuracy,0.06379


[34m[1mwandb[0m: Agent Starting Run: 2bxqtwln with config:
[34m[1mwandb[0m: 	activation: tanh
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	beam_width: 1
[34m[1mwandb[0m: 	dropout_p: 0.1
[34m[1mwandb[0m: 	embed_size: 128
[34m[1mwandb[0m: 	hidden_size: 512
[34m[1mwandb[0m: 	layer: gru
[34m[1mwandb[0m: 	learning_rate: 0.005
[34m[1mwandb[0m: 	num_epochs: 6
[34m[1mwandb[0m: 	num_layers: 3
[34m[1mwandb[0m: 	teacher_forcing_prob: 0.9


0m 45s (- 3m 48s) (1 16%) 1.9766


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


1m 44s (- 3m 28s) (2 33%) 1.5826


  cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  plt.tight_layout()
  plt.tight_layout()
  plt.show()


2m 42s (- 2m 42s) (3 50%) 1.5519


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


3m 39s (- 1m 49s) (4 66%) 1.5416


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


4m 41s (- 0m 56s) (5 83%) 1.5393


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


7m 14s (- 0m 0s) (6 100%) 1.5393


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


0,1
epoch,▁▂▄▅▇█
train_loss,█▂▁▁▁▁
val_accuracy,▁▆▆▇█▁

0,1
epoch,6.0
train_loss,1.53925
val_accuracy,0.0078


[34m[1mwandb[0m: Agent Starting Run: v9rwdz8k with config:
[34m[1mwandb[0m: 	activation: tanh
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	beam_width: 2
[34m[1mwandb[0m: 	dropout_p: 0.1
[34m[1mwandb[0m: 	embed_size: 512
[34m[1mwandb[0m: 	hidden_size: 512
[34m[1mwandb[0m: 	layer: gru
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	num_epochs: 6
[34m[1mwandb[0m: 	num_layers: 4
[34m[1mwandb[0m: 	teacher_forcing_prob: 0.9


1m 39s (- 8m 17s) (1 16%) 1.5053


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


3m 48s (- 7m 37s) (2 33%) 1.0092


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


6m 20s (- 6m 20s) (3 50%) 0.8723


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


9m 5s (- 4m 32s) (4 66%) 0.7864


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


11m 26s (- 2m 17s) (5 83%) 0.7225


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


13m 59s (- 0m 0s) (6 100%) 0.6735


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


0,1
epoch,▁▂▄▅▇█
train_loss,█▄▃▂▁▁
val_accuracy,▁▅▄▇██

0,1
epoch,6.0
train_loss,0.67346
val_accuracy,0.10142


[34m[1mwandb[0m: Agent Starting Run: 98e0g30a with config:
[34m[1mwandb[0m: 	activation: tanh
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	beam_width: 1
[34m[1mwandb[0m: 	dropout_p: 0.3
[34m[1mwandb[0m: 	embed_size: 256
[34m[1mwandb[0m: 	hidden_size: 256
[34m[1mwandb[0m: 	layer: gru
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	num_epochs: 6
[34m[1mwandb[0m: 	num_layers: 4
[34m[1mwandb[0m: 	teacher_forcing_prob: 0.8


1m 20s (- 6m 41s) (1 16%) 1.7693


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


2m 47s (- 5m 34s) (2 33%) 1.1071


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


4m 10s (- 4m 10s) (3 50%) 0.9407


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


5m 24s (- 2m 42s) (4 66%) 0.8377


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


6m 42s (- 1m 20s) (5 83%) 0.7553


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


8m 0s (- 0m 0s) (6 100%) 0.6915


  plt.tight_layout()
  plt.tight_layout()
  plt.show()


0,1
epoch,▁▂▄▅▇█
train_loss,█▄▃▂▁▁
val_accuracy,▁▃▄▇▅█

0,1
epoch,6.0
train_loss,0.69147
val_accuracy,0.10968


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: dn2ue5eh with config:
[34m[1mwandb[0m: 	activation: tanh
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	beam_width: 4
[34m[1mwandb[0m: 	dropout_p: 0.3
[34m[1mwandb[0m: 	embed_size: 512
[34m[1mwandb[0m: 	hidden_size: 512
[34m[1mwandb[0m: 	layer: gru
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	num_epochs: 6
[34m[1mwandb[0m: 	num_layers: 4
[34m[1mwandb[0m: 	teacher_forcing_prob: 0.9


1m 29s (- 7m 27s) (1 16%) 1.5030
