# RTL LoRA Adapter Training

**Radiology Trust Layer — MedGemma Impact Challenge**

This notebook fine-tunes [MedGemma-4B-IT](https://huggingface.co/google/medgemma-4b-it) with a LoRA adapter to improve:
1. **JSON schema compliance** — structured output reliability for the RTL audit pipeline
2. **Uncertainty calibration** — reduced overconfident language in alignment labels

The trained adapter is published to Hugging Face Hub for use in the [RTL live demo](https://huggingface.co/spaces/outlawpink/RadiologyTrustLayer).

**Requirements:**
- Kaggle GPU T4 accelerator
- `HF_TOKEN` added as a Kaggle secret
- MedGemma license accepted at https://huggingface.co/google/medgemma-4b-it

In [None]:
# Install dependencies
!pip install -q peft trl transformers accelerate datasets bitsandbytes huggingface_hub sentencepiece

In [None]:
# Authenticate with Hugging Face Hub
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient

secrets = UserSecretsClient()
HF_TOKEN = secrets.get_secret("HF_TOKEN")
login(token=HF_TOKEN)
print("Logged in to Hugging Face Hub")

In [None]:
# ============================================================
# Generate synthetic training data
# Adapted from: hf_lora/dataset/make_synthetic.py
# ============================================================
import json
import random
from pathlib import Path

OUTPUT_DIR = Path("/kaggle/working")

CLAIM_TEMPLATES = [
    ("There is {finding} in the {location}.", "finding"),
    ("No {finding} is identified.", "absence"),
    ("The {finding} measures {size} cm.", "measurement"),
    ("Findings are consistent with {diagnosis}.", "impression"),
    ("{finding} is noted, possibly representing {diagnosis}.", "impression"),
    ("Mild {finding} is present.", "finding"),
    ("The {location} appears within normal limits.", "finding"),
]

FINDINGS = ["consolidation", "opacity", "effusion", "atelectasis", "pneumothorax",
            "infiltrate", "nodule", "mass", "cardiomegaly", "hyperinflation"]
LOCATIONS = ["right lower lobe", "left upper lobe", "bilateral lung bases",
             "right hemithorax", "left costophrenic angle", "mediastinum", "right hilum"]
DIAGNOSES = ["pneumonia", "heart failure", "COPD", "pulmonary edema", "lung cancer",
             "pleural effusion", "atelectasis"]
SIZES = [str(round(random.uniform(0.5, 4.0), 1)) for _ in range(20)]
LABELS = ["supported", "uncertain", "not_assessable", "needs_review"]

OVERCONFIDENT_PHRASES = [
    ("There is definitely", "There appears to be"),
    ("This is consistent with", "Findings may be consistent with"),
    ("Clearly shows", "Suggests"),
    ("No doubt", "Possibly"),
    ("Confirms", "May suggest"),
]


def make_claim(i):
    template, ctype = random.choice(CLAIM_TEMPLATES)
    text = template.format(
        finding=random.choice(FINDINGS),
        location=random.choice(LOCATIONS),
        size=random.choice(SIZES),
        diagnosis=random.choice(DIAGNOSES),
    )
    return {"claim_id": f"c{i+1}", "text": text,
            "sentence_span": {"start": i*60, "end": i*60+len(text)}, "claim_type": ctype}


def make_alignment_example(n_claims=4):
    claims = [make_claim(i) for i in range(n_claims)]
    alignments = []
    for claim in claims:
        label = random.choices(LABELS, weights=[0.5, 0.25, 0.15, 0.1])[0]
        alignments.append({
            "claim_id": claim["claim_id"],
            "label": label,
            "evidence": f"Visual evidence {'supports' if label=='supported' else 'does not clearly support'} this claim.",
            "confidence": round(random.uniform(0.5, 0.95), 2),
            "related_finding_ids": [f"f{random.randint(1,3)}"],
            "claim_text": claim["text"],
        })
    return {"claims": claims, "alignments": alignments}


def make_json_compliance_pair():
    example = make_alignment_example()
    claims_json = json.dumps(example["claims"], indent=2)
    schema = json.dumps({"type": "object", "required": ["alignments"],
                         "properties": {"alignments": {"type": "array"}}})
    prompt = f"Align the following claims to image findings.\nClaims:\n{claims_json}\nRespond with JSON matching: {schema}"
    good = json.dumps({"alignments": example["alignments"]}, indent=2)
    return {"prompt": prompt, "good": good}


def make_uncertainty_pair():
    claim_text = make_claim(0)["text"]
    calibrated = claim_text
    original = random.choice(["Definite", "Clearly", "Obviously", ""]) + " " + claim_text
    original = original.strip()
    for over, cal in OVERCONFIDENT_PHRASES:
        if over.lower() in original.lower():
            calibrated = original.replace(over, cal)
            break
    return {"overconfident": original, "calibrated": calibrated}


def generate_dataset(n_train=200, n_eval=50):
    # JSON schema compliance
    with open(OUTPUT_DIR / "train.jsonl", "w") as f:
        for _ in range(n_train):
            pair = make_json_compliance_pair()
            f.write(json.dumps({"prompt": pair["prompt"], "completion": pair["good"]}) + "\n")
    with open(OUTPUT_DIR / "eval.jsonl", "w") as f:
        for _ in range(n_eval):
            pair = make_json_compliance_pair()
            f.write(json.dumps({"prompt": pair["prompt"], "completion": pair["good"]}) + "\n")

    # Uncertainty calibration
    with open(OUTPUT_DIR / "uncertainty_train.jsonl", "w") as f:
        for _ in range(n_train):
            pair = make_uncertainty_pair()
            f.write(json.dumps({"input": pair["overconfident"], "output": pair["calibrated"]}) + "\n")
    with open(OUTPUT_DIR / "uncertainty_eval.jsonl", "w") as f:
        for _ in range(n_eval):
            pair = make_uncertainty_pair()
            f.write(json.dumps({"input": pair["overconfident"], "output": pair["calibrated"]}) + "\n")

    print(f"Generated {n_train} train + {n_eval} eval pairs")


generate_dataset()

In [None]:
# ============================================================
# Format data for Gemma chat template
# Adapted from: hf_lora/dataset/format.py
# ============================================================

def format_for_chat(input_path, output_path):
    """Convert prompt/completion pairs to Gemma chat template."""
    count = 0
    with open(input_path) as fin, open(output_path, "w") as fout:
        for line in fin:
            pair = json.loads(line.strip())
            prompt = pair.get("prompt", pair.get("input", ""))
            completion = pair.get("completion", pair.get("output", ""))
            text = (
                f"<start_of_turn>user\n{prompt}<end_of_turn>\n"
                f"<start_of_turn>model\n{completion}<end_of_turn>"
            )
            fout.write(json.dumps({"text": text}) + "\n")
            count += 1
    print(f"Formatted {count} pairs -> {output_path}")


# Format all datasets
for split in ["train", "eval"]:
    format_for_chat(OUTPUT_DIR / f"{split}.jsonl", OUTPUT_DIR / f"{split}_chat.jsonl")
for split in ["train", "eval"]:
    format_for_chat(OUTPUT_DIR / f"uncertainty_{split}.jsonl", OUTPUT_DIR / f"uncertainty_{split}_chat.jsonl")

In [None]:
# ============================================================
# Load MedGemma with 8-bit quantization (fits in T4 16GB)
# ============================================================
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig

MODEL_ID = "google/medgemma-4b-it"

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

print(f"Loading {MODEL_ID} with 8-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=quantization_config,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
tokenizer = processor.tokenizer

print(f"Model loaded. Device: {model.device}")
print(f"GPU memory: {torch.cuda.memory_allocated() / 1e9:.1f} GB")

In [None]:
# ============================================================
# Configure and apply LoRA adapter
# ============================================================
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
# ============================================================
# Train with SFTTrainer
# ============================================================
from datasets import load_dataset
from transformers import TrainingArguments
from trl import SFTTrainer

CHECKPOINT_DIR = "/kaggle/working/rtl-lora-checkpoint"

dataset = load_dataset("json", data_files={
    "train": str(OUTPUT_DIR / "train_chat.jsonl"),
    "validation": str(OUTPUT_DIR / "eval_chat.jsonl"),
})

print(f"Train: {len(dataset['train'])} samples, Eval: {len(dataset['validation'])} samples")

training_args = TrainingArguments(
    output_dir=CHECKPOINT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    bf16=False,
    fp16=True,
    max_grad_norm=1.0,
    logging_steps=10,
    save_steps=50,
    eval_steps=50,
    save_total_limit=2,
    evaluation_strategy="steps",
    load_best_model_at_end=True,
    report_to="none",
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    dataset_text_field="text",
    max_seq_length=1024,
    packing=False,
)

print("Starting training...")
trainer.train()
trainer.save_model(CHECKPOINT_DIR)
print(f"Training complete. Model saved to {CHECKPOINT_DIR}")

# Save training summary
import json
summary = {
    "model_id": MODEL_ID,
    "lora_r": 8,
    "lora_alpha": 16,
    "target_modules": ["q_proj", "v_proj"],
    "train_samples": len(dataset["train"]),
    "eval_samples": len(dataset["validation"]),
    "epochs": 3,
    "quantization": "8-bit",
}
with open(f"{CHECKPOINT_DIR}/training_summary.json", "w") as f:
    json.dump(summary, f, indent=2)
print("Training summary saved.")

In [None]:
# ============================================================
# Evaluate before/after LoRA
# Adapted from: hf_lora/eval/eval_lora_before_after.py
# ============================================================
import re

OVERCONFIDENT_PATTERNS = [
    r"\bdefinitely\b", r"\bclearly\b", r"\bobviously\b", r"\bconfirms\b",
    r"\bno doubt\b", r"\bwithout question\b", r"\bconclusively\b",
]


def is_json_valid(text):
    try:
        data = json.loads(text)
        return "alignments" in data and isinstance(data["alignments"], list)
    except Exception:
        return False


def try_extract_json(text):
    try:
        return json.loads(text)
    except Exception:
        m = re.search(r"\{.*\}", text, re.DOTALL)
        if m:
            try:
                return json.loads(m.group(0))
            except Exception:
                pass
    return None


def has_overconfident_language(text):
    return any(re.search(p, text.lower()) for p in OVERCONFIDENT_PATTERNS)


def evaluate_model(model_fn, test_cases):
    n = len(test_cases)
    json_valid = 0
    overconfident = 0
    schema_repaired = 0

    for case in test_cases:
        prompt = case["prompt"]
        try:
            output = model_fn(prompt)
        except Exception:
            output = ""

        if is_json_valid(output):
            json_valid += 1
        else:
            parsed = try_extract_json(output)
            if parsed and "alignments" in parsed:
                json_valid += 1
                schema_repaired += 1

        if has_overconfident_language(output):
            overconfident += 1

    return {
        "n": n,
        "json_valid_rate": json_valid / n,
        "overconfidence_rate": overconfident / n,
        "schema_repair_rate": schema_repaired / n,
    }


# Load test cases
test_cases = []
with open(OUTPUT_DIR / "eval.jsonl") as f:
    for line in f:
        test_cases.append(json.loads(line.strip()))
test_cases = test_cases[:50]


# Create inference function
def make_inference_fn(m):
    def fn(prompt):
        inputs = tokenizer(prompt, return_tensors="pt").to(m.device)
        with torch.no_grad():
            out = m.generate(**inputs, max_new_tokens=512, do_sample=False)
        return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    return fn


# Evaluate LoRA model
print("Evaluating LoRA model on", len(test_cases), "cases...")
lora_fn = make_inference_fn(model)
lora_metrics = evaluate_model(lora_fn, test_cases)

print("\n=== LoRA Model Results ===")
print(f"JSON Valid Rate:      {lora_metrics['json_valid_rate']:.1%}")
print(f"Overconfidence Rate:  {lora_metrics['overconfidence_rate']:.1%}")
print(f"Schema Repair Rate:   {lora_metrics['schema_repair_rate']:.1%}")

# Save metrics
with open(f"{CHECKPOINT_DIR}/eval_metrics.json", "w") as f:
    json.dump({"lora_model": lora_metrics, "test_set_size": len(test_cases)}, f, indent=2)
print(f"\nMetrics saved to {CHECKPOINT_DIR}/eval_metrics.json")

In [None]:
# ============================================================
# Push adapter to Hugging Face Hub
# ============================================================
from huggingface_hub import HfApi, upload_folder

REPO_ID = "outlawpink/rtl-medgemma-lora"

api = HfApi()
api.create_repo(repo_id=REPO_ID, exist_ok=True, repo_type="model")

# Generate model card
metrics = lora_metrics
model_card = f"""---
library_name: peft
base_model: google/medgemma-4b-it
tags:
  - medical
  - radiology
  - lora
  - peft
  - medgemma
  - rtl
license: apache-2.0
---

# RTL LoRA Adapter for MedGemma

LoRA adapter for `google/medgemma-4b-it` trained for the **Radiology Trust Layer (RTL)** project.

## What this adapter does

- Improves **JSON schema compliance** in structured output tasks
- Reduces **overconfident language** in uncertainty alignment tasks
- Trained on synthetic radiology QA pairs (no PHI)

## Usage

```python
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoProcessor

base = AutoModelForCausalLM.from_pretrained("google/medgemma-4b-it")
model = PeftModel.from_pretrained(base, "{REPO_ID}")
model = model.merge_and_unload()
```

## Training Details

| Setting | Value |
|---------|-------|
| Base model | `google/medgemma-4b-it` |
| LoRA rank | 8 |
| LoRA alpha | 16 |
| Target modules | q_proj, v_proj |
| Training samples | 200 |
| Quantization | 8-bit (bitsandbytes) |
| Precision | fp16 |
| Framework | PEFT + TRL SFTTrainer |

## Evaluation

| Metric | Score |
|--------|-------|
| JSON Valid Rate | {metrics['json_valid_rate']:.1%} |
| Overconfidence Rate | {metrics['overconfidence_rate']:.1%} |
| Schema Repair Rate | {metrics['schema_repair_rate']:.1%} |

## Links

- [RTL Live Demo](https://huggingface.co/spaces/outlawpink/RadiologyTrustLayer)
- [GitHub Repository](https://github.com/carmmmm/RadiologyTrustLayer)

## Disclaimer

This adapter is for **research purposes only**. Not intended for clinical use.
"""

with open(f"{CHECKPOINT_DIR}/README.md", "w") as f:
    f.write(model_card)

# Upload
url = upload_folder(
    folder_path=CHECKPOINT_DIR,
    repo_id=REPO_ID,
    commit_message="Upload RTL LoRA adapter",
)
print(f"\nAdapter published: https://huggingface.co/{REPO_ID}")

## Done!

The LoRA adapter has been published to Hugging Face Hub.

**To use it in the RTL app:**

1. Go to your HF Space settings: https://huggingface.co/spaces/outlawpink/RadiologyTrustLayer/settings
2. Add variable: `RTL_LORA_ID=outlawpink/rtl-medgemma-lora`
3. Check the "Use RTL LoRA adapter" checkbox when running an audit

**Published model:** https://huggingface.co/outlawpink/rtl-medgemma-lora