In [37]:
import os

PROJECT_ROOT = "/Users/jasleenkaur/Desktop/translit-consistency"
os.chdir(PROJECT_ROOT)

print("Working Directory:", os.getcwd())

Working Directory: /Users/jasleenkaur/Desktop/translit-consistency


In [None]:
import torch
import torch.nn as nn
from g2p_en import G2p
from sklearn.ensemble import RandomForestClassifier

In [39]:
import warnings
warnings.filterwarnings("ignore")

In [40]:
def normalize_english(word):
    return {
        "delhi": "dilli",
        "bangalore": "bangalor",
        "bengaluru": "bangalor",
        "maharashta": "maharashtra",
    }.get(word.lower(), word.lower())

In [41]:
def encode(word, stoi):
    if stoi is en_stoi:
        word = normalize_english(word)
    return [stoi[SOS]] + [stoi[c] for c in word] + [stoi[EOS]]

In [42]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx = 0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first = True)

    def forward(self, x):
        emb = self.embedding(x)
        outputs, (h, c) = self.lstm(emb)
        return outputs, h, c

In [43]:
class LuongAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.scale = 1.0 / (hidden_dim ** 0.5)

    def forward(self, decoder_hidden, encoder_outputs):
        """
        decoder_hidden: (B, H)
        encoder_outputs: (B, src_len, H)
        """
        # (B, src_len)
        scores = torch.bmm(
            encoder_outputs,
            decoder_hidden.unsqueeze(2)
        ).squeeze(2)

        attn_weights = torch.softmax(scores * self.scale, dim=1)

        # (B, H)
        context = torch.bmm(
            attn_weights.unsqueeze(1),
            encoder_outputs
        ).squeeze(1)

        return context, attn_weights

In [44]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.attention = LuongAttention(hidden_dim)
        self.lstm = nn.LSTM(embed_dim + hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, h, c, encoder_outputs):
        # x: (B, 1)
        emb = self.embedding(x)  # (B, 1, E)

        # Attention
        context, _ = self.attention(h[-1], encoder_outputs)  # (B, H)
        context = context.unsqueeze(1)  # (B, 1, H)

        # LSTM
        lstm_input = torch.cat([emb, context], dim=2)  # (B, 1, E+H)
        output, (h, c) = self.lstm(lstm_input, (h, c))  # output: (B, 1, H)

        logits = self.fc(output.squeeze(1))  # (B, vocab)
        return logits, h, c

In [45]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        """
        src: (B, src_len)
        tgt: (B, tgt_len)
        """
        batch_size, tgt_len = tgt.shape
        vocab_size = self.decoder.fc.out_features

        outputs = torch.zeros(batch_size, tgt_len, vocab_size).to(src.device)

        encoder_outputs, h, c = self.encoder(src)

        input_tok = tgt[:, 0].unsqueeze(1)

        for t in range(1, tgt_len):
            logits, h, c = self.decoder(input_tok, h, c, encoder_outputs)
            outputs[:, t] = logits

            teacher_force = random.random() < teacher_forcing_ratio
            top1 = logits.argmax(1).unsqueeze(1)  

            input_tok = tgt[:, t].unsqueeze(1) if teacher_force else top1

        return outputs

In [46]:
COMMON_FIXES = {
    "िि": "ि",
    "ाा": "ा",
    "ुु": "ु",
    "ूू": "ू",
    "ेे": "े",
    "ोो": "ो"
}

def normalize_hindi(text):
    for b, g in COMMON_FIXES.items():
        text = text.replace(b, g)
    return text

In [47]:
def postprocess_hindi(text):
    # 1️⃣ Collapse duplicated matras (िि → ि, etc.)
    for m in ["ि", "ी", "ा", "ु", "ू", "े", "ो"]:
        text = text.replace(m + m, m)

    # 2️⃣ Fix duplicated final consonant (ष्ट्ट → ष्ट)
    if len(text) >= 2 and text[-1] == text[-2]:
        text = text[:-1]

    # 3️⃣ PROTECT common Hindi suffixes
    protected_suffixes = (
        "स्थान", "पुर", "नगर", "गंज", "गढ़", "पुरम"
    )
    for suf in protected_suffixes:
        if text.endswith(suf):
            return text   # DO NOTHING further

    # 4️⃣ Trim hallucinated trailing junk ONLY if long
    if len(text) >= 7:
        # remove trailing vowels like "जो", "यी", "ऊ"
        if text[-1] in {"ो", "ू", "ी"}:
            text = text[:-1]

        # remove trailing filler consonants
        if text[-1] in {"य", "र"}:
            text = text[:-1]

    return text

In [48]:
def fix_common_seq2seq_errors(text):
    # 1️⃣ Remove hallucinated trailing syllables
    if text.endswith("का") and len(text) > 5:
        text = text[:-2]

    # 2️⃣ Fix vowel echo (रीरे → री, रिरे → री)
    text = text.replace("रीरे", "री")
    text = text.replace("रिरे", "री")

    # 3️⃣ Fix double syllable drift (ल्ली → ली)
    text = text.replace("ल्लि", "ल्ली")
    text = text.replace("मिल्लि", "दिल्ली")

    # 4️⃣ Remove dangling halant at end
    if text.endswith("्"):
        text = text[:-1]

    return text

In [49]:
def beam_transliterate(model, word, beam_width=3, max_len=40):
    model.eval()

    src = torch.tensor([encode(word.lower(), en_stoi)], dtype=torch.long).to(device)
    with torch.no_grad():
        encoder_outputs, h, c = model.encoder(src)

    beams = [([hi_stoi[SOS]], 0.0, h, c)]

    for _ in range(max_len):
        new_beams = []
        for seq, score, h, c in beams:
            if seq[-1] == hi_stoi[EOS]:
                new_beams.append((seq, score, h, c))
                continue

            input_tok = torch.tensor([[seq[-1]]], device=device)
            with torch.no_grad():
                logits, h_new, c_new = model.decoder(input_tok, h, c, encoder_outputs)

            log_probs = torch.log_softmax(logits[0], dim=-1)
            topk = torch.topk(log_probs, beam_width)

            for idx, val in zip(topk.indices, topk.values):
                next_char = hi_itos[idx.item()]

                penalty = 0.0
                if seq.count(idx.item()) >= 2:
                    penalty = 1.5
                    
                new_beams.append(
                    (seq + [idx.item()], score + val.item() - penalty, h_new, c_new)
                )

        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
        if all(seq[-1] == hi_stoi[EOS] for seq, _, _, _ in beams):
            break

    best = beams[0][0]
    out = "".join(
        hi_itos[i] for i in best
        if i not in {hi_stoi[SOS], hi_stoi[EOS], hi_stoi[PAD]}
    )
    out = normalize_hindi(out)
    out = postprocess_hindi(out)
    out = fix_common_seq2seq_errors(out)
    return out

In [50]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

checkpoint = torch.load("seq2seq_stable.pt", map_location=device)

In [51]:
en_stoi = checkpoint["en_stoi"]
hi_stoi = checkpoint["hi_stoi"]
hi_itos = checkpoint["hi_itos"]

PAD = checkpoint["PAD"]
SOS = checkpoint["SOS"]
EOS = checkpoint["EOS"]

EMBED_DIM = checkpoint["EMBED_DIM"]
HIDDEN_DIM = checkpoint["HIDDEN_DIM"]

In [52]:
encoder = Encoder(
    vocab_size = len(en_stoi), 
    embed_dim = EMBED_DIM, 
    hidden_dim = HIDDEN_DIM
).to(device)

decoder = Decoder(
    vocab_size = len(hi_stoi), 
    embed_dim = EMBED_DIM, 
    hidden_dim = HIDDEN_DIM
).to(device)

model = Seq2Seq(encoder, decoder, hi_stoi[PAD]).to(device)
model.load_state_dict(checkpoint["model_state"])
model.eval()

print("Seq2Seq model loaded successfully")

Seq2Seq model loaded successfully


In [53]:
from p2g import transliterate_p2g

In [54]:
def char_accuracy(pred, gold):
    m = min(len(pred), len(gold))
    return sum(pred[i] == gold[i] for i in range(m)) / max(len(gold), 1)

In [55]:
def extract_features(word, seq_out, p2g_out):
    return [
        len(seq_out) / max(len(word), 1),      # len_ratio
        seq_out.count("्"),                    # seq_halants
        p2g_out.count("्"),                    # p2g_halants
        len(seq_out) - len(set(seq_out)),      # seq_repeats
        len(p2g_out) - len(set(p2g_out)),      # p2g_repeats
        int(seq_out[-1] in "ािीुूेो") if seq_out else 0,
        int(p2g_out[-1] in "ािीुूेो") if p2g_out else 0,
    ]

In [56]:
import pandas as pd

In [57]:
import json, random

with open("data/processed/aligned_pairs_high_conf.json", encoding = "utf-8") as f:
    raw_pairs = json.load(f)

pairs = [(en.lower(), hi) for en, hi, _ in raw_pairs]
random.seed(42)
random.shuffle(pairs)

n = len(pairs)
train = pairs[:int(0.8 * n)]
val   = pairs[int(0.8 * n):int(0.9 * n)]
test  = pairs[int(0.9 * n):]

print("Train:", len(train))
print("Val:", len(val))
print("Test:", len(test))

Train: 32835
Val: 4104
Test: 4105


In [58]:
X = []
y = []

for en, hi in train[:5000]:   # start with 5k only
    seq_out = beam_transliterate(model, en)
    p2g_out = transliterate_p2g(en)

    features = extract_features(en, seq_out, p2g_out)
    X.append(features)

    # Label: which output is better?
    seq_acc = char_accuracy(seq_out, hi)
    p2g_acc = char_accuracy(p2g_out, hi)

    label = 1 if seq_acc >= p2g_acc else 0
    y.append(label)

In [59]:
X = pd.DataFrame(X)

In [60]:
clf = RandomForestClassifier(
    n_estimators = 200,
    max_depth = 10,
    random_state = 42
)

clf.fit(X, y)
print("Classifier trained")

Classifier trained


In [61]:
def hybrid_transliterate(word):
    seq_out = beam_transliterate(model, word)
    p2g_out = transliterate_p2g(word)

    features = extract_features(word, seq_out, p2g_out)
    decision = clf.predict([features])[0]

    return seq_out if decision == 1 else p2g_out

In [64]:
def evaluate_system(data, predict_fn, limit=None):
    scores = []

    for i, (en, hi) in enumerate(data):
        if limit and i >= limit:
            break

        pred = predict_fn(en)
        scores.append(char_accuracy(pred, hi))

    return sum(scores) / len(scores)

In [66]:
print("Seq2Seq Val:",
      round(evaluate_system(val, lambda w: beam_transliterate(model, w), 2000) * 100, 2), "%")

print("P2G Val:",
      round(evaluate_system(val, transliterate_p2g, 2000) * 100, 2), "%")

print("Hybrid (heuristic) Val:",
      round(evaluate_system(val, hybrid_transliterate, 2000) * 100, 2), "%")

Seq2Seq Val: 47.79 %
P2G Val: 34.76 %
Hybrid (heuristic) Val: 47.83 %


In [67]:
def error_type(pred, gold):
    if len(pred) < len(gold):
        return "UNDER"
    if len(pred) > len(gold):
        return "OVER"
    if "्" in pred and "्" not in gold:
        return "HALANT"
    return "OTHER"

In [68]:
from collections import Counter

def analyze_errors_seq(data, limit=1000):
    counter = Counter()

    for i, (en, hi) in enumerate(data):
        if i >= limit:
            break

        pred = beam_transliterate(model, en)
        counter[error_type(pred, hi)] += 1

    total = sum(counter.values())
    return {k: round(v / total * 100, 2) for k, v in counter.items()}

In [69]:
def analyze_errors_hybrid(data, limit=1000):
    counter = Counter()

    for i, (en, hi) in enumerate(data):
        if i >= limit:
            break

        pred = hybrid_transliterate(en)
        counter[error_type(pred, hi)] += 1
    total = sum(counter.values())
    return {k: round(v / total * 100, 2) for k, v in counter.items()}

In [70]:
print("Seq2Seq Errors:", analyze_errors_seq(val))
print("Hybrid Errors:", analyze_errors_hybrid(val))

Seq2Seq Errors: {'OTHER': 37.0, 'UNDER': 27.4, 'OVER': 34.7, 'HALANT': 0.9}
Hybrid Errors: {'OTHER': 37.0, 'UNDER': 27.5, 'OVER': 34.6, 'HALANT': 0.9}


In [None]:
from collections import Counter

decisions = Counter()

for en, hi in val[:2000]:
    seq_out = beam_transliterate(model, en)
    p2g_out = transliterate_p2g(en)

    feats = extract_features(en, seq_out, p2g_out)
    decision = clf.predict([feats])[0]

    decisions["SEQ2SEQ" if decision == 1 else "P2G"] += 1

print(decisions)

In [None]:
import joblib

torch.save(checkpoint, "seq2seq_final.pt")
joblib.dump(clf, "confidence_selector.pkl")