## 1) Setup & Data

This cell:
- Installs all dependencies (Transformers, PEFT stack, datasets, etc.)
- Detects GPU device
- Loads **MedCalc-Bench v1.0** from Hugging Face
- Materializes `train_df` / `test_df` as Pandas DataFrames

In [1]:
!pip -q install datasets pandas

import torch, pandas as pd, json, re, os, numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
device

!pip -q install -U huggingface_hub fsspec pyarrow==15.0.2

from datasets import load_dataset
ds = load_dataset("ncbi/MedCalc-Bench-v1.0")

train_df = ds["train"].to_pandas()
test_df  = ds["test"].to_pandas()

!pip -q install -U "transformers>=4.51.0" "accelerate>=0.34.2" "bitsandbytes>=0.45.0"

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from collections import defaultdict
import time

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/193.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2025.9.0 which is incompatible.
datasets 4.0.0 requires fsspec[http]<=2025.3.0,>=2023.1.0, but you have fsspec 2025.9.0 which is incompatible.[0m[31m
[0m

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Generating train split:   0%|          | 0/10053 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1047 [00:00<?, ? examples/s]

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.1/40.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m50.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25h

## 2) Models: Qwen3 loading & quantization

We load two Qwen3 variants with 4-bit NF4 quantization (QLoRA-friendly).  
This keeps memory within Colab limits while preserving reasonable generation quality.


In [2]:


MODEL_ID_1 = "Qwen/Qwen3-1.7B"
MODEL_ID_2 = "Qwen/Qwen3-0.6B"


def load_qwen(model_id, load_4bit=True):
    bnb = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    ) if load_4bit else None

    tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.float16,
        quantization_config=bnb,
        trust_remote_code=True,
    )
    return tok, model

tok1, model1 = load_qwen(MODEL_ID_1, load_4bit=True)
tok2, model2 = load_qwen(MODEL_ID_2, load_4bit=True)


models = {
    "Qwen3-1.7B": (tok1, model1),
    "Qwen3-0.6B": (tok2, model2),
}


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/622M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

## 3) Parsing & Scoring helpers

The benchmark requires strict exact match on several output types:
- decimal/integer (with tolerance interval if provided),
- date in MM/DD/YY,
- gestational age in X weeks, Y days.

This cell defines:
- An Answer: line extractor (we always prompt the model to output exactly one such line).
- Parsers for numeric, date, and weeks/days formats.
- A comparator that checks predicted vs. ground truth using the benchmark’s rules.


In [3]:
import re, numpy as np

NUM_RE = r'([\-+]?(?:\d{1,3}(?:,\d{3})*|\d+)(?:\.\d+)?(?:e[+\-]?\d+)?)'

def _last_answer_line(text: str):
    # Return the last "Answer: ..." line (stripped), or None
    matches = list(re.finditer(r'(?im)^\s*answer\s*:\s*(.+)\s*$', text))
    return matches[-1].group(1).strip() if matches else None

def parse_numeric(text: str):
    # Prefer the last explicit Answer: line; normalize commas
    ans = _last_answer_line(text)
    hay = ans if ans is not None else text
    # Look for a number on that line/text
    m = re.search(NUM_RE, hay)
    if not m and ans is None:
        # If no Answer: line at all, as a last resort use the LAST number in whole text
        m_all = list(re.finditer(NUM_RE, text))
        m = m_all[-1] if m_all else None
    if not m:
        return np.nan
    s = m.group(1).replace(',', '')
    try:
        return float(s)
    except:
        return np.nan

def parse_date_mmddyy(s: str):
    def _norm(dt):
        m = re.match(r'^\s*(\d{1,2})[/-](\d{1,2})[/-](\d{2,4})\s*$', dt)
        if not m: return None
        mm, dd, yy = int(m.group(1)), int(m.group(2)), m.group(3)
        if len(yy) == 4: yy = yy[-2:]
        return f"{mm:02d}/{dd:02d}/{yy}"
    ans = _last_answer_line(s)
    if ans:
        m = re.search(r'(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', ans)
        if m:
            n = _norm(m.group(1))
            if n: return n
    # fallback: last date anywhere
    m_all = list(re.finditer(r'(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', s))
    return _norm(m_all[-1].group(1)) if m_all else None

def parse_weeks_days(s: str):
    ans = _last_answer_line(s) or s
    # strict form first
    m = re.search(r'(\d+)\s*weeks?[, ]+\s*(\d+)\s*days?', ans, re.I)
    if m:
        return f"({int(m.group(1))} weeks, {int(m.group(2))} days)"
    # compact w/d
    m = re.search(r'(\d+)\s*w(?:eeks?)?\b[\s,]*([0-9]+)\s*d(?:ays?)?\b', ans, re.I)
    if m:
        return f"({int(m.group(1))} weeks, {int(m.group(2))} days)"
    # single terms -> fill missing with 0
    mw = re.search(r'(\d+)\s*weeks?', ans, re.I)
    md = re.search(r'(\d+)\s*days?',  ans, re.I)
    if mw or md:
        w = int(mw.group(1)) if mw else 0
        d = int(md.group(1)) if md else 0
        return f"({w} weeks, {d} days)"
    return None

def extract_answer(text, output_type):
    output_type = (str(output_type) or "").strip().lower()
    if output_type in ("decimal", "integer"):
        return parse_numeric(text)
    if output_type == "date":
        return parse_date_mmddyy(text)
    if "week" in output_type or "day" in output_type:
        return parse_weeks_days(text)
    return parse_numeric(text)

def compare_pred_to_gt(pred, row, tol=1e-6):
    out_type = str(row.get("Output Type", "")).lower()
    gt = row["Ground Truth Answer"]
    lo = row.get("Lower Limit", gt)
    hi = row.get("Upper Limit", gt)

    if out_type in ("decimal", "integer"):
        try:
            pred = float(pred)
            lo = float(lo); hi = float(hi)
        except:
            return False

        return (not np.isnan(pred)) and (lo - tol <= pred <= hi + tol)

    if out_type == "date":
        gt_norm = parse_date_mmddyy(str(gt))
        return (pred is not None) and (gt_norm is not None) and (pred == gt_norm)

    if "week" in out_type or "day" in out_type:
        gt_str = str(gt)
        m = re.search(r'(\d+)\s*weeks?[, ]+\s*(\d+)\s*days?', gt_str, re.I)
        if m:
            gt_can = f"({int(m.group(1))} weeks, {int(m.group(2))} days)"
        else:

            gt_can = parse_weeks_days(gt_str)
        return (pred is not None) and (gt_can is not None) and (pred == gt_can)

    try:
        pred = float(pred); gt = float(gt)
        return abs(pred - gt) <= tol
    except:
        return False


## 4) Zero-shot prompt & smoke test

We define a zero shot prompt that:
- Restricts the model to use only ENTITIES,
- Demands a single Answer: line,
- Injects a type hint (decimal/integer, date, weeks/days) to enforce formatting.


In [4]:
def build_zero_shot(row):
    import json
    ents = row["Relevant Entities"]
    try:
        ents = json.dumps(json.loads(ents), ensure_ascii=False)
    except:
        ents = str(ents)

    question  = str(row["Question"])
    out_type  = str(row.get("Output Type", "")).lower()

    if out_type == "date":
        type_hint = "The answer must be a date in MM/DD/YY."
    elif "week" in out_type or "day" in out_type:
        type_hint = "The answer must be exactly in the form (X weeks, Y days)."
    elif out_type == "integer":
        type_hint = "The answer must be an integer (no decimals, no units)."
    else:
        type_hint = "The answer must be a decimal number (no units)."


    prompt = (
        "You are a helpful medical calculation assistant.\n"
        "For the following calculation question, use ONLY the values in [ENTITIES] to compute the answer. Ignore unrelated prose.\n"
        "Respond with EXACTLY ONE line that begins with 'Answer:' followed by the value.\n"
        "Example: Answer: 9\n"
        "Do NOT include templates, angle brackets, headings, bullet points, or extra words.\n\n"
        "[CASE]\n"
        f"{row['Patient Note']}\n\n"
        "[QUESTION]\n"
        f"{question}\n\n"
        "[ENTITIES]\n"
        f"{ents}\n"
    )
    return prompt

@torch.inference_mode()
def generate_with(model, tokenizer, prompt, max_new_tokens=12):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,           # deterministic
        temperature=0.0,
        top_p=1.0,
        repetition_penalty=1.08,   # small nudge against repeating "Answer:"
        pad_token_id=tokenizer.eos_token_id,
    )
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    return text[len(prompt):].strip() if text.startswith(prompt) else text

# Smoke test on 1 random test row
row = test_df.sample(n=1, random_state=0).iloc[0]
prompt = build_zero_shot(row)
raw = generate_with(model1, tok1, prompt)
print("raw prediction:", raw)
pred = extract_answer(raw, row["Output Type"])
print("Extracted prediction:", pred)
ok = compare_pred_to_gt(pred, row)
print("GT:", row["Ground Truth Answer"])
print("Within range?", ok)

The following generation flags are not valid and may be ignored: ['temperature', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


raw prediction: Answer:
Answer: 10

The SOFA score
Extracted prediction: 10.0
GT: 9
Within range? False


## 5) Zero-shot evaluation: full test sweep & CSV export

This section runs the zero-shot prompt over all test rows, computes accuracy per category and overall, and saves raw predictions to a CSV.

What this does
- Builds a strict zero-shot prompt per row (via `build_zero_shot`).
- Generates a deterministic answer (greedy decoding).
- Parses the model output into the required type (`decimal/integer`, `date`, `(weeks, days)`).
- Compares against ground truth with the benchmark’s rules/tolerances.
- Logs interim overall accuracy every `log_every` examples.
- Saves a row-level CSV with both raw text and parsed value



In [None]:


def eval_zero_shot_full(test_df, model, tok, max_new_tokens=16, log_every=100, save_path="results/zero_shot_predictions.csv"):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    correct_by_cat = defaultdict(int)
    total_by_cat   = defaultdict(int)
    rows_out = []

    t0 = time.time()
    for i, row in test_df.reset_index().iterrows():
        prompt = build_zero_shot(row)
        raw = generate_with(model, tok, prompt, max_new_tokens=max_new_tokens)

        # directly use your parser
        pred = extract_answer(raw, row["Output Type"])
        ok = compare_pred_to_gt(pred, row)

        cat = row["Category"]
        correct_by_cat[cat] += int(ok)
        total_by_cat[cat]   += 1

        rows_out.append({
            "row_index": row["Row Number"] if "Row Number" in row else int(row["index"]),
            "category": cat,
            "calculator": row.get("Calculator Name", ""),
            "output_type": row.get("Output Type", ""),
            "question": row["Question"],
            "prediction_text": raw,       # store raw generation
            "prediction_value": pred,     # parsed value
            "ground_truth": row["Ground Truth Answer"],
            "lower_limit": row.get("Lower Limit", row["Ground Truth Answer"]),
            "upper_limit": row.get("Upper Limit", row["Ground Truth Answer"]),
            "correct": ok,
        })

        if (i+1) % log_every == 0:
            elapsed = time.time() - t0
            overall_tmp = sum(correct_by_cat.values()) / max(1, sum(total_by_cat.values()))
            print(f"[{i+1}/{len(test_df)}] interim overall acc: {overall_tmp:.3f} | elapsed {elapsed:.1f}s")

    pd.DataFrame(rows_out).to_csv(save_path, index=False)

    acc_by_cat = {c: correct_by_cat[c] / total_by_cat[c] for c in total_by_cat}
    overall = sum(correct_by_cat.values()) / max(1, sum(total_by_cat.values()))
    return acc_by_cat, overall, pd.DataFrame(rows_out)

# Run it (this will iterate over ALL test rows)
acc_by_cat, overall, preds_df = eval_zero_shot_full(
    test_df=test_df,
    model=model1,
    tok=tok1,
    max_new_tokens=16,
    log_every=10,

)

print("Overall accuracy:", round(overall, 4))
print("Per-category:", acc_by_cat)

[10/1047] interim overall acc: 0.000 | elapsed 18.1s
[20/1047] interim overall acc: 0.000 | elapsed 34.7s
[30/1047] interim overall acc: 0.000 | elapsed 49.6s
[40/1047] interim overall acc: 0.050 | elapsed 64.8s
[50/1047] interim overall acc: 0.080 | elapsed 78.0s
[60/1047] interim overall acc: 0.067 | elapsed 92.6s
[70/1047] interim overall acc: 0.057 | elapsed 107.0s
[80/1047] interim overall acc: 0.075 | elapsed 121.7s
[90/1047] interim overall acc: 0.111 | elapsed 137.0s
[100/1047] interim overall acc: 0.120 | elapsed 153.5s
[110/1047] interim overall acc: 0.127 | elapsed 169.4s
[120/1047] interim overall acc: 0.133 | elapsed 185.5s
[130/1047] interim overall acc: 0.154 | elapsed 202.3s
[140/1047] interim overall acc: 0.164 | elapsed 218.4s
[150/1047] interim overall acc: 0.160 | elapsed 232.5s
[160/1047] interim overall acc: 0.156 | elapsed 247.3s
[170/1047] interim overall acc: 0.153 | elapsed 263.3s
[180/1047] interim overall acc: 0.161 | elapsed 277.0s
[190/1047] interim overal

In [None]:
acc_cat_small, overall_small, preds_small = eval_zero_shot_full(
    test_df=test_df,
    model=model2,
    tok=tok2,
    max_new_tokens=16,
    log_every=10,
)
print("0.6B | Zero-shot overall:", round(overall_small, 4))
print("Per-category:", acc_cat_small)

[10/1047] interim overall acc: 0.000 | elapsed 15.5s
[20/1047] interim overall acc: 0.000 | elapsed 29.7s
[30/1047] interim overall acc: 0.000 | elapsed 42.6s
[40/1047] interim overall acc: 0.000 | elapsed 56.0s
[50/1047] interim overall acc: 0.020 | elapsed 68.9s
[60/1047] interim overall acc: 0.033 | elapsed 82.2s
[70/1047] interim overall acc: 0.029 | elapsed 95.2s
[80/1047] interim overall acc: 0.025 | elapsed 108.8s
[90/1047] interim overall acc: 0.022 | elapsed 122.3s
[100/1047] interim overall acc: 0.030 | elapsed 136.7s
[110/1047] interim overall acc: 0.036 | elapsed 150.7s
[120/1047] interim overall acc: 0.067 | elapsed 164.9s
[130/1047] interim overall acc: 0.062 | elapsed 178.8s
[140/1047] interim overall acc: 0.057 | elapsed 193.2s
[150/1047] interim overall acc: 0.053 | elapsed 206.5s
[160/1047] interim overall acc: 0.050 | elapsed 219.6s
[170/1047] interim overall acc: 0.047 | elapsed 233.4s
[180/1047] interim overall acc: 0.044 | elapsed 246.2s
[190/1047] interim overall

## 6) Few-shot prompting: exemplar selection & prompt builder

This section constructs **few-shot prompts** that show the model a small set of solved examples (“EXAMPLES: … Answer: …”) before asking it to solve the target case.

**Design choices & rationale**
- **Exemplar selection (`pick_exemplars`)** follows this strategy:
  1) Prefer up to two items from the same calculator as the query (highest transfer for formula & units).
  2) Then fill remaining slots from the same category + output type where units are compatible (prevents mg↔mcg or mmol/L↔mg/dL drift).
  3) If still short, add a format anchor from the same output type (teaches the exact output string shape, e.g., `MM/DD/YY`, `(X weeks, Y days)`).
  4) Finally, fall back to random items to reach `k` (ensures the prompt is always populated).
- **Unit compatibility** is enforced by extracting units from `[ENTITIES]` JSON and requiring overlaps to match; if either side has `None`, it doesn’t block the match.
- **Prompt builder (`build_few_shot`)** prints each exemplar as:

[ QUESTION ]
 [ text ]
[ ENTITIES ]
{ ...json... }
[ Answer: canonical ground truth ]

Then prints the target **[CASE]/[QUESTION]/[ENTITIES]** with a strict **type hint** and the rule “**Answer must be exactly one line beginning with ‘Answer:’**”.
- **Ground-truth canonicalization** (`_normalize_gt_for_prompt`) ensures exemplars display **exactly** the formats used by the grader (e.g., `MM/DD/YY`, `(X weeks, Y days)`, numeric without units).

In [5]:
WEAK_CATS = {"lab", "date", "dosage", "severity"}

import json, random

def _json_safe(x):
    try:
        return json.dumps(json.loads(x), ensure_ascii=False)
    except Exception:
        return str(x)

def _unit_map(ents_str):
    """Extract {key -> unit_str or None} from the ENTITIES string."""
    try:
        obj = json.loads(ents_str)
    except Exception:
        return {}
    out = {}
    if isinstance(obj, dict):
        for k, v in obj.items():
            if isinstance(v, list) and len(v) >= 2 and isinstance(v[1], str):
                out[k] = v[1].strip().lower()
            else:
                out[k] = None
    return out

def _units_compatible(ents_a, ents_b):
    """Return True if overlapping keys have the same unit strings (or both None)."""
    ua, ub = _unit_map(ents_a), _unit_map(ents_b)
    overlap = set(ua.keys()) & set(ub.keys())
    if not overlap:
        return True
    for k in overlap:
        if ua[k] is None or ub[k] is None:
            continue
        if ua[k] != ub[k]:
            return False
    return True

def pick_exemplars(train_df, row, k=3, seed=42):
    rnd = random.Random(seed)
    calc = row.get("Calculator Name", "")
    cat  = str(row["Category"]).lower()
    ot   = str(row["Output Type"]).lower()
    ents_q = row["Relevant Entities"]

    chosen = []

    # 1) Same calculator (up to 2)
    same_calc = train_df[train_df["Calculator Name"] == calc]
    same_calc = same_calc[same_calc.index != getattr(row, "name", None)]
    same_calc = same_calc.sample(n=min(2, len(same_calc)), random_state=seed) if len(same_calc) else same_calc
    chosen += list(same_calc.index)

    # 2) Same category + output type + unit-compatible
    need = k - len(chosen)
    if need > 0:
        pool = train_df[
            (train_df["Category"].str.lower() == cat) &
            (train_df["Output Type"].str.lower() == ot) &
            (~train_df.index.isin(chosen))
        ]
        pool = pool[[ _units_compatible(ents_q, e) for e in pool["Relevant Entities"] ]] if len(pool) else pool
        if len(pool) > 0:
            chosen += rnd.sample(list(pool.index), k=min(need, len(pool)))

    # 3) Format anchor: same output type + unit-compatible
    need = k - len(chosen)
    if need > 0:
        pool = train_df[
            (train_df["Output Type"].str.lower() == ot) &
            (~train_df.index.isin(chosen))
        ]
        pool = pool[[ _units_compatible(ents_q, e) for e in pool["Relevant Entities"] ]] if len(pool) else pool
        if len(pool) > 0:
            chosen += rnd.sample(list(pool.index), k=min(need, len(pool)))

    # 4) Fallback random
    need = k - len(chosen)
    if need > 0 and len(train_df[~train_df.index.isin(chosen)]) > 0:
        chosen += rnd.sample(list(train_df[~train_df.index.isin(chosen)].index), k=min(need, len(train_df)))

    return train_df.loc[chosen]


def build_few_shot(row, train_df, k=3, seed=0):
    ex_df = pick_exemplars(train_df, row, k=k, seed=seed)

    shots = []
    for _, r in ex_df.iterrows():
        ents = _json_safe(r["Relevant Entities"])
        gt   = _normalize_gt_for_prompt(r)
        shots.append(
            f"[QUESTION]\n{r['Question']}\n"
            f"[ENTITIES]\n{ents}\n"
            f"\nAnswer: {gt}\n"
        )
    fewshot_block = "\n".join(shots)


    out_type = str(row.get("Output Type","")).lower()
    if out_type == "date":
        type_hint = "The answer must be a date in MM/DD/YY."
    elif "week" in out_type or "day" in out_type:
        type_hint = "The answer must be exactly in the form (X weeks, Y days)."
    elif out_type == "integer":
        type_hint = "The answer must be an integer (no decimals, no units)."
    else:
        type_hint = "The answer must be a decimal number (no units)."

    ents_q = _json_safe(row["Relevant Entities"])
    question = str(row["Question"])

    prompt = (
        "You are a helpful medical calculation assistant. Your task is to compute the answer to a medical task that involves calculation.\n"
        f"{type_hint}\n"
        "Use ONLY the values in [ENTITIES] to compute the answer. Ignore unrelated prose.\n"
        "Respond with EXACTLY ONE line that begins with 'Answer:' followed by the value.\n"
        "Do NOT include templates, headings, or extra words.\n\n"
        "EXAMPLES:"
        f"{fewshot_block}\n"
        "### YOUR TASK\n"
        "[CASE]\n"
        f"{row['Patient Note']}\n\n"
        "[QUESTION]\n"
        f"{question}\n\n"
        "[ENTITIES]\n"
        f"{ents_q}\n"
    )
    return prompt


def _normalize_gt_for_prompt(row):
    """Make sure GT is in the benchmark's canonical display form for the prompt."""
    ot = str(row.get("Output Type","")).lower()
    gt = str(row["Ground Truth Answer"])
    if ot == "date":
        norm = parse_date_mmddyy(f"Answer: {gt}")
        return norm if norm else gt
    if "week" in ot or "day" in ot:
        norm = parse_weeks_days(f"Answer: {gt}")
        return norm if norm else gt

    return gt



## 7) General evaluator for prompt-based methods (Few-shot / Zero-shot / CoT)

This section defines a general-purpose evaluator that runs a prompt builder over the entire test set, generates answers, parses them, computes accuracy per category and overall, and exports a row-level CSV.

We can plug in any `build_fn` (e.g., `build_zero_shot`, `build_few_shot`, `build_cot_fewshot_fast`) to reuse the same scoring and logging logic across methods and models. This makes comparisons fair and reproducible.

We also evaluate the few_shot method here


In [6]:
def eval_prompt_full(test_df, model, tok, build_fn, *, max_new_tokens=16, log_every=100, save_path="results/preds.csv"):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    correct_by_cat, total_by_cat = defaultdict(int), defaultdict(int)
    rows_out = []
    t0 = time.time()

    for i, row in test_df.reset_index().iterrows():
        prompt = build_fn(row)
        raw = generate_with(model, tok, prompt, max_new_tokens=max_new_tokens)
        pred = extract_answer(raw, row["Output Type"])
        ok = compare_pred_to_gt(pred, row)

        cat = row["Category"]
        correct_by_cat[cat] += int(ok)
        total_by_cat[cat]   += 1

        rows_out.append({
            "row_index": row.get("Row Number", int(row["index"])),
            "category": cat,
            "calculator": row.get("Calculator Name",""),
            "output_type": row.get("Output Type",""),
            "question": row["Question"],
            "prediction_text": raw,
            "prediction_value": pred,
            "ground_truth": row["Ground Truth Answer"],
            "lower_limit": row.get("Lower Limit", row["Ground Truth Answer"]),
            "upper_limit": row.get("Upper Limit", row["Ground Truth Answer"]),
            "correct": ok,
        })

        if (i+1) % log_every == 0:
            overall_tmp = sum(correct_by_cat.values()) / max(1, sum(total_by_cat.values()))
            print(f"[{i+1}/{len(test_df)}] interim overall acc: {overall_tmp:.3f} | elapsed {time.time()-t0:.1f}s")

    pd.DataFrame(rows_out).to_csv(save_path, index=False)
    acc_by_cat = {c: correct_by_cat[c] / total_by_cat[c] for c in total_by_cat}
    overall = sum(correct_by_cat.values()) / max(1, sum(total_by_cat.values()))
    return acc_by_cat, overall, pd.DataFrame(rows_out)


In [None]:

from collections import defaultdict
import time

acc_fs_17b, overall_fs_17b, preds_fs_17b = eval_prompt_full(
    test_df=test_df,
    model=model1,
    tok=tok1,
    build_fn=lambda r: build_few_shot(r, train_df=train_df, k=3, seed=42),
    max_new_tokens=16,
    log_every=10,
    save_path="results/fewshot3_qwen3-1.7b_test_preds.csv"
)
print("1.7B | Few-shot(k=3) overall:", round(overall_fs_17b, 4))
print("Per-category:", acc_fs_17b)

[10/1047] interim overall acc: 0.000 | elapsed 21.4s
[20/1047] interim overall acc: 0.000 | elapsed 40.9s
[30/1047] interim overall acc: 0.000 | elapsed 57.2s
[40/1047] interim overall acc: 0.000 | elapsed 74.2s
[50/1047] interim overall acc: 0.020 | elapsed 89.3s
[60/1047] interim overall acc: 0.050 | elapsed 105.6s
[70/1047] interim overall acc: 0.071 | elapsed 121.2s
[80/1047] interim overall acc: 0.113 | elapsed 136.7s
[90/1047] interim overall acc: 0.144 | elapsed 153.1s
[100/1047] interim overall acc: 0.150 | elapsed 170.9s
[110/1047] interim overall acc: 0.182 | elapsed 188.4s
[120/1047] interim overall acc: 0.217 | elapsed 206.8s
[130/1047] interim overall acc: 0.215 | elapsed 224.4s
[140/1047] interim overall acc: 0.207 | elapsed 242.1s
[150/1047] interim overall acc: 0.200 | elapsed 258.8s
[160/1047] interim overall acc: 0.188 | elapsed 275.2s
[170/1047] interim overall acc: 0.194 | elapsed 292.3s
[180/1047] interim overall acc: 0.206 | elapsed 307.5s
[190/1047] interim overa

In [None]:
acc_fs_06b, overall_fs_06b, preds_fs_06b = eval_prompt_full(
    test_df=test_df,
    model=model2,
    tok=tok2,
    build_fn=lambda r: build_few_shot(r, train_df=train_df, k=3, seed=42),
    max_new_tokens=16,
    log_every=10,
    save_path="results/fewshot3_qwen3-0.6b_test_preds.csv"
)
print("0.6-B | Few-shot(k=3) overall:", round(overall_fs_06b, 4))
print("Per-category:", acc_fs_06b)

[10/1047] interim overall acc: 0.000 | elapsed 20.3s
[20/1047] interim overall acc: 0.000 | elapsed 37.5s
[30/1047] interim overall acc: 0.000 | elapsed 52.9s
[40/1047] interim overall acc: 0.025 | elapsed 68.5s
[50/1047] interim overall acc: 0.040 | elapsed 85.4s
[60/1047] interim overall acc: 0.050 | elapsed 100.6s
[70/1047] interim overall acc: 0.057 | elapsed 115.1s
[80/1047] interim overall acc: 0.050 | elapsed 129.6s
[90/1047] interim overall acc: 0.044 | elapsed 144.8s
[100/1047] interim overall acc: 0.070 | elapsed 160.9s
[110/1047] interim overall acc: 0.073 | elapsed 177.2s
[120/1047] interim overall acc: 0.092 | elapsed 195.3s
[130/1047] interim overall acc: 0.092 | elapsed 212.1s
[140/1047] interim overall acc: 0.121 | elapsed 228.5s
[150/1047] interim overall acc: 0.113 | elapsed 243.7s
[160/1047] interim overall acc: 0.106 | elapsed 258.8s
[170/1047] interim overall acc: 0.106 | elapsed 274.4s
[180/1047] interim overall acc: 0.111 | elapsed 288.6s
[190/1047] interim overa

## 8) Chain-of-Thought (CoT): one worked example + two-line constrained output

This section builds a **CoT** prompt. Example:
- `[CALCULATOR]`, `[QUESTION]`, `[ENTITIES]`
- a **Explanation:** from the dataset’s ground-truth rationale
- the **Answer:** in the exact format the grader expects

Then we present the **TASK** with its own `[CALCULATOR]/[QUESTION]/[ENTITIES]` and enforce a two line output contract:

Explanation: [short steps using ENTITIES and the formula]

Answer: [value only]

In [7]:

def _type_hint(ot: str) -> str:
    ot = (ot or "").strip().lower()
    if ot == "date": return "Final answer must be a date in MM/DD/YY."
    if "week" in ot or "day" in ot: return "Final answer must be exactly: (X weeks, Y days)."
    if ot == "integer": return "Final answer must be an integer (no decimals, no units)."
    return "Final answer must be a decimal number (no units)."

def _tidy_explanation(x: str):
    return re.sub(r'\s+', ' ', str(x or '')).strip()

def _json_safe(x):
    try: return json.dumps(json.loads(x), ensure_ascii=False)
    except Exception: return str(x)

def _normalize_gt_for_prompt(row):
    ot = str(row.get("Output Type","")).lower()
    gt = str(row["Ground Truth Answer"])
    if ot == "date": return parse_date_mmddyy(f"Answer: {gt}") or gt
    if "week" in ot or "day" in ot: return parse_weeks_days(f"Answer: {gt}") or gt
    return gt

def build_cot_fewshot_fast(row, train_df, seed=42):
    ex_df = pick_exemplars(train_df, row, k=1, seed=seed)

    shots = []
    for _, r in ex_df.iterrows():
        ents = _json_safe(r["Relevant Entities"])
        gt   = _normalize_gt_for_prompt(r)
        expl = str(r.get("Ground Truth Explanation","")).strip()
        shots.append(
            f"[CALCULATOR]\n{r.get('Calculator Name','')}\n"
            f"[QUESTION]\n{r['Question']}\n"
            f"[ENTITIES]\n{ents}\n"
            f"Explanation: {expl}\n"
            f"Answer: {gt}\n"
        )
    fewshot_block = "\n".join(shots)

    try:
        ents_q = json.dumps(json.loads(row["Relevant Entities"]), ensure_ascii=False)
    except Exception:
        ents_q = str(row["Relevant Entities"])

    type_hint = _type_hint(row.get("Output Type",""))

    prompt = (
        "You are a careful clinical calculator.\n"
        f"{type_hint}\n"
        "Use ONLY the numbers in TASK [ENTITIES]. Ignore numbers from EXAMPLES and any other text.\n"
        "Always convert units if needed. Do not invent values.\n"
        "Output exactly two lines for the TASK:\n"
        "Explanation: <short steps using ENTITIES and the formula>\n"
        "Answer: <value only>\n"
        "After printing the Answer line, STOP.\n\n"
        "### EXAMPLES\n"
        f"{fewshot_block}\n"
        "### TASK\n"
        f"[CALCULATOR]\n{row.get('Calculator Name','')}\n"
        "[QUESTION]\n" + row["Question"] + "\n\n"
        "[ENTITIES]\n" + ents_q + "\n"
    )
    return prompt

import torch

@torch.inference_mode()
def generate_cot_one_pass(model, tok, prompt, max_new_tokens=96):
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=0.0,
        top_p=1.0,
        repetition_penalty=1.00,
        use_cache=True,
        pad_token_id=tok.eos_token_id,
        eos_token_id=tok.eos_token_id,
    )
    text = tok.decode(out[0], skip_special_tokens=True)
    return text[len(prompt):].strip() if text.startswith(prompt) else text





## 9) Fast batched evaluator (vectorized prompts → batched generate)

This section accelerates evaluation by:
1) Pre-building all prompts (so tokenizer can batch them),
2) Padding/truncating to a shared max_length per batch,
3) Running a single model.generate(...` call per batch,
4) Decoding and scoring each item, then exporting a row-level CSV.

Then the COT method is evaluated using this

In [8]:
def eval_prompt_full_fast(
    test_df, model, tok, build_fn,
    *, batch_size=16, max_new_tokens=48, log_every=100, save_path="results/preds_fast.csv"
):
    from collections import defaultdict
    import os, time, pandas as pd
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    prompts = [build_fn(row) for _, row in test_df.reset_index(drop=True).iterrows()]
    cats    = [row["Category"] for _, row in test_df.reset_index(drop=True).iterrows()]
    ots     = [row["Output Type"] for _, row in test_df.reset_index(drop=True).iterrows()]
    rows    = [row for _, row in test_df.reset_index(drop=True).iterrows()]

    correct_by_cat, total_by_cat = defaultdict(int), defaultdict(int)
    rows_out, t0 = [], time.time()

    for start in range(0, len(prompts), batch_size):
        batch_prompts = prompts[start:start+batch_size]
        ctx = model.config.max_position_embeddings
        max_input_len = ctx - max_new_tokens - 16
        enc = tok(
          batch_prompts,
          return_tensors="pt",
          padding=True,
          truncation=True,
          max_length=max_input_len
        ).to(model.device)

        out = model.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            top_p=1.0,
            repetition_penalty=1.00,
            use_cache=True,
            pad_token_id=tok.eos_token_id,
            eos_token_id=tok.eos_token_id,
        )
        texts = tok.batch_decode(out, skip_special_tokens=True)

        for i, full in enumerate(texts):
            prompt = batch_prompts[i]
            raw = full[len(prompt):].strip() if full.startswith(prompt) else full
            idx = start + i
            row = rows[idx]

            pred = extract_answer(raw, row["Output Type"])
            ok   = compare_pred_to_gt(pred, row)
            cat  = cats[idx]
            correct_by_cat[cat] += int(ok)
            total_by_cat[cat]   += 1

            rows_out.append({
                "row_index": row.get("Row Number", idx),
                "category": cat,
                "calculator": row.get("Calculator Name",""),
                "output_type": row.get("Output Type",""),
                "question": row["Question"],
                "prompt": prompt,
                "prediction_text": raw,
                "prediction_value": pred,
                "ground_truth": row["Ground Truth Answer"],
                "lower_limit": row.get("Lower Limit", row["Ground Truth Answer"]),
                "upper_limit": row.get("Upper Limit", row["Ground Truth Answer"]),
                "correct": ok,
            })

        done = start + len(batch_prompts)
        if done % log_every == 0:
            overall_tmp = sum(correct_by_cat.values()) / max(1, sum(total_by_cat.values()))
            print(f"[{done}/{len(prompts)}] interim overall acc: {overall_tmp:.3f} | {time.time()-t0:.1f}s")

    df = pd.DataFrame(rows_out)
    df.to_csv(save_path, index=False)
    acc_by_cat = {c: correct_by_cat[c] / total_by_cat[c] for c in total_by_cat}
    overall = sum(correct_by_cat.values()) / max(1, sum(total_by_cat.values()))
    return acc_by_cat, overall, df



In [None]:
acc_cot_17b, overall_cot_17b, preds_cot_17b = eval_prompt_full_fast(
    test_df=test_df,
    model=model1,
    tok=tok1,
    build_fn=lambda r: build_cot_fewshot_fast(r, train_df, seed=42),
    batch_size=16,                  # try 8/16/24 depending on VRAM
    max_new_tokens=48,              # short CoT
    log_every=10,
    save_path="results/cot_fast_qwen3-1.7b_test.csv",
)
print("1.7B | CoT fast overall:", round(overall_cot_17b, 4))
print("Per-category:", acc_cot_17b)


[80/1047] interim overall acc: 0.025 | 56.6s
[160/1047] interim overall acc: 0.075 | 102.7s
[240/1047] interim overall acc: 0.075 | 141.4s
[320/1047] interim overall acc: 0.072 | 192.2s
[400/1047] interim overall acc: 0.065 | 231.4s
[480/1047] interim overall acc: 0.065 | 278.7s
[560/1047] interim overall acc: 0.059 | 327.1s
[640/1047] interim overall acc: 0.056 | 380.8s
[720/1047] interim overall acc: 0.054 | 421.0s
[800/1047] interim overall acc: 0.051 | 485.1s
[880/1047] interim overall acc: 0.048 | 537.2s
[960/1047] interim overall acc: 0.048 | 587.9s
[1040/1047] interim overall acc: 0.044 | 641.7s
1.7B | CoT fast overall: 0.0439
Per-category: {'lab': 0.021406727828746176, 'risk': 0.06666666666666667, 'physical': 0.025, 'severity': 0.125, 'diagnosis': 0.08333333333333333, 'date': 0.0, 'dosage': 0.05}


In [None]:
print(acc_cot_06b)
print(overall_cot_06b)

{'lab': 0.039755351681957186, 'risk': 0.0375, 'physical': 0.0375, 'severity': 0.05, 'diagnosis': 0.08333333333333333, 'date': 0.0, 'dosage': 0.075}
0.041069723018147083


In [None]:
acc_cot_06b, overall_cot_06b, preds_cot_06b = eval_prompt_full_fast(
    test_df=test_df,
    model=model2,
    tok=tok2,
    build_fn=lambda r: build_cot_fewshot_fast(r, train_df, seed=42),
    batch_size=16,                  # try 8/16/24 depending on VRAM
    max_new_tokens=48,              # short CoT
    log_every=10,
    save_path="results/cot_fast_qwen3-0.6b_test.csv",
)
print("0.6B | CoT fast overall:", round(overall_cot_06b, 4))
print("Per-category:", acc_cot_17b)

[80/1047] interim overall acc: 0.025 | 42.4s
[160/1047] interim overall acc: 0.075 | 76.8s
[240/1047] interim overall acc: 0.054 | 108.7s
[320/1047] interim overall acc: 0.059 | 145.5s
[400/1047] interim overall acc: 0.068 | 176.7s
[480/1047] interim overall acc: 0.060 | 210.9s
[560/1047] interim overall acc: 0.057 | 245.6s
[640/1047] interim overall acc: 0.056 | 284.7s
[720/1047] interim overall acc: 0.054 | 316.4s
[800/1047] interim overall acc: 0.049 | 360.8s
[880/1047] interim overall acc: 0.044 | 399.6s
[960/1047] interim overall acc: 0.044 | 437.0s
[1040/1047] interim overall acc: 0.041 | 476.2s
0.6B | CoT fast overall: 0.0439
Per-category: {'lab': 0.021406727828746176, 'risk': 0.06666666666666667, 'physical': 0.025, 'severity': 0.125, 'diagnosis': 0.08333333333333333, 'date': 0.0, 'dosage': 0.05}


## 10) LoRA setup

Prepares **parameter-efficient finetuning** for both models. Using QLoRA  by default to stay within Colab memory

In [9]:
!pip -q install -U "transformers>=4.51.0" "accelerate>=0.34.2" "bitsandbytes>=0.45.0" peft datasets

import os, math, json, re, time, random
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    BitsAndBytesConfig, TrainingArguments, DataCollatorForLanguageModeling, Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m503.6/503.6 kB[0m [31m37.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.8/42.8 MB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pylibcudf-cu12 25.6.0 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 21.0.0 which is incompatible.
cudf-cu12 25.6.0 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 21.0.0 which is incompatible.[0m[31m
[0m

In [10]:
SEED = 42
torch.manual_seed(SEED)
random.seed(SEED)

OUTDIR_06B = "ft_qwen3_06b_qlora_fewshot"
OUTDIR_17B = "ft_qwen3_17b_qlora_fewshot"

USE_QLoRA_06B = True
USE_QLoRA_17B = True

## 11) Build SFT pairs, tokenize with label-masking, and prep QLoRA-ready models

We construct a dataset where each example is:
- **prompt**: an instruction created with `build_few_shot(...)` (k=3), containing EXAMPLES and the target `[CASE]/[QUESTION]/[ENTITIES]`.
- **target**: a single  line: `Answer: <ground-truth>\n`.  
  This teaches the exact output string the grader expects (numeric/date/gestational).

### QLoRA/LoRA parameter notes
- **Rank (`r`)**: controls adapter size; higher rank = more expressive, but heavier. We use `r=16` as a balance.  
- **Alpha (`α`)**: scaling factor; effective update strength is `α / r` (here = 2.0).  
- **Dropout**: regularizes adapter updates (`0.05` default) to avoid overfitting on repetitive SFT data.  
- **Target modules**: we adapt the attention projections (`q/k/v/o`) since they govern reasoning and output formatting.  
These settings keep training lightweight on Colab while still giving models the capacity to learn stable calculation and formatting behaviors.



In [11]:
def make_direct_answer_pairs(train_df, seed=SEED):
    rows = []
    for _, row in train_df.reset_index(drop=True).iterrows():
        prompt = build_few_shot(row, train_df, k=3, seed=seed)
        ans = _normalize_gt_for_prompt(row)
        target = f"Answer: {ans}\n"
        rows.append({"prompt": prompt, "target": target})
    return Dataset.from_list(rows)

def tokenize_mask_direct(tokenizer, ex, max_len=768):
    prompt_ids = tokenizer(ex["prompt"], add_special_tokens=False).input_ids
    target_ids = tokenizer(ex["target"], add_special_tokens=False).input_ids

    input_ids = prompt_ids + target_ids
    labels    = [-100] * len(prompt_ids) + target_ids

    # truncate from left if too long
    if len(input_ids) > max_len:
        overflow = len(input_ids) - max_len
        input_ids = input_ids[overflow:]
        labels    = labels[overflow:]

    # pad to max_len
    pad_len = max_len - len(input_ids)
    if pad_len > 0:
        input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
        labels    = labels + [-100] * pad_len
        attn_mask = [1] * len(input_ids[:-pad_len]) + [0] * pad_len
    else:
        attn_mask = [1] * max_len

    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": attn_mask,
    }

def build_tokenized_dataset(tokenizer, ds, max_len=768):
    return ds.map(lambda ex: tokenize_mask_direct(tokenizer, ex, max_len=max_len),
                  remove_columns=list(ds.features), batched=False)

def load_qwen_for_peft(model_id: str, qlora: bool = True):
    if qlora:
        bnb = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )
    else:
        bnb = None

    tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=None if qlora else torch.float16,
        quantization_config=bnb,
        trust_remote_code=True,
    )
    if qlora:
        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
    return tok, model

def attach_lora(model,
                r=16, alpha=32, dropout=0.05,
                target_modules=("q_proj","k_proj","v_proj","o_proj"),
                bias="none"):
    cfg = LoraConfig(
        r=r, lora_alpha=alpha, lora_dropout=dropout,
        target_modules=list(target_modules),
        task_type="CAUSAL_LM",
        bias=bias,
    )
    model = get_peft_model(model, cfg)
    model.print_trainable_parameters()
    return model

if tok2.pad_token is None:
    tok2.pad_token = tok2.eos_token

from transformers import DataCollatorForLanguageModeling
collator06 = DataCollatorForLanguageModeling(tokenizer=tok2, mlm=False)


In [12]:
train_sft = make_direct_answer_pairs(train_df)

KeyboardInterrupt: 

## 12) Fine-Tune Qwen3 Models with QLoRA and Evaluate

In this section we run parameter-efficient fine-tuning** on both selected models (0.6B and 1.7B) using QLoRA.  
- **TrainingArguments** are tuned for Colab T4 (effective batch size ≈32, 1 epoch, cosine schedule).  
- Adapters are attached only to **attention projections (`q/k/v/o`)** to keep training light.  
- Adapters and tokenizers are saved to Google Drive for persistence across sessions.  
- After training, we reload the adapters into the frozen base models using `PeftModel.from_pretrained(...)`.  
- Finally, we evaluate the finetuned models with the same few-shot evaluation pipeline used earlier to compare against zero-/few-shot baselines.


In [None]:
from transformers import TrainingArguments, default_data_collator

# 1. Tokenize dataset with max_len=512
ds06 = train_sft.map(
    lambda ex: tokenize_mask_direct(tok2, ex, max_len=512),
    remove_columns=list(train_sft.features),
    batched=False
).with_format("torch")

# 2. Training args tuned for T4
args06_fast = TrainingArguments(
    output_dir=OUTDIR_06B,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,    # 4x8 = 32 effective batch
    learning_rate=1e-4,
    num_train_epochs=1,               # only 1 epoch
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    logging_steps=1,                  # log every step
    save_strategy="epoch",
    bf16=torch.cuda.is_available(),
    gradient_checkpointing=False,     # off for speed
    optim="paged_adamw_8bit",
    report_to="none",
    disable_tqdm=False,
)

# 3. Collator (handles padding safely)
collator06 = default_data_collator

# 4. Trainer
trainer06 = Trainer(
    model=model06,
    args=args06_fast,
    train_dataset=ds06,
    data_collator=collator06,
    tokenizer=tok2,
)

# 5. Train
trainer06.train()



Map:   0%|          | 0/10053 [00:00<?, ? examples/s]

  trainer06 = Trainer(


Step,Training Loss
1,1.3277
2,1.1208
3,1.1645
4,1.2015
5,1.1447
6,1.1255
7,1.0825
8,1.0901
9,0.929
10,0.9605


Step,Training Loss
1,1.3277
2,1.1208
3,1.1645
4,1.2015
5,1.1447
6,1.1255
7,1.0825
8,1.0901
9,0.929
10,0.9605


TrainOutput(global_step=315, training_loss=0.7285082096145267, metrics={'train_runtime': 11602.7665, 'train_samples_per_second': 0.866, 'train_steps_per_second': 0.027, 'total_flos': 1.4226247756283904e+16, 'train_loss': 0.7285082096145267, 'epoch': 1.0})

In [None]:
from google.colab import drive
drive.mount('/content/drive')
OUTDIR_06B = "/content/drive/MyDrive/qwen3_06b_lora"
trainer06.model.save_pretrained(f"{OUTDIR_06B}/adapter")
tok2.save_pretrained(f"{OUTDIR_06B}/tokenizer")

Mounted at /content/drive


('/content/drive/MyDrive/qwen3_06b_lora/tokenizer/tokenizer_config.json',
 '/content/drive/MyDrive/qwen3_06b_lora/tokenizer/special_tokens_map.json',
 '/content/drive/MyDrive/qwen3_06b_lora/tokenizer/chat_template.jinja',
 '/content/drive/MyDrive/qwen3_06b_lora/tokenizer/vocab.json',
 '/content/drive/MyDrive/qwen3_06b_lora/tokenizer/merges.txt',
 '/content/drive/MyDrive/qwen3_06b_lora/tokenizer/added_tokens.json',
 '/content/drive/MyDrive/qwen3_06b_lora/tokenizer/tokenizer.json')

In [None]:
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from transformers import TrainingArguments, default_data_collator, Trainer

OUTDIR_17B = "ft_qwen3_17b_qlora_fewshot_fast"

# 0) Safety: pad token
if tok1.pad_token is None:
    tok1.pad_token = tok1.eos_token

# 1) Prepare quantized base for training
model17 = prepare_model_for_kbit_training(model1, use_gradient_checkpointing=False)

# 2) Attach LoRA (attention-only to keep it fast)
lora_cfg17 = LoraConfig(
    r=12, lora_alpha=24, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj"],
    task_type="CAUSAL_LM", bias="none",
)
model17 = get_peft_model(model17, lora_cfg17)
model17.print_trainable_parameters()


ds17 = train_sft.map(
    lambda ex: tokenize_mask_direct(tok1, ex, max_len=384),
    remove_columns=list(train_sft.features),
    batched=False
).with_format("torch")

# 4) Fast TrainingArguments (1 epoch, effective batch 32)
args17_fast = TrainingArguments(
    output_dir=OUTDIR_17B,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,
    num_train_epochs=1,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    logging_steps=1,
    save_strategy="epoch",
    bf16=torch.cuda.is_available(),
    gradient_checkpointing=False,
    optim="paged_adamw_8bit",
    report_to="none",
    disable_tqdm=False,
)

# 5) Train
trainer17 = Trainer(
    model=model17,
    args=args17_fast,
    train_dataset=ds17,
    data_collator=default_data_collator,
    tokenizer=tok1,
)
trainer17.train()

# 6) Save adapter + tokenizer
trainer17.model.save_pretrained(f"{OUTDIR_17B}/adapter")
tok1.save_pretrained(f"{OUTDIR_17B}/tokenizer")
print("Saved:", OUTDIR_17B)



trainable params: 4,816,896 || all params: 1,725,391,872 || trainable%: 0.2792


Map:   0%|          | 0/10053 [00:00<?, ? examples/s]

  trainer17 = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Step,Training Loss
1,2.4466
2,2.1308
3,1.9486
4,2.2692
5,2.1961
6,2.1591
7,2.2078
8,2.0128
9,1.7755
10,1.8733


Step,Training Loss
1,2.4466
2,2.1308
3,1.9486
4,2.2692
5,2.1961
6,2.1591
7,2.2078
8,2.0128
9,1.7755
10,1.8733


Saved: ft_qwen3_17b_qlora_fewshot_fast


In [None]:
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = "/content/drive/MyDrive/qwen_finetunes/qwen3_17b_adapter"
trainer17.model.save_pretrained(f"{SAVE_DIR}/adapter")
tok1.save_pretrained(f"{SAVE_DIR}/tokenizer")
print("Saved to Google Drive:", SAVE_DIR)

Mounted at /content/drive
Saved to Google Drive: /content/drive/MyDrive/qwen_finetunes/qwen3_17b_adapter


In [15]:
import os, glob
from peft import PeftModel
from google.colab import drive
drive.mount('/content/drive')


os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"


#OUTDIR_06B = "/content/drive/MyDrive/qwen3_06b_lora"                 # <- adjust if different
OUTDIR_17B = "/content/drive/MyDrive/qwen_finetunes/qwen3_17b_adapter"  # <- adjust if different

#ADAPTER_06B = f"{OUTDIR_06B}/adapter"
ADAPTER_17B = f"{OUTDIR_17B}/adapter"


#print("0.6B adapter files:", glob.glob(f"{ADAPTER_06B}/*"))
print("1.7B adapter files:", glob.glob(f"{ADAPTER_17B}/*"))

#assert os.path.isfile(f"{ADAPTER_06B}/adapter_config.json"), f"Missing adapter_config.json in {ADAPTER_06B}"
assert os.path.isfile(f"{ADAPTER_17B}/adapter_config.json"), f"Missing adapter_config.json in {ADAPTER_17B}"


if tok2.pad_token is None: tok2.pad_token = tok2.eos_token
if tok1.pad_token is None: tok1.pad_token = tok1.eos_token


model2_f = PeftModel.from_pretrained(model2, ADAPTER_06B, is_trainable=False, local_files_only=True).eval()
tok2_f   = tok2

model1_f = PeftModel.from_pretrained(model1, ADAPTER_17B, is_trainable=False, local_files_only=True).eval()
tok1_f   = tok1

# 5) Evaluate (few-shot, single-line Answer)
#SEED = 42
#acc_fs_06b, overall_fs_06b, _ = eval_prompt_full_fast(
 #   test_df=test_df,
 #   model=model2_f, tok=tok2_f,
 #   build_fn=lambda r: build_few_shot(r, train_df, k=3, seed=SEED),
 #   batch_size=16, max_new_tokens=16,
 #   save_path="results/fewshot_qwen3-0.6b_finetuned.csv",
#)
#print("0.6B | Few-shot finetuned overall:", round(overall_fs_06b, 4))
#print("Per-category:", acc_fs_06b)

acc_fs_17b, overall_fs_17b, _ = eval_prompt_full_fast(
    test_df=test_df,
    model=model1_f, tok=tok1_f,
    build_fn=lambda r: build_few_shot(r, train_df, k=3, seed=SEED),
    batch_size=16, max_new_tokens=16,
    save_path="results/fewshot_qwen3-1.7b_finetuned.csv",
)
print("1.7B | Few-shot finetuned overall:", round(overall_fs_17b, 4))
print("Per-category:", acc_fs_17b)





Mounted at /content/drive
1.7B adapter files: ['/content/drive/MyDrive/qwen_finetunes/qwen3_17b_adapter/adapter/README.md', '/content/drive/MyDrive/qwen_finetunes/qwen3_17b_adapter/adapter/adapter_model.safetensors', '/content/drive/MyDrive/qwen_finetunes/qwen3_17b_adapter/adapter/adapter_config.json']
[400/1047] interim overall acc: 0.258 | 264.8s
[800/1047] interim overall acc: 0.241 | 505.5s
1.7B | Few-shot finetuned overall: 0.2073
Per-category: {'lab': 0.10091743119266056, 'risk': 0.1375, 'physical': 0.4083333333333333, 'severity': 0.15, 'diagnosis': 0.4666666666666667, 'date': 0.16666666666666666, 'dosage': 0.075}
