In [1]:
import google.generativeai as genai
import json
import time
import os

# --- CONFIGURATION ---
# Replace with your actual API key
API_KEY = "AIzaSyA0Gv7sDdriegProVmOhwcTABftdKjdIwA"
OUTPUT_FILE = "gemini_medical_data.json"
TOTAL_SAMPLES_NEEDED = 50
BATCH_SIZE = 5

genai.configure(api_key=API_KEY)

model = genai.GenerativeModel(
    model_name="gemini-2.5-flash", # 1.5 Flash is more stable for JSON than 2.5 preview
    generation_config={"response_mime_type": "application/json"}
)

def get_prompt(num_samples):
    return f"""
    You are an expert medical educator. Generate {num_samples} synthetic medical training examples using "Chain of Thought" reasoning.

    For each example, generate a JSON object with these EXACT keys:
    1. "instruction": Fixed string: "Analyze the patient dialogue, provide clinical reasoning, and write a SOAP note."
    2. "input": A realistic doctor-patient dialogue string.
    3. "reasoning": A string explaining the clinical thought process (e.g., "Patient has fever and RLQ pain, ruling out...").
    4. "output": A SINGLE STRING containing the full SOAP note. Do NOT use a nested object. Format it with headers like "SUBJECTIVE: ... OBJECTIVE: ...".

    CRITICAL: Ensure all values are STRINGS, not objects or lists.

    Return a raw JSON list of these objects.
    """

def validate_and_fix(entry):
    """
    Fixes common data issues before saving to ensure load_dataset doesn't crash.
    """
    # 1. Ensure 'output' is a string, not a dict
    if isinstance(entry.get('output'), dict):
        # Flatten dict to string if model messed up
        soap_dict = entry['output']
        text = ""
        for k, v in soap_dict.items():
            text += f"{k.upper()}:\n{v}\n\n"
        entry['output'] = text.strip()

    # 2. Ensure 'reasoning' exists (Backfill if missing to prevent crash)
    if 'reasoning' not in entry or not entry['reasoning']:
        entry['reasoning'] = "Clinical reasoning not provided for this case."

    # 3. Ensure all are strings
    for key in ['instruction', 'input', 'reasoning', 'output']:
        if not isinstance(entry.get(key), str):
            entry[key] = str(entry.get(key, ""))

    return entry

def generate_dataset():
    collected_data = []
    print(f"Starting generation of {TOTAL_SAMPLES_NEEDED} Chain-of-Thought samples...")

    while len(collected_data) < TOTAL_SAMPLES_NEEDED:
        remaining = TOTAL_SAMPLES_NEEDED - len(collected_data)
        current_batch = min(BATCH_SIZE, remaining)

        print(f"Requesting batch of {current_batch}...")
        try:
            response = model.generate_content(get_prompt(current_batch))
            batch_data = json.loads(response.text)

            if isinstance(batch_data, list):
                # Validate every single entry before adding
                fixed_batch = [validate_and_fix(item) for item in batch_data]
                collected_data.extend(fixed_batch)
                print(f"  -> Added {len(fixed_batch)} samples. (Total: {len(collected_data)})")

            time.sleep(1)

        except Exception as e:
            print(f"  -> Error: {e}")
            time.sleep(2)

    print(f"Saving {len(collected_data)} cleaned samples to {OUTPUT_FILE}...")
    with open(OUTPUT_FILE, "w") as f:
        json.dump(collected_data, f, indent=2)
    print("Done! You can now run the training script.")

if __name__ == "__main__":
    generate_dataset()

Starting generation of 50 Chain-of-Thought samples...
Requesting batch of 5...
  -> Added 5 samples. (Total: 5)
Requesting batch of 5...
  -> Added 5 samples. (Total: 10)
Requesting batch of 5...
  -> Added 5 samples. (Total: 15)
Requesting batch of 5...
  -> Added 5 samples. (Total: 20)
Requesting batch of 5...
  -> Added 5 samples. (Total: 25)
Requesting batch of 5...
  -> Added 5 samples. (Total: 30)
Requesting batch of 5...
  -> Added 5 samples. (Total: 35)
Requesting batch of 5...
  -> Added 5 samples. (Total: 40)
Requesting batch of 5...
  -> Added 5 samples. (Total: 45)
Requesting batch of 5...
  -> Added 5 samples. (Total: 50)
Saving 50 cleaned samples to gemini_medical_data.json...
Done! You can now run the training script.


In [3]:
!pip install trl



In [1]:
!pip install "transformers>=4.41.0" "trl==0.9.4" peft accelerate bitsandbytes


Collecting trl==0.9.4
  Downloading trl-0.9.4-py3-none-any.whl.metadata (11 kB)
Collecting tyro>=0.5.11 (from trl==0.9.4)
  Downloading tyro-0.9.35-py3-none-any.whl.metadata (12 kB)
Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl==0.9.4)
  Downloading shtab-1.8.0-py3-none-any.whl.metadata (7.3 kB)
Downloading trl-0.9.4-py3-none-any.whl (226 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.7/226.7 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tyro-0.9.35-py3-none-any.whl (132 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.6/132.6 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading shtab-1.8.0-py3-none-any.whl (14 kB)
Installing collected packages: shtab, tyro, trl
  Attempting uninstall: trl
    Found existing installation: trl 0.25.1
    Uninstalling trl-0.25.1:
      Successfully uninstalled trl-0.25.1
Successfully installed shtab-1.8.0 trl-0.9.4 tyro-0.9.35


In [2]:
!pip install -U bitsandbytes transformers peft accelerate datasets trl

Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting peft
  Downloading peft-0.18.0-py3-none-any.whl.metadata (14 kB)
Collecting accelerate
  Downloading accelerate-1.12.0-py3-none-any.whl.metadata (19 kB)
Collecting datasets
  Downloading datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting trl
  Downloading trl-0.25.1-py3-none-any.whl.metadata (11 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.2 kB)
Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading peft-0.18.0-py3-none-any.whl (556 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m556.4/556.4 kB[0m [31m45.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading accelerate-1.12.0-py3-none-any.whl (380 kB)
[2K  

In [2]:
from datasets import load_dataset
import unicodedata, string

DATA_FILE = "gemini_medical_data.json"
ds = load_dataset("json", data_files=DATA_FILE, split="train")

def is_mostly_printable(s, threshold=0.95):
    if s is None:
        return False
    s = str(s)
    if len(s) == 0:
        return False
    printable = set(string.printable)
    frac = sum(1 for ch in s if ch in printable) / len(s)
    return frac >= threshold

print("Total samples:", len(ds))
for i in range(min(10, len(ds))):
    inst = ds[i].get("instruction", "")[:300]
    inp  = ds[i].get("input", "")[:300]
    reason = ds[i].get("reasoning", "")[:300]
    out = ds[i].get("output", "")[:300]
    print(f"\n--- sample {i} ---")
    print("instruction OK:", is_mostly_printable(inst), "len:", len(inst))
    print("input OK:", is_mostly_printable(inp), "len:", len(inp))
    print("reasoning OK:", is_mostly_printable(reason), "len:", len(reason))
    print("output OK:", is_mostly_printable(out), "len:", len(out))
    print("instruction excerpt:", repr(inst))
    print("input excerpt:", repr(inp))
    print("reasoning excerpt:", repr(reason))
    print("output excerpt:", repr(out))


Total samples: 50

--- sample 0 ---
instruction OK: True len: 80
input OK: True len: 300
reasoning OK: True len: 300
output OK: True len: 300
instruction excerpt: 'Analyze the patient dialogue, provide clinical reasoning, and write a SOAP note.'
input excerpt: "Doctor: 'Good morning, Mr. Jones. What brings you in today?' Patient: 'Morning, Doc. I've had this really bad pain in my lower right stomach for about 12 hours now. It started off dull around my belly button, but now it's sharp and moved lower right. I also feel pretty nauseous and threw up once. Go"
reasoning excerpt: "The patient's presentation of periumbilical pain migrating to the right lower quadrant (RLQ), associated with nausea, vomiting, and subjective fever, is highly classic for acute appendicitis. Other considerations include gastroenteritis, regional enteritis, or a ruptured ovarian cyst/ectopic pregnan"
output excerpt: 'SUBJECTIVE: Mr. Jones is a 35-year-old male presenting with a 12-hour history of abdominal pain. 

In [2]:
from datasets import load_dataset, Dataset
import string, json

DATA_FILE = "gemini_medical_data.json"
OUT_CLEAN = "gemini_medical_data_clean.json"
ds = load_dataset("json", data_files=DATA_FILE, split="train")

def is_ok(s, threshold=0.95):
    if s is None: return False
    s = str(s)
    if len(s) == 0: return False
    printable = set(string.printable)
    frac = sum(1 for ch in s if ch in printable) / len(s)
    return frac >= threshold

clean = []
for ex in ds:
    if not (is_ok(ex.get("instruction")) and is_ok(ex.get("input"))
            and is_ok(ex.get("reasoning")) and is_ok(ex.get("output"))):
        continue
    # optional: require minimal lengths
    if len(str(ex.get("input","")).strip()) < 10 or len(str(ex.get("output","")).strip()) < 20:
        continue
    clean.append({
        "instruction": str(ex.get("instruction","")).strip(),
        "input": str(ex.get("input","")).strip(),
        "reasoning": str(ex.get("reasoning","")).strip(),
        "output": str(ex.get("output","")).strip(),
    })

print("Kept", len(clean), "of", len(ds))
with open(OUT_CLEAN, "w") as f:
    json.dump(clean, f, indent=2)
print("Saved cleaned dataset to", OUT_CLEAN)


Kept 50 of 50
Saved cleaned dataset to gemini_medical_data_clean.json


In [3]:
# evaluate_and_generate.py
import os
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# ---------- CONFIG ----------
DATA_FILE = "gemini_medical_data_clean.json"  # or "gemini_medical_data.json" (your cleaned file)
BASE_MODEL = "unsloth/llama-3-8b-bnb-4bit"
OUTPUT_DIR = "./llama_medical_lora"            # where you saved adapters (trainer.model.save_pretrained)
MAX_LENGTH = 512
EVAL_BATCH = 4        # number of eval examples to compute loss on (keeps memory small)
GEN_SAMPLES = 6       # number of generation examples to show
MAX_NEW_TOKENS = 200

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:64")

# ---------- LOAD DATA ----------
print("[DATA] Loading dataset...")
ds = load_dataset("json", data_files=DATA_FILE, split="train")
print("Total samples:", len(ds))

# if you want a deterministic small eval/test split, use first N and next N:
eval_ds = ds.select(range(min(EVAL_BATCH, len(ds))))
gen_ds = ds.select(range(min(GEN_SAMPLES, len(ds))))

# ---------- LOAD TOKENIZER & BASE MODEL (4-bit) ----------
print("[MODEL] Loading tokenizer and base model in 4-bit (NF4)...")
tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR, trust_remote_code=True) if os.path.isdir(OUTPUT_DIR) and os.path.exists(os.path.join(OUTPUT_DIR, "tokenizer_config.json")) else AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

# Load base model in 4-bit (device_map auto)
base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    quantization_config=bnb
)

# ---------- LOAD PEFT ADAPTER ----------
print("[PEFT] Loading LoRA adapter from", OUTPUT_DIR)
# This will return a PeftModel wrapping the base model with adapters applied
model = PeftModel.from_pretrained(base, OUTPUT_DIR, device_map="auto")

# Put model in eval mode
model.eval()
for p in model.parameters():
    p.requires_grad = False

device = next(model.parameters()).device
print("Model device:", device)
print("Model dtype:", next(model.parameters()).dtype)
print("Tokenizer vocab size:", tokenizer.vocab_size)

# ---------- HELPERS ----------
def build_prompt(example):
    inst = str(example.get("instruction", "")).strip()
    inp = str(example.get("input", "")).strip()
    reason = str(example.get("reasoning", "")).strip()
    out = str(example.get("output", "")).strip()
    prompt = (
        "### Instruction:\n" + inst + "\n\n"
        "### Input:\n" + inp + "\n\n"
        "### Clinical Reasoning:\n" + reason + "\n\n"
        "### SOAP Note:\n" + out
    )
    return prompt

def tokenize_prompt_only(prompt):
    return tokenizer(prompt, truncation=True, max_length=MAX_LENGTH, return_tensors="pt", padding="longest")

# ---------- QUICK EVAL: compute avg loss over small eval set ----------
print("\n[EVAL] Computing average loss on small eval subset...")
model_loss_sum = 0.0
count = 0

for idx in range(len(eval_ds)):
    example = eval_ds[idx]
    # Build prompt in training format: (Instruction + Input + Reasoning + SOAP)
    text = build_prompt(example)
    enc = tokenizer(text, truncation=True, max_length=MAX_LENGTH, return_tensors="pt", padding="longest")
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)
    # For causal LM, labels = input_ids (shift handled internally)
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
    loss = outputs.loss.item()
    print(f"  sample {idx} loss: {loss:.4f}  (tokens: {input_ids.shape[1]})")
    model_loss_sum += loss
    count += 1

if count > 0:
    avg_loss = model_loss_sum / count
    print(f"[EVAL] Average loss over {count} samples: {avg_loss:.4f}")
    try:
        ppl = float(torch.exp(torch.tensor(avg_loss)))
        print(f"[EVAL] Approx perplexity (exp(avg_loss)): {ppl:.2f}")
    except Exception:
        print("[EVAL] Perplexity computation overflowed.")

# ---------- GENERATION: conservative settings (deterministic / low temp) ----------
print("\n[GEN] Generating SOAP notes for examples...")

for i in range(len(gen_ds)):
    ex = gen_ds[i]
    # Build prompt WITHOUT the target SOAP note so the model must generate it.
    inst = str(ex.get("instruction", "")).strip()
    inp = str(ex.get("input", "")).strip()
    reason = str(ex.get("reasoning", "")).strip()
    # Provide instruction + input + reasoning, but NOT the gold SOAP — let model generate SOAP
    gen_prompt = (
        "### Instruction:\n" + inst + "\n\n"
        "### Input:\n" + inp + "\n\n"
        "### Clinical Reasoning:\n" + reason + "\n\n"
        "### SOAP Note:\n"
    )

    # Debug: show the decoded prompt so you can confirm tokenization is sane
    enc_debug = tokenizer(gen_prompt, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
    print("\n--- GEN SAMPLE", i, "---")
    print("Prompt (decoded):\n", tokenizer.decode(enc_debug["input_ids"][0], skip_special_tokens=True)[:1000])
    input_ids = enc_debug["input_ids"].to(device)

    with torch.no_grad():
        out = model.generate(
            input_ids,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,            # greedy (deterministic) — conservative
            temperature=0.2,
            top_p=0.9,
            repetition_penalty=1.2,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
            early_stopping=True,
        )

    gen_text = tokenizer.decode(out[0], skip_special_tokens=True)
    # The generated text may include the prompt; remove the prompt prefix for clarity:
    if gen_text.startswith(tokenizer.decode(input_ids[0], skip_special_tokens=True)):
        generated_only = gen_text[len(tokenizer.decode(input_ids[0], skip_special_tokens=True)):]
    else:
        generated_only = gen_text

    print("\n=== GENERATED SOAP NOTE ===")
    print(generated_only.strip())
    print("===========================")

print("\n[DONE] Generation complete.")


[DATA] Loading dataset...


Generating train split: 0 examples [00:00, ? examples/s]

Total samples: 50
[MODEL] Loading tokenizer and base model in 4-bit (NF4)...




[PEFT] Loading LoRA adapter from ./llama_medical_lora
Model device: cuda:0
Model dtype: torch.float16
Tokenizer vocab size: 128000

[EVAL] Computing average loss on small eval subset...
  sample 0 loss: 1.2768  (tokens: 512)
  sample 1 loss: 1.0581  (tokens: 512)
  sample 2 loss: 1.2289  (tokens: 512)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


  sample 3 loss: 1.2254  (tokens: 512)
[EVAL] Average loss over 4 samples: 1.1973
[EVAL] Approx perplexity (exp(avg_loss)): 3.31

[GEN] Generating SOAP notes for examples...

--- GEN SAMPLE 0 ---
Prompt (decoded):
 ### Instruction:
Analyze the patient dialogue, provide clinical reasoning, and write a SOAP note.

### Input:
Doctor: 'Good morning, Mr. Jones. What brings you in today?' Patient: 'Morning, Doc. I've had this really bad pain in my lower right stomach for about 12 hours now. It started off dull around my belly button, but now it's sharp and moved lower right. I also feel pretty nauseous and threw up once. Got a bit of a fever too, I think.' Doctor: 'Any changes in bowel habits? Any issues with urination?' Patient: 'No, everything seems normal there. Just this pain, nausea, and feeling really tired.' Doctor: 'Okay, let's get you examined.'

### Clinical Reasoning:
The patient's presentation of periumbilical pain migrating to the right lower quadrant (RLQ), associated with naus