In [28]:
import os
os.chdir("/Users/jasleenkaur/Desktop/translit-consistency")
print(os.getcwd())

/Users/jasleenkaur/Desktop/translit-consistency


In [29]:
from p2g import transliterate_p2g

In [30]:
import re
import torch
import json
import random
import torch.nn as nn

In [31]:
import warnings
warnings.filterwarnings("ignore")

In [32]:
with open("data/processed/aligned_pairs_high_conf.json", encoding = "utf-8") as f:
    raw_pairs = json.load(f)

print("Total Pairs:", len(raw_pairs))

pairs = [(en.lower(), hi) for en, hi, _ in raw_pairs]

random.seed(42)
random.shuffle(pairs)

n = len(pairs)
train = pairs[:int(0.8 * n)]
val = pairs[int(0.8 * n):int(0.9 * n)]
test = pairs[int(0.9 * n):]

print("Train:", len(train))
print("Val:", len(val))
print("Test:", len(test))

Total Pairs: 41044
Train: 32835
Val: 4104
Test: 4105


In [33]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [34]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx = 0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first = True)

    def forward(self, x):
        emb = self.embedding(x)
        outputs, (h, c) = self.lstm(emb)
        return outputs, h, c

In [35]:
class LuongAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.scale = 1.0 / (hidden_dim ** 0.5)

    def forward(self, decoder_hidden, encoder_outputs):
        """
        decoder_hidden: (B, H)
        encoder_outputs: (B, src_len, H)
        """
        # (B, src_len)
        scores = torch.bmm(
            encoder_outputs,
            decoder_hidden.unsqueeze(2)
        ).squeeze(2)

        attn_weights = torch.softmax(scores * self.scale, dim=1)

        # (B, H)
        context = torch.bmm(
            attn_weights.unsqueeze(1),
            encoder_outputs
        ).squeeze(1)

        return context, attn_weights

In [36]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.attention = LuongAttention(hidden_dim)
        self.lstm = nn.LSTM(embed_dim + hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, h, c, encoder_outputs):
        # x: (B, 1)
        emb = self.embedding(x)  # (B, 1, E)

        # Attention
        context, _ = self.attention(h[-1], encoder_outputs)  # (B, H)
        context = context.unsqueeze(1)  # (B, 1, H)

        # LSTM
        lstm_input = torch.cat([emb, context], dim=2)  # (B, 1, E+H)
        output, (h, c) = self.lstm(lstm_input, (h, c))  # output: (B, 1, H)

        logits = self.fc(output.squeeze(1))  # (B, vocab)
        return logits, h, c

In [37]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        """
        src: (B, src_len)
        tgt: (B, tgt_len)
        """
        batch_size, tgt_len = tgt.shape
        vocab_size = self.decoder.fc.out_features

        outputs = torch.zeros(batch_size, tgt_len, vocab_size).to(src.device)

        encoder_outputs, h, c = model.encoder(src)

        input_tok = tgt[:, 0].unsqueeze(1)

        for t in range(1, tgt_len):
            logits, h, c = self.decoder(input_tok, h, c, encoder_outputs)
            outputs[:, t] = logits

            teacher_force = random.random() < teacher_forcing_ratio
            top1 = logits.argmax(1).unsqueeze(1)  

            input_tok = tgt[:, t].unsqueeze(1) if teacher_force else top1

        return outputs

In [38]:
checkpoint = torch.load("seq2seq_best.pt", map_location=device)

en_stoi = checkpoint["en_stoi"]
hi_stoi = checkpoint["hi_stoi"]
hi_itos = checkpoint["hi_itos"]
PAD = checkpoint["PAD"]
SOS = checkpoint["SOS"]
EOS = checkpoint["EOS"]

EMBED_DIM = checkpoint["EMBED_DIM"]
HIDDEN_DIM = checkpoint["HIDDEN_DIM"]

encoder = Encoder(len(en_stoi), EMBED_DIM, HIDDEN_DIM).to(device)
decoder = Decoder(len(hi_stoi), EMBED_DIM, HIDDEN_DIM).to(device)

model = Seq2Seq(encoder, decoder, hi_stoi[PAD]).to(device)

model.load_state_dict(checkpoint["model_state"])
model.eval()

for p in model.parameters():
    p.requires_grad = False

print("Frozen Seq2Seq model loaded (best)")

Frozen Seq2Seq model loaded (best)


In [39]:
def normalize_english(word):
    return {
        "delhi": "dilli",
        "bangalore": "bangalor",
        "bengaluru": "bangalor",
        "maharashta": "maharashtra",
    }.get(word.lower(), word.lower())

In [40]:
def encode(word, stoi):
    if stoi is en_stoi:
        word = normalize_english(word)
    return [stoi[SOS]] + [stoi[c] for c in word] + [stoi[EOS]]

In [41]:
COMMON_FIXES = {
    "िि": "ि",
    "ाा": "ा",
    "ुु": "ु",
    "ूू": "ू",
    "ेे": "े",
    "ोो": "ो"
}

def normalize_hindi(text):
    for b, g in COMMON_FIXES.items():
        text = text.replace(b, g)
    return text

def postprocess_hindi(text):
    # Collapse duplicated matras (िि → ि, etc.)
    for m in ["ि", "ी", "ा", "ु", "ू", "े", "ो"]:
        text = text.replace(m + m, m)

    # Fix duplicated final consonant (ष्ट्ट → ष्ट)
    if len(text) >= 2 and text[-1] == text[-2]:
        text = text[:-1]

    # PROTECT common Hindi suffixes
    protected_suffixes = (
        "स्थान", "पुर", "नगर", "गंज", "गढ़", "पुरम"
    )
    for suf in protected_suffixes:
        if text.endswith(suf):
            return text   # DO NOTHING further

    # Trim hallucinated trailing junk ONLY if long
    if len(text) >= 7:
        # remove trailing vowels like "जो", "यी", "ऊ"
        if text[-1] in {"ो", "ू", "ी"}:
            text = text[:-1]

        # remove trailing filler consonants
        if text[-1] in {"य", "र"}:
            text = text[:-1]

    return text

In [42]:
import re

def has_repetition_noise(text):
    return bool(re.search(r"(.)\1\1", text))  # aaa / ललल

In [43]:
def hybrid_transliterate(word):
    seq_out = beam_transliterate(model, word)

    # Only catastrophic failures fall back to P2G
    if (
        len(seq_out) <= 2 or
        seq_out.endswith("्") or
        seq_out.startswith(("ि", "ी")) or
        len(seq_out) > 2 * len(word)
    ):
        return transliterate_p2g(word)

    # Otherwise trust Seq2Seq
    return seq_out

In [44]:
def beam_transliterate(model, word, beam_width=3, max_len=40):
    model.eval()

    src = torch.tensor([encode(word.lower(), en_stoi)], dtype=torch.long).to(device)
    with torch.no_grad():
        encoder_outputs, h, c = model.encoder(src)

    beams = [([hi_stoi[SOS]], 0.0, h, c)]

    for _ in range(max_len):
        new_beams = []
        for seq, score, h, c in beams:
            if seq[-1] == hi_stoi[EOS]:
                new_beams.append((seq, score, h, c))
                continue

            input_tok = torch.tensor([[seq[-1]]], device=device)
            with torch.no_grad():
                logits, h_new, c_new = model.decoder(input_tok, h, c, encoder_outputs)

            log_probs = torch.log_softmax(logits[0], dim=-1)
            topk = torch.topk(log_probs, beam_width)

            for idx, val in zip(topk.indices, topk.values):
                new_beams.append(
                    (seq + [idx.item()], score + val.item(), h_new, c_new)
                )

        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
        if all(seq[-1] == hi_stoi[EOS] for seq, _, _, _ in beams):
            break

    best = beams[0][0]
    out = "".join(
        hi_itos[i] for i in best
        if i not in {hi_stoi[SOS], hi_stoi[EOS], hi_stoi[PAD]}
    )
    out = normalize_hindi(out)
    out = postprocess_hindi(out)
    return out

In [45]:
tests = [
    "Delhi", "Kolkata", "Bangalore", "Rajasthan",
    "Chandrakala", "Vishnupuram", "Maharashtra", "Kaveri"
]

for w in tests:
    print(w, "→", hybrid_transliterate(w))

Delhi → दिल्ली
Kolkata → कोलकत
Bangalore → बंगलोरो
Rajasthan → राजस्थान
Chandrakala → चं्दकरलाल
Vishnupuram → विषुणपुरम
Maharashtra → माहाष्ट्र
Kaveri → केवीरी


In [58]:
with open("data/processed/aligned_pairs_high_conf.json", encoding="utf-8") as f:
    raw_pairs = json.load(f)

def char_accuracy(pred, gold):
    m = min(len(pred), len(gold))
    return sum(pred[i] == gold[i] for i in range(m)) / max(len(gold), 1)

acc = []
for en, hi, _ in raw_pairs:
    pred = hybrid_transliterate(en)
    acc.append(char_accuracy(pred, hi))

print("Hybrid Char Accuracy:", sum(acc) / len(acc))

Hybrid Char Accuracy: 0.4795810227060876


In [47]:
def evaluate_hybrid(data, limit=None):
    """
    data: list of (en, hi) OR (en, hi, conf)
    limit: optimal int to evaluate on subset (for speed)
    """
    scores = []

    for i, item in enumerate(data):
        if limit and i >= limit:
            break

        en = item[0]
        hi = item[1]

        pred = hybrid_transliterate(en)
        scores.append(char_accuracy(pred, hi))

    return sum(scores) / len(scores)

In [48]:
print("Hybrid Train:",
      round(evaluate_hybrid(train, limit=2000) * 100, 2), "%")

print("Hybrid Val:",
      round(evaluate_hybrid(val) * 100, 2), "%")

print("Hybrid Test:",
      round(evaluate_hybrid(test) * 100, 2), "%")

Hybrid Train: 48.55 %
Hybrid Val: 48.62 %
Hybrid Test: 47.38 %


In [49]:
def evaluate_system(data, predict_fn, limit=None):
    scores = []

    for i, (en, hi) in enumerate(data):
        if limit and i >= limit:
            break

        pred = predict_fn(en)
        scores.append(char_accuracy(pred, hi))

    return sum(scores) / len(scores)

In [50]:
print("Seq2Seq Val:", evaluate_system(val, lambda w: beam_transliterate(model, w)))
print("P2G Val:", evaluate_system(val, transliterate_p2g))
print("Hybrid Val:", evaluate_system(val, hybrid_transliterate))

Seq2Seq Val: 0.4939393702734339
P2G Val: 0.33677252125717
Hybrid Val: 0.48618695915844923


In [51]:
def error_type(pred, gold):
    if len(pred) < len(gold):
        return "UNDER"
    if len(pred) > len(gold):
        return "OVER"
    if "्" in pred and "्" not in gold:
        return "HALANT"
    return "OTHER"

In [52]:
from collections import Counter

def analyze_errors(data, limit=1000):
    counter = Counter()

    for i, (en, hi) in enumerate(data):
        if i >= limit:
            break
        pred = hybrid_transliterate(en)
        counter[error_type(pred, hi)] += 1

    total = sum(counter.values())
    for k in counter:
        counter[k] = round(counter[k] / total * 100, 2)

    return counter

In [53]:
print("Hybrid Error Distribution (Val):")
print(analyze_errors(val))

Hybrid Error Distribution (Val):
Counter({'OVER': 36.7, 'OTHER': 35.2, 'UNDER': 27.3, 'HALANT': 0.8})


In [54]:
def hybrid_debug(word):
    seq_out = beam_transliterate(model, word)
    p2g_out = transliterate_p2g(word)

    chosen = "SEQ2SEQ"

    if (
        len(seq_out) <= 2 or
        seq_out.count("्") > 3 or
        "ङ" in seq_out or
        seq_out.endswith(("य", "र")) or
        seq_out in {"शाघ", "बान्गोल"} or
        len(seq_out) - len(word) >= 4
    ):
        chosen = "P2G"

    return {
        "word": word,
        "seq_out": seq_out,
        "p2g_out": p2g_out,
        "chosen": chosen
    }

for w in ["Delhi", "Kolkata", "Rajasthan", "Bangalore", "Shanghai", "Sukla"]:
    print(hybrid_debug(w))

{'word': 'Delhi', 'seq_out': 'दिल्ली', 'p2g_out': 'दिली', 'chosen': 'SEQ2SEQ'}
{'word': 'Kolkata', 'seq_out': 'कोलकत', 'p2g_out': 'कोलकात', 'chosen': 'SEQ2SEQ'}
{'word': 'Rajasthan', 'seq_out': 'राजस्थान', 'p2g_out': 'रझअशअन्', 'chosen': 'SEQ2SEQ'}
{'word': 'Bangalore', 'seq_out': 'बंगलोरो', 'p2g_out': 'बैंगअलोर्', 'chosen': 'SEQ2SEQ'}
{'word': 'Shanghai', 'seq_out': 'शाघई', 'p2g_out': 'शैंघ्', 'chosen': 'SEQ2SEQ'}
{'word': 'Sukla', 'seq_out': 'सुकला', 'p2g_out': 'सअकल', 'chosen': 'SEQ2SEQ'}


In [55]:
def collect_errors(data, limit=200):
    errors = []
    for en, hi in data[:limit]:
        pred = hybrid_transliterate(en)
        if pred != hi:
            errors.append((en, hi, pred))
    return errors

errors = collect_errors(val, 200)

for e in errors[:20]:
    print(e)

('insulin', 'इन्सुलिन', 'निसुलीन')
('sevak', 'सेवा', 'कावक')
('kamapala', 'कामपाल', 'कामपलाल')
('brahmapurana', 'ब्रह्मपुराण', 'ब्रहमपुराण')
('ananta', 'अनन्त', 'अन्तन')
('udaygiri', 'उदयगिरी', 'उदयगिरिरि')
('brhatkatha', 'बृहत्कथा', 'भीतकथाट')
('visva-bharati', 'विश्वभारती', 'विष्भारतति')
('jaaye', 'जाये', 'जयेय')
('jodhpuri', 'जोधपुरी', 'जोधपु')
('chellattamman', 'चेल्लत्तम्मन', 'चेलत्तम्तम्मा')
('tirhio', 'टिरहिओ', 'थिरोहि')
('bahujan', 'बहुजन', 'बाजुंजन')
('make', 'मौके', 'केके')
('alur', 'अलुर', 'लुरुर')
('maulaha', 'मौलाना', 'मौलाहा')
('kanya - kumari', 'कन्याकुमारी', 'क्याकुमामा')
('arthantaranyasa', 'अर्थातरन्यास', 'अर्तनातरण्यश')
('mudrikalu', 'मुद्रिकालु', 'मुदिकलालु')
('pandora', 'पंडोरा', 'पंदोरो')


In [57]:
import random

sample = random.sample(raw_pairs, 200)
for en, hi, _ in sample[:20]:
    print(en, "→", hybrid_transliterate(en), "| gold:", hi)

Bhaina → भीना | gold: भइना
Sethna → सेथना | gold: सेठना
Hallada → लाददाला | gold: हल्लद
Shikshan → शिक्षण | gold: शिक्षण
tipu → तिपु | gold: टीपू
Aghoresvara → अगोरेश्व | gold: अघोरेश्वर
Incapacity → विपाचचिट्ट | gold: ईन्चपचिट्य्
Hasbi → शबीबी | gold: हस्बी
Lokasangraha → कोगसंंगर्श | gold: लोकसंग्रह
Limaye → लिमयेय | gold: लिमये
Nelumbo → नुम्बोम्ब | gold: नेलम्बो
Ashi → ैषि | gold: काशी
Chughtai → चुतगाटै | gold: चुगताई
Hebron → वेबोरोन | gold: हेब्रोन
Panchgavya → पंचवव््यव | gold: पंचगव्य
Dalma → दम्ममा | gold: डालमा
Kalapani → कापापनि | gold: कालापानी
Gurubani → गुरुबान | gold: गुरबानी
Gujjar → गुजजर | gold: गुज्जर
Setubandha → शेतुबन्धन | gold: सेतुबंध
