In [None]:
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import random
import gc
import wandb
import traceback
import os
import io
import csv
import copy
import heapq
from IPython.display import HTML, display

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- WANDB SETUP ---
WANDB_API_KEY = "c4c7a78b7e8600d02ded519f43e6ef09838dc431"
wandb.login(key=WANDB_API_KEY)

# Special tokens
START_TOKEN = '<'
END_TOKEN = '>'
PAD_TOKEN = '_'

# Teacher forcing ratio
TEACHER_FORCING_RATIO = 0.5

# Data paths for Tamil
DATA_PATHS = {
    "train": "/kaggle/input/aksharantar-sampled/aksharantar_sampled/tam/tam_train.csv",
    "test": "/kaggle/input/aksharantar-sampled/aksharantar_sampled/tam/tam_test.csv",
    "valid": "/kaggle/input/aksharantar-sampled/aksharantar_sampled/tam/tam_valid.csv"
}

In [None]:
class DataProcessor:
    """Class for handling all data processing operations"""

    def __init__(self):
        """Initialize data processor with empty dictionaries and special tokens"""
        self.data = {
            "source_chars": [START_TOKEN, END_TOKEN, PAD_TOKEN],
            "target_chars": [START_TOKEN, END_TOKEN, PAD_TOKEN],
            "source_char_index": {START_TOKEN: 0, END_TOKEN: 1, PAD_TOKEN: 2},
            "source_index_char": {0: START_TOKEN, 1: END_TOKEN, 2: PAD_TOKEN},
            "target_char_index": {START_TOKEN: 0, END_TOKEN: 1, PAD_TOKEN: 2},
            "target_index_char": {0: START_TOKEN, 1: END_TOKEN, 2: PAD_TOKEN},
            "source_len": 3,
            "target_len": 3
        }

    def load_data(self, data_paths):
        """Load data from CSV files"""
        data_frames = {}
        data_pairs = {}

        for split, path in data_paths.items():
            df = pd.read_csv(path, header=None)
            data_frames[split] = df
            data_pairs[split] = (df[0].to_numpy(), df[1].to_numpy())
            print(f"Loaded {split} data: {len(df)} examples")

        return data_frames, data_pairs

    def add_padding(self, sequences, max_length):
        """Add padding to sequences"""
        padded_strings = []
        for seq in sequences:
            # Add start and end tokens
            padded_seq = START_TOKEN + seq + END_TOKEN
            # Truncate or pad
            padded_seq = padded_seq[:max_length]
            padded_seq += PAD_TOKEN * (max_length - len(padded_seq))
            padded_strings.append(padded_seq)
        return padded_strings

    def chars_to_indices(self, string, char_index_dict):
        """Convert characters to their indices"""
        char_indices = []
        for char in string:
            # Handle OOV characters by using PAD token index
            if char in char_index_dict:
                char_indices.append(char_index_dict[char])
            else:
                char_indices.append(char_index_dict[PAD_TOKEN])
        return torch.tensor(char_indices, dtype=torch.long, device=device)

    def generate_sequence_from_string(self, strings, char_index_dict):
        """Convert strings to sequences of indices"""
        sequences = []
        for string in strings:
            # Convert characters to indices
            sequences.append(self.chars_to_indices(string, char_index_dict))

        # Pad sequences to the same length
        sequences = pad_sequence(sequences, batch_first=True, padding_value=char_index_dict[PAD_TOKEN])
        return sequences

    def update_char_dictionaries(self, padded_source, padded_target):
        """Update character dictionaries with new characters"""
        for i in range(len(padded_source)):
            for c in padded_source[i]:
                if c not in self.data["source_char_index"]:
                    self.data["source_chars"].append(c)
                    idx = len(self.data["source_chars"]) - 1
                    self.data["source_char_index"][c] = idx
                    self.data["source_index_char"][idx] = c

            for c in padded_target[i]:
                if c not in self.data["target_char_index"]:
                    self.data["target_chars"].append(c)
                    idx = len(self.data["target_chars"]) - 1
                    self.data["target_char_index"][c] = idx
                    self.data["target_index_char"][idx] = c

    def process_data(self, source_data, target_data):
        """Process source and target data"""
        # Store original data
        self.data["source_data"] = source_data
        self.data["target_data"] = target_data

        # Calculate max lengths
        self.data["INPUT_MAX_LENGTH"] = max(len(s) for s in source_data) + 2  # +2 for START and END tokens
        self.data["OUTPUT_MAX_LENGTH"] = max(len(t) for t in target_data) + 2

        print(f"Input max length: {self.data['INPUT_MAX_LENGTH']}")
        print(f"Output max length: {self.data['OUTPUT_MAX_LENGTH']}")

        # Add padding
        padded_source = self.add_padding(source_data, self.data["INPUT_MAX_LENGTH"])
        padded_target = self.add_padding(target_data, self.data["OUTPUT_MAX_LENGTH"])

        # Update character dictionaries
        self.update_char_dictionaries(padded_source, padded_target)

        # Generate sequences
        self.data["source_data_seq"] = self.generate_sequence_from_string(padded_source, self.data["source_char_index"])
        self.data["target_data_seq"] = self.generate_sequence_from_string(padded_target, self.data["target_char_index"])

        # Update lengths
        self.data["source_len"] = len(self.data["source_chars"])
        self.data["target_len"] = len(self.data["target_chars"])

        print(f"Source vocabulary size: {self.data['source_len']}")
        print(f"Target vocabulary size: {self.data['target_len']}")

        return self.data

    def process_validation(self, val_source, val_target):
        """Process validation data using existing character maps"""
        # Add padding
        padded_val_source = self.add_padding(val_source, self.data["INPUT_MAX_LENGTH"])
        padded_val_target = self.add_padding(val_target, self.data["OUTPUT_MAX_LENGTH"])

        # Generate sequences
        val_source_seq = self.generate_sequence_from_string(padded_val_source, self.data["source_char_index"])
        val_target_seq = self.generate_sequence_from_string(padded_val_target, self.data["target_char_index"])

        return val_source_seq, val_target_seq

    def indices_to_string(self, indices, index_char_dict):
        """Convert indices to a string"""
        string = ""
        for idx in indices:
            if isinstance(idx, torch.Tensor):
                idx = idx.item()
            if idx in index_char_dict:
                char = index_char_dict[idx]
                if char not in [PAD_TOKEN]:
                    string += char
        return string.replace(START_TOKEN, "").replace(END_TOKEN, "")

    def prepare_input_for_prediction(self, input_string):
        """Prepare an input string for prediction"""
        padded_input = START_TOKEN + input_string + END_TOKEN
        padded_input = padded_input[:self.data["INPUT_MAX_LENGTH"]]
        padded_input += PAD_TOKEN * (self.data["INPUT_MAX_LENGTH"] - len(padded_input))

        # Convert to indices and create tensor
        input_indices = [self.data["source_char_index"].get(c, self.data["source_char_index"][PAD_TOKEN])
                        for c in padded_input]
        input_tensor = torch.tensor(input_indices, device=device).unsqueeze(0)  # Add batch dimension

        return input_tensor

class TransliterationDataset(Dataset):
    """Dataset class for transliteration data"""
    def __init__(self, source_seq, target_seq):
        self.source_seq = source_seq
        self.target_seq = target_seq

    def __len__(self):
        return len(self.source_seq)

    def __getitem__(self, idx):
        return self.source_seq[idx], self.target_seq[idx]

class DataManager:
    """High-level manager for data operations"""
    def __init__(self, data_paths):
        self.data_paths = data_paths
        self.processor = DataProcessor()

    def load_all_data(self):
        """Load all data splits"""
        _, data_pairs = self.processor.load_data(self.data_paths)
        self.train_source, self.train_target = data_pairs["train"]
        self.val_source, self.val_target = data_pairs["valid"]
        self.test_source, self.test_target = data_pairs["test"]

        # Process training data
        self.data = self.processor.process_data(self.train_source, self.train_target)

        return self.data

    def create_dataloaders(self, h_params):
        """Create DataLoaders for training and validation"""
        # Training data
        train_dataset = TransliterationDataset(self.data["source_data_seq"], self.data["target_data_seq"])

        # Process validation data
        val_source_seq, val_target_seq = self.processor.process_validation(self.val_source, self.val_target)
        val_dataset = TransliterationDataset(val_source_seq, val_target_seq)

        # Create DataLoaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=h_params["batch_size"],
            shuffle=True,
            num_workers=0,
            pin_memory=False
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=h_params["batch_size"],
            shuffle=False,
            num_workers=0,
            pin_memory=False
        )

        return train_loader, val_loader

    def create_test_dataloader(self, h_params):
        """Create DataLoader for test data"""
        test_source_seq, test_target_seq = self.processor.process_validation(self.test_source, self.test_target)
        test_dataset = TransliterationDataset(test_source_seq, test_target_seq)

        test_loader = DataLoader(
            test_dataset,
            batch_size=h_params["batch_size"],
            shuffle=False,
            num_workers=0,
            pin_memory=False
        )

        return test_loader

# Class for beam search decoding
class BeamSearchNode:
    """Node in beam search"""
    def __init__(self, hidden_state, previous_node, word_id, log_prob, length):
        self.hidden_state = hidden_state
        self.previous_node = previous_node
        self.word_id = word_id
        self.log_prob = log_prob
        self.length = length

    def eval(self, alpha=1.0):
        """Evaluate this node"""
        # Normalize score by length (to avoid penalizing longer sequences)
        return self.log_prob / float(self.length - 1 + 1e-6) ** alpha

def get_chars_string_from_nodes(nodes, target_index_char):
    """Convert nodes to strings"""
    chars = []
    for n in nodes:
        if n.word_id != 0 and n.word_id != 1 and n.word_id != 2:  # Exclude special tokens
            chars.append(target_index_char[n.word_id])
    return ''.join(chars)

In [None]:
class RNNType:
    """Enum-like class for RNN cell types"""
    RNN = "RNN"
    LSTM = "LSTM"
    GRU = "GRU"

    @staticmethod
    def get_cell(cell_type):
        """Get the appropriate RNN cell type"""
        if cell_type == RNNType.RNN:
            return nn.RNN
        elif cell_type == RNNType.LSTM:
            return nn.LSTM
        elif cell_type == RNNType.GRU:
            return nn.GRU
        else:
            raise ValueError(f"Unsupported cell type: {cell_type}")

class Encoder(nn.Module):
    """Encoder model for sequence-to-sequence learning"""
    def __init__(self, h_params, data):
        super(Encoder, self).__init__()
        self.h_params = h_params

        # Embedding layer
        self.embedding = nn.Embedding(data["source_len"], h_params["char_embed_dim"])

        # RNN layer
        self.rnn_cell = RNNType.get_cell(h_params["cell_type"])(
            input_size=h_params["char_embed_dim"],
            hidden_size=h_params["hidden_size"],
            num_layers=h_params["num_layers"],
            dropout=h_params["dropout"] if h_params["num_layers"] > 1 else 0,
            batch_first=True
        )

        # Dropout
        self.dropout = nn.Dropout(h_params["dropout"])

    def forward(self, x, hidden=None):
        """Forward pass"""
        # Embed input
        embedded = self.dropout(self.embedding(x))

        # Pass through RNN
        output, hidden = self.rnn_cell(embedded, hidden)

        return output, hidden

    def init_hidden(self, batch_size):
        """Initialize hidden state"""
        h = torch.zeros(self.h_params["num_layers"], batch_size,
                      self.h_params["hidden_size"], device=device)

        if self.h_params["cell_type"] == RNNType.LSTM:
            c = torch.zeros(self.h_params["num_layers"], batch_size,
                          self.h_params["hidden_size"], device=device)
            return (h, c)
        else:
            return h

class Decoder(nn.Module):
    """Decoder model for sequence-to-sequence learning"""
    def __init__(self, h_params, data):
        super(Decoder, self).__init__()
        self.h_params = h_params
        self.data = data

        # Embedding layer
        self.embedding = nn.Embedding(data["target_len"], h_params["char_embed_dim"])

        # RNN layer
        self.rnn_cell = RNNType.get_cell(h_params["cell_type"])(
            input_size=h_params["char_embed_dim"],
            hidden_size=h_params["hidden_size"],
            num_layers=h_params["num_layers"],
            dropout=h_params["dropout"] if h_params["num_layers"] > 1 else 0,
            batch_first=True
        )

        # Output layer
        self.fc = nn.Linear(h_params["hidden_size"], data["target_len"])

        # Dropout
        self.dropout = nn.Dropout(h_params["dropout"])

    def forward(self, x, hidden):
        """Forward pass"""
        # Handle different input dimensions
        if x.dim() == 0:  # scalar
            x = x.unsqueeze(0).unsqueeze(0)  # (1, 1)
        elif x.dim() == 1:  # (batch,)
            x = x.unsqueeze(1)  # (batch, 1)

        # Embed input
        embedded = self.dropout(self.embedding(x))
        embedded = F.relu(embedded)

        # Pass through RNN
        output, hidden = self.rnn_cell(embedded, hidden)

        # Get prediction
        prediction = self.fc(output)

        return F.log_softmax(prediction, dim=2), hidden

    def greedy_decode(self, encoder_outputs, encoder_hidden, processor, max_length=None):
        """Greedy decoding method"""
        # Prepare for decoding
        decoder_input = torch.tensor([self.data["target_char_index"][START_TOKEN]], device=device)
        decoder_hidden = encoder_hidden

        # Storage for output
        output_tokens = []

        # Set max decoding length
        if max_length is None:
            max_length = self.data["OUTPUT_MAX_LENGTH"]

        # Greedy decoding
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self(decoder_input, decoder_hidden)

            # Get best prediction
            _, topi = decoder_output.data.topk(1)
            predicted_idx = topi.squeeze().item()

            # Stop if END token
            if predicted_idx == self.data["target_char_index"][END_TOKEN]:
                break

            # Add to outputs if it's a valid character
            if predicted_idx in self.data["target_index_char"]:
                output_tokens.append(self.data["target_index_char"][predicted_idx])

            # Next input is predicted token
            decoder_input = topi.squeeze().detach()

        # Combine into output string
        return ''.join(output_tokens)

    def beam_search_decode(self, encoder_outputs, encoder_hidden, processor, beam_width=5, max_length=None):
        """Beam search decoding method"""
        if max_length is None:
            max_length = self.data["OUTPUT_MAX_LENGTH"]

        # Start with START_TOKEN
        decoder_input = torch.tensor([self.data["target_char_index"][START_TOKEN]], device=device)
        decoder_hidden = encoder_hidden

        # Number of sentences to generate
        end_nodes = []

        # Starting node
        node = BeamSearchNode(decoder_hidden, None, decoder_input, 0, 1)
        nodes = [(-node.eval(), node)]  # Priority queue (lowest score first)

        # Start beam search
        for step in range(max_length):
            # Give up when decoding takes too long
            if len(nodes) == 0:
                break

            # Fetch the best node
            score, current_node = heapq.heappop(nodes)
            decoder_input = current_node.word_id
            decoder_hidden = current_node.hidden_state

            # If we reached the EOS token
            if decoder_input.item() == self.data["target_char_index"][END_TOKEN] and current_node.previous_node is not None:
                end_nodes.append((score, current_node))
                # If we have enough end nodes, stop
                if len(end_nodes) >= beam_width:
                    break
                continue

            # Decode for one step
            decoder_output, decoder_hidden = self(decoder_input, decoder_hidden)

            # Get top-k tokens
            log_probs, indexes = torch.topk(decoder_output.squeeze(), beam_width)

            # Put them into the queue
            for new_k in range(beam_width):
                decoded_t = indexes[0][new_k].view(1)
                log_p = log_probs[0][new_k].item()

                # Create new node
                node = BeamSearchNode(decoder_hidden, current_node, decoded_t, current_node.log_prob + log_p, current_node.length + 1)
                heapq.heappush(nodes, (-node.eval(), node))

        # If we don't have any end_nodes, get the top-k nodes from the queue
        if len(end_nodes) == 0:
            end_nodes = [heapq.heappop(nodes) for _ in range(min(beam_width, len(nodes)))]

        # Get the best sequence
        best_node = sorted(end_nodes, key=lambda x: x[0])[0][1]

        # Traverse back to get the sequence
        sequence = []
        current = best_node
        while current.previous_node is not None:
            sequence.append(current.word_id.item())
            current = current.previous_node

        # Reverse the sequence and convert to string
        sequence = sequence[::-1]
        result = ""
        for idx in sequence:
            if idx in self.data["target_index_char"] and idx != self.data["target_char_index"][END_TOKEN]:
                result += self.data["target_index_char"][idx]

        return result

class Seq2SeqModel:
    """Wrapper for encoder-decoder architecture"""
    def __init__(self, h_params, data):
        self.encoder = Encoder(h_params, data).to(device)
        self.decoder = Decoder(h_params, data).to(device)
        self.h_params = h_params
        self.data = data

    def forward(self, source, target=None, teacher_forcing_ratio=0.5):
        """Forward pass through the model"""
        batch_size = source.size(0)

        # Initialize encoder hidden state
        encoder_hidden = self.encoder.init_hidden(batch_size)

        # Encoder forward pass
        encoder_outputs, encoder_hidden = self.encoder(source, encoder_hidden)

        # Return early if no target (for inference)
        if target is None:
            return encoder_outputs, encoder_hidden

        # Prepare decoder input and hidden state
        decoder_input = torch.tensor([self.data["target_char_index"][START_TOKEN]] * batch_size, device=device)
        decoder_hidden = encoder_hidden

        # Determine if using teacher forcing
        use_teacher_forcing = random.random() < teacher_forcing_ratio

        # Storage for decoder outputs
        decoder_outputs = []

        # Decoder forward pass (one character at a time)
        for t in range(self.data["OUTPUT_MAX_LENGTH"]):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)

            # Determine next input (either from target or model's prediction)
            if use_teacher_forcing and t < self.data["OUTPUT_MAX_LENGTH"] - 1:
                decoder_input = target[:, t + 1]  # Next input is next character in target
            else:
                # Get predicted character
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # Next input is model's prediction

        # Stack decoder outputs
        decoder_outputs = torch.cat(decoder_outputs, dim=1)

        return decoder_outputs, encoder_hidden

    def predict(self, input_string, processor, use_beam_search=False, beam_width=5):
        """Make prediction for a single input string"""
        # Prepare input
        input_tensor = processor.prepare_input_for_prediction(input_string)

        # Initialize encoder hidden state
        encoder_hidden = self.encoder.init_hidden(1)

        # Encoder forward pass
        encoder_outputs, encoder_hidden = self.encoder(input_tensor, encoder_hidden)

        # Decode
        if use_beam_search:
            prediction = self.decoder.beam_search_decode(encoder_outputs, encoder_hidden, processor, beam_width)
        else:
            prediction = self.decoder.greedy_decode(encoder_outputs, encoder_hidden, processor)

        return prediction

In [None]:
class Trainer:
    """Class for training sequence-to-sequence models"""
    def __init__(self, model, h_params, data, train_loader, val_loader):
        self.model = model
        self.h_params = h_params
        self.data = data
        self.train_loader = train_loader
        self.val_loader = val_loader

        # Initialize optimizers
        if h_params["optimizer"].lower() == "adam":
            self.encoder_optimizer = optim.Adam(model.encoder.parameters(), lr=h_params["learning_rate"])
            self.decoder_optimizer = optim.Adam(model.decoder.parameters(), lr=h_params["learning_rate"])
        elif h_params["optimizer"].lower() == "nadam":
            self.encoder_optimizer = optim.NAdam(model.encoder.parameters(), lr=h_params["learning_rate"])
            self.decoder_optimizer = optim.NAdam(model.decoder.parameters(), lr=h_params["learning_rate"])
        else:
            self.encoder_optimizer = optim.SGD(model.encoder.parameters(), lr=h_params["learning_rate"])
            self.decoder_optimizer = optim.SGD(model.decoder.parameters(), lr=h_params["learning_rate"])

        # Loss function
        self.criterion = nn.NLLLoss()

        # Learning rate schedulers
        self.encoder_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.encoder_optimizer, 'min', patience=2)
        self.decoder_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.decoder_optimizer, 'min', patience=2)

        # Best model tracking
        self.best_val_accuracy = 0

    def train_epoch(self):
        """Train for one epoch"""
        self.model.encoder.train()
        self.model.decoder.train()

        epoch_loss = 0
        total_correct = 0
        total_examples = 0

        for batch_idx, (source, target) in enumerate(self.train_loader):
            batch_size = source.size(0)
            if batch_size == 1:  # Skip batches with only one example
                continue

            # Zero gradients
            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()

            try:
                # Forward pass
                decoder_outputs, _ = self.model.forward(source, target, TEACHER_FORCING_RATIO)

                # Calculate loss
                loss = 0
                all_predictions = []

                for t in range(self.data["OUTPUT_MAX_LENGTH"]):
                    decoder_output = decoder_outputs[:, t, :]
                    target_t = target[:, t]
                    loss += self.criterion(decoder_output, target_t)

                    # Get predicted character
                    _, topi = decoder_output.topk(1)
                    prediction = topi.squeeze().detach()
                    all_predictions.append(prediction)

                # Combine all predictions
                predictions = torch.stack(all_predictions, dim=1)  # batch_size x seq_len

                # Calculate accuracy (exact match)
                correct = (predictions == target).all(dim=1).sum().item()

                # Update totals
                total_correct += correct
                total_examples += batch_size

                # Backpropagation
                loss.backward()

                # Clip gradients to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(self.model.encoder.parameters(), 1)
                torch.nn.utils.clip_grad_norm_(self.model.decoder.parameters(), 1)

                # Update parameters
                self.encoder_optimizer.step()
                self.decoder_optimizer.step()

                # Track loss
                epoch_loss += loss.item() / self.data["OUTPUT_MAX_LENGTH"]

            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                traceback.print_exc()
                continue

            # Free up memory
            del source, target, loss
            torch.cuda.empty_cache()

            # Run garbage collection periodically
            if batch_idx % 10 == 0:
                gc.collect()
                torch.cuda.empty_cache()

            # Print progress periodically
            if (batch_idx + 1) % 100 == 0:
                print(f"Batch {batch_idx + 1}/{len(self.train_loader)}, Loss: {epoch_loss/(batch_idx+1):.4f}")

        if total_examples == 0:
            return 0, 0  # Avoid division by zero if all batches were skipped

        return epoch_loss / len(self.train_loader), total_correct / total_examples

    def evaluate(self, data_loader):
        """Evaluate model on validation or test data"""
        self.model.encoder.eval()
        self.model.decoder.eval()

        epoch_loss = 0
        total_correct = 0
        total_examples = 0

        with torch.no_grad():
            for batch_idx, (source, target) in enumerate(data_loader):
                batch_size = source.size(0)
                if batch_size == 1:  # Skip batches with only one example
                    continue

                try:
                    # Forward pass
                    decoder_outputs, _ = self.model.forward(source, target, teacher_forcing_ratio=1.0)

                    # Calculate loss
                    loss = 0
                    all_predictions = []

                    for t in range(self.data["OUTPUT_MAX_LENGTH"]):
                        decoder_output = decoder_outputs[:, t, :]
                        target_t = target[:, t]
                        loss += self.criterion(decoder_output, target_t)

                        # Get predicted character
                        _, topi = decoder_output.topk(1)
                        prediction = topi.squeeze().detach()
                        all_predictions.append(prediction)

                    # Combine all predictions
                    predictions = torch.stack(all_predictions, dim=1)  # batch_size x seq_len

                    # Calculate accuracy (exact match)
                    correct = (predictions == target).all(dim=1).sum().item()

                    # Update totals
                    total_correct += correct
                    total_examples += batch_size
                    epoch_loss += loss.item() / self.data["OUTPUT_MAX_LENGTH"]

                except Exception as e:
                    print(f"Error in evaluation batch {batch_idx}: {e}")
                    continue

                # Free up memory
                del source, target

                # Run garbage collection periodically
                if batch_idx % 10 == 0:
                    gc.collect()
                    torch.cuda.empty_cache()

        if total_examples == 0:
            return 0, 0  # Avoid division by zero if all batches were skipped

        return total_correct / total_examples, epoch_loss / len(data_loader)

    def train(self, test_source=None, test_target=None):
        """Train the model for specified number of epochs"""
        # WandB initialization
        run_name = f"{self.h_params['cell_type']}_{self.h_params['optimizer']}_noattn_layers{self.h_params['num_layers']}"
        try:
            wandb.init(project="Tamil-Transliteration-NoAttn", name=run_name, config=self.h_params)
        except Exception as e:
            print(f"Error initializing wandb: {e}")
            print("Continuing without wandb tracking...")

        for epoch in range(self.h_params["epochs"]):
            print(f"\nEpoch {epoch+1}/{self.h_params['epochs']}")
            print("-" * 30)

            # Train for one epoch
            train_loss, train_acc = self.train_epoch()

            # Evaluate on validation set
            val_acc, val_loss = self.evaluate(self.val_loader)

            # Update learning rate
            self.encoder_scheduler.step(val_loss)
            self.decoder_scheduler.step(val_loss)

            # Log metrics
            print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
            print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            print(f"Learning Rate: {self.encoder_optimizer.param_groups[0]['lr']:.6f}")

            try:
                wandb.log({
                    "epoch": epoch + 1,
                    "train_loss": train_loss,
                    "train_accuracy": train_acc,
                    "val_loss": val_loss,
                    "val_accuracy": val_acc,
                    "learning_rate": self.encoder_optimizer.param_groups[0]['lr']
                })
            except:
                pass

            # Save best model
            if val_acc > self.best_val_accuracy:
                self.best_val_accuracy = val_acc
                self.save_model(f'best_model_{run_name}.pt', epoch, val_acc)

            # Run garbage collection between epochs
            gc.collect()
            torch.cuda.empty_cache()

        # Test on test set if provided
        if test_source is not None and test_target is not None:
            self.test(test_source, test_target)

        try:
            wandb.finish()
        except:
            pass

    def save_model(self, filename, epoch, val_accuracy):
        """Save model checkpoint"""
        try:
            torch.save({
                'epoch': epoch,
                'encoder_state_dict': self.model.encoder.state_dict(),
                'decoder_state_dict': self.model.decoder.state_dict(),
                'encoder_optimizer': self.encoder_optimizer.state_dict(),
                'decoder_optimizer': self.decoder_optimizer.state_dict(),
                'val_accuracy': val_accuracy,
            }, filename)
            print(f"Model saved with validation accuracy: {val_accuracy:.4f}")
        except Exception as e:
            print(f"Error saving model: {e}")

    def test(self, test_source, test_target):
        """Test model on test set"""
        try:
            print("Testing model on test set...")

            # Create DataProcessor for test data
            processor = DataProcessor()
            # Copy data dictionary
            processor.data = copy.deepcopy(self.data)

            # Process test data
            test_source_seq, test_target_seq = processor.process_validation(test_source, test_target)
            test_dataset = TransliterationDataset(test_source_seq, test_target_seq)
            test_loader = DataLoader(
                test_dataset,
                batch_size=self.h_params["batch_size"],
                shuffle=False,
                num_workers=0,
                pin_memory=False
            )

            # Evaluate on test set
            test_acc, test_loss = self.evaluate(test_loader)
            print(f"Test Accuracy: {test_acc:.4f} | Test Loss: {test_loss:.4f}")

            # Log test metrics to wandb
            try:
                wandb.log({
                    "test_loss": test_loss,
                    "test_accuracy": test_acc
                })
            except:
                pass

            # Generate predictions and compare with targets
            evaluator = Evaluator(self.model, processor)
            csv_path, results, csv_accuracy = evaluator.generate_predictions_csv(test_source, test_target)

            # Log final metrics to wandb
            try:
                wandb.log({
                    "final_test_accuracy": test_acc,
                    "final_test_loss": test_loss,
                    "csv_accuracy": csv_accuracy
                })

                # Create a wandb Table for sample predictions
                prediction_table = wandb.Table(columns=["Input", "Target", "Prediction", "Correct"])
                for i in range(min(100, len(results))):
                    correct = results[i]['Prediction'] == results[i]['Target']
                    prediction_table.add_data(results[i]['Input'], results[i]['Target'],
                                             results[i]['Prediction'], correct)

                wandb.log({"prediction_samples": prediction_table})
            except Exception as e:
                print(f"Warning: Could not log to wandb: {e}")

            # Display sample predictions
            evaluator.visualize_samples(test_source, test_target, num_samples=10)

        except Exception as e:
            print(f"Error in testing: {e}")
            traceback.print_exc()

class Evaluator:
    """Class for evaluating transliteration models"""
    def __init__(self, model, processor):
        self.model = model
        self.processor = processor
        self.data = processor.data

    def predict(self, input_string, use_beam_search=False, beam_width=5):
        """Make prediction for a single input string"""
        return self.model.predict(input_string, self.processor, use_beam_search, beam_width)

    def generate_predictions_csv(self, test_source, test_target, use_beam_search=False):
        """Generate predictions for all test data and save to CSV"""
        print(f"Generating predictions for {len(test_source)} test examples...")

        results = []
        # Generate predictions
        for i in range(len(test_source)):
            if i % 1000 == 0:
                print(f"Processing test example {i+1}/{len(test_source)}")

            input_str = test_source[i]
            target_str = test_target[i]

            # Get prediction
            pred_str = self.predict(input_str, use_beam_search)

            # Store result
            results.append({
                'Input': input_str,
                'Prediction': pred_str,
                'Target': target_str
            })

        # Calculate accuracy
        correct = sum(1 for r in results if r['Prediction'] == r['Target'])
        accuracy = correct / len(results) if results else 0
        print(f"Test Accuracy: {accuracy:.4f} ({correct}/{len(results)})")

        # Add accuracy to wandb
        try:
            wandb.log({"CSV_Test_Accuracy": accuracy})
        except:
            pass

        # Save to CSV
        csv_path = 'tamil_transliteration_predictions.csv'
        with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile:
            fieldnames = ['Input', 'Prediction', 'Target']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

            writer.writeheader()
            for row in results:
                writer.writerow(row)

        print(f"Saved predictions to {csv_path}")

        # Also save some metadata to help with file encoding issues
        with open('encoding_info.txt', 'w', encoding='utf-8') as f:
            f.write("File Encoding: UTF-8\n")
            f.write(f"Total Samples: {len(results)}\n")
            f.write(f"Correct Predictions: {correct}\n")
            f.write(f"Accuracy: {accuracy:.4f}\n\n")
            f.write("Sample Predictions (for debugging):\n")
            for i in range(min(5, len(results))):
                f.write(f"Input: {results[i]['Input']}\n")
                f.write(f"Prediction: {results[i]['Prediction']}\n")
                f.write(f"Target: {results[i]['Target']}\n")
                f.write("-" * 30 + "\n")

        return csv_path, results, accuracy

    def visualize_samples(self, test_source, test_target, num_samples=10, use_beam_search=False):
        """Visualize sample predictions"""
        samples_indices = random.sample(range(len(test_source)), min(num_samples, len(test_source)))

        input_strings = []
        predictions = []
        targets = []

        for idx in samples_indices:
            input_str = test_source[idx]
            target_str = test_target[idx]

            # Get prediction
            pred_str = self.predict(input_str, use_beam_search)

            input_strings.append(input_str)
            predictions.append(pred_str)
            targets.append(target_str)

        # Display results
        self.display_prediction_results(input_strings, predictions, targets)

    def display_prediction_results(self, input_strings, predictions, targets, num_examples=None):
        """Display prediction results using HTML with web fonts"""
        if num_examples is None:
            num_examples = len(input_strings)

        # Create HTML with Tamil web font
        html = """
        <html>
        <head>
            <meta charset="UTF-8">
            <style>
                @import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Tamil:wght@400;700&display=swap');
                body {
                    font-family: 'Noto Sans Tamil', sans-serif;
                    font-size: 16px;
                    line-height: 1.6;
                }
                .result {
                    margin-bottom: 20px;
                    padding: 15px;
                    border: 1px solid #ddd;
                    border-radius: 5px;
                }
                .correct { color: green; }
                .incorrect { color: red; }
                .input { font-weight: bold; }
                .unicode { color: #888; font-size: 12px; font-family: monospace; }
            </style>
        </head>
        <body>
            <h2>Transliteration Results</h2>
        """

        # Limit examples
        n = min(num_examples, len(input_strings))

        for i in range(n):
            input_str = input_strings[i]
            pred = predictions[i]
            target = targets[i]

            correct = pred == target
            status_class = "correct" if correct else "incorrect"

            # Add this result to the HTML
            html += f"""
            <div class="result">
                <div class="input">Input: {input_str}</div>
                <div class="{status_class}">Prediction: {pred}</div>
                <div>Target: {target}</div>
                <div class="unicode">Unicode (Pred): {' '.join([f'U+{ord(c):04X}' for c in pred])}</div>
            </div>
            """

        html += """
        </body>
        </html>
        """

        # Display the HTML in the notebook
        try:
            display(HTML(html))
        except:
            # Fallback if IPython display is not available
            print("HTML output not available, showing plain text results:")
            for i in range(n):
                input_str = input_strings[i]
                pred = predictions[i]
                target = targets[i]
                correct = pred == target
                status = "✓" if correct else "✗"
                print(f"Sample {i+1} {status}")
                print(f"Input: {input_str}")
                print(f"Prediction: {pred}")
                print(f"Target: {target}")
                print("-" * 30)

        # Return statistics
        correct_count = sum(1 for i in range(len(predictions)) if predictions[i] == targets[i])
        return {
            "accuracy": correct_count / len(predictions) if predictions else 0,
            "correct": correct_count,
            "total": len(predictions)
        }

In [None]:
class HyperparameterTuner:
    """Class for tuning hyperparameters using wandb sweeps"""
    def __init__(self, data_manager):
        self.data_manager = data_manager

    def get_sweep_config(self):
        """Define hyperparameter sweep configuration"""
        sweep_config = {
            'method': 'bayes',
            'name': 'tamil-transliteration-sweep',
            'metric': {
                'goal': 'maximize',
                'name': 'val_accuracy'
            },
            'parameters': {
                'learning_rate': {
                    'values': [0.001, 0.0005, 0.0001]
                },
                'batch_size': {
                    'values': [32, 64, 128, 256]
                },
                'char_embed_dim': {
                    'values': [64, 128, 256]
                },
                'hidden_size': {
                    'values': [128, 256, 512]
                },
                'num_layers': {
                    'values': [1, 2, 3, 4]
                },
                'cell_type': {
                    'values': [RNNType.RNN, RNNType.LSTM, RNNType.GRU]
                },
                'dropout': {
                    'values': [0.0, 0.1, 0.2, 0.3]
                },
                'optimizer': {
                    'values': ['adam', 'nadam']
                },
                'epochs': {
                    'values': [10, 15, 20]
                }
            }
        }
        return sweep_config

    def sweep_agent(self):
        """Sweep agent function"""
        wandb.init()

        # Access sweep config
        config = wandb.config

        # Define hyperparameters
        h_params = {
            "char_embed_dim": config.char_embed_dim,
            "hidden_size": config.hidden_size,
            "batch_size": config.batch_size,
            "num_layers": config.num_layers,
            "learning_rate": config.learning_rate,
            "epochs": config.epochs,
            "cell_type": config.cell_type,
            "dropout": config.dropout,
            "optimizer": config.optimizer
        }

        # Load data
        try:
            # Process data
            data = self.data_manager.load_all_data()

            # Create dataloaders
            train_loader, val_loader = self.data_manager.create_dataloaders(h_params)

            # Create model
            model = Seq2SeqModel(h_params, data)

            # Train model
            trainer = Trainer(model, h_params, data, train_loader, val_loader)
            trainer.train(self.data_manager.test_source, self.data_manager.test_target)

            # Close wandb run
            wandb.finish()

        except Exception as e:
            print(f"Error in sweep agent: {e}")
            traceback.print_exc()
            wandb.finish()

    def run_sweep(self, count=10):
        """Run hyperparameter sweep"""
        sweep_config = self.get_sweep_config()
        sweep_id = wandb.sweep(sweep=sweep_config, project="Tamil-Transliteration-NoAttn")
        wandb.agent(sweep_id, function=self.sweep_agent, count=count)

In [None]:
class Experiment:
    """Class for running experiments"""
    def __init__(self):
        # Default hyperparameters
        self.default_hyperparams = {
            "char_embed_dim": 256,
            "hidden_size": 256,
            "batch_size": 256,
            "num_layers": 2,
            "learning_rate": 0.001,
            "epochs": 15,
            "cell_type": RNNType.GRU,
            "dropout": 0.2,
            "optimizer": "nadam"
        }

    def run(self, use_sweep=False, num_sweep_runs=10, use_beam_search=False):
        """Run experiment"""
        try:
            # Initial CUDA memory reset
            torch.cuda.empty_cache()
            gc.collect()

            # Create data manager
            data_manager = DataManager(DATA_PATHS)

            if use_sweep:
                # Run hyperparameter sweep
                tuner = HyperparameterTuner(data_manager)
                tuner.run_sweep(count=num_sweep_runs)
            else:
                # Load data
                data = data_manager.load_all_data()

                # Create dataloaders
                train_loader, val_loader = data_manager.create_dataloaders(self.default_hyperparams)

                # Create model
                model = Seq2SeqModel(self.default_hyperparams, data)

                # Initialize wandb
                run_name = f"{self.default_hyperparams['cell_type']}_{self.default_hyperparams['optimizer']}_noattn_layers{self.default_hyperparams['num_layers']}"
                try:
                    wandb.init(project="Tamil-Transliteration-NoAttn", name=run_name, config=self.default_hyperparams)
                except Exception as e:
                    print(f"Error initializing wandb: {e}")

                # Train model
                print(f"Starting training with {self.default_hyperparams['cell_type']} cell, {self.default_hyperparams['optimizer']} optimizer (No Attention)")
                trainer = Trainer(model, self.default_hyperparams, data, train_loader, val_loader)
                trainer.train(data_manager.test_source, data_manager.test_target)

                # Test with beam search if requested
                if use_beam_search and model is not None:
                    print("\nTesting with beam search...")
                    processor = DataProcessor()
                    processor.data = data
                    evaluator = Evaluator(model, processor)

                    # Compare greedy vs beam search predictions
                    print("Comparing greedy vs beam search predictions...")
                    samples_indices = random.sample(range(len(data_manager.test_source)), 10)

                    greedy_results = []
                    beam_results = []

                    for idx in samples_indices:
                        input_str = data_manager.test_source[idx]
                        target_str = data_manager.test_target[idx]

                        # Get predictions
                        greedy_pred = evaluator.predict(input_str, use_beam_search=False)
                        beam_pred = evaluator.predict(input_str, use_beam_search=True, beam_width=5)

                        greedy_results.append({
                            'Input': input_str,
                            'Prediction': greedy_pred,
                            'Target': target_str
                        })

                        beam_results.append({
                            'Input': input_str,
                            'Prediction': beam_pred,
                            'Target': target_str
                        })

                    # Display comparison
                    print("\nGreedy Search Results:")
                    greedy_inputs = [r['Input'] for r in greedy_results]
                    greedy_preds = [r['Prediction'] for r in greedy_results]
                    greedy_targets = [r['Target'] for r in greedy_results]
                    evaluator.display_prediction_results(greedy_inputs, greedy_preds, greedy_targets)

                    print("\nBeam Search Results:")
                    beam_inputs = [r['Input'] for r in beam_results]
                    beam_preds = [r['Prediction'] for r in beam_results]
                    beam_targets = [r['Target'] for r in beam_results]
                    evaluator.display_prediction_results(beam_inputs, beam_preds, beam_targets)

                # Close wandb run
                try:
                    wandb.finish()
                except:
                    pass

        except Exception as e:
            print(f"Error in experiment: {e}")
            traceback.print_exc()

# Run experiment
def main():
    experiment = Experiment()

    # Set use_sweep=True to run hyperparameter sweep, False to run with default hyperparameters
    # Set use_beam_search=True to test beam search decoding
    experiment.run(use_sweep=False, use_beam_search=True)

    # Uncomment to run sweep with 20 runs
    # experiment.run(use_sweep=True, num_sweep_runs=20)

if __name__ == "__main__":
    main()