# SLM Search Prototype Notebook

Quick end-to-end prototype for the rewrite pipeline from `design.md`.

Includes:
- normalization
- model/SKU-like detection
- catalog lookup
- gating policy (`NONE | SPELL_ONLY | BRAND_MODEL_ENRICH | NLU_REWRITE`)
- constrained rewrite generation
- lightweight evaluation and slices

This notebook now targets the established retail relevance dataset `tasksource/esci` from Hugging Face, with synthetic fallback if loading fails.

In [4]:
from __future__ import annotations

import re
import unicodedata
from dataclasses import dataclass
from difflib import SequenceMatcher
from typing import Any, Dict, List, Optional, Tuple

import pandas as pd

try:
    from datasets import load_dataset
except Exception:
    load_dataset = None

In [5]:
# Load an established retail relevance dataset from Hugging Face.
# Primary choice requested: tasksource/esci
HF_DATASET_ID = "tasksource/esci"
HF_SPLIT_PREFERENCE = ["train", "validation", "test"]


def infer_category_from_title(title: str) -> str:
    t = title.lower()
    if "heater" in t:
        return "water heater"
    if "drill" in t:
        return "power tool"
    if "paint" in t:
        return "paint"
    if "faucet" in t:
        return "faucet"
    return "home improvement"


def infer_brand_from_title(title: str) -> str:
    tokens = re.findall(r"[a-zA-Z0-9]+", title.lower())
    return tokens[0] if tokens else "unknown"


def infer_model_from_title(title: str) -> str:
    tokens = re.findall(r"[a-zA-Z0-9-]{5,}", title.lower())
    for tok in tokens:
        if any(c.isalpha() for c in tok) and any(c.isdigit() for c in tok):
            return tok
    return ""


def pick_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
    lower_to_actual = {c.lower(): c for c in df.columns}
    for name in candidates:
        if name.lower() in lower_to_actual:
            return lower_to_actual[name.lower()]
    return None


catalog_rows: List[Dict[str, Any]] = []
query_pool: List[str] = []
loaded_hf_dataset = None
loaded_hf_name = None
loaded_hf_split = None
loaded_hf_columns: List[str] = []

if load_dataset is not None:
    try:
        ds = load_dataset(HF_DATASET_ID)
        split = next((s for s in HF_SPLIT_PREFERENCE if s in ds), list(ds.keys())[0])
        raw_df = ds[split].to_pandas()
        loaded_hf_columns = list(raw_df.columns)

        # Robust mapping for ESCI-like schemas.
        query_col = pick_col(raw_df, ["query", "search_term", "search_query", "raw_query"])
        title_col = pick_col(raw_df, ["product_title", "title", "product_name", "document"])
        product_id_col = pick_col(raw_df, ["product_id", "product_uid", "doc_id", "asin", "id"])
        brand_col = pick_col(raw_df, ["product_brand", "brand"])
        model_col = pick_col(raw_df, ["model_number", "model", "mpn", "sku"])
        category_col = pick_col(raw_df, ["product_type", "category", "product_category"])

        if title_col is None:
            raise ValueError("No product title-like column found in tasksource/esci split")

        work = raw_df.copy()
        if product_id_col is None:
            work["_product_id"] = work.index.astype(str)
            product_id_col = "_product_id"

        if query_col is not None:
            query_pool = [q for q in work[query_col].astype(str).tolist() if str(q).strip()]

        keep_cols = [product_id_col, title_col]
        for c in [brand_col, model_col, category_col]:
            if c is not None and c not in keep_cols:
                keep_cols.append(c)

        cat_df = work[keep_cols].drop_duplicates(subset=[product_id_col]).head(5000).copy()

        if brand_col is None:
            cat_df["_brand"] = cat_df[title_col].astype(str).apply(infer_brand_from_title)
            brand_col = "_brand"
        if model_col is None:
            cat_df["_model_number"] = cat_df[title_col].astype(str).apply(infer_model_from_title)
            model_col = "_model_number"
        if category_col is None:
            cat_df["_category"] = cat_df[title_col].astype(str).apply(infer_category_from_title)
            category_col = "_category"

        cat_df["_attributes"] = [{} for _ in range(len(cat_df))]

        catalog_rows = [
            {
                "product_id": str(getattr(r, product_id_col)),
                "title": str(getattr(r, title_col)),
                "brand": str(getattr(r, brand_col)),
                "model_number": str(getattr(r, model_col)),
                "category": str(getattr(r, category_col)),
                "attributes": getattr(r, "_attributes"),
            }
            for r in cat_df.itertuples(index=False)
        ]

        loaded_hf_dataset = ds
        loaded_hf_name = HF_DATASET_ID
        loaded_hf_split = split
    except Exception as e:
        print(f"HF load failed for {HF_DATASET_ID}: {e}")

if not catalog_rows:
    # Fallback synthetic catalog
    catalog_rows = [
        {
            "product_id": "p1",
            "title": "Rheem Performance 50 Gal Electric Water Heater",
            "brand": "rheem",
            "model_number": "xe50t06st45u1",
            "category": "water heater",
            "attributes": {"capacity_gal": 50, "voltage": 240},
        },
        {
            "product_id": "p2",
            "title": "AO Smith Signature 40 Gallon Water Heater",
            "brand": "ao smith",
            "model_number": "eg6-40r45dv",
            "category": "water heater",
            "attributes": {"capacity_gal": 40, "voltage": 240},
        },
        {
            "product_id": "p3",
            "title": "Milwaukee M18 Fuel Hammer Drill Kit",
            "brand": "milwaukee",
            "model_number": "2804-20",
            "category": "power tool",
            "attributes": {"voltage": 18, "color": "red"},
        },
    ]
    query_pool = []

catalog_df = pd.DataFrame(catalog_rows)
print("Loaded HF dataset:", loaded_hf_name if loaded_hf_name else "None (using synthetic fallback)")
if loaded_hf_split:
    print("HF split:", loaded_hf_split)
if loaded_hf_columns:
    print("HF columns:", loaded_hf_columns[:20])
print("Catalog size:", len(catalog_df), "| Query pool size:", len(query_pool))
catalog_df.head(10)

HF load failed for tasksource/esci: 'Pandas' object has no attribute '_model_number'
Loaded HF dataset: None (using synthetic fallback)
HF columns: ['example_id', 'query', 'query_id', 'product_id', 'product_locale', 'esci_label', 'small_version', 'large_version', 'product_title', 'product_description', 'product_bullet_point', 'product_brand', 'product_color', 'product_text']
Catalog size: 3 | Query pool size: 0


Unnamed: 0,product_id,title,brand,model_number,category,attributes
0,p1,Rheem Performance 50 Gal Electric Water Heater,rheem,xe50t06st45u1,water heater,"{'capacity_gal': 50, 'voltage': 240}"
1,p2,AO Smith Signature 40 Gallon Water Heater,ao smith,eg6-40r45dv,water heater,"{'capacity_gal': 40, 'voltage': 240}"
2,p3,Milwaukee M18 Fuel Hammer Drill Kit,milwaukee,2804-20,power tool,"{'voltage': 18, 'color': 'red'}"


In [6]:
def normalize_query(q: str) -> str:
    """Fast normalizer: unicode, case, whitespace and dash normalization."""
    q = unicodedata.normalize("NFKC", q)
    q = q.lower().strip()
    q = re.sub(r"[\u2010-\u2015]", "-", q)  # normalize unicode dashes
    q = re.sub(r"\s+", " ", q)
    q = q.replace("/", " ")
    q = re.sub(r"\s*-\s*", "-", q)
    return q


def token_features(q: str) -> Dict[str, Any]:
    digits = sum(ch.isdigit() for ch in q)
    alpha = sum(ch.isalpha() for ch in q)
    n = max(len(q), 1)
    return {
        "len": len(q),
        "digit_ratio": digits / n,
        "alpha_ratio": alpha / n,
        "has_dash": "-" in q,
        "token_count": len(q.split()),
    }


MODEL_PATTERN = re.compile(r"^[a-z0-9][a-z0-9-]{4,}$")


def detect_model_like(q: str) -> bool:
    compact = q.replace(" ", "")
    has_digit = any(c.isdigit() for c in compact)
    has_alpha = any(c.isalpha() for c in compact)
    return bool(MODEL_PATTERN.match(compact)) and has_digit and has_alpha


def similarity(a: str, b: str) -> float:
    return SequenceMatcher(None, a, b).ratio()

In [7]:
def resolve_catalog_candidates(norm_q: str, top_n: int = 3) -> List[Dict[str, Any]]:
    """
    Resolve product candidates by exact/near model match and title overlap.
    This is intentionally simple for notebook testing.
    """
    q_compact = norm_q.replace(" ", "")
    cands: List[Tuple[float, Dict[str, Any]]] = []

    for row in catalog_rows:
        model = row["model_number"].lower()
        title = row["title"].lower()
        brand = row["brand"].lower()

        model_score = max(
            similarity(q_compact, model.replace(" ", "")),
            1.0 if q_compact == model.replace(" ", "") else 0.0,
        )
        title_score = similarity(norm_q, title)
        brand_boost = 0.1 if brand in norm_q else 0.0
        score = 0.65 * model_score + 0.25 * title_score + brand_boost
        cands.append((score, row))

    cands.sort(key=lambda x: x[0], reverse=True)
    out = []
    for score, row in cands[:top_n]:
        item = dict(row)
        item["candidate_score"] = round(score, 4)
        out.append(item)
    return out

In [8]:
@dataclass
class RewriteResult:
    original: str
    rewrite: str
    rewrite_type: str
    confidence: float
    signals: Dict[str, Any]


def gate_rewrite(norm_q: str, cands: List[Dict[str, Any]]) -> Tuple[str, float, Dict[str, Any]]:
    """
    Notebook heuristic gate approximating a tiny classifier.
    Returns (rewrite_type, confidence, signals)
    """
    feats = token_features(norm_q)
    model_like = detect_model_like(norm_q)
    best = cands[0] if cands else None
    best_score = best["candidate_score"] if best else 0.0

    if model_like and best_score >= 0.8:
        return "BRAND_MODEL_ENRICH", min(0.99, 0.7 + best_score * 0.3), {
            "model_number_detected": True,
            "candidate_score": best_score,
            "features": feats,
        }

    if not model_like and feats["token_count"] >= 4 and feats["digit_ratio"] < 0.2:
        return "NLU_REWRITE", 0.72, {
            "model_number_detected": False,
            "candidate_score": best_score,
            "features": feats,
        }

    if feats["token_count"] <= 2 and best_score < 0.45:
        return "SPELL_ONLY", 0.6, {
            "model_number_detected": model_like,
            "candidate_score": best_score,
            "features": feats,
        }

    return "NONE", 0.85, {
        "model_number_detected": model_like,
        "candidate_score": best_score,
        "features": feats,
    }


def constrained_rewrite(original_q: str) -> RewriteResult:
    norm_q = normalize_query(original_q)
    cands = resolve_catalog_candidates(norm_q)
    rewrite_type, conf, signals = gate_rewrite(norm_q, cands)
    rewrite = norm_q

    # Safety constraint: only enrich brand/model when candidate evidence is strong.
    if rewrite_type == "BRAND_MODEL_ENRICH" and cands:
        best = cands[0]
        rewrite = f"{best['brand']} {best['model_number']}"
        signals["brand_source"] = "catalog_lookup"
    elif rewrite_type == "SPELL_ONLY":
        # Simple spell path placeholder: currently only normalized form.
        rewrite = norm_q
    elif rewrite_type == "NLU_REWRITE" and cands:
        # Safe constrained expansion using top candidate category tokens.
        cat = cands[0]["category"]
        if cat not in norm_q:
            rewrite = f"{norm_q} {cat}"
    else:
        rewrite = norm_q

    return RewriteResult(
        original=original_q,
        rewrite=rewrite,
        rewrite_type=rewrite_type,
        confidence=round(float(conf), 3),
        signals=signals,
    )

In [9]:
# Use real search terms from HF when available, plus a few stress-test queries.
seed_queries = [
    "xe50t06st45u1",
    "xe 50 t06 st45 u1",
    "Rheem 50 gal electric heater",
    "milwukee 2804 20",
    "40 gal hot water heater ao smith",
    "drill red 18v",
]

hf_sample_queries = []
if query_pool:
    seen = set()
    for q in query_pool:
        q = str(q).strip()
        if q and q not in seen:
            seen.add(q)
            hf_sample_queries.append(q)
        if len(hf_sample_queries) >= 12:
            break

test_queries = hf_sample_queries + seed_queries

rows = []
for q in test_queries:
    r = constrained_rewrite(q)
    rows.append(
        {
            "original": r.original,
            "rewrite": r.rewrite,
            "rewrite_type": r.rewrite_type,
            "confidence": r.confidence,
            "model_detected": r.signals["model_number_detected"],
            "candidate_score": r.signals["candidate_score"],
        }
    )

pd.DataFrame(rows).head(25)

Unnamed: 0,original,rewrite,rewrite_type,confidence,model_detected,candidate_score
0,xe50t06st45u1,xe50t06st45u1,NONE,0.85,True,0.6924
1,xe 50 t06 st45 u1,xe 50 t06 st45 u1,NONE,0.85,True,0.7214
2,Rheem 50 gal electric heater,rheem 50 gal electric heater,NONE,0.85,True,0.4649
3,milwukee 2804 20,milwukee 2804 20,NONE,0.85,True,0.4793
4,40 gal hot water heater ao smith,40 gal hot water heater ao smith,NONE,0.85,True,0.3424
5,drill red 18v,drill red 18v,NONE,0.85,True,0.1737


In [10]:
# Tiny evaluation harness (replace with proper labels/metrics in later phases)
gold = pd.DataFrame(
    [
        {"query": "xe50t06st45u1", "expected_type": "BRAND_MODEL_ENRICH", "expected_rewrite": "rheem xe50t06st45u1"},
        {"query": "xe 50 t06 st45 u1", "expected_type": "BRAND_MODEL_ENRICH", "expected_rewrite": "rheem xe50t06st45u1"},
        {"query": "drill red 18v", "expected_type": "NLU_REWRITE", "expected_rewrite": "drill red 18v power tool"},
    ]
)

pred_rows = []
for q in gold["query"]:
    r = constrained_rewrite(q)
    pred_rows.append({"query": q, "pred_type": r.rewrite_type, "pred_rewrite": r.rewrite})

pred = pd.DataFrame(pred_rows)
merged = gold.merge(pred, on="query", how="left")
merged["type_ok"] = merged["expected_type"] == merged["pred_type"]
merged["rewrite_exact"] = merged["expected_rewrite"] == merged["pred_rewrite"]

summary = {
    "type_accuracy": merged["type_ok"].mean(),
    "exact_match": merged["rewrite_exact"].mean(),
    "n_examples": len(merged),
}

print(summary)
merged

{'type_accuracy': 0.0, 'exact_match': 0.0, 'n_examples': 3}


Unnamed: 0,query,expected_type,expected_rewrite,pred_type,pred_rewrite,type_ok,rewrite_exact
0,xe50t06st45u1,BRAND_MODEL_ENRICH,rheem xe50t06st45u1,NONE,xe50t06st45u1,False,False
1,xe 50 t06 st45 u1,BRAND_MODEL_ENRICH,rheem xe50t06st45u1,NONE,xe 50 t06 st45 u1,False,False
2,drill red 18v,NLU_REWRITE,drill red 18v power tool,NONE,drill red 18v,False,False


## Next Steps

1. If needed, install Hugging Face datasets in your notebook env: `pip install datasets`.
2. Keep `HF_DATASET_ID = "tasksource/esci"` as the baseline benchmark; only change if needed.
3. Replace inferred `brand/model/category` extraction with a cleaner catalog builder script.
4. Swap heuristic `gate_rewrite()` with a trained classifier.
5. Replace heuristic NLU rewrite path with an SLM rewrite model.
6. Expand evaluation with rewrite harm metrics:
   - wrong-brand enrichment rate
   - category drift rate
   - over-rewrite rate