In [81]:
from pathlib import Path
import json, re

In [82]:
BASE_DIR = Path(".").resolve()
USER_LINE_RE = re.compile(r'^\|user\|\s*:\s*(.*)\s*$', re.M)

In [83]:
def read_data(dataset:str):
    path = BASE_DIR / "queries data" / dataset
    with open(path, "r") as file:
        return file.readlines()
       

In [84]:
print(read_data("clapnq_questions.jsonl")[0])
print(json.loads(read_data("clapnq_questions.jsonl")[0])["_id"])

{"_id":"dd6b6ffd177f2b311abe676261279d2f<::>2","text":"|user|: where do the arizona cardinals play this week\n|user|: Do the Arizona Cardinals play outside the US?"}

dd6b6ffd177f2b311abe676261279d2f<::>2


In [85]:
def build_dict(lines:list[str]) -> dict:
    query_dict = {}
    for line in lines:
        obj = json.loads(line)
        _id = obj["_id"]
        main_id, turn = _id.split("<::>")
        turn = int(turn)
        queries = USER_LINE_RE.findall(obj.get("text"))
        history = queries[:-1]
        current = queries[-1]
        query_dict[(main_id, turn)] = {"history": history, "current": current}
    return(query_dict)

In [86]:
clap_dict = build_dict(read_data("clapnq_questions.jsonl"))

In [87]:
import ollama

In [88]:
EMBEDDING_MODEL = "bge-m3"
LANGUAGE_MODEL = "llama3.2"

In [89]:
def get_embed(text:str) -> list[float]:
    embedding = ollama.embed(model=EMBEDDING_MODEL, input=text)['embeddings'][0]
    return embedding

In [90]:
def cosine_similarity(a, b):
    dot_product = sum([x * y for x, y in zip(a, b)])
    norm_a = sum([x ** 2 for x in a]) ** 0.5
    norm_b = sum([x ** 2 for x in b]) ** 0.5
    return dot_product / (norm_a * norm_b)

def select_history(history, current, k=3, keep_recent=2):
    """keep the recent questions to keep the background knowledge and if the turn is longer find other relative questions"""
    if not history:
        return []
    if len(history) <= k:
        return history

    recent = history[-keep_recent:] if keep_recent > 0 else []
    pool = history[:-keep_recent] if keep_recent > 0 else history

    cur_vec = get_embed(current)
    scored = [(h, cosine_similarity(get_embed(h), cur_vec)) for h in pool]
    scored.sort(key=lambda x: x[1], reverse=True)
    need = max(0, k - len(recent))
    picked = [h for h, _ in scored[:need]]
    # get the question from early turn first and the recent one
    return picked + recent
        
        

In [91]:
history_for_rewrite = select_history(['where is bone marrow found what does it do for the body', 'What happens if it does not work well?', 'How is Sickle cell treated?', 'Will it kill me?', 'How about transplant?', "Is Huntington's disease also inherited?"], 'Any cures for it?')

In [92]:
def rewrite_query(history_for_rewrite:list[str], current:str):
    instruction_prompt = ("You are a query rewriter for multi-turn retrieval. "
    "Rewrite the current question into a single, standalone question that can be answered "
    "without conversation context. Do not add facts that are not stated. Keep names explicit.")
    content_prompt = ("Given the recent conversation history and a current question, "
        "rewrite the question so it is fully self-contained and explicit.\n\n"
        f"Recent history:\n" + "\n".join([f"- {h}" for h in history_for_rewrite]) + "\n\n Current question:\n" + current + "\n\nRewrite the question (one line, no extra explanation):")
    resp = ollama.chat(
        model=LANGUAGE_MODEL,
        messages=[
            {'role':'system','content':instruction_prompt},
            {'role':'user', 'content':content_prompt}
        ]
    )
    return resp['message']['content'].strip()

In [93]:
print(rewrite_query(history_for_rewrite, 'Any cures for it?'))

Is there a cure for Sickle Cell Disease?


In [94]:
import json

OUT = Path("rewritten_last_turn.jsonl")

with OUT.open("w", encoding="utf-8") as f:
    
    for (cid, turn), item in clap_dict.items():
        history = item.get("history") 
        current = item.get("current") 
        # do the select and rewrite
        selected_history = select_history(history, current, k=4, keep_recent=1)
        rewrite = rewrite_query(selected_history, current)

        if not rewrite:   
            continue

        rec = {
            "_id": f"{cid}<::>{turn}",
            "text": f"|user|: {rewrite}"
        }
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")

print("wrote:", OUT)

wrote: rewritten_last_turn.jsonl
