<a href="https://colab.research.google.com/github/manasdeshpande125/da6401_assignment_3/blob/main/DLASG3__Word.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch, os, random, numpy as np, pandas as pd
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


**Download dataset and unzip it**

In [None]:
import urllib.request, tarfile, pathlib, shutil

URL = "https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar"
TAR = "dakshina.tar"
if not pathlib.Path(TAR).exists():
    urllib.request.urlretrieve(URL, TAR)
    print("Downloaded.")

with tarfile.open(TAR) as t:
    members = [m for m in t.getmembers() if m.name.startswith("dakshina_dataset_v1.0/mr/lexicons/")]
    t.extractall(members=members)
DATA_ROOT = pathlib.Path("dakshina_dataset_v1.0/mr/lexicons")
print("Files:", os.listdir(DATA_ROOT))

**Classes for loading and preparing the dataset for word level**

In [None]:
class Vocabulary:
    def __init__(self, file_path, src_lang, trg_lang):
        df = pd.read_csv(file_path, sep="\t", header=None,
                         names=[src_lang, trg_lang, "count"], dtype=str).dropna()
        self.df = df
        self.src_lang, self.trg_lang = src_lang, trg_lang

        # Create word-level vocabularies
        self.src_vocab = {word: i+3 for i, word in enumerate(sorted(set(df[src_lang])))}
        self.trg_vocab = {word: i+3 for i, word in enumerate(sorted(set(df[trg_lang])))}

        # Add special tokens
        for v in (self.src_vocab, self.trg_vocab):
            v["<pad>"] = 1
            v["<unk>"] = 2
            v["<s>"] = 0

        self.s_word2idx, self.s_idx2word = self.src_vocab, {i: w for w, i in self.src_vocab.items()}
        self.t_word2idx, self.t_idx2word = self.trg_vocab, {i: w for w, i in self.trg_vocab.items()}

    def get(self):
        return (self.src_vocab, self.trg_vocab,
                self.t_word2idx, self.t_idx2word,
                self.s_word2idx, self.s_idx2word)

In [None]:
from torch.utils.data import Dataset

class TransliterationDataset(Dataset):
    def __init__(self, file_path, src_lang, trg_lang,
                 src_vocab, trg_vocab, t_word2idx):
        df = pd.read_csv(file_path, sep="\t", header=None,
                         names=[src_lang, trg_lang, "count"], dtype=str).dropna()
        self.df = df.reset_index(drop=True)
        self.src_vocab, self.trg_vocab = src_vocab, trg_vocab
        self.t_word2idx = t_word2idx
        self.src_lang, self.trg_lang = src_lang, trg_lang

    def __len__(self):
        return len(self.df)

    def _encode(self, word, vocab):
        return torch.tensor(vocab.get(word, vocab["<unk>"]), dtype=torch.long)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        src = self._encode(row[self.src_lang], self.src_vocab)
        trg = self._encode(row[self.trg_lang], self.trg_vocab)
        return src, trg, row[self.src_lang], row[self.trg_lang]

**Encoder, Decoder and Seq2Seq Classes for Transliteration**

In [None]:
import torch.nn as nn, torch.nn.functional as F
import pytorch_lightning as pl

class Encoder(nn.Module):
    def __init__(self, input_dim: int, embed_dim: int,
                 hid_dim: int, n_layers: int,
                 cell_type: str = "lstm",
                 bidirectional: bool = False,
                 dropout: float = 0.2):
        super().__init__()
        self.hid_dim, self.n_layers = hid_dim, n_layers
        self.bidirectional = bidirectional
        self.dir = 2 if bidirectional else 1

        self.embedding = nn.Embedding(input_dim, embed_dim)
        rnn_cls = {"rnn": nn.RNN, "gru": nn.GRU, "lstm": nn.LSTM}[cell_type.lower()]
        self.rnn = rnn_cls(embed_dim, hid_dim, n_layers,
                          bidirectional=bidirectional,
                          dropout=dropout if n_layers > 1 else 0.0)
        self.dropout = nn.Dropout(dropout)
        self.cell_type = cell_type.lower()

    def forward(self, src):
        # src ⇒ [batch]
        embedded = self.dropout(self.embedding(src))
        embedded = embedded.unsqueeze(0)  # Add sequence dimension
        return self.rnn(embedded)  # (output, hidden[, cell])

In [None]:
class Decoder(nn.Module):
    def __init__(self, trg_vocab, output_dim, embed_dim, hid_dim, n_layers,
                cell_type: str = "lstm", bidirectional: bool = False, dropout: float = 0.2):
        super().__init__()
        self.trg_vocab = trg_vocab  # store it here for loss
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, embed_dim)
        rnn_cls = {"rnn": nn.RNN, "gru": nn.GRU, "lstm": nn.LSTM}[cell_type.lower()]
        self.rnn = rnn_cls(embed_dim, hid_dim, n_layers,
                          dropout=dropout if n_layers > 1 else 0.0)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.cell_type = cell_type.lower()
        self.bidirectional = bidirectional

    def forward(self, inp, hidden):
        # inp ⇒ [batch]
        embedded = self.dropout(self.embedding(inp))
        embedded = embedded.unsqueeze(0)  # Add sequence dimension
        outputs = self.rnn(embedded, hidden)
        rnn_out, hidden = outputs if self.cell_type != "lstm" else (outputs[0], outputs[1])
        logits = self.fc_out(rnn_out.squeeze(0))
        return F.log_softmax(logits, dim=1), hidden

In [None]:
class Seq2SeqLightning(pl.LightningModule):
    def __init__(
        self,
        encoder: Encoder,
        decoder: Decoder,
        t_word2idx: dict,
        t_idx2word: dict,
        cell_type: str = "lstm",
        bidirectional: bool = False,
        device: str = "cpu",
        learning_rate: float = 1e-3,
        optim_name: str = "adam",
        tf_ratio: float = 0.5,
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.t_word2idx = t_word2idx  # Store target word to index mapping
        self.t_idx2word = t_idx2word  # Store index to target word mapping
        self.cell_type = cell_type.lower()
        self.bidirectional = bidirectional
        self.learning_rate = learning_rate
        self.optim_name = optim_name.lower()
        self.tf_ratio = tf_ratio
        self.criterion = nn.CrossEntropyLoss(ignore_index=decoder.trg_vocab["<pad>"])
        self.pad_idx = decoder.trg_vocab["<pad>"]
        self.predictions = []

    def _merge_bidir(self, h):
        """Average the fwd & bwd hidden states so that
           [layers*dir, batch, hid] → [layers, batch, hid]"""
        if self.bidirectional:
            if self.cell_type == "lstm":
                # h is tuple(hidden, cell)
                hidden = (h[0].view(self.decoder.rnn.num_layers, 2, -1, h[0].size(-1)).mean(1),
                         h[1].view(self.decoder.rnn.num_layers, 2, -1, h[1].size(-1)).mean(1))
            else:
                hidden = h.view(self.decoder.rnn.num_layers, 2, -1, h.size(-1)).mean(1)
        else:
            hidden = h
        return hidden

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        vocab_size = self.decoder.output_dim

        # Single step prediction (since we're doing word-level)
        enc_out, enc_hidden = self.encoder(src)
        dec_hidden = self._merge_bidir(enc_hidden)

        dec_inp = trg  # For word-level, we predict in one step
        dec_out, dec_hidden = self.decoder(dec_inp, dec_hidden)

        return dec_out.unsqueeze(0)  # Add sequence dimension for compatibility

    def configure_optimizers(self):
        learning_rate = self.learning_rate
        opt = self.optim_name.lower()  # "adam" | "nadam"

        if opt == "adam":
            optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        elif opt == "nadam":
            optimizer = torch.optim.NAdam(self.parameters(), lr=learning_rate)
        else:  # fallback / safety
            raise ValueError(f"Unknown optimizer '{opt}'")

        return optimizer

    def _accuracy(self, logits, trg):
        pred = logits.argmax(1)
        correct = (pred == trg) & (trg != self.pad_idx)
        return correct.float().mean()

    def _word_accuracy(self, pred_words, true_words):
        correct = sum(1 for p, t in zip(pred_words, true_words) if p == t)
        return correct / len(true_words)

    def _step(self, batch, stage):
        src, trg, src_words, trg_words = batch

        logits = self(src, trg, self.tf_ratio if stage=="train" else 0.0)
        loss = self.criterion(logits.squeeze(0), trg)

        # Calculate token-level accuracy
        acc = self._accuracy(logits.squeeze(0), trg)

        # Calculate word-level accuracy
        pred_indices = torch.argmax(logits.squeeze(0), dim=1).tolist()
        pred_words = [self.t_idx2word.get(idx, "<unk>") for idx in pred_indices]
        word_acc = self._word_accuracy(pred_words, trg_words)

        self.log(f"{stage}_loss", loss, prog_bar=True)
        self.log(f"{stage}_acc", acc, prog_bar=True)
        self.log(f"{stage}_word_acc", word_acc, prog_bar=True)

        # Store predictions for validation and test
        if stage in ["val", "test"]:
            for s_word, t_word, p_word in zip(src_words, trg_words, pred_words):
                self.predictions.append({
                    "input": s_word,
                    "target": t_word,
                    "predicted": p_word
                })

        return loss

    def training_step(self, batch, _): return self._step(batch, "train")
    def validation_step(self, batch, _): return self._step(batch, "val")
    def test_step(self, batch, _): return self._step(batch, "test")

    def on_validation_epoch_end(self):
        # Save predictions to CSV after validation epoch
        if self.predictions:
            df = pd.DataFrame(self.predictions)
            # df.to_csv(f"val_predictions_epoch_{self.current_epoch}.csv", index=False)
            self.predictions = []  # Clear predictions for next epoch

    def on_test_epoch_end(self):
        # Save final test predictions
        if self.predictions:
            df = pd.DataFrame(self.predictions)
            df.to_csv("test_predictions_final.csv", index=False)

**Training Function and Sweeps**

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping

def run_training(hparams: dict):
    # 1) Data
    dm = DakshinaDataModule(batch_size=hparams["batch_size"])
    dm.prepare_data(); dm.setup()

    # 2) Build encoder & decoder separately
    SRC_VOCAB = len(dm.src_vocab)
    TRG_VOCAB = len(dm.trg_vocab)

    encoder = Encoder(
        input_dim=SRC_VOCAB,
        embed_dim=hparams["embedding_size"],
        hid_dim=hparams["hidden_size"],
        n_layers=hparams["num_layers"],
        bidirectional=hparams["bidirectional"],
        cell_type=hparams["cell_type"],
        dropout=hparams["dropout"],
    )

    decoder = Decoder(
        trg_vocab=dm.trg_vocab,
        output_dim=TRG_VOCAB,
        embed_dim=hparams["embedding_size"],
        hid_dim=hparams["hidden_size"],
        n_layers=hparams["num_layers"],
        bidirectional=False,
        cell_type=hparams["cell_type"],
        dropout=hparams["dropout"],
    )

    model = Seq2SeqLightning(
        encoder=encoder,
        decoder=decoder,
        t_word2idx=dm.t_word2idx,  # Pass target word to index mapping
        t_idx2word=dm.t_idx2word,  # Pass index to target word mapping
        cell_type=hparams["cell_type"],
        bidirectional=hparams["bidirectional"],
        device="cuda" if torch.cuda.is_available() else "cpu",
        learning_rate=hparams["learning_rate"],
        optim_name=hparams["optim"],
        tf_ratio=hparams["teacher_forcing"],
    )

    ckpt = ModelCheckpoint(monitor="val_word_acc", save_top_k=1, mode="max")
    run_name = (
        f"word_level_e{hparams['epochs']}_lr{hparams['learning_rate']}_"
        f"bs{hparams['batch_size']}_emb{hparams['embedding_size']}_"
        f"hid{hparams['hidden_size']}"
    )

    wandb_logger = WandbLogger(
        project="DA6401_Assignment_3_WordLevel",
        name=run_name,
        config=hparams,
        log_model=True
    )

    early_stop = EarlyStopping(
        monitor="val_word_acc",
        patience=3,
        mode="max",
        verbose=True
    )

    trainer = Trainer(
        max_epochs=hparams["epochs"],
        callbacks=[ckpt, early_stop],
        accelerator="auto",
        devices=1,
        logger=wandb_logger
    )

    trainer.fit(model, dm)
    trainer.test(model, dm)  # Test with best model

default_hparams = dict(
    cell_type="lstm",
    dropout=0.2,
    embedding_size=256,
    num_layers=2,
    batch_size=128,
    hidden_size=512,
    bidirectional=False,
    learning_rate=1e-3,
    epochs=20,
    optim="adam",
    teacher_forcing=0.5
)

# run_training(default_hparams)

In [None]:
import wandb, yaml, json
wandb.login(key="41a2853ea088e37bd0d456e78102e82edb455afc")

sweep_config = {
    "method": "bayes",           # or "random", "grid", …
    "metric": {"name": "val_acc", "goal": "maximize"},
    "parameters": {
        "cell_type":      {"values": ["lstm", "gru", "rnn"]},
        "dropout":        {"values": [0.0, 0.1, 0.2, 0.5]},
        "embedding_size": {"values": [64, 128, 256, 512]},
        "num_layers":     {"values": [2, 3, 4]},
        "batch_size":     {"values": [32, 64, 128]},
        "hidden_size":    {"values": [128, 256, 512]},
        "bidirectional":  {"values": [True, False]},
        "learning_rate":  {"values": [1e-3, 2e-3, 1e-4, 2e-4]},
        "epochs":         {"values": [7, 10, 13]},
        "optim":          {"values": ["adam", "nadam"]},
        "teacher_forcing":{"values": [0.2, 0.5, 0.7]},
    },
}

def sweep_train():
    with wandb.init() as run:
        cfg = dict(run.config)
        run_training(cfg)

# Uncomment to launch
sweep_id = wandb.sweep(sweep_config, project="DA6401_Assignment_3",entity="cs24m024-iit-madras")
print(sweep_id)
wandb.agent(sweep_id, function=sweep_train, count=75)


**Listing down Predictions**

In [None]:
import os
import csv

# Mapping indices to tokens
index_to_char = {i: c for c, i in target_tokenizer.word_index.items()}
index_to_char[0] = ''

# Decode predicted sequences
def decode_seq(seq):
    return ''.join(index_to_char.get(idx, '') for idx in seq if idx != 0 and index_to_char.get(idx, '') != '\n')

# Directory for storing predictions
os.makedirs("predictions_vanilla", exist_ok=True)

# Get model predictions
preds = best_model.predict([test_encoder_input, test_decoder_input], verbose=1)
pred_indices = np.argmax(preds, axis=-1)

# Decode predictions and references
decoded_preds = [decode_seq(seq) for seq in pred_indices]
decoded_refs = [ref.replace(' </s>', '') for ref in test_deva_out]

# Save predictions to .txt
with open("predictions_vanilla/test_predictions.txt", "w", encoding="utf-8") as f:
    for inp, pred, ref in zip(test_lat, decoded_preds, decoded_refs):
        f.write(f"{inp}\t{pred}\t{ref}\n")

# Save predictions to .csv
with open("predictions_vanilla/test_predictions.csv", "w", encoding="utf-8", newline='') as f_csv:
    writer = csv.writer(f_csv)
    writer.writerow(["Input", "Predicted", "Reference"])
    writer.writerows(zip(test_lat, decoded_preds, decoded_refs))

# Print top 10 predictions
for i, (inp, pred, ref) in enumerate(zip(test_lat, decoded_preds, decoded_refs)[:10], 1):
    print(f"{i}. Input: {inp}\n   Predicted: {pred}\n   Reference: {ref}\n")


**Attention Classes**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, hidden, encoder_outputs, mask):
        src_len = encoder_outputs.shape[0]
        hidden = hidden.repeat(src_len, 1, 1)  # [src_len, batch, hid_dim]

        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)  # [src_len, batch]
        attention = attention.masked_fill(mask == 0, -1e10)
        return F.softmax(attention, dim=0)

class Encoder(nn.Module):
    def __init__(self, input_dim, embed_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hid_dim, n_layers, dropout=dropout, bidirectional=True)
        self.fc = nn.Linear(hid_dim * 2, hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_len):
        embedded = self.dropout(self.embedding(src))
        packed = pack_padded_sequence(embedded, src_len, enforce_sorted=False)
        outputs, (hidden, cell) = self.rnn(packed)
        outputs, _ = pad_packed_sequence(outputs)

        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))
        cell = torch.tanh(self.fc(torch.cat((cell[-2,:,:], cell[-1,:,:]), dim=1)))

        return outputs, hidden.unsqueeze(0), cell.unsqueeze(0)

class Decoder(nn.Module):
    def __init__(self, output_dim, embed_dim, hid_dim, n_layers, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.rnn = nn.LSTM(embed_dim + hid_dim * 2, hid_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim * 3 + embed_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell, encoder_outputs, mask):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))

        a = self.attention(hidden[-1], encoder_outputs, mask)
        a = a.permute(1, 0).unsqueeze(1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        weighted = torch.bmm(a, encoder_outputs).permute(1, 0, 2)

        rnn_input = torch.cat((embedded, weighted), dim=2)
        output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))

        prediction = self.fc_out(torch.cat(
            (output.squeeze(0), weighted.squeeze(0), embedded.squeeze(0)), dim=1))

        return prediction, hidden, cell

In [None]:
# class Attention(nn.Module):
#     def __init__(self, hidden_size):
#         super().__init__()
#         self.attn = nn.Linear(hidden_size * 2, hidden_size)
#         self.v = nn.Linear(hidden_size, 1, bias=False)

#     def forward(self, hidden, encoder_outputs, mask):
#         # hidden = [1, batch, hid_dim]
#         # encoder_outputs = [src_len, batch, hid_dim]

#         src_len = encoder_outputs.shape[0]

#         # Repeat decoder hidden state src_len times
#         hidden = hidden.repeat(src_len, 1, 1)  # [src_len, batch, hid_dim]

#         energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
#         attention = self.v(energy).squeeze(2)  # [src_len, batch]

#         attention = attention.masked_fill(mask == 0, -1e10)
#         return F.softmax(attention, dim=0)  # [src_len, batch]

# class Encoder(nn.Module):
#     def __init__(self, input_dim, embed_dim, hid_dim, n_layers, dropout):
#         super().__init__()
#         self.hid_dim = hid_dim
#         self.n_layers = n_layers
#         self.embedding = nn.Embedding(input_dim, embed_dim)
#         self.rnn = nn.LSTM(embed_dim, hid_dim, n_layers, dropout=dropout, bidirectional=True)
#         self.fc = nn.Linear(hid_dim * 2, hid_dim)
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, src, src_len):
#         # src = [src_len, batch]
#         embedded = self.dropout(self.embedding(src))  # [src_len, batch, embed_dim]

#         packed = pack_padded_sequence(embedded, src_len, enforce_sorted=False)
#         outputs, (hidden, cell) = self.rnn(packed)
#         outputs, _ = pad_packed_sequence(outputs)  # [src_len, batch, hid_dim * 2]

#         # Combine bidirectional outputs
#         hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))
#         cell = torch.tanh(self.fc(torch.cat((cell[-2,:,:], cell[-1,:,:]), dim=1)))

#         return outputs, hidden.unsqueeze(0), cell.unsqueeze(0)

# class Decoder(nn.Module):
#     def __init__(self, output_dim, embed_dim, hid_dim, n_layers, dropout, attention):
#         super().__init__()
#         self.output_dim = output_dim
#         self.hid_dim = hid_dim
#         self.attention = attention
#         self.embedding = nn.Embedding(output_dim, embed_dim)
#         self.rnn = nn.LSTM(embed_dim + hid_dim * 2, hid_dim, n_layers, dropout=dropout)
#         self.fc_out = nn.Linear(hid_dim * 3 + embed_dim, output_dim)
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, input, hidden, cell, encoder_outputs, mask):
#         # input = [batch]
#         # hidden = [n_layers, batch, hid_dim]
#         # cell = [n_layers, batch, hid_dim]

#         input = input.unsqueeze(0)  # [1, batch]
#         embedded = self.dropout(self.embedding(input))  # [1, batch, embed_dim]

#         a = self.attention(hidden[-1], encoder_outputs, mask)  # [src_len, batch]
#         a = a.permute(1, 0).unsqueeze(1)  # [batch, 1, src_len]

#         encoder_outputs = encoder_outputs.permute(1, 0, 2)  # [batch, src_len, hid_dim*2]
#         weighted = torch.bmm(a, encoder_outputs)  # [batch, 1, hid_dim*2]
#         weighted = weighted.permute(1, 0, 2)  # [1, batch, hid_dim*2]

#         rnn_input = torch.cat((embedded, weighted), dim=2)
#         output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))

#         embedded = embedded.squeeze(0)
#         output = output.squeeze(0)
#         weighted = weighted.squeeze(0)

#         prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1))

#         return prediction, hidden, cell

class Seq2SeqLightning(pl.LightningModule):
    def __init__(self, encoder, decoder, src_vocab, trg_vocab, learning_rate=1e-3):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.learning_rate = learning_rate
        self.criterion = nn.CrossEntropyLoss(ignore_index=trg_vocab["<pad>"])
        self.predictions = []

        # Save hyperparameters for logging
        self.save_hyperparameters(ignore=['encoder', 'decoder', 'src_vocab', 'trg_vocab'])

    def create_mask(self, src):
        mask = (src != self.src_vocab["<pad>"]).permute(1, 0)
        return mask

    def forward(self, src, src_len, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim

        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)

        encoder_outputs, hidden, cell = self.encoder(src, src_len)
        input = trg[0,:]  # First input is <sos>
        mask = self.create_mask(src)

        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs, mask)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[t] if teacher_force else top1

        return outputs

    def _accuracy(self, logits, trg):
        pred = logits.argmax(2)
        correct = (pred == trg) & (trg != self.trg_vocab["<pad>"])
        return correct.float().mean()

    def _word_accuracy(self, pred_words, true_words):
        correct = sum(1 for p, t in zip(pred_words, true_words) if p == t)
        return correct / len(true_words)

    def _step(self, batch, stage):
        src, trg, src_len, trg_len, src_words, trg_words = batch
        output = self(src, src_len, trg)

        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].reshape(-1)

        loss = self.criterion(output, trg)
        acc = self._accuracy(output.view(-1, output_dim), trg)

        # Calculate word-level accuracy
        pred_indices = output.view(-1, output_dim).argmax(1)
        pred_words = [self.trg_vocab.get(idx.item(), "<unk>") for idx in pred_indices]
        word_acc = self._word_accuracy(pred_words, trg_words)

        self.log(f"{stage}_loss", loss, prog_bar=True)
        self.log(f"{stage}_acc", acc, prog_bar=True)
        self.log(f"{stage}_word_acc", word_acc, prog_bar=True)

        if stage in ["val", "test"]:
            for s_word, t_word, p_word in zip(src_words, trg_words, pred_words):
                self.predictions.append({
                    "input": s_word,
                    "target": t_word,
                    "predicted": p_word
                })

        return loss

    def training_step(self, batch, batch_idx):
        return self._step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self._step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self._step(batch, "test")

    def on_validation_epoch_end(self):
        if self.predictions:
            df = pd.DataFrame(self.predictions)
            df.to_csv(f"val_predictions_epoch_{self.current_epoch}.csv", index=False)
            self.predictions = []

    def on_test_epoch_end(self):
        if self.predictions:
            df = pd.DataFrame(self.predictions)
            df.to_csv("test_predictions_final.csv", index=False)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

**Training Function and Sample Run**

In [None]:
# Hyperparameters
EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
DROPOUT = 0.5
LEARNING_RATE = 0.001
BATCH_SIZE = 128
EPOCHS = 20

# Initialize data module
dm = DakshinaDataModule(batch_size=BATCH_SIZE)
dm.prepare_data()
dm.setup()

# Get vocabulary sizes from the data module
SRC_VOCAB_SIZE = len(dm.src_vocab)
TRG_VOCAB_SIZE = len(dm.trg_vocab)

# Initialize model components
attention = Attention(hidden_size=HID_DIM)
encoder = Encoder(
    input_dim=SRC_VOCAB_SIZE,
    embed_dim=EMB_DIM,
    hid_dim=HID_DIM,
    n_layers=N_LAYERS,
    dropout=DROPOUT
)
decoder = Decoder(
    output_dim=TRG_VOCAB_SIZE,
    embed_dim=EMB_DIM,
    hid_dim=HID_DIM,
    n_layers=N_LAYERS,
    dropout=DROPOUT,
    attention=attention
)

# Initialize the complete model
model = Seq2SeqLightning(
    encoder=encoder,
    decoder=decoder,
    src_vocab=dm.src_vocab,
    trg_vocab=dm.trg_vocab,
    device=device,
    learning_rate=LEARNING_RATE
)

# Training callbacks
ckpt = ModelCheckpoint(
    monitor="val_word_acc",
    mode="max",
    save_top_k=1,
    filename="best-model"
)
early_stop = EarlyStopping(
    monitor="val_word_acc",
    patience=3,
    mode="max"
)

# Trainer
trainer = Trainer(
    max_epochs=EPOCHS,
    callbacks=[ckpt, early_stop],
    accelerator="auto",
    devices=1,
    logger=WandbLogger(project="DA6401_Assignment_3")
)

# Start training
trainer.fit(model, dm)
trainer.test(model, dm)

**Loading Best Model**

In [None]:
# Path to the best model checkpoint saved during training
best_model_path = "lightning_logs/version_0/checkpoints/best-model.ckpt"

# Load the best model
best_model = Seq2SeqLightning.load_from_checkpoint(
    checkpoint_path=best_model_path,
    encoder=encoder,
    decoder=decoder,
    src_vocab=dm.src_vocab,
    trg_vocab=dm.trg_vocab,
    device=device,
    learning_rate=LEARNING_RATE
)

In [None]:
def decode_with_attention(model, input_tensor, target_tokenizer, max_len=50, start_token='<s>', end_token='</s>'):
    """
    Performs greedy decoding with attention, returning predicted string and attention weights.

    Args:
        model: Trained PyTorch Lightning model
        input_tensor: Input tensor of shape (1, seq_len) for a single example
        target_tokenizer: tokenizer with word_index and index_word
        max_len: Maximum length of decoded output
        start_token: Start-of-sequence token
        end_token: End-of-sequence token

    Returns:
        decoded_text: Final decoded string
        attention_weights_all: List of attention weights for each decoding timestep
    """
    model.eval()
    index_to_token = {v: k for k, v in target_tokenizer.word_index.items()}
    index_to_token[0] = ''

    with torch.no_grad():
        input_tensor = input_tensor.to(model.device)
        encoder_outputs, hidden, cell = model.encoder(input_tensor)
        decoder_input = torch.tensor([[target_tokenizer.word_index[start_token]]], device=model.device)

        decoded_tokens = []
        attention_weights_all = []

        for _ in range(max_len):
            embedded = model.decoder.embedding(decoder_input)  # (1, 1, emb_dim)
            context_vector, attn_weights = model.attention(encoder_outputs, hidden)  # (1, enc_dim), (1, seq_len)
            rnn_input = torch.cat((embedded, context_vector.unsqueeze(1)), dim=2)
            output, (hidden, cell) = model.decoder.lstm(rnn_input, (hidden.unsqueeze(0), cell.unsqueeze(0)))
            logits = model.decoder.out(output.squeeze(1))  # (1, vocab_size)

            predicted_id = torch.argmax(logits, dim=1).item()
            decoded_tokens.append(predicted_id)
            attention_weights_all.append(attn_weights.squeeze(0).cpu().numpy())  # shape: (seq_len,)

            if predicted_id == target_tokenizer.word_index.get(end_token):
                break

            decoder_input = torch.tensor([[predicted_id]], device=model.device)

        decoded_text = ''.join([index_to_token.get(idx, '') for idx in decoded_tokens])
        return decoded_text, attention_weights_all


**Table for sample predictions**

In [None]:
def plot_colored_prediction_table(samples, save_path="predictions/colored_table.png"):
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.set_axis_off()

    tbl = Table(ax, bbox=[0, 0, 1, 1])
    col_widths = [0.3, 0.4, 0.3]
    headers = ["Input Word", "Predicted Word", "Target Word"]

    # Add headers
    for col_idx, title in enumerate(headers):
        tbl.add_cell(0, col_idx, col_widths[col_idx], 0.05, text=title, loc='center', facecolor='lightgrey')

    # Add prediction rows
    for row_idx, (inp, pred, ref) in enumerate(samples, start=1):
        color = 'lightgreen' if pred == ref else 'lightcoral'
        tbl.add_cell(row_idx, 0, col_widths[0], 0.05, text=inp, loc='center')
        tbl.add_cell(row_idx, 1, col_widths[1], 0.05, text=pred, loc='center', facecolor=color)
        tbl.add_cell(row_idx, 2, col_widths[2], 0.05, text=ref, loc='center')

    tbl.set_fontsize(12)
    tbl.scale(1, 1.5)
    ax.add_table(tbl)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close(fig)

# Prepare sample data for the table
sample_data = list(zip(test_lat[:20], decoded_preds[:20], decoded_refs[:20]))
plot_colored_prediction_table(sample_data)

# Log as image to WandB
wandb.log({"Prediction Table Image": wandb.Image("predictions/colored_table.png")})
wandb.finish()


**Model to create heatmaps**

In [None]:
# Build and compile the attention-based sequence-to-sequence model
attention_model = build_attention_seq2seq_model(
    vocab_size_input=VOCAB_SIZE_INPUT,
    vocab_size_target=VOCAB_SIZE_TARGET,
    embedding_dim=256,
    hidden_dim=256,
    dropout_rate=0.0
)

attention_model.compile(
    optimizer=Adam(),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model on training data with validation set
attention_model.fit(
    [train_encoder_input, train_decoder_input],
    train_target_output,
    validation_data=([val_encoder_input, val_decoder_input], val_target_output),
    batch_size=64,
    epochs=10,
    verbose=2
)


**Noting down the predictions with attention**

In [None]:
import os
import csv
import numpy as np

# Ensure prediction directory exists
os.makedirs("predictions_attentions", exist_ok=True)

# Generate predictions using the best model
predictions = best_model.predict([test_encoder_input, test_decoder_input])
predicted_token_ids = np.argmax(predictions, axis=-1)

# Map token indices to characters
index_to_token = {i: token for token, i in target_tokenizer.word_index.items()}
index_to_token[0] = ''

def decode_prediction(sequence):
    tokens = []
    for idx in sequence:
        if idx == 0:
            continue
        char = index_to_token.get(idx, '')
        if char == '\n':  # stop at end-of-sequence token
            break
        tokens.append(char)
    return ''.join(tokens)

# Decode all predicted and target sequences
decoded_predictions = [decode_prediction(seq) for seq in predicted_token_ids]
decoded_references = [ref.replace(' </s>', '') for ref in test_deva_out]

# Write predictions to CSV
with open("predictions_attentions/test_predictions.csv", "w", encoding="utf-8", newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["Input (Latin)", "Predicted (Devanagari)", "Actual (Devanagari)"])
    for lat_input, prediction, reference in zip(test_lat, decoded_predictions, decoded_references):
        writer.writerow([lat_input, prediction, reference])


**Some plotting functions**

In [None]:
def plot_9_grid():
    plt.figure(figsize=(20, 16))
    sns.set(style="white")  # clean style

    for i in range(9):
        input_text = test_lat[i]
        input_seq = test_encoder_input[i:i+1]
        output_text, attn_weights = decode_with_attention(input_seq)
        attn_matrix = np.stack(attn_weights)  # shape: (dec_len, enc_len)

        ax = plt.subplot(3, 3, i + 1)
        sns.heatmap(
            attn_matrix,
            xticklabels=list(input_text),
            yticklabels=list(output_text),
            cmap='YlGnBu',       # aesthetic colormap
            cbar=False,
            linewidths=0.0,      # cleaner grid
            square=True,
            annot=False          # optionally: annot=True for numeric values
        )

        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=10, ha='right')
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=10)
        ax.set_xlabel("Input (Latin)", fontsize=12)
        ax.set_ylabel("Output (Marathi)", fontsize=12)
        ax.set_title(f"Input: {input_text}", fontsize=14, fontweight='bold')

    plt.tight_layout()
    plt.suptitle("Attention Heatmaps (Latin → Marathi)", fontsize=18, fontweight='bold', y=1.02)
    plt.show()


plot_9_grid()


In [None]:
import os
import seaborn as sns
import matplotlib.pyplot as plt

os.makedirs("attention_heatmaps", exist_ok=True)

sns.set(style="white")  # clean and minimal theme

for i in range(10):
    input_text = test_lat[i]
    input_seq = test_encoder_input[i:i+1]
    output_text, attn_weights = decode_with_attention(input_seq)
    attn_matrix = np.stack(attn_weights)

    plt.figure(figsize=(7, 6))
    ax = sns.heatmap(
        attn_matrix,
        xticklabels=list(input_text),
        yticklabels=list(output_text),
        cmap='rocket',         # or try 'YlGnBu', 'viridis', 'coolwarm'
        linewidths=0,
        cbar=False,
        square=True
    )

    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=10)

    plt.xlabel("Input (Latin)", fontsize=12)
    plt.ylabel("Output (Marathi)", fontsize=12)
    plt.title(f"Sample {i+1}: {input_text}", fontsize=14, fontweight='bold')

    plt.tight_layout()
    plt.savefig(f"attention_heatmaps/sample_{i+1}.png", dpi=300)
    plt.close()


In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt

# Load all 10 images
image_paths = [f"attention_heatmaps/sample_{i+1}.png" for i in range(10)]
images = [Image.open(p) for p in image_paths]

# Get individual image size (assumes all are same size)
img_width, img_height = images[0].size

# Define grid (e.g. 3 rows × 4 columns = 12 slots, we’ll only use 10)
cols = 5
rows = 2

# Create a blank canvas for the collage
grid_width = cols * img_width
grid_height = rows * img_height
collage = Image.new('RGB', (grid_width, grid_height), color='white')

# Paste images into grid
for idx, image in enumerate(images):
    x = (idx % cols) * img_width
    y = (idx // cols) * img_height
    collage.paste(image, (x, y))

# Save combined image
output_path = "attention_heatmaps/combined_heatmaps.png"
collage.save(output_path)

print(f"Combined image saved to {output_path}")


**Attention Values in Heatmap**

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_attention_heatmap(input_text, output_text, attention_weights):
    """
    Visualize attention weights between input and output tokens.

    Parameters:
    - input_text: str
    - output_text: str
    - attention_weights: List[List[float]], shape (output_len, input_len)
    """
    sns.set(style="white")

    input_chars = list(input_text)
    output_chars = list(output_text)

    fig_width = max(6, len(input_chars) * 0.5)
    fig_height = max(4, len(output_chars) * 0.5)

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    sns.heatmap(
        np.array(attention_weights),
        xticklabels=input_chars,
        yticklabels=output_chars,
        cmap='rocket',  # Try 'viridis', 'YlGnBu', 'magma'
        cbar=True,
        linewidths=0,
        square=True,
        ax=ax
    )

    ax.set_xlabel("Input (Latin)", fontsize=12, weight='bold')
    ax.set_ylabel("Output (Marathi)", fontsize=12, weight='bold')
    ax.set_title("Attention", fontsize=14, weight='bold', pad=12)

    ax.tick_params(axis='x', rotation=45, labelsize=10)
    ax.tick_params(axis='y', rotation=0, labelsize=10)

    plt.tight_layout()
    plt.show()

In [None]:
plot_attention_heatmap(
    input_text="barechda",
    output_text="बरेचदा",
    attention_weights=attn_weights)