# 7) Multilingual Synonyms, Morphology & Transliteration for Query Expansion

In [None]:

%%capture
!pip -q install --upgrade pip
!pip -q install datasets transformers sentence-transformers faiss-cpu rank-bm25 torchmetrics scikit-learn lightgbm langdetect unidecode pandas matplotlib tqdm nltk

In [None]:

import numpy as np, pandas as pd, re
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from nltk.corpus import wordnet as wn
import nltk
from unidecode import unidecode
nltk.download('wordnet')

In [None]:

ds = load_dataset("amazon_reviews_multi","en", split="train[:10%]")
df = ds.to_pandas()[["product_id","review_title","review_body"]].dropna()
g = df.groupby("product_id")
docs = g.agg({"review_title":lambda s:" | ".join(s.head(10).astype(str)),
              "review_body":lambda s:" ".join(s.head(5).astype(str))}).reset_index()
docs["doc_text"] = (docs["review_title"].fillna("")+" "+docs["review_body"].fillna("")).str.strip()
docs = docs[docs["doc_text"].str.len()>16].reset_index(drop=True)
pids = set(docs["product_id"])
queries = df[df["product_id"].isin(pids)][["review_title","product_id"]].dropna().drop_duplicates().head(3000).reset_index(drop=True)

In [None]:

def tok(s): return [t for t in re.sub(r"\W+"," ", str(s).lower()).split() if t]
bm25 = BM25Okapi([tok(t) for t in docs["doc_text"].tolist()])
pid_by_idx = docs["product_id"].tolist()

In [None]:

def expand_query_en(q):
    terms = tok(q); syns=set()
    for term in terms:
        for syn in wn.synsets(term):
            for l in syn.lemmas():
                s = l.name().replace("_"," ")
                if len(s)>2 and s not in terms: syns.add(s)
    roman = unidecode(q)
    qexp = q + " " + (" ".join(list(syns)[:5]))
    if roman.lower()!=q.lower(): qexp += " " + roman
    return qexp

def eval_recall(qs, k=10, expand=False):
    hits=0
    for q,pid in qs[["review_title","product_id"]].itertuples(index=False):
        qtext = expand_query_en(q) if expand else q
        s = bm25.get_scores(tok(qtext))
        top = np.argpartition(s, -k)[-k:]
        top = top[np.argsort(-s[top])]
        top_pids = set(docs.iloc[top]["product_id"].tolist())
        if pid in top_pids: hits+=1
    return hits/len(qs)

r_base = eval_recall(queries.head(1000), 10, expand=False)
r_exp = eval_recall(queries.head(1000), 10, expand=True)
print(f"Recall@10 baseline={r_base:.3f} | expanded={r_exp:.3f} | lift={(r_exp-r_base):.3f}")