## Gemma on Hotpot 1200

In [None]:
import os
import json
import csv
import pandas as pd
import torch
from typing import List, Optional, Any, Iterable
from transformers import AutoModelForCausalLM, AutoTokenizer

# ========= CONFIG =========
INPUT_CSV = 'output_files/hotpot/hotpotqa_sampled_with_full_texts_1200.csv'  # must contain: question, hybrid_full_text, contriever_full_text
OUTPUT_HYBRID_CSV = "output_files/hotpot/hotpot_answers_hybrid_gemma2b_1200rows.csv"
OUTPUT_CONTRIEVER_CSV = "output_files/hotpot/hotpot_answers_contriever_gemma2b_1200rows.csv"

# Only process first N rows when testing. Set to None to process all.
TEST_LIMIT: Optional[int] = None  # set None for full run

QUESTION_COL = "question"
HYBRID_CTX_COL = "hybrid_full_text"
CONTRIEVER_CTX_COL = "contriever_full_text"

MODEL_ID = "google/gemma-2-2b-it"  # instruction-tuned Gemma 2B
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Inference/runtime knobs
MAX_NEW_TOKENS = 256
SEED = 42

# Debug: print the exact prompt fed to the model
DEBUG_PROMPTS = True
PROMPT_LOG = None  # e.g., "prompt_log.txt" to save to file; keep None to print to stdout

# Optional: clip long context to avoid OOM
MAX_CTX_CHARS_PER_ITEM = 2000  # clip each item; adjust as needed

# If your context columns are delimiter-separated strings (not JSON), set a delimiter here.
# If None, we will try JSON first; if that fails, treat the whole cell as one item.
CONTEXT_SPLIT_DELIM: Optional[str] = None  # e.g., " ||| " or "\n\n"

ENCODING = "utf-8"
NEWLINE = ""  # important for csv on Windows

torch.manual_seed(SEED)
# ==========================

# ---------- Load model + tokenizer (NO bitsandbytes, works on Python 3.13) ----------
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

# Add this line to fix the chat template error:
if tokenizer.chat_template is None:
    tokenizer.chat_template = "{{ bos_token }}{{ messages[0]['content'] }}{{ eos_token }}"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float16,
    device_map="auto"  # may offload to CPU if VRAM insufficient
)

# ---------- Helpers to normalize context into a list of items ----------
def normalize_context_to_list(val: Any) -> List[str]:
    """
    Converts the cell value into a list of strings (context items).
    Accepts:
      - JSON array of strings
      - JSON array of dicts with "full_text" (will extract the strings)
      - Delimiter-separated string if CONTEXT_SPLIT_DELIM is set
      - Plain string (returns [string])
      - list[str]
      - list[dict] with "full_text"
    """
    if val is None or (isinstance(val, float) and pd.isna(val)):
        return []

    if isinstance(val, list):
        items = []
        for item in val:
            if isinstance(item, dict):
                if "full_text" in item and isinstance(item["full_text"], str):
                    items.append(item["full_text"])
            elif isinstance(item, str):
                items.append(item)
        return [s for s in items if s and str(s).strip()]

    if isinstance(val, dict):
        for k in ["docs", "results", "items", "data"]:
            if k in val and isinstance(val[k], list):
                return normalize_context_to_list(val[k])
        if "full_text" in val and isinstance(val["full_text"], str):
            return [val["full_text"]]
        return []

    if isinstance(val, str):
        s = val.strip()
        if s.startswith("[") or s.startswith("{"):
            try:
                parsed = json.loads(s)
                return normalize_context_to_list(parsed)
            except Exception:
                pass
        if CONTEXT_SPLIT_DELIM is not None and CONTEXT_SPLIT_DELIM in s:
            parts = [p.strip() for p in s.split(CONTEXT_SPLIT_DELIM)]
            return [p for p in parts if p]
        return [s]

    return [str(val)]

# ---------- Structured Prompt builder (plain text; no chat_template) ----------
def build_structured_prompt(question: str, context_items: List[str]) -> str:
    labeled_lines = []
    for i, txt in enumerate(context_items, start=1):
        tx = str(txt)
        if MAX_CTX_CHARS_PER_ITEM is not None:
            tx = tx[:MAX_CTX_CHARS_PER_ITEM]
        clean = " ".join(tx.split())
        labeled_lines.append(f"[context {i}] - {clean}")
    contexts_block = "\n".join(labeled_lines) if labeled_lines else ""

    prompt = (
        "You are a helpful assistant. Follow the user's instructions carefully and answer the question in one or max two sentences. "
        "Do not repeat any part of the prompt in your final answer and answer strictly based on provided contexts.\n\n"
        "Prompt:\n"
        f"Question:- {question}\n"
        "Contexts:-\n"
        f"{contexts_block}\n\n"
        "Your task is to answer the given question by thinking progressively following the steps:\n"
        "Step 0: Process the answer within your memory and only provide the necessary\n"
        "answer.\n"
        "Step 1: Carefully read and understand the question and given contexts which can support you to answer better.\n"
        "Step 2: Analyze whether the given contexts provides sufficient information to answer the question.\n"
        "- If the given contexts do not provide sufficient information, respond with: “The context does not provide\n"
        "sufficient information to answer the question.”\n"
        "- If the given contexts provide sufficient information, proceed to Step 3.\n"
        "Step 3: Generate an accurate and grounded response strictly based on the provided contexts. Avoid\n"
        "guessing or providing incorrect/hallucinated responses.\n"
        "Step 4: Only when you are sure of Step 3, Clearly state the final answer in one or two sentences.\n\n"
        "Answer:"
    )
    return prompt

def generate_answer(question: str, context_text: Any) -> str:
    context_items = normalize_context_to_list(context_text)
    prompt = build_structured_prompt(question, context_items)

    if DEBUG_PROMPTS:
        header = "\n=== PROMPT START ===\n"
        footer = "\n=== PROMPT END ===\n"
        if PROMPT_LOG:
            with open(PROMPT_LOG, "a", encoding=ENCODING) as f:
                f.write(header + prompt + footer)
        else:
            print(header + prompt + footer)

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    output_ids = model.generate(
        **inputs,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
        # use_cache=False,  # uncomment if you still see OOM
    )
    gen_ids = output_ids[0, inputs.input_ids.shape[1]:]
    text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
    return text

# ---------- Incremental CSV writing using csv module ----------
def ensure_header(out_path: str, header_cols: List[str]):
    """
    Create CSV with header if it doesn't exist or is empty.
    """
    need_header = (not os.path.exists(out_path)) or (os.path.getsize(out_path) == 0)
    if need_header:
        with open(out_path, "w", encoding=ENCODING, newline=NEWLINE) as f:
            writer = csv.writer(f)
            writer.writerow(header_cols)

def append_row(out_path: str, row_values: Iterable[Any]):
    """
    Append a single row to the CSV and fsync to ensure durability.
    row_values must match the header column order.
    """
    with open(out_path, "a", encoding=ENCODING, newline=NEWLINE) as f:
        writer = csv.writer(f)
        writer.writerow(row_values)
        f.flush()
        os.fsync(f.fileno())

def run_inference_for_column_incremental(
    df: pd.DataFrame,
    ctx_col: str,
    out_col: str,
    out_path: str,
    source_label: str,
):
    header_cols = list(df.columns) + [out_col]
    ensure_header(out_path, header_cols)

    for idx, row in df.iterrows():
        question = row.get(QUESTION_COL, "")
        context_text = row.get(ctx_col, "")

        if DEBUG_PROMPTS and not PROMPT_LOG:
            preview_q = str(question)[:120].replace("\n", " ")
            print(f"\n--- Row {idx} | Source: {source_label} | Question: {preview_q} ---")

        try:
            ans = generate_answer(question, context_text)
        except Exception as e:
            ans = f"[ERROR DURING GENERATION: {e}]"

        # Compose values in the header order
        row_values = [row.get(c, "") for c in df.columns] + [ans]
        append_row(out_path, row_values)

def main():
    df = pd.read_csv(INPUT_CSV)

    required = [QUESTION_COL, HYBRID_CTX_COL, CONTRIEVER_CTX_COL]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required column(s) in CSV: {missing}")

    if TEST_LIMIT is not None:
        df = df.head(TEST_LIMIT)

    # 1) HYBRID pass
    out_hybrid = "output_files/hotpot/hotpot_answers_hybrid_gemma2b_1200rows.csv" if TEST_LIMIT else OUTPUT_HYBRID_CSV
    run_inference_for_column_incremental(
        df.copy(),
        ctx_col=HYBRID_CTX_COL,
        out_col="answer_llama31_8b_hybrid",
        out_path=out_hybrid,
        source_label="hybrid_full_text"
    )
    print(f"Wrote incrementally: {out_hybrid}")

    # 2) CONTRIEVER pass
    out_contr = "output_files/hotpot/hotpot_answers_contriever_gemma2b_1200rows.csv" if TEST_LIMIT else OUTPUT_CONTRIEVER_CSV
    run_inference_for_column_incremental(
        df.copy(),
        ctx_col=CONTRIEVER_CTX_COL,
        out_col="answer_llama31_8b_contriever",
        out_path=out_contr,
        source_label="contriever_full_text"
    )
    print(f"Wrote incrementally: {out_contr}")

if __name__ == "__main__":
    main()