In [13]:
import json, faiss, numpy as np, pandas as pd
from sentence_transformers import SentenceTransformer, CrossEncoder
from pathlib import Path
import ipywidgets as W
from IPython.display import display, Markdown

def norm_paper(x: str) -> str:
    return (x or "").replace("http://arxiv.org/abs/","").replace("https://arxiv.org/abs/","").replace("arXiv:","").strip()

INDEX_PATH   = "../../indexes/faiss_base/index.faiss"
META_PATH    = "../../indexes/faiss_base/meta.jsonl"
RERANKER_DIR = "../../outputs/reranker/minilm_ce_hard"

index = faiss.read_index(INDEX_PATH)

docids, passages, paper_ids, titles = [], [], [], []
with open(META_PATH) as f:
    for line in f:
        o = json.loads(line)
        pid = norm_paper(o.get("paper_id",""))
        cid = int(o.get("chunk_id"))
        title = o.get("title", "").strip()
        docids.append(f"{pid}:{cid}")
        passages.append(o.get("passage") or o.get("text") or "")
        paper_ids.append(pid)
        titles.append(title)

id2text  = dict(zip(docids, passages))
id2paper = dict(zip(docids, paper_ids))
id2title = dict(zip(docids, titles))
dense = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
ce    = CrossEncoder(RERANKER_DIR)

def snippet(txt: str, chars: int = 220) -> str:
    return (txt[:chars] + "…") if len(txt) > chars else txt

In [14]:
def dense_search(query: str, topk: int = 150):
    qemb = dense.encode([query], normalize_embeddings=True)
    D, I = index.search(np.asarray(qemb, dtype="float32"), topk)
    cand_ids  = [docids[i] for i in I[0]]
    cand_txts = [passages[i] for i in I[0]]
    return cand_ids, cand_txts, D[0].tolist()

def ce_rerank(query: str, cand_ids, cand_txts, final_topk: int = 10):
    pairs = [[query, t] for t in cand_txts]
    scores = ce.predict(pairs)
    order  = np.argsort(-np.array(scores))[:final_topk]
    reranked = [(cand_ids[i], float(scores[i])) for i in order]
    return reranked, scores

def neighbor_chunks(docid: str, k: int = 1):
    paper, chunk = docid.split(":")
    chunk = int(chunk)
    out = []
    for c in range(chunk - k, chunk + k + 1):
        key = f"{paper}:{c}"
        if key in id2text:
            out.append((key, id2text[key]))
    return out

In [15]:
def search_and_show(query: str, faiss_topk: int = 150, final_topk: int = 10):
    cand_ids, cand_txts, dense_scores = dense_search(query, topk=faiss_topk)
    
    ce_ranked, ce_scores_full = ce_rerank(query, cand_ids, cand_txts, final_topk=final_topk)

    results = []
    for rank, (docid, score) in enumerate(ce_ranked, start=1):
        paper_id = id2paper.get(docid, "")
        title = id2title.get(docid, "")
        snippet_text = snippet(id2text.get(docid, ""))
        link = f"https://arxiv.org/abs/{paper_id}" if paper_id else ""
        results.append({
            "rank_ce": rank,
            "score_ce": round(score, 4),
            "title": title,
            "snippet": snippet_text,
            "link": link
        })
    
    df = pd.DataFrame(results, columns=["rank_ce", "score_ce", "title", "snippet", "link"])
    return df

In [20]:
q_box = W.Text(
    value="stellar metallicity gradient open clusters",
    placeholder="Type a query…",
    description="Query:",
    layout=W.Layout(width="100%")
)
ce_k    = W.IntSlider(value=10,  min=5,  max=50,  step=1,  description="CE topK")
btn     = W.Button(description="Search", button_style="primary")
out     = W.Output()

def on_click(_):
    out.clear_output()
    with out:
        df = search_and_show(q_box.value, faiss_topk=faiss_k.value, final_topk=ce_k.value)
        display(Markdown(f"### Query: `{q_box.value}`\n"))
        for _, row in df.iterrows():
            title_md = f"[{row['title']}]({row['link']})" if row['link'] else row['title']
            display(Markdown(f"**{row['rank_ce']}. {title_md}**  \n"
                             f"_Score:_ {row['score_ce']}  \n"
                             f"{row['snippet']}\n---"))

btn.on_click(on_click)
display(W.VBox([q_box, W.HBox([ce_k, btn]), out]))

VBox(children=(Text(value='stellar metallicity gradient open clusters', description='Query:', layout=Layout(wi…