In [1]:
!pip install -q transformers peft bitsandbytes trl accelerate datasets sentence-transformers rank_bm25 faiss-cpu scikit-learn pandas numpy

print("✅ Dependencies installed")


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m465.5/465.5 kB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m107.7 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m72.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m42.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m90.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m 

In [6]:
# =========================
# DiagnosisGPT v3.3 - Full updated notebook cell (drop-in)
# - Paste and run in one cell.
# =========================

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import json, gc, random, re, pickle, time, math
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import faiss
import torch

from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, LogitsProcessorList
from peft import LoraConfig, get_peft_model

torch.cuda.empty_cache(); gc.collect()

print("Device:", "cuda" if torch.cuda.is_available() else "cpu")

# --------------------
# OPTIONS / FLAGS (edit as needed)
# --------------------
RUN_BUILD_RETRIEVER = True       # build KB + retriever
RUN_BUILD_SYNTHETIC = True       # build / cache synthetic dataset
RUN_BUILD_REAL = True            # build / cache real dataset (MedDialog)
RUN_COMBINE = True               # combine synthetic + real into final dataset
RUN_TRAIN = True                # set True only when you want to actually train (heavy)
ENABLE_LLM_SCORING = False       # if True, use LLM to compute candidate logprobs (slower)
SAVE_TRANSCRIPT = True           # save JSON transcript of consultations
EARLY_STOPPING_ENTROPY_EPS = 0.01  # if entropy decrease < eps across two rounds -> stop
CACHE_DIR = "/kaggle/working/cod_v3_3_cache"
os.makedirs(CACHE_DIR, exist_ok=True)

# --------------------
# CONFIG
# --------------------
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"   # change if needed
TARGET_SYNTHETIC = 20000
TARGET_REAL = 5000
RETRIEVER_TOPK = 15
ALPHA_RETRIEVER = 0.6
PER_DEVICE_BATCH = 1
GRAD_ACCUM = 4
TRAIN_MAX_STEPS = 100
LEARNING_RATE = 2e-4
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
CONFIDENCE_THRESHOLD = 0.70
INQUIRY_THRESHOLD = 0.50
MAX_CONSULTATION_ROUNDS = 5

# Paths
KB_DF_PATH = os.path.join(CACHE_DIR, "kb_df.pkl")
BM25_PATH = os.path.join(CACHE_DIR, "bm25.pkl")
BM25_CORPUS_PATH = os.path.join(CACHE_DIR, "bm25_corpus.pkl")
FAISS_INDEX_PATH = os.path.join(CACHE_DIR, "faiss_index.bin")
DISEASE_SYMPTOM_MAP_PATH = os.path.join(CACHE_DIR, "disease_symptom_map.pkl")

SYNTHETIC_PATH = os.path.join(CACHE_DIR, "synthetic_data.jsonl")
REAL_PATH = os.path.join(CACHE_DIR, "real_data.jsonl")
COMBINED_PATH = os.path.join(CACHE_DIR, "combined_data.jsonl")
FINAL_DATASET_DIR = os.path.join(CACHE_DIR, "final_dataset")

MODEL_OUTPUT_DIR = "diagnosis_gpt_v3_3"
FINAL_MODEL_DIR = os.path.join(MODEL_OUTPUT_DIR, "model_final")

REAL_WORLD_CSV = "/kaggle/input/meddialog-new2/meddialog_all_dialogues_all.csv"  # update if needed
TRANSCRIPT_SAVE_DIR = os.path.join(CACHE_DIR, "consult_transcripts")
os.makedirs(TRANSCRIPT_SAVE_DIR, exist_ok=True)

# Utility
def dedup(items):
    seen = set()
    out = []
    for it in items:
        if it not in seen:
            out.append(it); seen.add(it)
    return out

# =========================
# STEP A: Build / Load KB + Retriever (BM25 + FAISS + sentence-transformer)
# =========================
if RUN_BUILD_RETRIEVER:
    if (
        os.path.exists(KB_DF_PATH)
        and os.path.exists(BM25_PATH)
        and os.path.exists(BM25_CORPUS_PATH)
        and os.path.exists(FAISS_INDEX_PATH)
        and os.path.exists(DISEASE_SYMPTOM_MAP_PATH)
    ):
        print("Loading KB & retriever from cache...")
        kb_df = pd.read_pickle(KB_DF_PATH)
        with open(DISEASE_SYMPTOM_MAP_PATH, "rb") as f:
            disease_symptom_map = pickle.load(f)
        with open(BM25_PATH, "rb") as f:
            bm25 = pickle.load(f)
        with open(BM25_CORPUS_PATH, "rb") as f:
            tokenized_corpus = pickle.load(f)
        index = faiss.read_index(FAISS_INDEX_PATH)
        embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device="cpu")
        print("KB & retriever loaded.")
    else:
        print("Downloading Disease DB (FreedomIntelligence/Disease_Database) and building retriever...")
        kb_ds = load_dataset("FreedomIntelligence/Disease_Database", "en", split="train")
        kb_df = pd.DataFrame(kb_ds)
        if "symptom_text" not in kb_df.columns and "common_symptom" in kb_df.columns:
            kb_df = kb_df.rename(columns={"common_symptom": "symptom_text"})
        if "symptom_text" not in kb_df.columns:
            possible = [c for c in kb_df.columns if "symptom" in c.lower()]
            if possible:
                kb_df = kb_df.rename(columns={possible[0]: "symptom_text"})
            else:
                kb_df["symptom_text"] = ""
        kb_df["symptom_text"] = kb_df["symptom_text"].fillna("").astype(str)
        if "disease" not in kb_df.columns:
            if "title" in kb_df.columns:
                kb_df = kb_df.rename(columns={"title": "disease"})
            else:
                kb_df["disease"] = kb_df.index.astype(str)
        kb_df["disease"] = kb_df["disease"].astype(str)

        disease_symptom_map = {row["disease"]: row["symptom_text"].lower() for _, row in kb_df.iterrows()}

        print("Building BM25 (tokenized corpus)...")
        tokenized_corpus = [re.findall(r"\w+", str(doc).lower()) for doc in kb_df["symptom_text"]]
        bm25 = BM25Okapi(tokenized_corpus)

        print("Loading embedder (all-mpnet-base-v2) and encoding KB...")
        embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device="cpu")
        symptom_texts = kb_df["symptom_text"].astype(str).tolist()
        train_embeddings = embedder.encode(symptom_texts, convert_to_numpy=True, show_progress_bar=True, batch_size=32)
        faiss.normalize_L2(train_embeddings)
        dim = train_embeddings.shape[1]
        index = faiss.IndexFlatIP(dim)
        index.add(train_embeddings)

        # cache
        kb_df.to_pickle(KB_DF_PATH)
        with open(DISEASE_SYMPTOM_MAP_PATH, "wb") as f:
            pickle.dump(disease_symptom_map, f)
        with open(BM25_PATH, "wb") as f:
            pickle.dump(bm25, f)
        with open(BM25_CORPUS_PATH, "wb") as f:
            pickle.dump(tokenized_corpus, f)
        faiss.write_index(index, FAISS_INDEX_PATH)
        print("KB & retriever cached.")
else:
    if os.path.exists(KB_DF_PATH):
        kb_df = pd.read_pickle(KB_DF_PATH)
        with open(DISEASE_SYMPTOM_MAP_PATH, "rb") as f:
            disease_symptom_map = pickle.load(f)
        with open(BM25_PATH, "rb") as f:
            bm25 = pickle.load(f)
        with open(BM25_CORPUS_PATH, "rb") as f:
            tokenized_corpus = pickle.load(f)
        index = faiss.read_index(FAISS_INDEX_PATH)
        embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device="cpu")
        print("Loaded retriever from cache.")
    else:
        raise RuntimeError("Retriever build skipped and cache not found. Set RUN_BUILD_RETRIEVER=True.")

# hybrid retrieve
def hybrid_retrieve(user_symptoms, k=5, alpha=ALPHA_RETRIEVER, negative_findings=None):
    """
    Safe hybrid retriever that handles None, empty strings, and missing fields.
    Prevents crashes like: AttributeError: 'NoneType' object has no attribute 'lower'
    """

    # ---- SAFETY FIX (CRITICAL) ----
    if user_symptoms is None:
        user_symptoms = ""
    user_symptoms = str(user_symptoms).strip().lower()
    if user_symptoms == "":
        user_symptoms = "symptoms"

    if negative_findings is None:
        negative_findings = []

    # Tokenize safely
    tokenized_query = re.findall(r"\w+", user_symptoms)

    # ---------- BM25 ----------
    bm25_scores = np.array(bm25.get_scores(tokenized_query))
    if bm25_scores.max() > 0:
        bm25_scores /= (bm25_scores.max() + 1e-12)

    # ---------- FAISS ----------
    try:
        user_emb = embedder.encode([user_symptoms], convert_to_numpy=True)
        faiss.normalize_L2(user_emb)
        sims, indices = index.search(user_emb, len(kb_df))

        faiss_scores = np.zeros(len(kb_df))
        faiss_scores[indices[0]] = sims[0]

        if faiss_scores.max() > 0:
            faiss_scores /= (faiss_scores.max() + 1e-12)
    except Exception:
        # fallback to BM25-only
        faiss_scores = np.zeros(len(kb_df))

    # Hybrid combine
    hybrid_scores = alpha * faiss_scores + (1 - alpha) * bm25_scores
    top_indices = np.argsort(hybrid_scores)[::-1]

    # ---------- Collect candidates ----------
    candidates = []
    for i in top_indices:
        disease = kb_df.iloc[i]["disease"]

        skip = False
        for neg in negative_findings:
            if neg.lower() in disease_symptom_map.get(disease, ""):
                skip = True
                break

        if not skip:
            candidates.append(disease)

        if len(candidates) >= k:
            break

    return dedup(candidates)[:k]


print("Hybrid retriever ready.")

# =========================
# STEP B: Build Synthetic Dataset (CoD) -> upto TARGET_SYNTHETIC
# =========================
synthetic_data = []
if RUN_BUILD_SYNTHETIC:
    if os.path.exists(SYNTHETIC_PATH):
        print("Loading synthetic cache...")
        with open(SYNTHETIC_PATH, "r") as f:
            for line in f:
                synthetic_data.append(json.loads(line))
        print("Loaded synthetic:", len(synthetic_data))
    else:
        print("Building synthetic dataset from FreedomIntelligence/CoD-PatientSymDisease...")
        cod_ds = load_dataset("FreedomIntelligence/CoD-PatientSymDisease", "en", split="train")
        shuffled_cod = cod_ds.shuffle(seed=42)

        for entry in tqdm(shuffled_cod, desc="Parsing Synthetic"):
            if len(synthetic_data) >= TARGET_SYNTHETIC:
                break
            conversations = entry.get("CoD_conversations")
            if isinstance(conversations, str):
                try:
                    conversations = json.loads(conversations)
                except Exception:
                    continue
            if not conversations:
                continue

            current_symptoms = []
            negative_symptoms = []

            for turn in conversations:
                if turn.get("from") == "human":
                    if "provided_symptom" in turn and isinstance(turn["provided_symptom"], list):
                        for sym_pair in turn["provided_symptom"]:
                            if len(sym_pair) >= 2 and str(sym_pair[1]).lower() in ("true", "1", "yes"):
                                current_symptoms.append(sym_pair[0])
                            elif len(sym_pair) >= 2 and str(sym_pair[1]).lower() in ("false", "0", "no"):
                                negative_symptoms.append(sym_pair[0])

                elif turn.get("from") == "gpt":
                    dist = turn.get("confidence_assessment", {}) or {}
                    decision = turn.get("decision", "")
                    disease_recall = turn.get("disease_recall", []) or []
                    candidates = [d.get("disease") for d in disease_recall if "disease" in d]
                    if not candidates:
                        continue

                    if decision == "diagnosis":
                        top_disease = max(dist, key=dist.get) if dist else "Unknown"
                        action = {"judge": True, "disease": top_disease}
                    else:
                        question_text = turn.get("value", "")
                        action = {"judge": False, "symptom": question_text}

                    # CoD++ richer output
                    output_json = {
                        "step_1_symptom_abstraction": {
                            "extracted_symptoms": current_symptoms[:6],
                            "negative_findings": negative_symptoms[:3]
                        },
                        "step_2_candidate_recall": {
                            "retrieved_diseases": candidates[:10],
                            "retrieval_method": "Hybrid (BM25 + Dense Embedding)"
                        },
                        "step_3_diagnostic_reasoning": {
                            "comparison": f"Comparing {len(candidates)} diseases with {len(current_symptoms)} symptoms",
                            "top_3_candidates": candidates[:3]
                        },
                        "step_4_confidence_assessment": {
                            "scores": dist,
                            "max_confidence": max(dist.values()) if dist else 0.0,
                        },
                        "step_5_decision_making": action
                    }

                    input_text = (
                        f"Patient Symptoms: {', '.join(current_symptoms)}.\n"
                        f"Negative Findings: {', '.join(negative_symptoms) if negative_symptoms else 'None'}.\n"
                        f"Candidates: {', '.join(candidates)}."
                    )

                    synthetic_data.append({"input": input_text, "output": json.dumps(output_json)})

        # Save
        with open(SYNTHETIC_PATH, "w") as f:
            for row in synthetic_data:
                f.write(json.dumps(row) + "\n")
        print("Saved synthetic:", len(synthetic_data))
else:
    if os.path.exists(SYNTHETIC_PATH):
        with open(SYNTHETIC_PATH, "r") as f:
            synthetic_data = [json.loads(l) for l in f]
        print("Loaded synthetic:", len(synthetic_data))
    else:
        print("No synthetic dataset available. Enable RUN_BUILD_SYNTHETIC to create it.")
print("Synthetic dataset size:", len(synthetic_data))

# =========================
# STEP C: Build Real-world Dataset (MedDialog CSV) using retriever-based symptom abstraction
# =========================
real_data = []
if RUN_BUILD_REAL:
    if os.path.exists(REAL_PATH):
        print("Loading real cache...")
        with open(REAL_PATH, "r") as f:
            for line in f:
                real_data.append(json.loads(line))
        print("Loaded real:", len(real_data))
    else:
        print("Loading MedDialog CSV:", REAL_WORLD_CSV)
        if not os.path.exists(REAL_WORLD_CSV):
            print("Real-world CSV not found at path. Skipping real data creation.")
        else:
            med_df = pd.read_csv(REAL_WORLD_CSV)
            # Parse turns_json into patient / doctor
            def parse_turns(turns_json):
                try:
                    turns = json.loads(turns_json)
                    patient_text = " ".join(t["utterance"] for t in turns if "patient" in t.get("speaker", "").lower())
                    doctor_text = " ".join(t["utterance"] for t in turns if "doctor" in t.get("speaker", "").lower())
                    return patient_text.strip(), doctor_text.strip()
                except Exception:
                    return None, None

            med_df[["src", "tgt"]] = med_df["turns_json"].apply(lambda x: pd.Series(parse_turns(x)))
            med_df = med_df.dropna(subset=["src", "tgt"]).reset_index(drop=True)

            sample_n = min(TARGET_REAL, len(med_df))
            med_df = med_df.sample(n=sample_n, random_state=42).reset_index(drop=True)

            print("Preparing hybrid retrieval candidates for real data (batched)...")
            BATCH = 64
            all_candidates = []
            for i in tqdm(range(0, len(med_df), BATCH), desc="Batch Retrieval"):
                batch_texts = med_df["src"].iloc[i : i + BATCH].tolist()
                batch_candidates = [hybrid_retrieve(text, k=RETRIEVER_TOPK) for text in batch_texts]
                all_candidates.extend(batch_candidates)

            # Create real_data examples using retriever-based symptom abstraction
            for src, tgt, candidates in zip(med_df["src"].tolist(), med_df["tgt"].tolist(), all_candidates):
                symptoms = []
                for c in candidates[:5]:
                    st = disease_symptom_map.get(c, "")
                    if st:
                        parts = [s.strip() for s in re.split(r"[;,\n\.]", st) if s.strip()]
                        symptoms.extend(parts)
                symptoms = list(dict.fromkeys(symptoms))[:6]

                input_text = (
                    f"Patient Symptoms: {', '.join(symptoms) if symptoms else src[:256]}.\n"
                    f"Negative Findings: None.\n"
                    f"Candidates: {', '.join(candidates[:5])}."
                )

                # Attempt to infer ground-truth disease from doctor's reply (very heuristic)
                true_disease = next((c for c in candidates if c.lower() in tgt.lower()), "Unknown")
                if true_disease != "Unknown":
                    dist = {true_disease: 0.8}
                    rem = 0.2 / (len(candidates) - 1) if len(candidates) > 1 else 0.0
                    for cc in candidates:
                        if cc != true_disease:
                            dist[cc] = rem
                else:
                    dist = {c: 1/len(candidates) for c in candidates} if candidates else {}

                output_json = {
                    "step_1_symptom_abstraction": {"extracted_symptoms": symptoms},
                    "step_2_candidate_recall": {"retrieved_diseases": candidates[:5]},
                    "step_3_diagnostic_reasoning": {"doctor_text_summary": tgt[:400]},
                    "step_4_confidence_assessment": {"scores": dist, "max_confidence": max(dist.values()) if dist else 0.0},
                    "step_5_decision_making": {"judge": True, "disease": true_disease}
                }

                real_data.append({"input": input_text, "output": json.dumps(output_json)})

            # Save
            with open(REAL_PATH, "w") as f:
                for row in real_data:
                    f.write(json.dumps(row) + "\n")
            print("Saved real:", len(real_data))
else:
    if os.path.exists(REAL_PATH):
        with open(REAL_PATH, "r") as f:
            real_data = [json.loads(l) for l in f]
        print("Loaded real:", len(real_data))
    else:
        print("No real data available. Enable RUN_BUILD_REAL to create it.")

print("Real dataset size:", len(real_data))

# =========================
# STEP D: Merge Datasets (TARGET_SYNTHETIC synthetic + TARGET_REAL real)
# =========================
if RUN_COMBINE:
    if os.path.exists(COMBINED_PATH):
        print("Loading combined cache...")
        combined_data = [json.loads(l) for l in open(COMBINED_PATH)]
    else:
        print("Merging synthetic + real...")
        s = synthetic_data[:TARGET_SYNTHETIC]
        r = real_data[:TARGET_REAL]
        combined_data = s + r
        random.shuffle(combined_data)
        with open(COMBINED_PATH, "w") as f:
            for row in combined_data:
                f.write(json.dumps(row) + "\n")
        print("Saved combined:", len(combined_data))
else:
    if os.path.exists(COMBINED_PATH):
        combined_data = [json.loads(l) for l in open(COMBINED_PATH)]
        print("Loaded combined:", len(combined_data))
    else:
        combined_data = synthetic_data + real_data
        print("Combined created in-memory:", len(combined_data))

print("Combined size:", len(combined_data))

# =========================
# STEP E: Prepare HF Dataset for SFT (Alpaca-style prompt)
# =========================
alpaca_prompt = """Below is a medical consultation example.
Follow the 5-step diagnostic pipeline and output the JSON content exactly.

### Input:
{}

### Response:
{}"""

def format_for_sft_batch(ex):
    texts = []
    for i, o in zip(ex["input"], ex["output"]):
        texts.append(alpaca_prompt.format(i, o))
    return {"text": texts}

if os.path.exists(FINAL_DATASET_DIR):
    final_dataset = Dataset.load_from_disk(FINAL_DATASET_DIR)
    print("Loaded final_dataset from disk:", len(final_dataset))
else:
    final_dataset = Dataset.from_list(combined_data)
    print("Created final Dataset object (unformatted). Samples:", len(final_dataset))

# =========================
# STEP F: MODEL LOAD (4-bit) + LoRA (safe settings)  -- optional heavy
# =========================
print("\nLoading base model + LoRA (skipping heavy train if RUN_TRAIN=False)...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = None
tokenizer = None
if ENABLE_LLM_SCORING or RUN_TRAIN:
    try:
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
        )
        model.config.use_cache = False
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        tokenizer.pad_token = tokenizer.eos_token

        peft_config = LoraConfig(
            r=LORA_R,
            lora_alpha=LORA_ALPHA,
            lora_dropout=LORA_DROPOUT,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "k_proj", "v_proj"],
        )
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
        print("Model & LoRA ready.")
    except Exception as e:
        print("Could not fully load model. Error:", e)
else:
    print("Skipping model load (ENABLE_LLM_SCORING=False and RUN_TRAIN=False).")

# Format final dataset if tokenizer present (optional)
if tokenizer is not None and not os.path.exists(FINAL_DATASET_DIR):
    print("Formatting final HF dataset (this may take a while)...")
    final_dataset = final_dataset.map(format_for_sft_batch, batched=True, remove_columns=final_dataset.column_names)
    final_dataset = final_dataset.map(lambda ex: {"text":[t + tokenizer.eos_token for t in ex["text"]]}, batched=True)
    final_dataset.save_to_disk(FINAL_DATASET_DIR)
    print("Saved formatted final_dataset to disk.")
elif os.path.exists(FINAL_DATASET_DIR):
    final_dataset = Dataset.load_from_disk(FINAL_DATASET_DIR)

print("Final_dataset size:", len(final_dataset) if final_dataset is not None else "N/A")

# =========================
# STEP G (optional): TRAIN (SMOKE RUN)
# =========================
if RUN_TRAIN and model is not None and final_dataset is not None:
    from trl import SFTTrainer, SFTConfig
    import glob, os

    os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
    sft_cfg = SFTConfig(
        output_dir=MODEL_OUTPUT_DIR,
        dataset_text_field="text",
        per_device_train_batch_size=PER_DEVICE_BATCH,
        gradient_accumulation_steps=GRAD_ACCUM,
        warmup_steps=30,
        max_steps=TRAIN_MAX_STEPS,
        learning_rate=LEARNING_RATE,
        fp16=True,
        logging_steps=10,
        save_steps=200,
        save_total_limit=2,
        optim="paged_adamw_8bit",
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        report_to="none",
        packing=False,
    )
    trainer = SFTTrainer(model=model, train_dataset=final_dataset, args=sft_cfg)
    torch.cuda.empty_cache(); gc.collect()
    print("Starting smoke-run training (max_steps=%d) ..." % TRAIN_MAX_STEPS)
    trainer.train()
    trainer.save_model(FINAL_MODEL_DIR)
    tokenizer.save_pretrained(FINAL_MODEL_DIR)
    print("Saved model to", FINAL_MODEL_DIR)
else:
    print("Skipping training step (RUN_TRAIN=False or model missing).")

# =========================
# STEP H: Paper-aligned Interactive Consultation (improved)
# =========================


Device: cuda
Loading KB & retriever from cache...
KB & retriever loaded.
Hybrid retriever ready.
Loading synthetic cache...
Loaded synthetic: 20002
Synthetic dataset size: 20002
Loading real cache...
Loaded real: 5000
Real dataset size: 5000
Loading combined cache...
Combined size: 25000
Loaded final_dataset from disk: 25000

Loading base model + LoRA (skipping heavy train if RUN_TRAIN=False)...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 3,440,640 || all params: 7,619,057,152 || trainable%: 0.0452
Model & LoRA ready.
Final_dataset size: 25000


Adding EOS to train dataset:   0%|          | 0/25000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/25000 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/25000 [00:00<?, ? examples/s]

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Starting smoke-run training (max_steps=100) ...


Step,Training Loss
10,1.6121
20,1.4317
30,1.1836
40,0.8426
50,0.5601
60,0.5186
70,0.4378
80,0.3709
90,0.4046
100,0.4014


Saved model to diagnosis_gpt_v3_3/model_final


In [11]:
# ==============================
# Updated InteractiveConsultationPaperV3
# - Minimum 3 follow-ups
# - Strong age/gender gating
# - Red-flag detection
# - Clean reasoning (Streamlit-friendly)
# - Gemini integration (optional)
# - Entropy-based follow-up selection
# ==============================

import re, math, time, textwrap, os, numpy as np


# -----------------------------------------------------
# RED FLAG DETECTION
# -----------------------------------------------------
def detect_red_flags(text):
    RED_FLAGS = [
        "severe chest pain", "severe shortness of breath", "unable to breathe",
        "blue lips", "cyanosis", "blood in sputum", "hemoptysis",
        "confusion", "loss of consciousness", "fainting", "very low oxygen",
        "spo2", "respiratory distress"
    ]
    if not text:
        return False
    t = text.lower()

    # simple phrase match
    for rf in RED_FLAGS:
        if rf in t:
            return True

    # SpO2 < 94 detection
    m = re.search(r"sp[o0]2[: ]*(\d{2,3})", t)
    if m:
        try:
            if int(m.group(1)) < 94:
                return True
        except:
            pass

    return False


# -----------------------------------------------------
# GEMINI LLM NARRATIVE (optional)
# -----------------------------------------------------
def call_external_narrative_llm(name, age, gender, symptoms, candidates, probs):
    """
    Provides narrative reasoning using Gemini.
    Requires environment var: GEMINI_API_KEY
    Returns None on failure.
    """
    try:
        import google.generativeai as genai
    except Exception:
        return None

    api_key = os.environ.get("GEMINI_API_KEY")
    if not api_key:
        return None

    try:
        genai.configure(api_key=api_key)

        prompt = f"""
Write a polished medical diagnostic reasoning narrative in 3–4 paragraphs:

Patient Info:
- Name: {name}
- Age: {age}
- Gender: {gender}

Reported Symptoms:
{', '.join(symptoms)}

Candidate Diseases:
{', '.join(candidates)}

Probability Scores:
{probs}

Guidelines:
- Compare each disease briefly.
- Match hallmark symptoms.
- Mention age/gender incompatibility.
- Explain why the final diagnosis is most likely.
Return plain text only.
"""

        model = genai.GenerativeModel("gemini-pro")
        resp = model.generate_content(prompt)

        if hasattr(resp, "text") and resp.text:
            return resp.text.strip()
        return str(resp)

    except Exception:
        return None



# -----------------------------------------------------
#  Main Consultation Class
# -----------------------------------------------------
class InteractiveConsultationPaperV3:

    MIN_INQUIRIES_BEFORE_DIAG = 3   # <--- REQUIRED: at least 3 questions

    def __init__(self,
                 model=None,
                 tokenizer=None,
                 llm_scoring=False,
                 device=None,
                 k_candidates=6,
                 k_symptoms_per_candidate=6,
                 confidence_threshold=0.75,
                 max_rounds=6,
                 save_transcript=True):
        
        self.model = model
        self.tokenizer = tokenizer
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        self.k = k_candidates
        self.k_symptoms_per_candidate = k_symptoms_per_candidate
        self.conf_threshold = confidence_threshold
        self.max_rounds = max_rounds
        self.llm_scoring = llm_scoring
        self.save_transcript = save_transcript

        self.S = []
        self.negatives = set()
        self.asked = set()
        self.round = 0
        self.transcript = {}

    # -----------------------------------------------------
    #  Utility Functions
    # -----------------------------------------------------
    def _softmax(self, score_map, temp=0.32):
        keys = list(score_map.keys())
        arr = np.array([score_map[k] for k in keys], float)
        arr -= arr.max()
        ex = np.exp(arr / temp)
        ex /= (ex.sum() + 1e-12)
        return {k: float(v) for k, v in zip(keys, ex)}

    def _candidate_symptoms_pool(self, candidates):
        pool = []
        for d in candidates:
            txt = disease_symptom_map.get(d, "") or ""
            parts = [p.strip() for p in re.split(r"[;,\n\.]", txt) if p.strip()]
            pool.extend(parts[: self.k_symptoms_per_candidate])
        out, seen = [], set()
        for p in pool:
            lp = p.lower()
            if lp not in seen and 1 <= len(p.split()) <= 8:
                out.append(p); seen.add(lp)
        return out

    # -----------------------------------------------------
    #  Strong Scoring + Age/Gender Gating
    # -----------------------------------------------------
    def _score_candidates(self, candidates, symptoms, age=None, gender=None):
        s_text = " ".join(symptoms).lower()
        acute = bool(re.search(r"\b(day|days|acute|hours)\b", s_text))
        chronic = bool(re.search(r"\b(month|months|year|chronic|long)\b", s_text))

        # parse demographics
        try:
            age_i = int(age) if age else None
        except:
            age_i = None
        g = gender.lower().strip() if gender else None

        scores = {}

        for i, d in enumerate(candidates):
            dl = d.lower()
            kb = (disease_symptom_map.get(d, "") or "").lower()

            # -----------------------------------------
            # HARD AGE/GENDER EXCLUSIONS
            # -----------------------------------------
            incompatible = False

            if age_i is not None:
                if age_i < 40 and any(x in dl for x in ["elderly", "senile", "geriatric"]):
                    incompatible = True
                if age_i >= 18 and any(x in dl for x in ["pediatric", "child", "infant"]):
                    incompatible = True

            if g:
                if g.startswith("m") and any(x in dl for x in ["pregnan", "ovarian", "uterine", "breast"]):
                    incompatible = True
                if g.startswith("f") and any(x in dl for x in ["prostate", "testicular"]):
                    incompatible = True

            if incompatible:
                scores[d] = 1e-12
                continue

            # -----------------------------------------
            # MATCHING SCORE
            # -----------------------------------------
            exact = sum(1 for s in symptoms if s.lower() in kb)

            partial = 0
            for s in symptoms:
                toks = re.findall(r"\w+", s.lower())
                hit = sum(1 for t in toks if t in kb)
                if hit >= max(1, len(toks)//2):
                    partial += 1

            score = 1.0 + 2.0 * exact + 1.1 * partial
            score += 0.22 * (self.k - i)  # small ranking bonus

            # acute/chronic mismatch
            if acute and "chronic" in dl:
                score *= 0.5
            if chronic and "acute" in dl:
                score *= 0.65

            # tumor penalty
            if any(x in dl for x in ["tumor", "malign", "carcinoma", "sarcoma"]):
                score *= 0.65

            scores[d] = max(score, 1e-12)

        return self._softmax(scores, temp=0.32)

    # -----------------------------------------------------
    #  Yes/No interpretation
    # -----------------------------------------------------
    def interpret_answer_as_bool(self, ans, symptom):
        if not ans:
            return False
        a = ans.lower()
        if a.startswith("y"): return True
        if a.startswith("n"): return False
        if "day" in a or "week" in a: return True
        toks = re.findall(r"\w+", symptom.lower())
        return any(tok in a for tok in toks)

    # -----------------------------------------------------
    #  Entropy-based follow-up selection
    # -----------------------------------------------------
    def _simulate(self, candidates, S, symptom):
        new = set(S); new.add(symptom)
        probs = self._score_candidates(candidates, list(new))
        s = sum(probs.values())
        return {k: v/s for k, v in probs.items()}

    def choose_inquiry_by_entropy(self, candidates):
        base = self._score_candidates(candidates, list(self.S))
        H0 = -sum(p * math.log(p + 1e-12) for p in base.values())

        pool = self._candidate_symptoms_pool(candidates)
        pool = [p for p in pool if p.lower() not in self.asked and p.lower() not in [s.lower() for s in self.S]]

        best, best_delta = None, 0
        for s in pool:
            post = self._simulate(candidates, self.S, s)
            H1 = -sum(p * math.log(p + 1e-12) for p in post.values())
            delta = H0 - H1
            if delta > best_delta:
                best, best_delta = s, delta
        return best

    # -----------------------------------------------------
    #  Final Report
    # -----------------------------------------------------
    def generate_final_cod_report(self, candidates, probs, symptoms,
                                  name=None, age=None, gender=None, use_gemini=True):

        sorted_probs = sorted(probs.items(), key=lambda x: x[1], reverse=True)
        top_d, top_p = sorted_probs[0]

        # Try Gemini
        narrative = None
        if use_gemini:
            narrative = call_external_narrative_llm(
                name, age, gender, list(symptoms),
                [c for c, _ in sorted_probs[:5]],
                {c: p for c, p in sorted_probs}
            )

        # Fallback narrative
        if not narrative:
            lines = []
            lines.append(f"**Patient Summary**\nName: {name}\nAge: {age}\nGender: {gender}\n")
            lines.append(f"Reported symptoms: {', '.join(symptoms)}\n")

            lines.append("**Considered Conditions**")
            for d, p in sorted_probs[:5]:
                snippet = (disease_symptom_map.get(d, "") or "").split(".")[0][:180]
                lines.append(f"- **{d}** — {snippet}")

            lines.append("\n**Reasoning Summary**")
            for d, p in sorted_probs[:5]:
                kb = disease_symptom_map.get(d, "").lower()
                matched = [s for s in symptoms if s.lower() in kb]
                if matched:
                    lines.append(f"- {d}: matches {matched}.")
                else:
                    lines.append(f"- {d}: limited symptom overlap.")

            lines.append(
                f"\n**Final Assessment:** Most likely diagnosis = **{top_d}** "
                f"(confidence {top_p:.2f})."
            )

            narrative = "\n".join(lines)

        print("\n================ FINAL DIAGNOSTIC REPORT ================\n")
        print(textwrap.fill(narrative, 100))
        print("\n==========================================================\n")

        return {
            "patient_meta": {"name": name, "age": age, "gender": gender},
            "symptoms": list(symptoms),
            "candidates": [c for c, p in sorted_probs],
            "confidence": {c: float(p) for c, p in sorted_probs},
            "final_diagnosis": top_d,
            "narrative": narrative
        }

    # -----------------------------------------------------
    #  Main Consultation Loop
    # -----------------------------------------------------
    def consult(self, initial_message=None, name=None, age=None, gender=None, use_gemini=True):

        # Ask demographics first
        if name is None:
            name = input("Patient Name: ").strip()
        if age is None:
            age = input("Age: ").strip()
        if gender is None:
            gender = input("Gender (M/F/Other): ").strip()

        if initial_message is None:
            initial_message = input("Describe your symptoms: ").strip()

        # Red flag
        if detect_red_flags(initial_message):
            print("\n⚠ RED FLAG: URGENT medical attention advised.\n")

        # Seed via retriever
        init_cands = hybrid_retrieve(initial_message, k=self.k)
        seed = self._candidate_symptoms_pool(init_cands)[:3]
        self.S = seed.copy()

        inquiry_count = 0
        prev_entropy = None

        for _ in range(self.max_rounds):
            candidates = hybrid_retrieve(", ".join(self.S), k=self.k)
            probs = self._score_candidates(candidates, self.S, age=age, gender=gender)

            best_d = max(probs, key=probs.get)
            best_p = probs[best_d]

            print(f"\nROUND {_+1} — Best candidate: {best_d} ({best_p:.3f})")

            # Finalize only after 3+ questions
            if inquiry_count >= self.MIN_INQUIRIES_BEFORE_DIAG and best_p >= self.conf_threshold:
                return self.generate_final_cod_report(candidates, probs, self.S, name, age, gender, use_gemini)

            # Entropy-based early stop (after 3 inquiries only)
            entropy = -sum(p * math.log(p+1e-12) for p in probs.values())
            if inquiry_count >= self.MIN_INQUIRIES_BEFORE_DIAG and prev_entropy is not None:
                if (prev_entropy - entropy) < EARLY_STOPPING_ENTROPY_EPS:
                    print("\nEntropy plateau — finalizing early.")
                    return self.generate_final_cod_report(candidates, probs, self.S, name, age, gender, use_gemini)
            prev_entropy = entropy

            # Choose next symptom
            st = self.choose_inquiry_by_entropy(candidates)
            if not st:
                return self.generate_final_cod_report(candidates, probs, self.S, name, age, gender, use_gemini)

            inquiry_count += 1
            self.asked.add(st.lower())

            q = f"Have you experienced '{st}'? (Yes/No)"
            ans = input(q + ": ")

            if self.interpret_answer_as_bool(ans, st):
                self.S.append(st)
            else:
                self.negatives.add(st.lower())

        return self.generate_final_cod_report(candidates, probs, self.S, name, age, gender, use_gemini)



# -----------------------------------------------------
# Convenience wrapper
# -----------------------------------------------------
def run_diagnosis(initial_message=None, name=None, age=None, gender=None, session=None, use_gemini=True):
    if session is None:
        session = InteractiveConsultationPaperV3(
            model=globals().get("model", None),
            tokenizer=globals().get("tokenizer", None),
            llm_scoring=globals().get("ENABLE_LLM_SCORING", False),
            k_candidates=6,
            confidence_threshold=0.75,
            max_rounds=6,
            save_transcript=True
        )
    return session.consult(initial_message, name, age, gender, use_gemini)


In [12]:
run_diagnosis()

Patient Name:  haasiwth
Age:  18
Gender (M/F/Other):  M
Describe your symptoms:  i have cough and fever for 7 days



ROUND 1 — Best candidate: Pulmonary Infection (0.999)


Have you experienced 'dyspnea'? (Yes/No):  yes



ROUND 2 — Best candidate: Scrub Typhus Pneumonia (0.992)


Have you experienced 'persistent high fever'? (Yes/No):  yes



ROUND 3 — Best candidate: Influenza Virus Pneumonia (0.957)


Have you experienced 'rales'? (Yes/No):  no



ROUND 4 — Best candidate: Influenza Virus Pneumonia (0.957)


**Patient Summary** Name: haasiwth Age: 18 Gender: M  Reported symptoms: high fever, cough,
expectoration, dyspnea, persistent high fever  **Considered Conditions** - **Influenza Virus
Pneumonia** — cough, fever with chills, sore throat, headache, chest pain, hemoptysis, myalgia,
persistent high fever, dyspnea, cyanosis - **Pneumonia** — cough, expectoration, chest pain,
dyspnea, fever - **Pulmonary Infection** — high fever, cough, expectoration - **Elderly Pulmonary
Tuberculosis** — rales, cough, expectoration, hemoptysis, dyspnea, persistent high fever -
**Pediatric Viral Pneumonia** — persistent high fever, dyspnea, cyanosis, paroxysmal cough, and
scanty hemoptysis sputum  **Reasoning Summary** - Influenza Virus Pneumonia: matches ['high fever',
'cough', 'dyspnea', 'persistent high fever']. - Pneumonia: matches ['cough', 'expectoration',
'dyspnea']. - Pulmonary Infection: matches ['high fever', 'cough', 'expectoration'].

{'patient_meta': {'name': 'haasiwth', 'age': '18', 'gender': 'M'},
 'symptoms': ['high fever',
  'cough',
  'expectoration',
  'dyspnea',
  'persistent high fever'],
 'candidates': ['Influenza Virus Pneumonia',
  'Pneumonia',
  'Pulmonary Infection',
  'Elderly Pulmonary Tuberculosis',
  'Pediatric Viral Pneumonia',
  'Aspiration Pneumonia in the Elderly'],
 'confidence': {'Influenza Virus Pneumonia': 0.9565885010904421,
  'Pneumonia': 0.028886469744809094,
  'Pulmonary Infection': 0.014525029163792176,
  'Elderly Pulmonary Tuberculosis': 1.5757554348766385e-19,
  'Pediatric Viral Pneumonia': 1.5757554348766385e-19,
  'Aspiration Pneumonia in the Elderly': 1.5757554348766385e-19},
 'final_diagnosis': 'Influenza Virus Pneumonia',
 'narrative': "**Patient Summary**\nName: haasiwth\nAge: 18\nGender: M\n\nReported symptoms: high fever, cough, expectoration, dyspnea, persistent high fever\n\n**Considered Conditions**\n- **Influenza Virus Pneumonia** — cough, fever with chills, sore throat, 

In [5]:
# Save final PEFT LoRA model + tokenizer
save_dir = "/kaggle/working/diagnosisgpt_v3_3_model"

os.makedirs(save_dir, exist_ok=True)

model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)

print("Model saved to:", save_dir)


Model saved to: /kaggle/working/diagnosisgpt_v3_3_model


In [13]:
MODEL_DIR = "diagnosis_gpt_v3_3/model_final"
print("Saved model directory:", MODEL_DIR)


Saved model directory: diagnosis_gpt_v3_3/model_final


In [14]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model.save_pretrained(MODEL_DIR)
tokenizer.save_pretrained(MODEL_DIR)

print("✔ Model + tokenizer saved to:", MODEL_DIR)


✔ Model + tokenizer saved to: diagnosis_gpt_v3_3/model_final


In [15]:
import shutil

ZIP_PATH = "diagnosis_gpt_v3_3_model.zip"
shutil.make_archive("diagnosis_gpt_v3_3_model", "zip", MODEL_DIR)

print("✔ ZIP created:", ZIP_PATH)


✔ ZIP created: diagnosis_gpt_v3_3_model.zip


In [None]:
!pip install datasets tqdm


In [None]:
from datasets import load_dataset

dx = load_dataset("FreedomIntelligence/DxBench", "en")
test_data = dx["test"] if "test" in dx else dx["train"]     # if no test split
print("Samples:", len(test_data))


In [None]:
def build_symptom_text(sample):
    pos = [s[0] for s in sample["explicit_symptoms"] if s[1].lower()=="true"]
    pos += [s[0] for s in sample["implicit_symptoms"] if s[1].lower()=="true"]
    text = "I have " + ", ".join(pos)
    return text
