<a href="https://colab.research.google.com/github/manasdeshpande125/da6401_assignment_3/blob/main/DLASG3_CombinedAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch, os, random, numpy as np, pandas as pd
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
import urllib.request, tarfile, pathlib, shutil

URL = "https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar"
TAR = "dakshina.tar"
if not pathlib.Path(TAR).exists():
    urllib.request.urlretrieve(URL, TAR)
    print("Downloaded.")

with tarfile.open(TAR) as t:
    members = [m for m in t.getmembers() if m.name.startswith("dakshina_dataset_v1.0/mr/lexicons/")]
    t.extractall(members=members)
DATA_ROOT = pathlib.Path("dakshina_dataset_v1.0/mr/lexicons")
print("Files:", os.listdir(DATA_ROOT))

In [None]:
import wandb, yaml, json
wandb.login(key="41a2853ea088e37bd0d456e78102e82edb455afc")

In [None]:
!pip install lightning

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
import wandb
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import torch.nn.functional as F

@dataclass
class ModelConfig:
    vocab_size: int
    embedding_dim: int
    hidden_dim: int
    cell_type: str = 'LSTM'
    num_encoder_layers: int = 1
    num_decoder_layers: int = 1
    dropout_rate: float = 0.0
    learning_rate: float = 0.001

class CustomDataset(Dataset):
    def __init__(self, encoder_inputs: np.ndarray, decoder_inputs: np.ndarray, target_outputs: np.ndarray):
        self.encoder_inputs = torch.from_numpy(encoder_inputs).long()
        self.decoder_inputs = torch.from_numpy(decoder_inputs).long()
        self.target_outputs = torch.from_numpy(target_outputs.squeeze(-1)).long()

    def __len__(self) -> int:
        return len(self.encoder_inputs)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        return {
            'encoder_input': self.encoder_inputs[idx],
            'decoder_input': self.decoder_inputs[idx],
            'target': self.target_outputs[idx]
        }

class CustomAttention(nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, query: torch.Tensor, keys: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # query: (batch_size, 1, hidden_dim)
        # keys: (batch_size, seq_len, hidden_dim)
        query = query.unsqueeze(1)
        energy = self.attention(torch.cat([query.expand(-1, keys.size(1), -1), keys], dim=2))
        attention_weights = F.softmax(energy, dim=1)
        context = torch.bmm(attention_weights.transpose(1, 2), keys)
        return context, attention_weights

class CustomSeq2Seq(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=0)

        if config.cell_type == 'LSTM':
            self.encoder = nn.LSTM(
                config.embedding_dim,
                config.hidden_dim,
                num_layers=config.num_encoder_layers,
                dropout=config.dropout_rate if config.num_encoder_layers > 1 else 0,
                batch_first=True,
                bidirectional=True
            )
            self.decoder = nn.LSTM(
                config.embedding_dim + config.hidden_dim * 2,
                config.hidden_dim,
                num_layers=config.num_decoder_layers,
                dropout=config.dropout_rate if config.num_decoder_layers > 1 else 0,
                batch_first=True
            )
        else:
            self.encoder = nn.GRU(
                config.embedding_dim,
                config.hidden_dim,
                num_layers=config.num_encoder_layers,
                dropout=config.dropout_rate if config.num_encoder_layers > 1 else 0,
                batch_first=True,
                bidirectional=True
            )
            self.decoder = nn.GRU(
                config.embedding_dim + config.hidden_dim * 2,
                config.hidden_dim,
                num_layers=config.num_decoder_layers,
                dropout=config.dropout_rate if config.num_decoder_layers > 1 else 0,
                batch_first=True
            )

        self.attention = CustomAttention(config.hidden_dim * 2)
        self.output_layer = nn.Sequential(
            nn.Linear(config.hidden_dim * 2, config.hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(config.hidden_dim, config.vocab_size)
        )

    def forward(self, encoder_input: torch.Tensor, decoder_input: torch.Tensor) -> torch.Tensor:
        # Encoder
        enc_emb = self.embedding(encoder_input)
        enc_output, enc_hidden = self.encoder(enc_emb)

        # Decoder
        dec_emb = self.embedding(decoder_input)

        # Attention
        context, _ = self.attention(enc_hidden[0][-1].unsqueeze(1), enc_output)
        dec_input = torch.cat([dec_emb, context.expand(-1, dec_emb.size(1), -1)], dim=2)

        # Decode
        dec_output, _ = self.decoder(dec_input)
        output = self.output_layer(dec_output)

        return output

class CustomModel(L.LightningModule):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.save_hyperparameters(config.__dict__)
        self.model = CustomSeq2Seq(config)
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)

    def forward(self, encoder_input: torch.Tensor, decoder_input: torch.Tensor) -> torch.Tensor:
        return self.model(encoder_input, decoder_input)

    def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        output = self(batch['encoder_input'], batch['decoder_input'])
        loss = self.criterion(output.view(-1, output.size(-1)), batch['target'].view(-1))
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        output = self(batch['encoder_input'], batch['decoder_input'])
        loss = self.criterion(output.view(-1, output.size(-1)), batch['target'].view(-1))

        pred = output.argmax(dim=-1)
        correct = (pred == batch['target']).float().mean()

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_accuracy', correct, prog_bar=True)
        return loss

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)

def train_custom_model(train_data: Dict[str, np.ndarray], val_data: Dict[str, np.ndarray], config: ModelConfig) -> CustomModel:
    train_dataset = CustomDataset(
        train_data['encoder_input'],
        train_data['decoder_input'],
        train_data['target']
    )
    val_dataset = CustomDataset(
        val_data['encoder_input'],
        val_data['decoder_input'],
        val_data['target']
    )

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=4)

    model = CustomModel(config)

    wandb_logger = WandbLogger(project="DL_A3", name=f"custom_{config.cell_type}")

    trainer = L.Trainer(
        max_epochs=config.epochs,
        logger=wandb_logger,
        callbacks=[
            ModelCheckpoint(
                dirpath='checkpoints',
                filename='custom-model-{epoch:02d}-{val_accuracy:.2f}',
                monitor='val_accuracy',
                mode='max'
            )
        ],
        accelerator='auto',
        devices='auto'
    )

    trainer.fit(model, train_loader, val_loader)

    return model

def decode_custom_sequence(model: CustomModel, input_seq: np.ndarray, target_tokenizer: any, max_len: int = 30) -> str:
    model.eval()
    with torch.no_grad():
        input_tensor = torch.from_numpy(input_seq).long().unsqueeze(0)
        target_seq = torch.zeros((1, 1), dtype=torch.long)
        target_seq[0, 0] = target_tokenizer.word_index['\t']

        decoded = []
        for _ in range(max_len):
            output = model(input_tensor, target_seq)
            sampled_token = output[0, -1, :].argmax().item()
            decoded.append(sampled_token)

            if sampled_token == target_tokenizer.word_index['\n']:
                break

            target_seq = torch.cat([target_seq, torch.LongTensor([[sampled_token]])], dim=1)

        return ''.join([target_tokenizer.index_word.get(idx, '') for idx in decoded])

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Embedding, Dense, LSTM, GRU, Bidirectional, Layer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.optimizers import Adam
import torch
import wandb
from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn as nn
import torch.nn.functional as F

# Enable eager execution
tf.config.run_functions_eagerly(True)

# from pytorch_implementation import (
#     CustomDataset,
#     CustomModel,
#     ModelConfig,
#     train_custom_model,
#     decode_custom_sequence
# )

@dataclass
class TFModelConfig:
    vocab_size: int
    embedding_dim: int
    hidden_dim: int
    cell_type: str = 'LSTM'
    num_encoder_layers: int = 1
    num_decoder_layers: int = 1
    dropout_rate: float = 0.0
    learning_rate: float = 0.001
    batch_size: int = 64

class BahdanauAttention(Layer):
    """TensorFlow implementation of Bahdanau attention"""
    def __init__(self, units):
        super().__init__()
        self.W1 = Dense(units)
        self.W2 = Dense(units)
        self.V = Dense(1)

    def call(self, query, values):
        # query: (batch_size, dec_len, hidden)
        # values: (batch_size, enc_len, hidden)
        query_with_time_axis = tf.expand_dims(query, 2)  # (batch_size, dec_len, 1, hidden)
        values_with_time_axis = tf.expand_dims(values, 1)  # (batch_size, 1, enc_len, hidden)

        score = self.V(tf.nn.tanh(self.W1(values_with_time_axis) + self.W2(query_with_time_axis)))
        attention_weights = tf.nn.softmax(score, axis=2)
        context_vector = tf.reduce_sum(attention_weights * values_with_time_axis, axis=2)

        return context_vector, tf.squeeze(attention_weights, -1)

def build_tf_model(config: TFModelConfig) -> Model:
    """Build TensorFlow model with attention"""
    # Encoder
    encoder_inputs = Input(shape=(None,))
    encoder_emb = Embedding(config.vocab_size, config.embedding_dim, mask_zero=True)(encoder_inputs)

    if config.cell_type == 'LSTM':
        encoder = LSTM(
            config.hidden_dim,
            return_sequences=True,
            return_state=True,
            dropout=config.dropout_rate
        )
        encoder_outputs, state_h, state_c = encoder(encoder_emb)
        encoder_states = [state_h, state_c]
    else:
        encoder = GRU(
            config.hidden_dim,
            return_sequences=True,
            return_state=True,
            dropout=config.dropout_rate
        )
        encoder_outputs, state_h = encoder(encoder_emb)
        encoder_states = [state_h]

    # Decoder
    decoder_inputs = Input(shape=(None,))
    decoder_emb = Embedding(config.vocab_size, config.embedding_dim, mask_zero=True)(decoder_inputs)

    if config.cell_type == 'LSTM':
        decoder = LSTM(
            config.hidden_dim,
            return_sequences=True,
            return_state=True,
            dropout=config.dropout_rate
        )
        decoder_outputs, _, _ = decoder(decoder_emb, initial_state=encoder_states)
    else:
        decoder = GRU(
            config.hidden_dim,
            return_sequences=True,
            return_state=True,
            dropout=config.dropout_rate
        )
        decoder_outputs, _ = decoder(decoder_emb, initial_state=encoder_states)

    # Attention
    attention_layer = BahdanauAttention(config.hidden_dim)
    context_vector, attention_weights = attention_layer(decoder_outputs, encoder_outputs)

    # Concatenate context vector and decoder output
    concat = tf.keras.layers.Concatenate()([decoder_outputs, context_vector])

    # Output layer
    outputs = Dense(config.vocab_size, activation='softmax')(concat)

    # Create model
    model = Model([encoder_inputs, decoder_inputs], outputs)

    # Store attention weights for visualization
    model.attention_weights = attention_weights

    return model

def load_and_preprocess_data(data_dir: str) -> Tuple[Dict[str, Dict[str, np.ndarray]], Tokenizer, Tokenizer, int, int]:
    """Load and preprocess data for both frameworks"""
    def load_tsv(path: str) -> List[List[str]]:
        with open(path, encoding='utf-8') as f:
            lines = f.read().strip().split('\n')
        return [line.split('\t') for line in lines if '\t' in line]

    # Load data
    train_pairs = load_tsv(os.path.join(data_dir, "mr.translit.sampled.train.tsv"))
    val_pairs = load_tsv(os.path.join(data_dir, "mr.translit.sampled.dev.tsv"))
    test_pairs = load_tsv(os.path.join(data_dir, "mr.translit.sampled.test.tsv"))

    # Tokenize
    def tokenize_pairs(pairs: List[List[str]]) -> Tuple[List[str], List[str], List[str]]:
        latin_texts = [x[1] for x in pairs]
        marathi_texts = [x[0] for x in pairs]
        marathi_texts_in = ['\t' + t for t in marathi_texts]
        marathi_texts_out = [t + '\n' for t in marathi_texts]
        return latin_texts, marathi_texts_in, marathi_texts_out

    train_lat, train_mr_in, train_mr_out = tokenize_pairs(train_pairs)
    val_lat, val_mr_in, val_mr_out = tokenize_pairs(val_pairs)
    test_lat, test_mr_in, test_mr_out = tokenize_pairs(test_pairs)

    # Create tokenizers
    input_tokenizer = Tokenizer(char_level=True, lower=False)
    input_tokenizer.fit_on_texts(train_lat + val_lat)

    target_tokenizer = Tokenizer(char_level=True, lower=False)
    target_tokenizer.fit_on_texts(train_mr_in + train_mr_out)

    # Get vocabulary sizes
    vocab_size_input = len(input_tokenizer.word_index) + 1
    vocab_size_target = len(target_tokenizer.word_index) + 1

    # Get max lengths
    maxlen_input = max(map(len, train_lat))
    maxlen_target = max(map(len, train_mr_out))

    # Encode and pad sequences
    def encode_and_pad(texts: List[str], tokenizer: Tokenizer, maxlen: Optional[int] = None) -> np.ndarray:
        return pad_sequences(tokenizer.texts_to_sequences(texts), padding='post', maxlen=maxlen)

    # Prepare data for both frameworks
    data = {
        'train': {
            'encoder_input': encode_and_pad(train_lat, input_tokenizer, maxlen_input),
            'decoder_input': encode_and_pad(train_mr_in, target_tokenizer, maxlen_target),
            'target': np.expand_dims(encode_and_pad(train_mr_out, target_tokenizer, maxlen_target), -1)
        },
        'val': {
            'encoder_input': encode_and_pad(val_lat, input_tokenizer, maxlen_input),
            'decoder_input': encode_and_pad(val_mr_in, target_tokenizer, maxlen_target),
            'target': np.expand_dims(encode_and_pad(val_mr_out, target_tokenizer, maxlen_target), -1)
        },
        'test': {
            'encoder_input': encode_and_pad(test_lat, input_tokenizer, maxlen_input),
            'decoder_input': encode_and_pad(test_mr_in, target_tokenizer, maxlen_target),
            'target': np.expand_dims(encode_and_pad(test_mr_out, target_tokenizer, maxlen_target), -1)
        }
    }

    return data, input_tokenizer, target_tokenizer, vocab_size_input, vocab_size_target

def train_models(data: Dict[str, Dict[str, np.ndarray]], vocab_size_input: int, vocab_size_target: int) -> Tuple[Model, CustomModel]:
    """Train both TensorFlow and PyTorch models"""

    # Common configuration
    tf_config = TFModelConfig(
        vocab_size=vocab_size_target,
        embedding_dim=256,
        hidden_dim=256,
        cell_type='LSTM',
        num_encoder_layers=2,
        num_decoder_layers=2,
        dropout_rate=0.2,
        learning_rate=0.001,
        batch_size=64
    )

    pt_config = ModelConfig(
        vocab_size=vocab_size_target,
        embedding_dim=256,
        hidden_dim=256,
        cell_type='LSTM',
        num_encoder_layers=2,
        num_decoder_layers=2,
        dropout_rate=0.2,
        learning_rate=0.001,
        batch_size=64
    )

    # Initialize wandb
    wandb.init(project="DA6401_Assignment_3", name="combined_implementation")

    # Train TensorFlow model
    tf_model = build_tf_model(tf_config)
    tf_model.compile(
        optimizer=Adam(learning_rate=tf_config.learning_rate),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    # Create a custom callback that only logs metrics
    class CustomWandbCallback(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            if logs:
                wandb.log(logs)

    tf_model.fit(
        [data['train']['encoder_input'], data['train']['decoder_input']],
        data['train']['target'],
        validation_data=([data['val']['encoder_input'], data['val']['decoder_input']], data['val']['target']),
        batch_size=tf_config.batch_size,
        epochs=10,
        callbacks=[CustomWandbCallback()]
    )

    # Train PyTorch model
    pt_model = train_custom_model(data['train'], data['val'], pt_config)

    return tf_model, pt_model

def evaluate_models(tf_model: Model, pt_model: CustomModel, test_data: Dict[str, np.ndarray], target_tokenizer: Tokenizer) -> Tuple[float, float]:
    """Evaluate both models on test data"""

    # Evaluate TensorFlow model
    tf_loss, tf_acc = tf_model.evaluate(
        [test_data['encoder_input'], test_data['decoder_input']],
        test_data['target']
    )
    print(f"TensorFlow Test Accuracy: {tf_acc:.4f}")

    # Evaluate PyTorch model
    pt_model.eval()
    with torch.no_grad():
        test_dataset = CustomDataset(
            test_data['encoder_input'],
            test_data['decoder_input'],
            test_data['target']
        )
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, num_workers=4)

        correct = 0
        total = 0
        for batch in test_loader:
            output = pt_model(batch['encoder_input'], batch['decoder_input'])
            pred = output.argmax(dim=-1)
            correct += (pred == batch['target']).float().sum().item()
            total += batch['target'].numel()

        pt_acc = correct / total
        print(f"PyTorch Test Accuracy: {pt_acc:.4f}")

    return tf_acc, pt_acc

def plot_attention_heatmap(input_text: str, output_text: str, attention_weights: np.ndarray, idx: int = 1):
    """Plot attention heatmap for a single sample"""
    plt.figure(figsize=(6, 5))
    ax = sns.heatmap(
        attention_weights,
        xticklabels=list(input_text),
        yticklabels=list(output_text),
        cmap='magma',
        cbar=False,
        linewidths=0.5,
        annot=False
    )
    plt.xlabel("Input (Latin)")
    plt.ylabel("Output (Hindi)")
    plt.title(f"Sample {idx} Attention")
    plt.tight_layout()
    plt.show()

def plot_9_grid(test_lat: List[str], test_encoder_input: np.ndarray, model: CustomModel, target_tokenizer: Tokenizer):
    """Plot 9 attention heatmaps in a grid"""
    plt.figure(figsize=(18, 15))
    for i in range(9):
        input_text = test_lat[i]
        input_seq = test_encoder_input[i:i+1]
        output_text, attn_weights = decode_custom_sequence(model, input_seq, target_tokenizer)
        attn_matrix = np.stack(attn_weights)

        plt.subplot(3, 3, i+1)
        sns.heatmap(attn_matrix, xticklabels=list(input_text), yticklabels=list(output_text), cmap='coolwarm', cbar=False)
        plt.title(f"Input: {input_text}")
        plt.xlabel("Latin chars")
        plt.ylabel("Hindi chars")
    plt.tight_layout()
    plt.show()

def save_attention_heatmaps(test_lat: List[str], test_encoder_input: np.ndarray, model: CustomModel, target_tokenizer: Tokenizer, num_samples: int = 10):
    """Save attention heatmaps for multiple samples"""
    os.makedirs("attention_heatmaps", exist_ok=True)
    for i in range(num_samples):
        input_text = test_lat[i]
        input_seq = test_encoder_input[i:i+1]
        output_text, attn_weights = decode_custom_sequence(model, input_seq, target_tokenizer)
        attn_matrix = np.stack(attn_weights)

        plt.figure(figsize=(6, 5))
        sns.heatmap(attn_matrix, xticklabels=list(input_text), yticklabels=list(output_text), cmap='plasma')
        plt.title(f"Sample {i+1}")
        plt.xlabel("Input (Latin)")
        plt.ylabel("Output (Hindi)")
        plt.tight_layout()
        plt.savefig(f"attention_heatmaps/sample_{i+1}.png")
        plt.close()

def create_wandb_table(test_lat: List[str], decoded_preds: List[str], decoded_refs: List[str], num_samples: int = 10):
    """Create and log a wandb table with predictions"""
    wandb.init(project="seq2seq_sweep", name="prediction_samples_colored_table")

    table = wandb.Table(columns=["Input Word", "Predicted Word", "Target Word"])

    for i in range(num_samples):
        input_word = test_lat[i]
        pred_word = decoded_preds[i]
        target_word = decoded_refs[i]

        color = "#00FF00" if pred_word == target_word else "#FF0000"
        colored_pred = f'<span style="color: {color}">{pred_word}</span>'

        table.add_data(input_word, colored_pred, target_word)

    wandb.log({"Prediction Samples Colored Table": table})
    wandb.finish()

def sweep_train_tf(data: Dict[str, Dict[str, np.ndarray]]):
    """Train TensorFlow model with wandb sweep"""
    with wandb.init(project="DA6401_Assignment_3", entity="cs24m024-iit-madras"):
        config = wandb.config
        wandb.run.name = f"tf_{config.cell_type}_emb{config.embedding_dim}_hid{config.hidden_dim}_enc{config.num_encoder_layers}_dec{config.num_decoder_layers}_drop{int(config.dropout_rate*100)}"

        tf_config = TFModelConfig(
            vocab_size=config.vocab_size,
            embedding_dim=config.embedding_dim,
            hidden_dim=config.hidden_dim,
            cell_type=config.cell_type,
            num_encoder_layers=config.num_encoder_layers,
            num_decoder_layers=config.num_decoder_layers,
            dropout_rate=config.dropout_rate,
            learning_rate=config.learning_rate,
            batch_size=config.batch_size
        )

        model = build_tf_model(tf_config)
        model.compile(
            optimizer=Adam(learning_rate=config.learning_rate),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )

        # Create a custom callback that only logs metrics
        class CustomWandbCallback(tf.keras.callbacks.Callback):
            def on_epoch_end(self, epoch, logs=None):
                if logs:
                    wandb.log(logs)

        model.fit(
            [data['train']['encoder_input'], data['train']['decoder_input']],
            data['train']['target'],
            validation_data=([data['val']['encoder_input'], data['val']['decoder_input']], data['val']['target']),
            batch_size=config.batch_size,
            epochs=config.epochs,
            callbacks=[CustomWandbCallback()]
        )

global data

def sweep_train_pt(config=None):
    global data
    """Train PyTorch model with wandb sweep"""
    with wandb.init(config=config, project="DA6401_Assignment_3", entity="cs24m024-iit-madras"):
        config = wandb.config
        wandb.run.name = f"pt_{config.cell_type}_emb{config.embedding_dim}_hid{config.hidden_dim}_enc{config.num_encoder_layers}_dec{config.num_decoder_layers}_drop{int(config.dropout_rate*100)}"

        pt_config = ModelConfig(
            vocab_size=config.vocab_size,
            embedding_dim=config.embedding_dim,
            hidden_dim=config.hidden_dim,
            cell_type=config.cell_type,
            num_encoder_layers=config.num_encoder_layers,
            num_decoder_layers=config.num_decoder_layers,
            dropout_rate=config.dropout_rate,
            learning_rate=config.learning_rate,
            batch_size=config.batch_size
        )

        model = train_custom_model(data['train'], data['val'], pt_config)
        return model

def run_sweeps(data_arg: Dict[str, Dict[str, np.ndarray]], vocab_size: int):
    global data
    data = data_arg
    # TensorFlow sweep config
    tf_sweep_config = {
        'method': 'bayes',
        'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
        'parameters': {
            'vocab_size': {'value': vocab_size},
            'embedding_dim': {'values': [16, 32, 64, 256]},
            'hidden_dim': {'values': [16, 32, 64, 256]},
            'cell_type': {'values': ['RNN', 'LSTM']},
            'num_encoder_layers': {'values': [1, 2]},
            'num_decoder_layers': {'values': [1, 2]},
            'dropout_rate': {'values': [0.2, 0.3]},
            'batch_size': {'values': [32, 64]},
            'epochs': {'value': 1},
            'learning_rate': {'value': 0.001}
        }
    }

    # PyTorch sweep config
    pt_sweep_config = {
        'method': 'bayes',
        'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
        'parameters': {
            'vocab_size': {'value': vocab_size},
            'embedding_dim': {'values': [16, 32, 64, 256]},
            'hidden_dim': {'values': [16, 32, 64, 256]},
            'cell_type': {'values': ['RNN', 'LSTM']},
            'num_encoder_layers': {'values': [1, 2]},
            'num_decoder_layers': {'values': [1, 2]},
            'dropout_rate': {'values': [0.2, 0.3]},
            'batch_size': {'values': [64]},
            'epochs': {'value': 1},
            'learning_rate': {'value': 0.001}
        }
    }

    # Run TensorFlow sweep
    def sweep_train_tf_wrapper():
        sweep_train_tf(data)
    tf_sweep_id = wandb.sweep(tf_sweep_config, project="DA6401_Assignment_3", entity="cs24m024-iit-madras")
    wandb.agent(tf_sweep_id, function=sweep_train_tf_wrapper, count=1)

    # Run PyTorch sweep
    pt_sweep_id = wandb.sweep(pt_sweep_config, project="DA6401_Assignment_3", entity="cs24m024-iit-madras")
    wandb.agent(pt_sweep_id, function=sweep_train_pt, count=1)

def main():
    # Load and preprocess data
    data_dir = "/kaggle/working/dakshina_dataset_v1.0/mr/lexicons"
    data, input_tokenizer, target_tokenizer, vocab_size_input, vocab_size_target = load_and_preprocess_data(data_dir)

    # Run sweeps
    run_sweeps(data, vocab_size_target)

    # Train models with best configurations
    tf_model, pt_model = train_models(data, vocab_size_input, vocab_size_target)

    # Evaluate models
    tf_acc, pt_acc = evaluate_models(tf_model, pt_model, data['test'], target_tokenizer)

    # Generate predictions
    decoded_preds = [decode_custom_sequence(pt_model, data['test']['encoder_input'][i:i+1], target_tokenizer)
                    for i in range(len(data['test']['encoder_input']))]
    decoded_refs = [t.replace(' </s>', '') for t in test_mr_out]

    # Create wandb table
    create_wandb_table(test_lat, decoded_preds, decoded_refs)

    # Plot attention heatmaps
    plot_9_grid(test_lat, data['test']['encoder_input'], pt_model, target_tokenizer)
    save_attention_heatmaps(test_lat, data['test']['encoder_input'], pt_model, target_tokenizer)

    # Save models
    tf_model.save("best_tf_model.keras")
    torch.save(pt_model.state_dict(), "best_pt_model.pt")

    # Log results
    wandb.log({
        "tf_test_accuracy": tf_acc,
        "pt_test_accuracy": pt_acc
    })

if __name__ == "__main__":
    main()