In [None]:
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import random
import gc
import wandb
import traceback
import os
import io
import csv
import copy
import heapq
from IPython.display import HTML, display
import matplotlib.pyplot as plt
from matplotlib import colors
import seaborn as sns

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- WANDB SETUP ---
WANDB_API_KEY = "c4c7a78b7e8600d02ded519f43e6ef09838dc431"
wandb.login(key=WANDB_API_KEY)

# Special tokens
START_TOKEN = '<'
END_TOKEN = '>'
PAD_TOKEN = '_'

# Data paths for Tamil
DATA_PATHS = {
    "train": "/kaggle/input/aksharantar-sampled/aksharantar_sampled/tam/tam_train.csv",
    "test": "/kaggle/input/aksharantar-sampled/aksharantar_sampled/tam/tam_test.csv",
    "valid": "/kaggle/input/aksharantar-sampled/aksharantar_sampled/tam/tam_valid.csv"
}

In [None]:
class DataProcessor:
    """Class for handling all data processing operations"""

    def __init__(self):
        """Initialize data processor with empty dictionaries and special tokens"""
        self.data = {
            "source_chars": [START_TOKEN, END_TOKEN, PAD_TOKEN],
            "target_chars": [START_TOKEN, END_TOKEN, PAD_TOKEN],
            "source_char_index": {START_TOKEN: 0, END_TOKEN: 1, PAD_TOKEN: 2},
            "source_index_char": {0: START_TOKEN, 1: END_TOKEN, 2: PAD_TOKEN},
            "target_char_index": {START_TOKEN: 0, END_TOKEN: 1, PAD_TOKEN: 2},
            "target_index_char": {0: START_TOKEN, 1: END_TOKEN, 2: PAD_TOKEN},
            "source_len": 3,
            "target_len": 3
        }

    def load_data(self, data_paths):
        """Load data from CSV files"""
        data_frames = {}
        data_pairs = {}

        for split, path in data_paths.items():
            df = pd.read_csv(path, header=None)
            data_frames[split] = df
            data_pairs[split] = (df[0].to_numpy(), df[1].to_numpy())
            print(f"Loaded {split} data: {len(df)} examples")

        return data_frames, data_pairs

    def add_padding(self, sequences, max_length):
        """Add padding to sequences"""
        padded_strings = []
        for seq in sequences:
            # Add start and end tokens
            padded_seq = START_TOKEN + seq + END_TOKEN
            # Truncate or pad
            padded_seq = padded_seq[:max_length]
            padded_seq += PAD_TOKEN * (max_length - len(padded_seq))
            padded_strings.append(padded_seq)
        return padded_strings

    def chars_to_indices(self, string, char_index_dict):
        """Convert characters to their indices"""
        char_indices = []
        for char in string:
            # Handle OOV characters by using PAD token index
            if char in char_index_dict:
                char_indices.append(char_index_dict[char])
            else:
                char_indices.append(char_index_dict[PAD_TOKEN])
        return torch.tensor(char_indices, dtype=torch.long, device=device)

    def generate_sequence_from_string(self, strings, char_index_dict):
        """Convert strings to sequences of indices"""
        sequences = []
        for string in strings:
            # Convert characters to indices
            sequences.append(self.chars_to_indices(string, char_index_dict))

        # Pad sequences to the same length
        sequences = pad_sequence(sequences, batch_first=True, padding_value=char_index_dict[PAD_TOKEN])
        return sequences

    def update_char_dictionaries(self, padded_source, padded_target):
        """Update character dictionaries with new characters"""
        for i in range(len(padded_source)):
            for c in padded_source[i]:
                if c not in self.data["source_char_index"]:
                    self.data["source_chars"].append(c)
                    idx = len(self.data["source_chars"]) - 1
                    self.data["source_char_index"][c] = idx
                    self.data["source_index_char"][idx] = c

            for c in padded_target[i]:
                if c not in self.data["target_char_index"]:
                    self.data["target_chars"].append(c)
                    idx = len(self.data["target_chars"]) - 1
                    self.data["target_char_index"][c] = idx
                    self.data["target_index_char"][idx] = c

    def process_data(self, source_data, target_data):
        """Process source and target data"""
        # Store original data
        self.data["source_data"] = source_data
        self.data["target_data"] = target_data

        # Calculate max lengths
        self.data["INPUT_MAX_LENGTH"] = max(len(s) for s in source_data) + 2  # +2 for START and END tokens
        self.data["OUTPUT_MAX_LENGTH"] = max(len(t) for t in target_data) + 2

        print(f"Input max length: {self.data['INPUT_MAX_LENGTH']}")
        print(f"Output max length: {self.data['OUTPUT_MAX_LENGTH']}")

        # Add padding
        padded_source = self.add_padding(source_data, self.data["INPUT_MAX_LENGTH"])
        padded_target = self.add_padding(target_data, self.data["OUTPUT_MAX_LENGTH"])

        # Update character dictionaries
        self.update_char_dictionaries(padded_source, padded_target)

        # Generate sequences
        self.data["source_data_seq"] = self.generate_sequence_from_string(padded_source, self.data["source_char_index"])
        self.data["target_data_seq"] = self.generate_sequence_from_string(padded_target, self.data["target_char_index"])

        # Update lengths
        self.data["source_len"] = len(self.data["source_chars"])
        self.data["target_len"] = len(self.data["target_chars"])

        print(f"Source vocabulary size: {self.data['source_len']}")
        print(f"Target vocabulary size: {self.data['target_len']}")

        return self.data

    def process_validation(self, val_source, val_target):
        """Process validation data using existing character maps"""
        # Add padding
        padded_val_source = self.add_padding(val_source, self.data["INPUT_MAX_LENGTH"])
        padded_val_target = self.add_padding(val_target, self.data["OUTPUT_MAX_LENGTH"])

        # Generate sequences
        val_source_seq = self.generate_sequence_from_string(padded_val_source, self.data["source_char_index"])
        val_target_seq = self.generate_sequence_from_string(padded_val_target, self.data["target_char_index"])

        return val_source_seq, val_target_seq

    def indices_to_string(self, indices, index_char_dict):
        """Convert indices to a string"""
        string = ""
        for idx in indices:
            if isinstance(idx, torch.Tensor):
                idx = idx.item()
            if idx in index_char_dict:
                char = index_char_dict[idx]
                if char not in [PAD_TOKEN]:
                    string += char
        return string.replace(START_TOKEN, "").replace(END_TOKEN, "")

    def prepare_input_for_prediction(self, input_string):
        """Prepare an input string for prediction"""
        padded_input = START_TOKEN + input_string + END_TOKEN
        padded_input = padded_input[:self.data["INPUT_MAX_LENGTH"]]
        padded_input += PAD_TOKEN * (self.data["INPUT_MAX_LENGTH"] - len(padded_input))

        # Convert to indices and create tensor
        input_indices = [self.data["source_char_index"].get(c, self.data["source_char_index"][PAD_TOKEN])
                        for c in padded_input]
        input_tensor = torch.tensor(input_indices, device=device).unsqueeze(0)  # Add batch dimension

        return input_tensor, padded_input

class TransliterationDataset(Dataset):
    """Dataset class for transliteration data"""
    def __init__(self, source_seq, target_seq):
        self.source_seq = source_seq
        self.target_seq = target_seq

    def __len__(self):
        return len(self.source_seq)

    def __getitem__(self, idx):
        return self.source_seq[idx], self.target_seq[idx]

class DataManager:
    """High-level manager for data operations"""
    def __init__(self, data_paths):
        self.data_paths = data_paths
        self.processor = DataProcessor()

    def load_all_data(self):
        """Load all data splits"""
        _, data_pairs = self.processor.load_data(self.data_paths)
        self.train_source, self.train_target = data_pairs["train"]
        self.val_source, self.val_target = data_pairs["valid"]
        self.test_source, self.test_target = data_pairs["test"]

        # Process training data
        self.data = self.processor.process_data(self.train_source, self.train_target)

        return self.data

    def create_dataloaders(self, h_params):
        """Create DataLoaders for training and validation"""
        # Training data
        train_dataset = TransliterationDataset(self.data["source_data_seq"], self.data["target_data_seq"])

        # Process validation data
        val_source_seq, val_target_seq = self.processor.process_validation(self.val_source, self.val_target)
        val_dataset = TransliterationDataset(val_source_seq, val_target_seq)

        # Create DataLoaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=h_params["batch_size"],
            shuffle=True,
            num_workers=0,
            pin_memory=False
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=h_params["batch_size"],
            shuffle=False,
            num_workers=0,
            pin_memory=False
        )

        return train_loader, val_loader

    def create_test_dataloader(self, h_params):
        """Create DataLoader for test data"""
        test_source_seq, test_target_seq = self.processor.process_validation(self.test_source, self.test_target)
        test_dataset = TransliterationDataset(test_source_seq, test_target_seq)

        test_loader = DataLoader(
            test_dataset,
            batch_size=h_params["batch_size"],
            shuffle=False,
            num_workers=0,
            pin_memory=False
        )

        return test_loader

# Class for beam search decoding
class BeamSearchNode:
    """Node in beam search"""
    def __init__(self, hidden_state, previous_node, word_id, log_prob, length, attn_weights=None):
        self.hidden_state = hidden_state
        self.previous_node = previous_node
        self.word_id = word_id
        self.log_prob = log_prob
        self.length = length
        self.attn_weights = attn_weights

    def eval(self, alpha=1.0):
        """Evaluate this node"""
        # Normalize score by length (to avoid penalizing longer sequences)
        return self.log_prob / float(self.length - 1 + 1e-6) ** alpha

def get_chars_from_nodes(nodes, target_index_char):
    """Convert nodes to strings"""
    chars = []
    attn_weights_list = []

    # Traverse backwards through the linked list of nodes
    current = nodes[-1]
    while current.previous_node is not None:
        if current.word_id not in [0, 1, 2]:  # Exclude special tokens
            chars.append(target_index_char[current.word_id])
            if current.attn_weights is not None:
                attn_weights_list.append(current.attn_weights)
        current = current.previous_node

    # Reverse the lists since we traversed backwards
    chars = chars[::-1]
    attn_weights_list = attn_weights_list[::-1] if attn_weights_list else []

    return ''.join(chars), attn_weights_list

In [None]:
class RNNType:
    """Enum-like class for RNN cell types"""
    RNN = "RNN"
    LSTM = "LSTM"
    GRU = "GRU"

    @staticmethod
    def get_cell(cell_type):
        """Get the appropriate RNN cell type"""
        if cell_type == RNNType.RNN:
            return nn.RNN
        elif cell_type == RNNType.LSTM:
            return nn.LSTM
        elif cell_type == RNNType.GRU:
            return nn.GRU
        else:
            raise ValueError(f"Unsupported cell type: {cell_type}")

class Attention(nn.Module):
    """Attention mechanism for sequence-to-sequence model"""
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, decoder_hidden, encoder_outputs):
        """Forward pass for attention mechanism"""
        # Handle LSTM hidden state
        if isinstance(decoder_hidden, tuple):  # LSTM: (h, c)
            decoder_hidden = decoder_hidden[0]

        # Handle different shapes of decoder_hidden
        if decoder_hidden.dim() == 3:  # (num_layers, batch, hidden_size)
            decoder_hidden = decoder_hidden[-1].unsqueeze(1)  # (batch, 1, hidden_size)
        elif decoder_hidden.dim() == 2:  # (batch, hidden_size)
            decoder_hidden = decoder_hidden.unsqueeze(1)  # (batch, 1, hidden_size)
        elif decoder_hidden.dim() == 1:  # (hidden_size,)
            decoder_hidden = decoder_hidden.unsqueeze(0).unsqueeze(0)  # (1, 1, hidden_size)

        # Make sure encoder_outputs has the right shape (batch, seq_len, hidden_size)
        if encoder_outputs.dim() == 2:  # (seq_len, hidden_size)
            encoder_outputs = encoder_outputs.unsqueeze(0)  # (1, seq_len, hidden_size)

        # Check batch sizes match
        if decoder_hidden.size(0) != encoder_outputs.size(0):
            # Broadcast smaller batch to larger if needed
            if decoder_hidden.size(0) == 1:
                decoder_hidden = decoder_hidden.expand(encoder_outputs.size(0), -1, -1)
            elif encoder_outputs.size(0) == 1:
                encoder_outputs = encoder_outputs.expand(decoder_hidden.size(0), -1, -1)
            else:
                # If sizes don't match and can't broadcast, use only one example
                decoder_hidden = decoder_hidden[:1]
                encoder_outputs = encoder_outputs[:1]

        # Calculate attention scores
        decoder_features = self.Wa(decoder_hidden)  # (batch, 1, hidden_size)
        encoder_features = self.Ua(encoder_outputs)  # (batch, seq_len, hidden_size)
        energy = torch.tanh(decoder_features + encoder_features)  # (batch, seq_len, hidden_size)
        scores = self.Va(energy)  # (batch, seq_len, 1)

        # Apply softmax to get attention weights
        attn_weights = F.softmax(scores, dim=1)  # (batch, seq_len, 1)

        # Calculate context vector as weighted sum of encoder outputs
        context = torch.bmm(attn_weights.transpose(1, 2), encoder_outputs).squeeze(1)  # (batch, hidden_size)

        return context, attn_weights

class Encoder(nn.Module):
    """Encoder model for sequence-to-sequence learning"""
    def __init__(self, h_params, data):
        super(Encoder, self).__init__()
        self.h_params = h_params

        # Embedding layer
        self.embedding = nn.Embedding(data["source_len"], h_params["char_embed_dim"])

        # RNN layer
        self.rnn_cell = RNNType.get_cell(h_params["cell_type"])(
            input_size=h_params["char_embed_dim"],
            hidden_size=h_params["hidden_size"],
            num_layers=h_params["num_layers"],
            dropout=h_params["dropout"] if h_params["num_layers"] > 1 else 0,
            batch_first=True
        )

        # Dropout
        self.dropout = nn.Dropout(h_params["dropout"])

    def forward(self, x, hidden=None):
        """Forward pass for encoder"""
        # Embed input
        embedded = self.dropout(self.embedding(x))

        # Pass through RNN
        output, hidden = self.rnn_cell(embedded, hidden)

        return output, hidden

    def init_hidden(self, batch_size):
        """Initialize hidden state"""
        h = torch.zeros(self.h_params["num_layers"], batch_size,
                      self.h_params["hidden_size"], device=device)

        if self.h_params["cell_type"] == RNNType.LSTM:
            c = torch.zeros(self.h_params["num_layers"], batch_size,
                          self.h_params["hidden_size"], device=device)
            return (h, c)
        else:
            return h

class Decoder(nn.Module):
    """Decoder model with attention for sequence-to-sequence learning"""
    def __init__(self, h_params, data):
        super(Decoder, self).__init__()
        self.h_params = h_params
        self.data = data

        # Embedding layer
        self.embedding = nn.Embedding(data["target_len"], h_params["char_embed_dim"])

        # Attention mechanism
        self.attention = Attention(h_params["hidden_size"])

        # RNN layer - input is the embedding + context vector
        self.rnn_cell = RNNType.get_cell(h_params["cell_type"])(
            input_size=h_params["char_embed_dim"] + h_params["hidden_size"],
            hidden_size=h_params["hidden_size"],
            num_layers=h_params["num_layers"],
            dropout=h_params["dropout"] if h_params["num_layers"] > 1 else 0,
            batch_first=True
        )

        # Output layer
        self.fc_out = nn.Linear(h_params["hidden_size"], data["target_len"])

        # Dropout
        self.dropout = nn.Dropout(h_params["dropout"])

    def forward(self, x, hidden, encoder_outputs):
        """Forward pass for decoder with attention"""
        # Handle different input dimensions
        if x.dim() == 0:  # scalar
            x = x.unsqueeze(0).unsqueeze(0)  # (1, 1)
        elif x.dim() == 1:  # (batch,)
            x = x.unsqueeze(1)  # (batch, 1)

        # Embed input
        embedded = self.dropout(self.embedding(x))  # (batch, seq_len, embed_dim)

        # Make sure encoder_outputs has 3 dimensions
        if encoder_outputs.dim() == 2:
            encoder_outputs = encoder_outputs.unsqueeze(0)  # Add batch dimension if missing

        # Get attention context vector and weights
        context, attn_weights = self.attention(hidden, encoder_outputs)
        context = context.unsqueeze(1)  # (batch, 1, hidden_size)

        # Concatenate embedding and context vector
        rnn_input = torch.cat([embedded, context], dim=2)  # (batch, seq_len, embed_dim + hidden_size)

        # Pass through RNN
        output, hidden = self.rnn_cell(rnn_input, hidden)

        # Get prediction
        prediction = self.fc_out(output)

        return F.log_softmax(prediction, dim=2), hidden, attn_weights

class Seq2SeqAttentionModel:
    """Wrapper for encoder-decoder architecture with attention"""
    def __init__(self, h_params, data):
        self.encoder = Encoder(h_params, data).to(device)
        self.decoder = Decoder(h_params, data).to(device)
        self.h_params = h_params
        self.data = data

    def forward(self, source, target=None, teacher_forcing_ratio=0.5):
        """Forward pass through the model"""
        batch_size = source.size(0)

        # Initialize encoder hidden state
        encoder_hidden = self.encoder.init_hidden(batch_size)

        # Encoder forward pass
        encoder_outputs, encoder_hidden = self.encoder(source, encoder_hidden)

        # Return early if no target (for inference)
        if target is None:
            return encoder_outputs, encoder_hidden

        # Prepare decoder input and hidden state
        decoder_input = torch.tensor([self.data["target_char_index"][START_TOKEN]] * batch_size, device=device)
        decoder_hidden = encoder_hidden

        # Determine if using teacher forcing
        use_teacher_forcing = random.random() < teacher_forcing_ratio

        # Storage for decoder outputs and attention weights
        decoder_outputs = []
        attention_weights = []

        # Decoder forward pass (one character at a time)
        for t in range(self.data["OUTPUT_MAX_LENGTH"]):
            decoder_output, decoder_hidden, attn_weights = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output.squeeze(1))  # Remove sequence dimension
            attention_weights.append(attn_weights)

            # Determine next input (either from target or model's prediction)
            if use_teacher_forcing and t < self.data["OUTPUT_MAX_LENGTH"] - 1:
                decoder_input = target[:, t + 1]  # Next input is next character in target
            else:
                # Get predicted character
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # Next input is model's prediction

        # Stack decoder outputs
        decoder_outputs = torch.stack(decoder_outputs, dim=1)
        attention_weights = torch.stack(attention_weights, dim=1) if attention_weights else None

        return decoder_outputs, attention_weights

    def greedy_decode(self, input_tensor, processor):
        """Perform greedy decoding with attention"""
        # Initialize encoder
        batch_size = input_tensor.size(0)
        encoder_hidden = self.encoder.init_hidden(batch_size)

        # Encoder forward pass
        encoder_outputs, encoder_hidden = self.encoder(input_tensor, encoder_hidden)

        # Prepare for decoding
        decoder_input = torch.tensor([self.data["target_char_index"][START_TOKEN]], device=device)
        decoder_hidden = encoder_hidden

        # Storage for output
        output_tokens = []
        attn_weights_list = []

        # Greedy decoding
        for _ in range(self.data["OUTPUT_MAX_LENGTH"]):
            # Forward pass through decoder
            decoder_output, decoder_hidden, attn_weights = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )

            # Get best prediction
            _, topi = decoder_output.data.topk(1)
            predicted_idx = topi.squeeze().item()

            # Stop if END token
            if predicted_idx == self.data["target_char_index"][END_TOKEN]:
                break

            # Add to outputs if it's a valid character
            if predicted_idx in self.data["target_index_char"]:
                output_tokens.append(self.data["target_index_char"][predicted_idx])

            # Store attention weights
            if attn_weights is not None:
                attn_weights_list.append(attn_weights.squeeze(0).squeeze(-1).cpu().numpy())

            # Next input is predicted token
            decoder_input = topi.squeeze().detach()

        # Return prediction and attention weights
        pred_str = ''.join(output_tokens)
        attn_weights_array = np.array(attn_weights_list) if attn_weights_list else np.array([])

        return pred_str, attn_weights_array

    def beam_search_decode(self, input_tensor, padded_input, beam_width=5):
        """Perform beam search decoding with attention"""
        # Initialize encoder
        batch_size = input_tensor.size(0)
        encoder_hidden = self.encoder.init_hidden(batch_size)

        # Encoder forward pass
        encoder_outputs, encoder_hidden = self.encoder(input_tensor, encoder_hidden)

        # Number of sentences to generate
        end_nodes = []

        # Starting node
        decoder_input = torch.tensor([self.data["target_char_index"][START_TOKEN]], device=device)
        node = BeamSearchNode(encoder_hidden, None, decoder_input, 0, 1)
        nodes = [(node.eval(), node)]  # Priority queue
        heapq.heapify(nodes)

        # Start beam search
        for step in range(self.data["OUTPUT_MAX_LENGTH"]):
            # Give up when decoding takes too long
            if len(nodes) == 0:
                break

            # Fetch the best node
            _, current_node = heapq.heappop(nodes)
            decoder_input = current_node.word_id
            decoder_hidden = current_node.hidden_state

            # If we reached the EOS token
            if decoder_input.item() == self.data["target_char_index"][END_TOKEN] and current_node.previous_node is not None:
                end_nodes.append((current_node.eval(), current_node))
                # If we have enough end nodes, stop
                if len(end_nodes) >= beam_width:
                    break
                continue

            # Decode for one step
            decoder_output, decoder_hidden, attn_weights = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )

            # Get top-k tokens
            log_probs, indexes = torch.topk(decoder_output.squeeze(), beam_width)

            # Put them into the queue
            for new_k in range(beam_width):
                decoded_t = indexes[0][new_k].view(1)
                log_p = log_probs[0][new_k].item()

                # Create new node
                new_node = BeamSearchNode(
                    decoder_hidden, current_node, decoded_t,
                    current_node.log_prob + log_p, current_node.length + 1,
                    attn_weights.cpu().numpy() if attn_weights is not None else None
                )

                # Add to priority queue
                heapq.heappush(nodes, (new_node.eval(), new_node))

        # If we don't have any end_nodes, get the top-k nodes from the queue
        if len(end_nodes) == 0:
            end_nodes = [heapq.heappop(nodes) for _ in range(min(beam_width, len(nodes)))]

        # Sort end nodes by score (higher is better)
        end_nodes = sorted(end_nodes, key=lambda x: x[0], reverse=True)

        # Get the best sequence
        best_nodes = [node for _, node in end_nodes[:beam_width]]

        # Get characters and attention weights from each path
        results = []
        for node in best_nodes:
            path_chars = []
            path_attn_weights = []

            # Traverse back to build sequence
            current = node
            node_list = []
            while current.previous_node is not None:
                node_list.append(current)
                current = current.previous_node

            # Reverse to get correct order
            node_list.reverse()

            # Extract characters and attention weights
            pred_str, attn_weights_list = get_chars_from_nodes(node_list, self.data["target_index_char"])

            # Convert attention weights to array
            attn_weights_array = np.array(attn_weights_list) if attn_weights_list else np.array([])

            results.append((pred_str, attn_weights_array))

        return results

    def predict(self, input_string, processor, use_beam_search=False, beam_width=5):
        """Make prediction for a single input string"""
        # Prepare input
        input_tensor, padded_input = processor.prepare_input_for_prediction(input_string)

        # Decode
        if use_beam_search:
            results = self.beam_search_decode(input_tensor, padded_input, beam_width)
            # Return best result (first in the list)
            return results[0] if results else ("", np.array([]))
        else:
            return self.greedy_decode(input_tensor, processor)

In [None]:
class AttentionVisualizer:
    """Class for visualizing attention weights"""

    @staticmethod
    def plot_attention_matrix(input_chars, output_chars, attention_matrix, title="Attention Matrix"):
        """Plot attention matrix as a heatmap"""
        if attention_matrix.size == 0 or len(input_chars) == 0 or len(output_chars) == 0:
            print(f"No attention data available for {title}")
            return

        # Filter out padding tokens
        active_input_chars = [c for c in input_chars if c not in [PAD_TOKEN]]
        active_output_chars = [c for c in output_chars if c not in [PAD_TOKEN, START_TOKEN, END_TOKEN]]

        # Ensure matrix dimensions match characters
        matrix_height = min(len(active_output_chars), attention_matrix.shape[0])
        matrix_width = min(len(active_input_chars), attention_matrix.shape[1])

        if matrix_height == 0 or matrix_width == 0:
            print(f"No valid dimensions for attention matrix in {title}")
            return

        # Adjust matrix dimensions
        adj_matrix = attention_matrix[:matrix_height, :matrix_width]

        # Create plot
        fig, ax = plt.figure(figsize=(10, 8)), plt.subplot(111)

        # Plot heatmap
        im = ax.imshow(adj_matrix, cmap='Blues')

        # Add colorbar
        plt.colorbar(im)

        # Set labels
        ax.set_xticks(range(len(active_input_chars)))
        ax.set_yticks(range(len(active_output_chars)))
        ax.set_xticklabels(active_input_chars, fontsize=12)
        ax.set_yticklabels(active_output_chars, fontsize=12)

        plt.xlabel('Input Sequence')
        plt.ylabel('Output Sequence')
        plt.title(title)

        # Rotate x-axis labels for better readability
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

        # Loop over data dimensions and create text annotations
        for i in range(matrix_height):
            for j in range(matrix_width):
                text = ax.text(j, i, f"{adj_matrix[i, j]:.2f}",
                              ha="center", va="center", color="white" if adj_matrix[i, j] > 0.5 else "black")

        plt.tight_layout()
        return fig

    @staticmethod
    def plot_attention_grid(input_strings, target_strings, predictions, attention_matrices, n_samples=4):
        """Plot grid of attention matrices"""
        # Calculate grid dimensions
        n_samples = min(n_samples, len(input_strings))
        n_cols = min(2, n_samples)
        n_rows = (n_samples + n_cols - 1) // n_cols

        # Create figure
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))

        # Flatten axes for easier indexing if needed
        if n_rows == 1 and n_cols == 1:
            axes = np.array([axes])
        elif n_rows == 1 or n_cols == 1:
            axes = axes.flatten()

        # Plot each attention matrix
        for i in range(n_samples):
            row, col = i // n_cols, i % n_cols
            ax = axes[row, col] if n_rows > 1 and n_cols > 1 else axes[i]

            # Get data for this sample
            input_string = input_strings[i]
            target_string = target_strings[i]
            prediction = predictions[i]
            attention_matrix = attention_matrices[i]

            # Skip if no attention data
            if attention_matrix.size == 0:
                ax.text(0.5, 0.5, f"No attention data for sample {i+1}",
                       ha='center', va='center', fontsize=12)
                ax.axis('off')
                continue

            # Create input and output character lists
            input_chars = [c for c in input_string if c not in [PAD_TOKEN]]
            output_chars = [c for c in prediction if c not in [PAD_TOKEN, START_TOKEN, END_TOKEN]]

            # Only use as many chars as we have attention weights for
            input_chars = input_chars[:attention_matrix.shape[1]]
            output_chars = output_chars[:attention_matrix.shape[0]]

            # Plot the heatmap
            im = ax.imshow(attention_matrix, cmap='Blues')

            # Set labels
            ax.set_xticks(range(len(input_chars)))
            ax.set_yticks(range(len(output_chars)))
            ax.set_xticklabels(input_chars, fontsize=10)
            ax.set_yticklabels(output_chars, fontsize=10)

            # Add title
            correct = prediction == target_string
            title = f"Sample {i+1} - {'✓' if correct else '✗'}\nInput: {input_string}\nTarget: {target_string}\nPred: {prediction}"
            ax.set_title(title, fontsize=10)

            # Rotate x-axis labels
            plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

        # Remove empty subplots
        for i in range(n_samples, n_rows * n_cols):
            row, col = i // n_cols, i % n_cols
            fig.delaxes(axes[row, col] if n_rows > 1 and n_cols > 1 else axes[i])

        plt.tight_layout()
        return fig

    @staticmethod
    def export_attention_html(input_chars, output_chars, attention_matrix, title="Attention Visualization"):
        """Export attention visualization as HTML with a heatmap"""
        if not input_chars or not output_chars or attention_matrix.size == 0:
            return f"<p>No attention data available for {title}</p>"

        # Filter out special tokens for better visualization
        filtered_input_chars = [c for c in input_chars if c not in [PAD_TOKEN]]
        filtered_output_chars = [c for c in output_chars if c not in [PAD_TOKEN, START_TOKEN, END_TOKEN]]

        # Adjust attention matrix dimensions if needed
        if len(filtered_input_chars) > 0 and len(filtered_output_chars) > 0:
            # Ensure matrix dimensions match our characters
            matrix_height = min(len(filtered_output_chars), attention_matrix.shape[0])
            matrix_width = min(len(filtered_input_chars), attention_matrix.shape[1])

            adj_matrix = attention_matrix[:matrix_height, :matrix_width]
        else:
            # Return early if no valid characters
            return f"<p>Not enough valid characters for {title}</p>"

        # Create a CSS-based heatmap
        html = f"""
        <div style="margin: 20px 0;">
            <h3>{title}</h3>
            <div style="font-family: 'Noto Sans Tamil', sans-serif; margin-bottom: 10px;">
                <b>Input:</b> {''.join(filtered_input_chars)}<br>
                <b>Output:</b> {''.join(filtered_output_chars)}
            </div>

            <table style="border-collapse: collapse; font-family: 'Noto Sans Tamil', sans-serif;">
                <tr>
                    <th style="width: 30px;"></th>
                    {''.join(f'<th style="width: 30px; text-align: center; padding: 4px;">{c}</th>' for c in filtered_input_chars)}
                </tr>
        """

        # Add rows for each output character
        for i, out_char in enumerate(filtered_output_chars):
            html += f'<tr><th style="text-align: right; padding: 4px;">{out_char}</th>'

            # Add attention cells
            for j, _ in enumerate(filtered_input_chars):
                if i < adj_matrix.shape[0] and j < adj_matrix.shape[1]:
                    # Get attention value and convert to color intensity
                    value = adj_matrix[i, j]
                    # Highlight the maximum value in the row
                    is_max = value == np.max(adj_matrix[i])

                    # Calculate color based on attention value (blue gradient)
                    bg_color = f"rgba(0, 0, 255, {value:.2f})"
                    border = "2px solid red" if is_max else "1px solid #ddd"

                    html += f'<td style="width: 30px; height: 30px; background-color: {bg_color}; border: {border}; text-align: center; color: white; font-size: 12px;">{value:.2f}</td>'
                else:
                    html += '<td style="width: 30px; height: 30px; background-color: #f9f9f9; border: 1px solid #ddd;"></td>'

            html += '</tr>'

        html += '</table></div>'
        return html

    @staticmethod
    def create_attention_grid_html(input_strings, target_strings, predictions, attention_matrices, n_samples=9):
        """Create HTML grid of attention visualizations"""
        # Limit samples
        n_samples = min(n_samples, len(input_strings))

        html = """
        <html>
        <head>
            <meta charset="UTF-8">
            <style>
                @import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Tamil:wght@400;700&display=swap');
                body {
                    font-family: 'Noto Sans Tamil', sans-serif;
                    font-size: 14px;
                }
                .grid-container {
                    display: grid;
                    grid-template-columns: repeat(3, 1fr);
                    gap: 20px;
                }
                .attention-cell {
                    border: 1px solid #ddd;
                    padding: 10px;
                    border-radius: 5px;
                }
                .examples-text {
                    margin-top: 20px;
                    line-height: 1.6;
                }
                .correct { color: green; }
                .incorrect { color: red; }
            </style>
        </head>
        <body>
            <h2>Attention Visualizations</h2>
            <div class="grid-container">
        """

        input_output_pairs = []

        for i in range(n_samples):
            input_str = input_strings[i]
            target_str = target_strings[i]
            pred_str = predictions[i]
            attention_matrix = attention_matrices[i]

            # Prepare character lists
            input_chars = [c for c in input_str if c not in [PAD_TOKEN]]
            output_chars = [c for c in pred_str if c not in [PAD_TOKEN, START_TOKEN, END_TOKEN]]

            # Generate visualization HTML
            viz_html = AttentionVisualizer.export_attention_html(
                input_chars, output_chars, attention_matrix, f"Example {i+1}"
            )

            # Add to grid
            html += f"""
            <div class="attention-cell">
                {viz_html}
            </div>
            """

            # Check if prediction is correct
            correct = pred_str == target_str
            correct_class = "correct" if correct else "incorrect"

            # Save info for text list
            input_output_pairs.append(
                f"Example {i+1}:<br>Input: {input_str}<br>Target: {target_str}<br>"
                f"Pred: <span class='{correct_class}'>{pred_str} {'(✓)' if correct else '(✗)'}</span>"
            )

        html += """
            </div>
            <div class="examples-text">
        """

        # Add text list
        for text in input_output_pairs:
            html += f"<p>{text}</p>"

        html += """
            </div>
        </body>
        </html>
        """

        return html

In [None]:
class Trainer:
    """Class for training sequence-to-sequence models with attention"""
    def __init__(self, model, h_params, data, train_loader, val_loader):
        self.model = model
        self.h_params = h_params
        self.data = data
        self.train_loader = train_loader
        self.val_loader = val_loader

        # Initialize optimizers
        if h_params["optimizer"].lower() == "adam":
            self.encoder_optimizer = optim.Adam(model.encoder.parameters(), lr=h_params["learning_rate"])
            self.decoder_optimizer = optim.Adam(model.decoder.parameters(), lr=h_params["learning_rate"])
        elif h_params["optimizer"].lower() == "nadam":
            self.encoder_optimizer = optim.NAdam(model.encoder.parameters(), lr=h_params["learning_rate"])
            self.decoder_optimizer = optim.NAdam(model.decoder.parameters(), lr=h_params["learning_rate"])
        else:
            self.encoder_optimizer = optim.SGD(model.encoder.parameters(), lr=h_params["learning_rate"])
            self.decoder_optimizer = optim.SGD(model.decoder.parameters(), lr=h_params["learning_rate"])

        # Loss function
        self.criterion = nn.NLLLoss()

        # Learning rate schedulers
        self.encoder_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.encoder_optimizer, 'min', patience=2)
        self.decoder_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.decoder_optimizer, 'min', patience=2)

        # Best model tracking
        self.best_val_accuracy = 0

    def train_epoch(self):
        """Train for one epoch"""
        self.model.encoder.train()
        self.model.decoder.train()

        epoch_loss = 0
        total_correct = 0
        total_examples = 0

        for batch_idx, (source, target) in enumerate(self.train_loader):
            batch_size = source.size(0)
            if batch_size == 1:  # Skip batches with only one example
                continue

            # Zero gradients
            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()

            try:
                # Forward pass
                decoder_outputs, _ = self.model.forward(
                    source, target, self.h_params["teacher_forcing_ratio"]
                )

                # Calculate loss
                loss = 0
                all_predictions = []

                for t in range(self.data["OUTPUT_MAX_LENGTH"]):
                    decoder_output = decoder_outputs[:, t, :]
                    target_t = target[:, t]
                    loss += self.criterion(decoder_output, target_t)

                    # Get predicted character
                    _, topi = decoder_output.topk(1)
                    prediction = topi.squeeze().detach()
                    all_predictions.append(prediction)

                # Combine all predictions
                predictions = torch.stack(all_predictions, dim=1)  # batch_size x seq_len

                # Calculate accuracy (exact match)
                correct = (predictions == target).all(dim=1).sum().item()

                # Update totals
                total_correct += correct
                total_examples += batch_size

                # Backpropagation
                loss.backward()

                # Clip gradients to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(self.model.encoder.parameters(), 1)
                torch.nn.utils.clip_grad_norm_(self.model.decoder.parameters(), 1)

                # Update parameters
                self.encoder_optimizer.step()
                self.decoder_optimizer.step()

                # Track loss
                epoch_loss += loss.item() / self.data["OUTPUT_MAX_LENGTH"]

            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                traceback.print_exc()
                continue

            # Free up memory
            del source, target, loss
            torch.cuda.empty_cache()

            # Run garbage collection periodically
            if batch_idx % 10 == 0:
                gc.collect()
                torch.cuda.empty_cache()

            # Print progress periodically
            if (batch_idx + 1) % 100 == 0:
                print(f"Batch {batch_idx + 1}/{len(self.train_loader)}, Loss: {epoch_loss/(batch_idx+1):.4f}")

        if total_examples == 0:
            return 0, 0  # Avoid division by zero if all batches were skipped

        return epoch_loss / len(self.train_loader), total_correct / total_examples

    def evaluate(self, data_loader):
        """Evaluate model on validation or test data"""
        self.model.encoder.eval()
        self.model.decoder.eval()

        epoch_loss = 0
        total_correct = 0
        total_examples = 0

        with torch.no_grad():
            for batch_idx, (source, target) in enumerate(data_loader):
                batch_size = source.size(0)
                if batch_size == 1:  # Skip batches with only one example
                    continue

                try:
                    # Forward pass with teacher forcing
                    decoder_outputs, _ = self.model.forward(
                        source, target, teacher_forcing_ratio=1.0
                    )

                    # Calculate loss
                    loss = 0
                    all_predictions = []

                    for t in range(self.data["OUTPUT_MAX_LENGTH"]):
                        decoder_output = decoder_outputs[:, t, :]
                        target_t = target[:, t]
                        loss += self.criterion(decoder_output, target_t)

                        # Get predicted character
                        _, topi = decoder_output.topk(1)
                        prediction = topi.squeeze().detach()
                        all_predictions.append(prediction)

                    # Combine all predictions
                    predictions = torch.stack(all_predictions, dim=1)  # batch_size x seq_len

                    # Calculate accuracy (exact match)
                    correct = (predictions == target).all(dim=1).sum().item()

                    # Update totals
                    total_correct += correct
                    total_examples += batch_size
                    epoch_loss += loss.item() / self.data["OUTPUT_MAX_LENGTH"]

                except Exception as e:
                    print(f"Error in evaluation batch {batch_idx}: {e}")
                    continue

                # Free up memory
                del source, target

                # Run garbage collection periodically
                if batch_idx % 10 == 0:
                    gc.collect()
                    torch.cuda.empty_cache()

        if total_examples == 0:
            return 0, 0  # Avoid division by zero if all batches were skipped

        return total_correct / total_examples, epoch_loss / len(data_loader)

    def train(self, test_source=None, test_target=None):
        """Train the model for specified number of epochs"""
        # WandB initialization
        run_name = f"{self.h_params['cell_type']}_{self.h_params['optimizer']}_attn_layers{self.h_params['num_layers']}"
        try:
            wandb.init(project="Tamil-Transliteration-Attention", name=run_name, config=self.h_params)
        except Exception as e:
            print(f"Error initializing wandb: {e}")
            print("Continuing without wandb tracking...")

        for epoch in range(self.h_params["epochs"]):
            print(f"\nEpoch {epoch+1}/{self.h_params['epochs']}")
            print("-" * 30)

            # Train for one epoch
            train_loss, train_acc = self.train_epoch()

            # Evaluate on validation set
            val_acc, val_loss = self.evaluate(self.val_loader)

            # Update learning rate
            self.encoder_scheduler.step(val_loss)
            self.decoder_scheduler.step(val_loss)

            # Log metrics
            print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
            print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            print(f"Learning Rate: {self.encoder_optimizer.param_groups[0]['lr']:.6f}")

            try:
                wandb.log({
                    "epoch": epoch + 1,
                    "train_loss": train_loss,
                    "train_accuracy": train_acc,
                    "val_loss": val_loss,
                    "val_accuracy": val_acc,
                    "learning_rate": self.encoder_optimizer.param_groups[0]['lr']
                })
            except:
                pass

            # Save best model
            if val_acc > self.best_val_accuracy:
                self.best_val_accuracy = val_acc
                self.save_model(f'best_model_{run_name}.pt', epoch, val_acc)

            # Run garbage collection between epochs
            gc.collect()
            torch.cuda.empty_cache()

        # Test on test set if provided
        if test_source is not None and test_target is not None:
            self.test(test_source, test_target)

        try:
            wandb.finish()
        except:
            pass

    def save_model(self, filename, epoch, val_accuracy):
        """Save model checkpoint"""
        try:
            torch.save({
                'epoch': epoch,
                'encoder_state_dict': self.model.encoder.state_dict(),
                'decoder_state_dict': self.model.decoder.state_dict(),
                'encoder_optimizer': self.encoder_optimizer.state_dict(),
                'decoder_optimizer': self.decoder_optimizer.state_dict(),
                'val_accuracy': val_accuracy,
            }, filename)
            print(f"Model saved with validation accuracy: {val_accuracy:.4f}")
        except Exception as e:
            print(f"Error saving model: {e}")

    def test(self, test_source, test_target):
        """Test model on test set"""
        try:
            print("Testing model on test set...")

            # Create DataProcessor for test data
            processor = DataProcessor()
            # Copy data dictionary
            processor.data = copy.deepcopy(self.data)

            # Process test data
            test_source_seq, test_target_seq = processor.process_validation(test_source, test_target)
            test_dataset = TransliterationDataset(test_source_seq, test_target_seq)
            test_loader = DataLoader(
                test_dataset,
                batch_size=self.h_params["batch_size"],
                shuffle=False,
                num_workers=0,
                pin_memory=False
            )

            # Evaluate on test set
            test_acc, test_loss = self.evaluate(test_loader)
            print(f"Test Accuracy: {test_acc:.4f} | Test Loss: {test_loss:.4f}")

            # Log test metrics to wandb
            try:
                wandb.log({
                    "test_loss": test_loss,
                    "test_accuracy": test_acc
                })
            except:
                pass

            # Generate predictions and compare with targets
            evaluator = Evaluator(self.model, processor)

            # Generate predictions with both greedy and beam search
            print("Generating test predictions...")
            csv_path, greedy_results, greedy_accuracy = evaluator.generate_predictions_csv(
                test_source, test_target, use_beam_search=False
            )
            print("Generating beam search predictions...")
            beam_csv_path, beam_results, beam_accuracy = evaluator.generate_predictions_csv(
                test_source, test_target, use_beam_search=True, beam_width=5
            )

            # Compare greedy vs beam search
            print(f"Greedy Search Accuracy: {greedy_accuracy:.4f}")
            print(f"Beam Search Accuracy: {beam_accuracy:.4f}")

            # Generate attention visualizations
            print("Generating attention visualizations...")
            evaluator.visualize_attention(test_source, test_target, num_samples=9)

            # Log final metrics to wandb
            try:
                wandb.log({
                    "final_test_accuracy": test_acc,
                    "final_test_loss": test_loss,
                    "greedy_accuracy": greedy_accuracy,
                    "beam_accuracy": beam_accuracy
                })

                # Create a wandb Table for sample predictions
                prediction_table = wandb.Table(columns=["Input", "Target", "Greedy Prediction", "Beam Prediction", "Greedy Correct", "Beam Correct"])

                # Include both greedy and beam search results
                combined_results = []
                for i in range(min(100, len(greedy_results))):
                    if i < len(beam_results):
                        input_str = greedy_results[i]['Input']
                        target_str = greedy_results[i]['Target']
                        greedy_pred = greedy_results[i]['Prediction']
                        beam_pred = beam_results[i]['Prediction']
                        greedy_correct = greedy_pred == target_str
                        beam_correct = beam_pred == target_str

                        prediction_table.add_data(
                            input_str, target_str,
                            greedy_pred, beam_pred,
                            greedy_correct, beam_correct
                        )

                        combined_results.append({
                            'Input': input_str,
                            'Target': target_str,
                            'Greedy': greedy_pred,
                            'Beam': beam_pred
                        })

                wandb.log({"prediction_samples": prediction_table})

                # Log attention visualizations as images
                evaluator.log_attention_visualizations_to_wandb(num_samples=4)

            except Exception as e:
                print(f"Warning: Could not log to wandb: {e}")

            # Display sample predictions
            evaluator.compare_decoding_methods(test_source, test_target, num_samples=5)

        except Exception as e:
            print(f"Error in testing: {e}")
            traceback.print_exc()

class Evaluator:
    """Class for evaluating transliteration models with attention"""
    def __init__(self, model, processor):
        self.model = model
        self.processor = processor
        self.data = processor.data
        self.visualizer = AttentionVisualizer()

    def predict(self, input_string, use_beam_search=False, beam_width=5):
        """Make prediction for a single input string"""
        return self.model.predict(input_string, self.processor, use_beam_search, beam_width)

    def generate_predictions_csv(self, test_source, test_target, use_beam_search=False, beam_width=5):
        """Generate predictions for all test data and save to CSV"""
        print(f"Generating predictions for {len(test_source)} test examples...")

        results = []
        attention_matrices = []

        # Generate predictions
        for i in range(len(test_source)):
            if i % 1000 == 0:
                print(f"Processing test example {i+1}/{len(test_source)}")

            input_str = test_source[i]
            target_str = test_target[i]

            # Get prediction with attention weights
            if use_beam_search:
                result = self.predict(input_str, use_beam_search=True, beam_width=beam_width)
                pred_str, attn_matrix = result
            else:
                pred_str, attn_matrix = self.predict(input_str, use_beam_search=False)

            # Store result
            results.append({
                'Input': input_str,
                'Prediction': pred_str,
                'Target': target_str
            })

            # Store attention matrix
            attention_matrices.append(attn_matrix)

        # Calculate accuracy
        correct = sum(1 for r in results if r['Prediction'] == r['Target'])
        accuracy = correct / len(results) if results else 0

        method = "beam_search" if use_beam_search else "greedy"
        print(f"{method.capitalize()} Accuracy: {accuracy:.4f} ({correct}/{len(results)})")

        # Add accuracy to wandb
        try:
            wandb.log({f"{method}_accuracy": accuracy})
        except:
            pass

        # Save to CSV
        csv_path = f'tamil_transliteration_predictions_{method}.csv'
        with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile:
            fieldnames = ['Input', 'Prediction', 'Target']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

            writer.writeheader()
            for row in results:
                writer.writerow(row)

        print(f"Saved predictions to {csv_path}")

        # Store attention matrices for later use
        self.last_results = results
        self.last_attention_matrices = attention_matrices

        return csv_path, results, accuracy

    def visualize_attention(self, test_source, test_target, num_samples=9):
        """Visualize attention weights for sample predictions"""
        # Select random samples
        indices = np.random.choice(len(test_source), min(num_samples, len(test_source)), replace=False)

        input_strings = []
        target_strings = []
        predictions = []
        attention_matrices = []

        for idx in indices:
            input_str = test_source[idx]
            target_str = test_target[idx]

            # Get prediction with attention
            pred_str, attn_matrix = self.predict(input_str)

            input_strings.append(input_str)
            target_strings.append(target_str)
            predictions.append(pred_str)
            attention_matrices.append(attn_matrix)

        # Create HTML grid
        html = self.visualizer.create_attention_grid_html(
            input_strings, target_strings, predictions, attention_matrices
        )

        # Save HTML file
        with open("attention_visualizations.html", "w", encoding="utf-8") as f:
            f.write(html)

        # Display in notebook if possible
        try:
            display(HTML(html))
            print("Attention visualizations displayed in notebook and saved to attention_visualizations.html")
        except:
            print("Attention visualizations saved to attention_visualizations.html")

        # Also plot as matplotlib figures
        fig = self.visualizer.plot_attention_grid(
            input_strings, target_strings, predictions, attention_matrices, n_samples=4
        )

        # Save figure
        try:
            fig.savefig("attention_grid.png", dpi=150, bbox_inches='tight')
            print("Attention grid saved to attention_grid.png")
        except Exception as e:
            print(f"Error saving attention grid: {e}")

        # Save results for later use
        self.vis_input_strings = input_strings
        self.vis_target_strings = target_strings
        self.vis_predictions = predictions
        self.vis_attention_matrices = attention_matrices

        return html

    def compare_decoding_methods(self, test_source, test_target, num_samples=5):
        """Compare greedy and beam search decoding"""
        # Select random samples
        indices = np.random.choice(len(test_source), min(num_samples, len(test_source)), replace=False)

        comparison_results = []

        for idx in indices:
            input_str = test_source[idx]
            target_str = test_target[idx]

            # Get predictions using both methods
            greedy_pred, _ = self.predict(input_str, use_beam_search=False)
            beam_result = self.predict(input_str, use_beam_search=True, beam_width=5)
            beam_pred = beam_result[0] if isinstance(beam_result, tuple) else beam_result

            comparison_results.append({
                'Input': input_str,
                'Target': target_str,
                'Greedy': greedy_pred,
                'Beam': beam_pred
            })

        # Create HTML table
        html = """
        <html>
        <head>
            <meta charset="UTF-8">
            <style>
                @import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Tamil:wght@400;700&display=swap');
                body {
                    font-family: 'Noto Sans Tamil', sans-serif;
                    font-size: 16px;
                    line-height: 1.6;
                }
                table {
                    border-collapse: collapse;
                    width: 100%;
                    margin: 20px 0;
                }
                th, td {
                    padding: 12px;
                    text-align: left;
                    border: 1px solid #ddd;
                }
                th {
                    background-color: #f2f2f2;
                    font-weight: bold;
                }
                tr:nth-child(even) {
                    background-color: #f9f9f9;
                }
                .correct { color: green; }
                .incorrect { color: red; }
            </style>
        </head>
        <body>
            <h2>Greedy vs Beam Search Decoding Comparison</h2>
            <table>
                <tr>
                    <th>Input</th>
                    <th>Target</th>
                    <th>Greedy Prediction</th>
                    <th>Beam Search Prediction</th>
                </tr>
        """

        for result in comparison_results:
            greedy_correct = result['Greedy'] == result['Target']
            beam_correct = result['Beam'] == result['Target']

            greedy_class = "correct" if greedy_correct else "incorrect"
            beam_class = "correct" if beam_correct else "incorrect"

            html += f"""
            <tr>
                <td>{result['Input']}</td>
                <td>{result['Target']}</td>
                <td class="{greedy_class}">{result['Greedy']} {"✓" if greedy_correct else "✗"}</td>
                <td class="{beam_class}">{result['Beam']} {"✓" if beam_correct else "✗"}</td>
            </tr>
            """

        html += """
            </table>
        </body>
        </html>
        """

        # Display in notebook if possible
        try:
            display(HTML(html))
        except:
            # Fallback
            print("HTML display not available. Showing text comparison:")
            for result in comparison_results:
                print(f"Input: {result['Input']}")
                print(f"Target: {result['Target']}")
                print(f"Greedy: {result['Greedy']} {'✓' if result['Greedy'] == result['Target'] else '✗'}")
                print(f"Beam: {result['Beam']} {'✓' if result['Beam'] == result['Target'] else '✗'}")
                print("-" * 50)

        return comparison_results

    def log_attention_visualizations_to_wandb(self, num_samples=4):
        """Log attention visualizations to wandb"""
        try:
            if not hasattr(self, 'vis_input_strings'):
                # No visualizations generated yet
                return

            # Create matplotlib figures
            fig = self.visualizer.plot_attention_grid(
                self.vis_input_strings[:num_samples],
                self.vis_target_strings[:num_samples],
                self.vis_predictions[:num_samples],
                self.vis_attention_matrices[:num_samples],
                n_samples=num_samples
            )

            # Log to wandb
            wandb.log({"attention_grid": wandb.Image(fig)})

            # Create individual attention matrices
            for i in range(min(num_samples, len(self.vis_input_strings))):
                if self.vis_attention_matrices[i].size > 0:
                    input_chars = [c for c in self.vis_input_strings[i] if c not in [PAD_TOKEN]]
                    output_chars = [c for c in self.vis_predictions[i] if c not in [PAD_TOKEN, START_TOKEN, END_TOKEN]]

                    # Create individual attention visualization
                    fig_i = self.visualizer.plot_attention_matrix(
                        input_chars, output_chars, self.vis_attention_matrices[i],
                        f"Sample {i+1}: {self.vis_input_strings[i]} → {self.vis_predictions[i]}"
                    )

                    if fig_i:
                        wandb.log({f"attention_matrix_{i+1}": wandb.Image(fig_i)})

            print("Attention visualizations logged to wandb")
        except Exception as e:
            print(f"Error logging attention visualizations to wandb: {e}")

In [None]:
class HyperparameterTuner:
    """Class for tuning hyperparameters using wandb sweeps"""
    def __init__(self, data_manager):
        self.data_manager = data_manager

    def get_sweep_config(self):
        """Define hyperparameter sweep configuration"""
        sweep_config = {
            'method': 'bayes',
            'name': 'tamil-transliteration-attention-sweep',
            'metric': {
                'goal': 'maximize',
                'name': 'val_accuracy'
            },
            'parameters': {
                'learning_rate': {
                    'values': [0.001, 0.0005, 0.0001]
                },
                'batch_size': {
                    'values': [32, 64, 128, 256]
                },
                'char_embed_dim': {
                    'values': [64, 128, 256]
                },
                'hidden_size': {
                    'values': [128, 256, 512]
                },
                'num_layers': {
                    'values': [1, 2, 3, 4]
                },
                'cell_type': {
                    'values': [RNNType.RNN, RNNType.LSTM, RNNType.GRU]
                },
                'dropout': {
                    'values': [0.0, 0.1, 0.2, 0.3]
                },
                'optimizer': {
                    'values': ['adam', 'nadam']
                },
                'epochs': {
                    'values': [10, 15, 20]
                },
                'teacher_forcing_ratio': {
                    'values': [0.3, 0.5, 0.7]
                }
            }
        }
        return sweep_config

    def sweep_agent(self):
        """Sweep agent function"""
        wandb.init()

        # Access sweep config
        config = wandb.config

        # Define hyperparameters
        h_params = {
            "char_embed_dim": config.char_embed_dim,
            "hidden_size": config.hidden_size,
            "batch_size": config.batch_size,
            "num_layers": config.num_layers,
            "learning_rate": config.learning_rate,
            "epochs": config.epochs,
            "cell_type": config.cell_type,
            "dropout": config.dropout,
            "optimizer": config.optimizer,
            "teacher_forcing_ratio": config.teacher_forcing_ratio
        }

        # Load data
        try:
            # Process data
            data = self.data_manager.load_all_data()

            # Create dataloaders
            train_loader, val_loader = self.data_manager.create_dataloaders(h_params)

            # Create model
            model = Seq2SeqAttentionModel(h_params, data)

            # Train model
            trainer = Trainer(model, h_params, data, train_loader, val_loader)
            trainer.train(self.data_manager.test_source, self.data_manager.test_target)

            # Close wandb run
            wandb.finish()

        except Exception as e:
            print(f"Error in sweep agent: {e}")
            traceback.print_exc()
            wandb.finish()

    def run_sweep(self, count=10):
        """Run hyperparameter sweep"""
        sweep_config = self.get_sweep_config()
        sweep_id = wandb.sweep(sweep=sweep_config, project="Tamil-Transliteration-Attention")
        wandb.agent(sweep_id, function=self.sweep_agent, count=count)

In [None]:
class Experiment:
    """Class for running experiments"""
    def __init__(self):
        # Default hyperparameters
        self.default_hyperparams = {
            "char_embed_dim": 256,
            "hidden_size": 256,
            "batch_size": 256,
            "num_layers": 2,
            "learning_rate": 0.001,
            "epochs": 15,
            "cell_type": RNNType.GRU,
            "dropout": 0.2,
            "optimizer": "nadam",
            "teacher_forcing_ratio": 0.5
        }

    def run(self, use_sweep=False, num_sweep_runs=10, use_beam_search=True):
        """Run experiment"""
        try:
            # Initial CUDA memory reset
            torch.cuda.empty_cache()
            gc.collect()

            # Create data manager
            data_manager = DataManager(DATA_PATHS)

            if use_sweep:
                # Run hyperparameter sweep
                tuner = HyperparameterTuner(data_manager)
                tuner.run_sweep(count=num_sweep_runs)
            else:
                # Load data
                data = data_manager.load_all_data()

                # Create dataloaders
                train_loader, val_loader = data_manager.create_dataloaders(self.default_hyperparams)

                # Create model
                model = Seq2SeqAttentionModel(self.default_hyperparams, data)

                # Initialize wandb
                run_name = f"{self.default_hyperparams['cell_type']}_{self.default_hyperparams['optimizer']}_attn_layers{self.default_hyperparams['num_layers']}"
                try:
                    wandb.init(project="Tamil-Transliteration-Attention", name=run_name, config=self.default_hyperparams)
                except Exception as e:
                    print(f"Error initializing wandb: {e}")

                # Train model
                print(f"Starting training with {self.default_hyperparams['cell_type']} cell, {self.default_hyperparams['optimizer']} optimizer (With Attention)")
                trainer = Trainer(model, self.default_hyperparams, data, train_loader, val_loader)
                trainer.train(data_manager.test_source, data_manager.test_target)

                # Test comparison between greedy and beam search if requested
                if use_beam_search and model is not None:
                    print("\nComparing greedy vs beam search decoding...")
                    processor = DataProcessor()
                    processor.data = data
                    evaluator = Evaluator(model, processor)

                    # Compare greedy vs beam search results
                    evaluator.compare_decoding_methods(data_manager.test_source, data_manager.test_target, num_samples=10)

                    # Visualize attention patterns
                    print("\nVisualizing attention patterns...")
                    evaluator.visualize_attention(data_manager.test_source, data_manager.test_target, num_samples=6)

                # Close wandb run
                try:
                    wandb.finish()
                except:
                    pass

        except Exception as e:
            print(f"Error in experiment: {e}")
            traceback.print_exc()

# Run experiment
def main():
    experiment = Experiment()

    # Set use_sweep=True to run hyperparameter sweep, False to run with default hyperparameters
    # Set use_beam_search=True to test beam search decoding
    experiment.run(use_sweep=False, use_beam_search=True)

    # Uncomment to run sweep with 20 runs
    # experiment.run(use_sweep=True, num_sweep_runs=20)

if __name__ == "__main__":
    main()