In [1]:
import os
import torch 
from io import open
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F
from torch import optim,nn
import time
import math
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns
from matplotlib import font_manager

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# Special tokens
PAD_TOKEN = "<PAD>"
EOS_TOKEN = "<EOS>"
SOS_TOKEN = "<SOS>"
UNK_TOKEN = "<UNK>"

# Update the function to create mappings to include the special tokens
def create_mappings(vocab):
    vocab = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN] + sorted(vocab)
    word2int = {word: i for i, word in enumerate(vocab)}
    int2word = {i: word for word, i in word2int.items()}
    return word2int, int2word

def wordEncoder(words,encodelist):
    n_letters = len(encodelist)
    tensor = torch.zeros(len(words), n_letters)
    for i,word in enumerate(words):
        tensor[i][encodelist[word]] = 1
    return tensor
    
def tokenise(word, wordMap):
    return torch.tensor([wordMap[SOS_TOKEN]] + [wordMap[letter] for letter in word] + [wordMap[EOS_TOKEN]], dtype=torch.long)

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [4]:
types = {"train":'mr.translit.sampled.train.tsv',"val":'mr.translit.sampled.dev.tsv',"test":"mr.translit.sampled.test.tsv"}
with open(os.path.join("lexicons/",types["train"]), "r", encoding="utf-8") as f:
    lines = f.readlines()
train_data = np.array([[text.split("\t")[0],text.split("\t")[1][:-1]] for text in lines if not text.split("\t")[0] == '</s>'])
with open(os.path.join("lexicons/",types["val"]), "r", encoding="utf-8") as f:
    lines = f.readlines()
val_data = np.array([[text.split("\t")[0],text.split("\t")[1][:-1]] for text in lines if not text.split("\t")[0] == '</s>'])
with open(os.path.join("lexicons/",types["test"]), "r", encoding="utf-8") as f:
    lines = f.readlines()
test_data = np.array([[text.split("\t")[0],text.split("\t")[1][:-1]] for text in lines if not text.split("\t")[0] == '</s>'])
test_data_point = np.array([["अनुज","anuj"],["निर्णयप्रक्रियेत","nirnayaprakriyet"]])
merged_data = np.concatenate((train_data,val_data,test_data))
len(merged_data)  

67643

In [5]:
test_data_point

array([['अनुज', 'anuj'],
       ['निर्णयप्रक्रियेत', 'nirnayaprakriyet']], dtype='<U16')

In [6]:
devnagri2int,latinList2int = {letter: idx for idx, letter in enumerate(set("".join(merged_data[:, 0])))},{letter: idx for idx, letter in enumerate(set("".join(merged_data[:, 1])))}
int2devnagri,int2latinList = {idx: letter for letter, idx in devnagri2int.items()},{idx: letter for letter, idx in latinList2int.items()}

In [7]:
data = np.array([[text.split("\t")[0],text.split("\t")[1][:-1]] for text in lines if not text.split("\t")[0] == '</s>'])

In [8]:
# Update the vocabularies
devnagri2int, int2devnagri = create_mappings(set("".join(merged_data[:, 0])))
latin2int, int2latin = create_mappings(set("".join(merged_data[:, 1])))

In [9]:
class LangDataset(Dataset):
    def __init__(self,type:str):
        types = {"train":train_data,"val":val_data,"test":test_data, "test_ponit":test_data_point}
        data = types[type]
        self.X,self.Y,self.X_encoded,self.Y_encoded = [],[],[],[]
        for word in data:
            self.X.append(word[1])
            self.Y.append(word[0])
            self.X_encoded.append(tokenise(word[1],latin2int))
            self.Y_encoded.append(tokenise(word[0],devnagri2int))
        
    def __getitem__(self, idx):
        latin_word= self.X[idx]
        devnagri_word = self.Y[idx]
        latin_tensor = self.X_encoded[idx]
        devnagri_tensor = self.Y_encoded[idx]

        return latin_word, devnagri_word, latin_tensor, devnagri_tensor

    def __len__(self):
        return len(self.X)

In [None]:
class EncoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, nonlinearity="tanh", dropout_p=0.1, layer="rnn"):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.layer = layer
        self.embedding = nn.Embedding(vocab_size, embed_size)
        
        if layer == "rnn":
            self.cell = nn.RNN(embed_size, hidden_size, num_layers, nonlinearity, batch_first=True) 
        elif layer == "gru":   
            self.cell = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True)
        elif layer == "lstm":
            self.cell = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input, input_lengths, hidden=None):
        embedded = self.dropout(self.embedding(input))
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, input_lengths, batch_first=True, enforce_sorted=True
        )
        
        if self.layer == "lstm":
            output, (hidden, cell) = self.cell(packed, hidden)
        else:
            output, hidden = self.cell(packed, hidden)
            cell = None
            
        # Unpack sequence
        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
        
        return output, hidden, cell

In [12]:
class BeamSearchNode:
    def __init__(self, hidden_state, cell_state, prev_node, token_id, log_prob, length):
        self.hidden = hidden_state
        self.cell = cell_state
        self.prev = prev_node
        self.token = token_id
        self.logp = log_prob
        self.length = length

    def get_score(self, length_normalize=True):
        if length_normalize:
            return self.logp / float(self.length + 1e-6)
        return self.logp

class DecoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, nonlinearity="tanh", layer="rnn"):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.layer = layer
        
        if layer == "rnn":
            self.cell = nn.RNN(embed_size, hidden_size, num_layers, nonlinearity, batch_first=True) 
        elif layer == "gru":   
            self.cell = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True)
        elif layer == "lstm":
            self.cell = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)


    def beam_search(self, encoder_outputs, encoder_hidden, encoder_cell, 
                    beam_width=5, max_length=30):
        # Get device from model parameters
        device = next(self.parameters()).device
        
        batch_size = encoder_outputs.size(0)
        decoded_batch = []

        # Process each example in the batch separately
        for idx in range(batch_size):
            # Ensure tensors are on correct device
            if isinstance(encoder_hidden, tuple):
                enc_hid = (
                    encoder_hidden[0][:, idx:idx+1].contiguous().to(device),
                    encoder_hidden[1][:, idx:idx+1].contiguous().to(device)
                )
            else:
                enc_hid = encoder_hidden[:, idx:idx+1].contiguous().to(device)
            
            enc_cell = encoder_cell[:, idx:idx+1].contiguous().to(device) if encoder_cell is not None else None
            enc_out = encoder_outputs[idx].unsqueeze(0).to(device)

            # Initialize beam
            start_token = devnagri2int[SOS_TOKEN]
            initial_node = BeamSearchNode(enc_hid, enc_cell, None, start_token, 0, 1)
            beams = [initial_node]
            finished = []

            for _ in range(max_length):
                candidates = []
                for node in beams:
                    if node.token == devnagri2int[EOS_TOKEN] and node.prev is not None:
                        finished.append(node)
                        continue

                    # Ensure we're using the correct device
                    if isinstance(node.hidden, tuple):
                        hidden = (node.hidden[0].contiguous().to(device), 
                                 node.hidden[1].contiguous().to(device))
                    else:
                        hidden = node.hidden.contiguous().to(device)
                        
                    cell = node.cell.contiguous().to(device) if node.cell is not None else None

                    # Forward step with device-aware tensors
                    with torch.no_grad():
                        input_tensor = torch.tensor([[node.token]], device=device)
                        embedded = self.embedding(input_tensor)
                        
                        if self.layer == "lstm":
                            output, (new_hidden, new_cell) = self.cell(embedded, (hidden, cell))
                            new_hidden = (new_hidden.contiguous(), new_cell.contiguous())
                        else:
                            output, new_hidden = self.cell(embedded, hidden)
                            new_hidden = new_hidden.contiguous()
                        
                        output = self.fc(output)
                        log_prob = F.log_softmax(output, dim=-1)
                        
                    # Get top candidates
                    log_prob = log_prob.squeeze()
                    top_log_probs, top_indices = log_prob.topk(beam_width * 2)

                    for i in range(top_log_probs.size(0)):
                        token = top_indices[i].item()
                        logp = node.logp + top_log_probs[i].item()

                        new_node = BeamSearchNode(
                            hidden, cell, node, token, logp, node.length + 1
                        )
                        candidates.append(new_node)

                # Sort and prune candidates
                candidates.sort(key=lambda x: x.get_score(), reverse=True)
                beams = candidates[:beam_width]

                if all(node.token == devnagri2int[EOS_TOKEN] for node in beams):
                    finished += beams
                    break

            # Collect finished beams
            finished += beams
            finished.sort(key=lambda x: x.get_score(), reverse=True)

            # Backtrack to get sequence
            best_node = finished[0]
            seq = []
            while best_node.prev is not None:
                seq.append(best_node.token)
                best_node = best_node.prev
            seq = seq[::-1]  # Reverse to get correct order

            # Convert to tensor and add to batch results
            decoded_batch.append(torch.tensor(seq, device=device))

        # Pad sequences and convert to tensor
        decoded_batch = nn.utils.rnn.pad_sequence(decoded_batch, batch_first=True, 
                                                 padding_value=devnagri2int[PAD_TOKEN])
        return decoded_batch

    def forward(self, encoder_outputs, encoder_hidden, encoder_cell, 
                target_tensor=None, MAX_LENGTH=None, beam_width=1):
        batch_size = encoder_outputs.size(0)
        
        # Initialize decoder states
        if self.layer == "lstm":
            decoder_hidden = encoder_hidden
            decoder_cell = encoder_cell
        else:
            decoder_hidden = encoder_hidden
            decoder_cell = None

        if target_tensor is not None or beam_width == 1:
            # Teacher forcing: Process all timesteps at once
            MAX_LENGTH = target_tensor.size(1)
            
            # Create shifted sequences for teacher forcing
            decoder_input = torch.cat([
                torch.full((batch_size, 1), devnagri2int[SOS_TOKEN], device=device),
                target_tensor[:, :-1]
            ], dim=1)
            
            # Process entire sequence
            embedded = self.embedding(decoder_input)
            
            # Run through RNN
            if self.layer == "lstm":
                decoder_outputs, (decoder_hidden, decoder_cell) = self.cell(embedded, (decoder_hidden, decoder_cell))
            else:
                decoder_outputs, decoder_hidden = self.cell(embedded, decoder_hidden)
            
            # Apply output projection
            decoder_outputs = self.fc(decoder_outputs)
            decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
            return decoder_outputs, decoder_hidden, decoder_cell, None     
        else:
            return self.beam_search(encoder_outputs, encoder_hidden, encoder_cell, 
                                  beam_width, MAX_LENGTH or 30), None, None, None

    def forward_step(self, x, hidden):
        out = self.embedding(x)
        
        if self.layer == "lstm":
            if hidden is None:
                h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
                c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
                hidden = (h0, c0)
            output, (hidden, cell) = self.cell(out, hidden)
            output = self.fc(output)
            return output, hidden, cell
        else:
            if hidden is None:
                hidden = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
                output, hidden = self.cell(out, hidden)
            else:
                output, hidden = self.cell(out, hidden[0])
            output = self.fc(output)
            return output, hidden, None

In [None]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Ua = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, decoder_hidden, encoder_outputs):
        
        query = decoder_hidden[-1].unsqueeze(1)  # (batch, 1, hidden)
        
        # Proper additive attention
        energy = torch.tanh(self.Wa(query) + self.Ua(encoder_outputs))  # (batch, seq_len, hidden)
        scores = self.Va(energy).squeeze(-1)  # (batch, seq_len)
        
        weights = F.softmax(scores, dim=1).unsqueeze(1)  # (batch, 1, seq_len)
        context = torch.bmm(weights, encoder_outputs)  # (batch, 1, hidden)
        
        return context, weights

class AttnDecoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, nonlinearity="tanh", layer="lstm"):
        super().__init__()
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.num_layers = num_layers
        self.layer = layer
        
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.attention = BahdanauAttention(hidden_size)
        
        # Input size is embed_size + hidden_size (for attention context)
        if layer == "lstm":
            self.rnn = nn.LSTM(embed_size + hidden_size, hidden_size, 
                              num_layers, batch_first=True)
        elif layer == "gru":
            self.rnn = nn.GRU(embed_size + hidden_size, hidden_size,
                             num_layers, batch_first=True)
        else:  # rnn
            self.rnn = nn.RNN(embed_size + hidden_size, hidden_size,
                            num_layers, nonlinearity, batch_first=True)
            
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, encoder_outputs, encoder_hidden, encoder_cell, 
                target_tensor=None, MAX_LENGTH=None):
        batch_size = encoder_outputs.size(0)
        
        # Initialize decoder states
        if self.layer == "lstm":
            decoder_hidden = encoder_hidden
            decoder_cell = encoder_cell
        else:
            decoder_hidden = encoder_hidden
            decoder_cell = None
        
        # Determine sequence length
        if target_tensor is not None:
            MAX_LENGTH = target_tensor.size(1)
        else:
            MAX_LENGTH = MAX_LENGTH or 30
            
        # Create initial input tensor with SOS tokens
        decoder_input = torch.full((batch_size, 1), 
                                 devnagri2int[SOS_TOKEN], 
                                 device=device)
        
        if target_tensor is not None:
            # Teacher forcing: Process all timesteps at once
            # Shift target tensor right by 1 to include SOS token at start
            decoder_input = torch.cat([
                decoder_input,
                target_tensor[:, :-1]
            ], dim=1)
            
            # Process entire sequence
            embedded = self.embedding(decoder_input)  # (batch, seq_len, embed_size)
            
            # Calculate attention for all timesteps
            context_vectors = []
            attentions = []
            
            # Process each timestep (can be parallelized further with einsum)
            for t in range(MAX_LENGTH):
                context, attn = self.attention(decoder_hidden, encoder_outputs)
                context_vectors.append(context)
                attentions.append(attn)
            
            # Stack contexts and attentions
            context_vectors = torch.cat(context_vectors, dim=1)
            attentions = torch.cat(attentions, dim=1)
            
            # Combine embeddings with context vectors
            rnn_input = torch.cat([embedded, context_vectors], dim=2)
            
            # Process entire sequence through RNN
            if self.layer == "lstm":
                outputs, (decoder_hidden, decoder_cell) = self.rnn(rnn_input, (decoder_hidden, decoder_cell))
            else:
                outputs, decoder_hidden = self.rnn(rnn_input, decoder_hidden)
            
            # Apply output projection
            decoder_outputs = self.fc(outputs)
            
        else:
            # Inference mode: Generate one token at a time
            decoder_outputs = []
            attentions = []
            
            for t in range(MAX_LENGTH):
                # Single timestep processing
                decoder_output, decoder_hidden, decoder_cell, attn = self.forward_step(
                    decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
                
                decoder_outputs.append(decoder_output)
                attentions.append(attn)
                
                # Get next input token
                decoder_input = decoder_output.argmax(dim=-1)
                
            decoder_outputs = torch.cat(decoder_outputs, dim=1)
            attentions = torch.cat(attentions, dim=1)
            
        # Apply log softmax
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        
        return decoder_outputs, decoder_hidden, decoder_cell, attentions
    
    def forward_step(self, input, hidden, cell, encoder_outputs):
        embedded = self.embedding(input)  # (batch, 1, embed)
        
        # Calculate attention
        context, attn_weights = self.attention(hidden, encoder_outputs)
        
        # Combine input with context
        rnn_input = torch.cat((embedded, context), dim=2)
        
        # Add dropout
        rnn_input = F.dropout(rnn_input, p=0.3, training=self.training)
        
        # RNN step
        if self.layer == "lstm":
            output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        else:
            output, hidden = self.rnn(rnn_input, hidden)
        
        output = self.fc(output)
        
        return output, hidden, cell, attn_weights


In [35]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer, 
                decoder_optimizer):
    total_loss = 0
    for data in dataloader:
        _, _, input_tensor, target_tensor, input_lengths, target_lengths = data
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        # Encoder forward
        encoder_outputs, encoder_hidden, encoder_cell = encoder(input_tensor, input_lengths)
        
        # Decoder forward (use target tensor without last token)
        decoder_outputs, _, _, attention = decoder(
            encoder_outputs, encoder_hidden, encoder_cell,
            target_tensor=target_tensor[:, :-1] if target_tensor is not None else None
        )
        
        # Calculate loss with masking
        loss = masked_cross_entropy(
            decoder_outputs, 
            target_tensor[:, 1:],  # Shift targets
            devnagri2int[PAD_TOKEN]
        )
        
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1.0)
        
        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def masked_cross_entropy(logits, target, pad_idx):
    # logits: (batch_size, seq_len, vocab_size)
    # target: (batch_size, seq_len)
    mask = (target != pad_idx).float()
    logits_flat = logits.view(-1, logits.size(-1))
    target_flat = target.reshape(-1)
    loss = F.nll_loss(logits_flat, target_flat, reduction='none')
    total_non_pad = mask.sum()
    loss = (loss * mask.view(-1)).sum() / (total_non_pad + 1e-6)
    return loss

def train(train_dataloader, val_dataloader,encoder, decoder, n_epochs, learning_rate=0.001,
          print_every=1, plot_every=100):
    encoder.train() 
    decoder.train()
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer)
        print_loss_total += loss
        plot_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            # print(f"Word Validation Accuracy {evaluate_model(encoder,decoder,val_dataloader,int2devnagri,device,False)}")
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    showPlot(plot_losses)

def evaluate_model(encoder, decoder, dataloader, int2devnagri, device, show_confusion=True):
    # Set up plotting with proper font handling
    font_path = 'C:/Users/aksha/Downloads/Noto_Sans_Devanagari/NotoSansDevanagari-VariableFont_wdth,wght.ttf'  # Adjust path if needed
    font_manager.fontManager.addfont(font_path)
    plt.rcParams['font.family'] = 'Noto Sans Devanagari'
    
    encoder.eval()
    decoder.eval()

    correct_words = 0
    total_words = 0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for _, _, inputs, targets, input_lengths, _ in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            encoder_outputs, encoder_hidden, encoder_cell = encoder(inputs, input_lengths)
            encoder_hidden = tuple(h.to(device) for h in encoder_hidden) if isinstance(encoder_hidden, tuple) \
                            else encoder_hidden.to(device)
            encoder_cell = encoder_cell.to(device) if encoder_cell is not None else None
            decoder_outputs, _, _, _ = decoder(encoder_outputs, encoder_hidden, encoder_cell)

            predicted_indices = decoder_outputs.argmax(dim=-1)

            for pred_seq, true_seq in zip(predicted_indices, targets):
                pred_list = [i.item() for i in pred_seq if i.item() != 0]
                true_list = [i.item() for i in true_seq if i.item() != 0]

                pred_str = ''.join([int2devnagri[i] for i in pred_list])
                true_str = ''.join([int2devnagri[i] for i in true_list])
                if pred_str == true_str:
                    correct_words += 1
                total_words += 1

                min_len = min(len(pred_list), len(true_list))
                y_true.extend(true_list[:min_len])
                y_pred.extend(pred_list[:min_len])

    word_accuracy = correct_words / total_words if total_words > 0 else 0.0
    
    fig = plt.figure(figsize=(18, 8), constrained_layout=True)
    
    # Word Accuracy Plot (Left)
    ax1 = fig.add_subplot(1, 2, 1)
    bars = ax1.bar(['Correct', 'Incorrect'], 
                 [correct_words, total_words - correct_words],
                 color=['#4CAF50', '#F44336'])
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:,}\n({height/total_words:.1%})',
                ha='center', va='bottom')
    
    ax1.set_title(f'Word Accuracy: {word_accuracy:.2%}\nTotal Words: {total_words:,}', pad=20)
    ax1.set_ylabel('Count', labelpad=10)
    ax1.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Confusion Matrix (Right)
    if show_confusion and y_true and y_pred:
        ax2 = fig.add_subplot(1, 2, 2)
        
        labels = sorted(list(set(y_true + y_pred)))
        cm = confusion_matrix(y_true, y_pred, labels=labels)
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
        # Plot every nth character to reduce crowding
        step = max(1, len(labels)//20)  # Show ~20 labels max
        display_labels = [int2devnagri[label] if i%step==0 else '' 
                         for i, label in enumerate(labels)]
        
        sns.heatmap(cm_normalized, ax=ax2,
                   cmap='YlOrRd',
                   cbar_kws={'label': 'Accuracy Percentage', 'shrink': 0.7},
                   xticklabels=display_labels,
                   yticklabels=display_labels,
                   annot=False,
                   square=True)
        
        ax2.set_title('Character Prediction Patterns', pad=20)
        ax2.set_xlabel('Predicted Characters', labelpad=10)
        ax2.set_ylabel('True Characters', labelpad=10)
        
        # Rotate labels and adjust spacing
        plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')
        plt.setp(ax2.get_yticklabels(), rotation=0)
        
        # Add divider lines
        for _, spine in ax2.spines.items():
            spine.set_visible(True)
            spine.set_color('gray')
    
    plt.show()
    plt.savefig("acc.png")
    return word_accuracy

In [15]:
train_dataset = LangDataset("train")
val_dataset = LangDataset("val")
test_dataset = LangDataset("test")

In [16]:
def collate_fn(batch):
    # Sort by input sequence length (descending)
    batch.sort(key=lambda x: len(x[2]), reverse=True)
    
    latin_words, devnagri_words, latin_tensors, devnagri_tensors = zip(*batch)
    
    # Get sequence lengths
    input_lengths = [len(seq) for seq in latin_tensors]
    target_lengths = [len(seq) for seq in devnagri_tensors]
    
    # Pad sequences
    latin_tensors = nn.utils.rnn.pad_sequence(latin_tensors, batch_first=True, padding_value=latin2int[PAD_TOKEN])
    devnagri_tensors = nn.utils.rnn.pad_sequence(devnagri_tensors, batch_first=True, padding_value=devnagri2int[PAD_TOKEN])
    
    return (latin_words, devnagri_words, 
            latin_tensors.to(device), devnagri_tensors.to(device),
            input_lengths, target_lengths)

# Update DataLoader initialization to use the collate function
train_dataloader = DataLoader(train_dataset, 
                            batch_size=64, 
                            shuffle=True, 
                            collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, 
                          batch_size=64, 
                          shuffle=True, 
                          collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, 
                           batch_size=64, 
                           shuffle=True, 
                           collate_fn=collate_fn)

In [18]:
import gc
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
# Initialize with proper parameters
# Model Architecture
encoder = EncoderRNN(
    vocab_size=len(latin2int),
    embed_size=256,
    hidden_size=512,
    num_layers=2,
    layer="lstm",
    dropout_p=0.3
).to(device)

decoder = AttnDecoderRNN(
    vocab_size=len(devnagri2int),
    embed_size=256,
    hidden_size=512,
    num_layers=2,
    layer="lstm"
).to(device)


# Training Schedule
train(train_dataloader, val_dataloader, encoder, decoder,
      n_epochs=5,  # Increased epochs
      learning_rate=0.001,
      print_every=2,
      plot_every=10)

2m 56s (- 4m 24s) (2 40%) 0.9251
6m 53s (- 1m 43s) (4 80%) 0.3082


In [None]:
train(train_dataloader, val_dataloader,encoder, decoder, n_epochs=5, learning_rate=0.0001)

1m 17s (- 5m 9s) (1 20%) 0.0862
3m 5s (- 4m 37s) (2 40%) 0.0679
4m 51s (- 3m 14s) (3 60%) 0.0637
6m 29s (- 1m 37s) (4 80%) 0.0617
8m 26s (- 0m 0s) (5 100%) 0.0603


  fig, ax = plt.subplots()


In [None]:
evaluate_model(encoder,decoder,val_dataloader,int2devnagri,device,True)

In [24]:
test_data_1point = LangDataset(type="test_ponit")
test_data_point_loader = DataLoader(test_data_1point, 
                           batch_size=64, 
                           shuffle=True, 
                           collate_fn=collate_fn)

In [None]:
with torch.no_grad():
    y_true, y_pred, total_words, correct_words = [], [], 0, 0
    for lw, dw, inputs, targets, input_lengths, _ in test_data_point_loader:
        print("Input:", lw)
        print("Target:", dw)
        
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # Encoder forward
        encoder_outputs, encoder_hidden, encoder_cell = encoder(inputs, input_lengths)
        
        # Decoder forward with beam search
        beam_outputs, _, _, _ = decoder(
            encoder_outputs, 
            encoder_hidden, 
            encoder_cell, 
            # beam_width=5
        )
        # Process beam search results
        for batch_idx in range(beam_outputs.size(0)):
            # Get top sequence for this batch item
            top_sequence = beam_outputs[batch_idx]
            
            # If top_sequence contains logits, get predicted indices
            if top_sequence.dim() > 1:
                pred_indices = top_sequence.argmax(dim=-1)
            else:
                pred_indices = top_sequence

            # Convert indices to characters
            pred_chars = []
            for token_id in pred_indices:
                token_id = token_id.item() if isinstance(token_id, torch.Tensor) else token_id
                if token_id == devnagri2int[EOS_TOKEN]:
                    break
                if token_id != devnagri2int[PAD_TOKEN] and token_id != devnagri2int[SOS_TOKEN]:
                    pred_chars.append(int2devnagri[token_id])
            
            # Get true translation
            true_chars = []
            for token_id in targets[batch_idx]:
                if token_id == devnagri2int[EOS_TOKEN]:
                    break
                if token_id != devnagri2int[PAD_TOKEN] and token_id != devnagri2int[SOS_TOKEN]:
                    true_chars.append(int2devnagri[token_id.item()])
            
            # Convert to strings
            pred_str = ''.join(pred_chars)
            true_str = ''.join(true_chars)
            
            print(f"Predicted: {pred_str}")
            print(f"Expected: {true_str}\n")
            
            # Update metrics
            if pred_str == true_str:
                correct_words += 1
            total_words += 1
            
            # For confusion matrix
            y_true.extend(true_chars)
            y_pred.extend(pred_chars[:len(true_chars)])

word_accuracy = correct_words / total_words if total_words > 0 else 0.0
print(f"Word Accuracy: {word_accuracy:.2%}")

Input: ('nirnayaprakriyet', 'anuj')
Target: ('निर्णयप्रक्रियेत', 'अनुज')
Predicted: निरणयर्यप्ती्ाी
Expected: निर्णयप्रक्रियेत

Predicted: अनजज
Expected: अनुज

Word Accuracy: 0.00%
