# ELI5 Demo

This notebook runs **one** example end-to-end:
1) Two-stage LLM reasoning → `Final: X`
2) Optional constrained scoring + confidence
3) Train a tiny surrogate on a small batch and show an ELI5 explanation

In [1]:
!pip -q install -U pandas pyarrow scikit-learn eli5 matplotlib
!pip -q install -U transformers accelerate sentencepiece safetensors bitsandbytes

In [2]:
import os, re, time
import pandas as pd
import eli5
from IPython.display import display, HTML

## Load Dataset

In [10]:
DATA_PATH = "compiled_df.parquet"
df = pd.read_parquet(DATA_PATH)
print("Rows:", len(df))
print("Columns:", df.columns.tolist())

row = df.iloc[50]
print(row)

Rows: 15642
Columns: ['dataset_name', 'id_in_dataset', 'question', 'options', 'answer_label', 'question_type', 'prompt_text']
dataset_name                                               medmcqa
id_in_dataset                                                 7217
question                        True regarding colovesical fistula
options          Answer Choices:\nA. Most commonly presents wit...
answer_label                                                     A
question_type                                                  MCQ
prompt_text      Question:\nTrue regarding colovesical fistula\...
Name: 50, dtype: object


## Load LLM wrapper

In [6]:
from medical_llm_wrapper_fixed import MedicalLLMWrapper

HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "BioMistral/BioMistral-7B"

llm = MedicalLLMWrapper(
    model_name=MODEL_ID,
    device="cuda",
    token=HF_TOKEN,
)

[MedicalLLMWrapper] Loading model: BioMistral/BioMistral-7B


config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/14.5G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Exception in thread Thread-auto_conversion:
Traceback (most recent call last):
  File "/usr/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.12/dist-packages/transformers/safetensors_conversion.py", line 117, in auto_conversion
    raise e
  File "/usr/local/lib/python3.12/dist-packages/transformers/safetensors_conversion.py", line 96, in auto_conversion
    sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/safetensors_conversion.py", line 72, in get_conversion_pr_reference
    spawn_conversion(token, private, model_id)
  File "/usr/local/lib/python3.12/dist-packages/transformers/safetensors_conversion.py", line 48, in spawn_con

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

[MedicalLLMWrapper] ✓ Model loaded successfully
[MedicalLLMWrapper]   Device: cuda
[MedicalLLMWrapper]   Dtype: torch.bfloat16
[MedicalLLMWrapper]   Option token IDs - AB: [330, 365], ABCD: [330, 365, 334, 384]


## prompt + helper

In [7]:
def parse_answer_letter_strict(text):
    if text is None:
        return None
    t = str(text).strip().upper()
    # prefer "ANSWER: X"
    m = re.search(r"(?:ANSWER\s*[:\-]?\s*)([ABCD])\b", t)
    if m:
        return m.group(1)
    # fallback: first standalone letter
    m = re.search(r"\b([ABCD])\b", t)
    return m.group(1) if m else None

def force_letter_suffix():
    return "\n\nReturn ONLY ONE LETTER (A, B, C, or D). No other text.\nAnswer:"

def render_mcq_prompt(question, options=None, prompt_text=None):
    q = str(question).strip()
    header = (str(prompt_text).strip() + "\n\n") if prompt_text else ""

    # normalize options to A-D lines if possible
    opts_block = ""
    if isinstance(options, (list, tuple)):
        letters = ["A","B","C","D"]
        lines = [f"{letters[i]}. {str(options[i]).strip()}" for i in range(min(4, len(options)))]
        opts_block = "\n".join(lines)
    elif isinstance(options, dict):
        # accept {"A": "..."} or {"a": "..."}
        lines = []
        for k in ["A","B","C","D"]:
            if k in options:
                lines.append(f"{k}. {str(options[k]).strip()}")
            elif k.lower() in options:
                lines.append(f"{k}. {str(options[k.lower()]).strip()}")
        opts_block = "\n".join(lines)
        if not opts_block:
            # fallback stringify
            opts_block = "\n".join([f"{k}. {v}" for k, v in options.items()])
    elif options is not None:
        opts_block = str(options).strip()

    prompt = (
        f"{header}"
        f"Question:\n{q}\n\n"
        f"Options:\n{opts_block}\n\n"
        "Select the best answer. Reply with ONLY one letter: A, B, C, or D.\n"
        "Answer:"
    )
    return prompt

## run 1 example demo

In [11]:
prompt = render_mcq_prompt(
    str(row["question"]),
    row.get("options", None),
    row.get("prompt_text", None),
)

gold = str(row.get("answer_label", "")).strip().upper()

# FREE
t0 = time.perf_counter()
# raw_free = llm.generate_free(prompt + reasoning_suffix_mcq(), 256)
raw_f = llm.generate(prompt + force_letter_suffix())
t1 = time.perf_counter()
pred_free = parse_answer_letter_strict(raw_f)

# SCORED / constrained
try: llm.set_task("mcq")
except: pass
try: llm.set_mode("answer_only")
except: pass

t2 = time.perf_counter()
raw_scored = llm.generate(prompt)
t3 = time.perf_counter()
pred_scored = parse_answer_letter_strict(raw_scored)

conf = getattr(llm, "last_confidence", None)
option_probs = getattr(llm, "last_option_probs", None)

final_pred = pred_scored if (conf is not None and float(conf) >= 0.65 and pred_scored in {"A","B","C","D"}) else pred_free

In [12]:
print("=== QUESTION ===\n", row["question"])
print("\n=== OPTIONS ===\n", row.get("options", None))
print("\n=== GOLD ===", gold)

print("\n--- FREE REASONING OUTPUT ---\n", raw_f)
print("Parsed (free):", pred_free, f"| time {(t1-t0):.2f}s")

print("\n--- SCORED OUTPUT ---\n", raw_scored)
print("Parsed (scored):", pred_scored, f"| time {(t3-t2):.2f}s")
print("Confidence:", conf)
print("Option probs:", option_probs)

print("\n>>> FINAL PRED:", final_pred)

=== QUESTION ===
 True regarding colovesical fistula

=== OPTIONS ===
 Answer Choices:
A. Most commonly presents with pneumaturia
B. Most commonly caused by colonic cancer
C. More common in females
D. Readily diagnosed on barium enema

=== GOLD === A

--- FREE REASONING OUTPUT ---
 
Parsed (free): None | time 1.08s

--- SCORED OUTPUT ---
 Answer: A
Parsed (scored): A | time 0.09s
Confidence: 0.7483415603637695
Option probs: {'A': 0.7483415603637695, 'B': 0.047839872539043427, 'C': 0.06538935750722885, 'D': 0.13842926919460297}

>>> FINAL PRED: A


## Surrogate

In [20]:
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression

def clean_for_surrogate(text):
    # minimal cleanup to reduce boilerplate impact
    t = str(text)
    t = re.sub(r"\s+", " ", t).strip()
    return t

def build_surrogate_text_from_row(r):
    # Keep it simple: question + options block (no extra suffixes)
    return render_mcq_prompt(r["question"], r.get("options", None), r.get("prompt_text", None))

def reasoning_suffix_mcq():
    return (
        "\n\nThink step-by-step briefly, then conclude with the final answer letter.\n"
        "Final Answer:"
    )

# build training examples
rows = df.head(200)
X_text = []
y_label = []

In [26]:
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
import eli5
from IPython.display import display

for _, r in rows.iterrows():
    p = render_mcq_prompt(str(r["question"]), r.get("options", None), r.get("prompt_text", None))

    # scored prediction + gate to free if low confidence
    raw_s = llm.generate(p)
    ps = parse_answer_letter_strict(raw_s)
    c = getattr(llm, "last_confidence", None)

    raw_f = llm.generate_free(p + reasoning_suffix_mcq(), 128)
    pf = parse_answer_letter_strict(raw_f)

    y = ps if (c is not None and float(c) >= 0.65 and ps in {"A","B","C","D"}) else pf
    if y not in {"A","B","C","D"}:
        continue

    X_text.append(clean_for_surrogate(build_surrogate_text_from_row(r)))
    y_label.append(y)

print("\nSurrogate training size:", len(y_label))
if len(set(y_label)) < 2 or len(y_label) < 30:
    print("Skipping surrogate: not enough data / class variety.")
else:
    X_train, X_test, y_train, y_test = train_test_split(
        X_text, y_label, test_size=0.25, random_state=0, stratify=y_label
    )

    surrogate = Pipeline([
        ("tfidfvectorizer", TfidfVectorizer(ngram_range=(1,2), min_df=2, max_features=50000)),
        ("logisticregression", LogisticRegression(max_iter=1000))
    ])

    surrogate.fit(X_train, y_train)
    fidelity = (surrogate.predict(X_test) == y_test).mean()
    print(f"[Surrogate] Fidelity to LLM FINAL predictions: {fidelity:.3f} (n_test={len(y_test)})")

    # ELI5 LOCAL EXPLANATION
    import eli5
    from IPython.display import display

    vec = surrogate.named_steps["tfidfvectorizer"]
    clf = surrogate.named_steps["logisticregression"]

    doc = clean_for_surrogate(prompt)   # raw string
    display(eli5.show_prediction(clf, doc, vec=vec, top=20))


Surrogate training size: 600
[Surrogate] Fidelity to LLM FINAL predictions: 0.787 (n_test=150)


Contribution?,Feature
+0.201,<BIAS>
… 22 more positive …,… 22 more positive …
… 3 more negative …,… 3 more negative …
-0.126,Highlighted in text (sum)

Contribution?,Feature
+0.017,Highlighted in text (sum)
… 24 more positive …,… 24 more positive …
… 1 more negative …,… 1 more negative …
-0.602,<BIAS>

Contribution?,Feature
+0.166,Highlighted in text (sum)
… 25 more positive …,… 25 more positive …
-0.524,<BIAS>

Contribution?,Feature
+0.924,<BIAS>
… 2 more positive …,… 2 more positive …
… 23 more negative …,… 23 more negative …
-0.073,Highlighted in text (sum)


In [22]:
raw_scored = llm.generate(prompt)
print("Option probs:", getattr(llm, "last_option_probs", None))
print("Confidence:", getattr(llm, "last_confidence", None))

Option probs: {'A': 0.7483415603637695, 'B': 0.047839872539043427, 'C': 0.06538935750722885, 'D': 0.13842926919460297}
Confidence: 0.7483415603637695


In [36]:
from eli5.formatters.text import format_as_text

exp = eli5.explain_prediction(clf, doc, vec=vec, top=500, top_targets=4)
print(format_as_text(exp))

Explained as: linear model
y='D' (probability=0.487, score=0.813) top features
Contribution  Feature       
------------  --------------
      +0.924  <BIAS>        
      +0.051  more          
      +0.028  caused by     
      +0.028  caused        
      +0.019  by            
      +0.014  commonly      
      +0.011  most          
      +0.009  most commonly 
      +0.002  the           
      +0.001  true          
      -0.001  only          
      -0.001  one           
      -0.001  with only     
      -0.001  reply with    
      -0.001  letter        
      -0.001  answer reply  
      -0.001  or answer     
      -0.001  one letter    
      -0.001  options answer
      -0.001  letter or     
      -0.001  options       
      -0.001  best answer   
      -0.001  only one      
      -0.001  reply         
      -0.001  cancer        
      -0.002  the best      
      -0.002  select        
      -0.002  select the    
      -0.002  or            
      -0.002  best    