# MedAssist-Edge: Offline Agentic Clinical Co-Pilot
### Google MedGemma Kaggle Competition — Submission Notebook

> **⚠️ Clinical Decision Support Only.** All outputs are AI-generated and must be reviewed by a qualified clinician before any clinical action.

---

## Overview

| Agent | Role |
|---|---|
| SOAP Structuring Agent | Reorganises raw clinical notes into SOAP format |
| Differential Diagnosis Agent | Generates ranked DDx with evidence |
| Guideline Retrieval Agent (RAG) | Retrieves relevant local guideline passages |
| Patient Explanation Agent | Produces plain-language patient summary |

**Model: `google/medgemma-1.5-4b-it` — runs fully offline after download.**

**API note:** MedGemma 1.5 uses `AutoProcessor` + `AutoModelForImageTextToText`
and `processor.apply_chat_template()` — NOT `AutoTokenizer`/`AutoModelForCausalLM`.

## 1. Environment Setup

In [None]:
# transformers >= 4.50.0 required for Gemma 3 / MedGemma 1.5
!pip install -q 'transformers>=4.50.0' accelerate sentencepiece bitsandbytes faiss-cpu sentence-transformers pymupdf
print('Dependencies installed.')

In [None]:
import sys, json, re, time, logging
from pathlib import Path

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger('medassist')

ROOT = Path('/kaggle/working')
MODEL_CACHE  = ROOT / 'model_cache'
EMBED_CACHE  = ROOT / 'embed_cache'
VECTOR_STORE = ROOT / 'vector_store'
GUIDELINES_DIR = ROOT / 'guidelines'

for d in [MODEL_CACHE, EMBED_CACHE, VECTOR_STORE, GUIDELINES_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print('Paths configured.')

## 2. Load MedGemma 1.5 (Inference Engine)

MedGemma 1.5 requires:
- `AutoProcessor` (not `AutoTokenizer`)
- `AutoModelForImageTextToText` (not `AutoModelForCausalLM`)
- `processor.apply_chat_template()` for chat formatting
- `do_sample=False` by default (Jan 23 2026 model card update)

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

MODEL_ID = 'google/medgemma-1.5-4b-it'
USE_GPU  = torch.cuda.is_available()
print(f'GPU available: {USE_GPU}')

print(f'Loading processor: {MODEL_ID} ...')
processor = AutoProcessor.from_pretrained(MODEL_ID, cache_dir=str(MODEL_CACHE))

print('Loading model weights ...')
t0 = time.time()
if USE_GPU:
    model = AutoModelForImageTextToText.from_pretrained(
        MODEL_ID,
        cache_dir=str(MODEL_CACHE),
        torch_dtype=torch.bfloat16,
        device_map='auto',
    )
else:
    # CPU path: int8 quantization via BitsAndBytes
    bnb_cfg = BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=torch.float32)
    model = AutoModelForImageTextToText.from_pretrained(
        MODEL_ID,
        cache_dir=str(MODEL_CACHE),
        quantization_config=bnb_cfg,
        low_cpu_mem_usage=True,
    )

print(f'Model loaded in {time.time()-t0:.1f}s  device={model.device}')

In [None]:
# ── Shared inference utilities ────────────────────────────────────────────────

def generate(messages: list) -> str:
    """Run MedGemma 1.5 text-only inference from a messages list."""
    dtype = torch.bfloat16 if model.device.type != 'cpu' else torch.float32
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors='pt',
    ).to(model.device, dtype=dtype)
    input_len = inputs['input_ids'].shape[-1]
    with torch.inference_mode():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=768,
            do_sample=False,          # greedy — deterministic (model card default)
            repetition_penalty=1.15,
        )
    new_tokens = output_ids[0][input_len:]
    return processor.decode(new_tokens, skip_special_tokens=True).strip()


def make_messages(system: str, user: str) -> list:
    """Build a MedGemma 1.5 messages list for text-only inference."""
    return [
        {'role': 'system', 'content': [{'type': 'text', 'text': system}]},
        {'role': 'user',   'content': [{'type': 'text', 'text': user}]},
    ]


def parse_json(raw: str) -> dict:
    """Extract JSON from model output, handling markdown fences."""
    text = re.sub(r'```(?:json)?', '', raw).strip().rstrip('`').strip()
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        m = re.search(r'\{[\s\S]+\}', text)
        if m:
            try: return json.loads(m.group())
            except: pass
    return {}


print('Inference utilities ready.')

## 3. RAG Pipeline Setup

In [None]:
# Write sample guideline content
sample = """IDIOPATHIC PULMONARY FIBROSIS (IPF) — ATS/ERS Guidelines 2022

DIAGNOSIS: IPF requires UIP pattern on HRCT (honeycombing, traction bronchiectasis) plus
exclusion of other ILD causes.

WORKUP: PFTs (restrictive pattern: FVC↓, FEV1/FVC normal/↑, DLCO↓), 6MWT, echocardiography
to exclude pulmonary hypertension. Serological panel (ANA, RF, anti-CCP) for CTD-ILD.

MANAGEMENT: Antifibrotic therapy (nintedanib or pirfenidone) slows FVC decline.
Supplemental O2 if SpO2 <88%. Lung transplant referral at diagnosis. Pulmonary rehab.
AVOID: prednisone + azathioprine + NAC (shown harmful).

MONITORING: FVC every 3-6 months; >10% decline = significant progression.
DLCO every 6-12 months.

COMMUNITY-ACQUIRED PNEUMONIA (CAP) — IDSA/ATS 2019

DIAGNOSIS: New infiltrate + fever/cough/leukocytosis. CURB-65 guides site of care.
CURB-65 ≥2 = consider hospitalisation.

WORKUP: CXR, pulse oximetry. Blood cultures x2 before antibiotics in severe CAP.
Urinary antigens (Legionella, S. pneumoniae) in severe cases. Procalcitonin for
antibiotic duration guidance.

MANAGEMENT: Empirical antibiotics (confirm with clinician and local formulary).
Duration 5 days for non-severe CAP if clinically improving.
"""
(GUIDELINES_DIR / 'sample_guidelines.txt').write_text(sample)
print('Sample guideline written.')

In [None]:
from sentence_transformers import SentenceTransformer
import faiss, numpy as np

EMBED_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
TOP_K = 4

print(f'Loading embedding model: {EMBED_MODEL} ...')
embed_model = SentenceTransformer(EMBED_MODEL, cache_folder=str(EMBED_CACHE))

def chunk_text(text, source, size=512, overlap=64):
    chunks, start, cid = [], 0, 0
    while start < len(text):
        t = text[start:start+size].strip()
        if t: chunks.append({'source': source, 'chunk_id': cid, 'text': t}); cid += 1
        start += size - overlap
    return chunks

all_chunks = []
for f in GUIDELINES_DIR.glob('*.txt'):
    all_chunks.extend(chunk_text(f.read_text(encoding='utf-8', errors='ignore'), f.name))

embeddings = embed_model.encode(
    [c['text'] for c in all_chunks], normalize_embeddings=True
).astype(np.float32)

index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
faiss.write_index(index, str(VECTOR_STORE / 'guidelines.faiss'))
with open(VECTOR_STORE / 'metadata.json', 'w') as f: json.dump(all_chunks, f)

print(f'RAG index: {len(all_chunks)} chunks')

def retrieve(query: str, k: int = TOP_K):
    q = embed_model.encode([query], normalize_embeddings=True).astype(np.float32)
    scores, idxs = index.search(q, k)
    return [{'source': all_chunks[i]['source'], 'text': all_chunks[i]['text'],
              'score': float(s)} for s, i in zip(scores[0], idxs[0]) if i >= 0]

print('RAG pipeline ready.')

## 4. Agent Definitions

Each agent uses the MedGemma 1.5 `make_messages()` + `generate()` pattern.
No raw `<start_of_turn>` strings — the processor handles chat formatting.

In [None]:
# ─── AGENT 1: SOAP Structuring ────────────────────────────────────────────────

SOAP_SYSTEM = (
    "You are a clinical documentation assistant. Reorganise the clinician's raw "
    "notes into SOAP format.\n\nRULES:\n"
    "1. Use ONLY information explicitly stated in the input. Do NOT add or infer anything.\n"
    "2. Plan section: SUGGESTIONS ONLY. Never write orders, prescriptions, or dosages.\n"
    "3. If a section is absent, write 'Not documented.'\n"
    "4. Output ONLY valid JSON:\n"
    '{"subjective": "...", "objective": "...", "assessment": "...", "plan_suggestions": "..."}'
)

def soap_agent(clinical_notes, lab_results='', radiology_text='', age=None, sex=None):
    user = (
        f"Organise into SOAP format.\n\n"
        f"CLINICAL NOTES:\n{clinical_notes}\n\n"
        f"LAB RESULTS:\n{lab_results or 'Not provided'}\n\n"
        f"RADIOLOGY:\n{radiology_text or 'Not provided'}\n\n"
        f"DEMOGRAPHICS: Age={age or 'N/A'} Sex={sex or 'N/A'}\n\n"
        "Return only the JSON object."
    )
    raw = generate(make_messages(SOAP_SYSTEM, user))
    parsed = parse_json(raw)
    return {
        'subjective': parsed.get('subjective', 'Not documented.'),
        'objective':  parsed.get('objective',  'Not documented.'),
        'assessment': parsed.get('assessment', 'Not documented.'),
        'plan_suggestions': parsed.get('plan_suggestions', 'Not documented.'),
        'raw': raw
    }

print('Agent 1 (SOAP) defined.')

In [None]:
# ─── AGENT 2: Differential Diagnosis ─────────────────────────────────────────

DDX_SYSTEM = (
    "You are a clinical reasoning assistant for differential diagnosis.\n\nRULES:\n"
    "1. Generate a RANKED differential (max 5 conditions).\n"
    "2. For each: condition, likelihood (High/Moderate/Low), supporting_features, against_features.\n"
    "3. Base reasoning ONLY on provided information. Use hedged language.\n"
    "4. Do NOT confirm any diagnosis. Do NOT recommend treatments or medications.\n"
    "5. Output ONLY valid JSON:\n"
    '{"diagnoses": [{"rank":1,"condition":"...","likelihood":"High",'
    '"supporting_features":"...","against_features":"..."}],'
    '"reasoning_summary": "..."}'
)

def ddx_agent(soap, age=None, sex=None):
    user = (
        f"Generate ranked differential.\n\n"
        f"SUBJECTIVE: {soap['subjective']}\n"
        f"OBJECTIVE: {soap['objective']}\n"
        f"ASSESSMENT: {soap['assessment']}\n"
        f"DEMOGRAPHICS: Age={age or 'N/A'} Sex={sex or 'N/A'}"
    )
    raw = generate(make_messages(DDX_SYSTEM, user))
    parsed = parse_json(raw)
    return {
        'diagnoses': parsed.get('diagnoses', [])[:5],
        'reasoning_summary': parsed.get('reasoning_summary', ''),
        'raw': raw
    }

print('Agent 2 (DDx) defined.')

In [None]:
# ─── AGENT 3: Guideline Retrieval (RAG) ──────────────────────────────────────

GUIDELINE_SYSTEM = (
    "You are a clinical guideline synthesis assistant.\n\nRULES:\n"
    "1. Synthesise ONLY from provided retrieved excerpts. Do NOT use training knowledge.\n"
    "2. Attribute every recommendation to its source.\n"
    "3. Organise by category: Workup, Management, Monitoring, Follow-up.\n"
    "4. Do NOT recommend specific drug doses. Do NOT issue clinical orders.\n"
    "5. Output ONLY valid JSON:\n"
    '{"recommendations":[{"category":"...","recommendation":"...",'
    '"source":"...","confidence":"Direct|Inferred|Low-evidence"}],'
    '"retrieved_sources":["..."]}'
)

def guideline_agent(soap, ddx):
    conditions = ', '.join(d['condition'] for d in ddx['diagnoses'][:3])
    query = f"{soap['assessment']} {conditions} {soap['subjective'][:200]}"
    chunks = retrieve(query)
    chunk_block = '\n\n---\n\n'.join(
        f"[{i+1}] SOURCE: {c['source']}\n{c['text']}" for i, c in enumerate(chunks)
    ) or 'No guideline excerpts retrieved.'
    user = (
        f"CLINICAL SUMMARY:\nAssessment: {soap['assessment']}\nTop DDx: {conditions}\n\n"
        f"RETRIEVED GUIDELINE EXCERPTS:\n{chunk_block}\n\n"
        "Synthesise recommendations from excerpts ONLY."
    )
    raw = generate(make_messages(GUIDELINE_SYSTEM, user))
    parsed = parse_json(raw)
    return {
        'recommendations': parsed.get('recommendations', []),
        'retrieved_sources': [c['source'] for c in chunks],
        'raw': raw
    }

print('Agent 3 (Guidelines) defined.')

In [None]:
# ─── AGENT 4: Patient Explanation ────────────────────────────────────────────

PATIENT_SYSTEM = (
    "You are a patient communication assistant.\n\nRULES:\n"
    "1. Write at 6th-grade reading level. Explain medical terms in brackets.\n"
    "2. Do NOT confirm any diagnosis. Say 'your doctor is considering...'\n"
    "3. Do NOT mention medications or specific procedures.\n"
    "4. Be warm, empathetic, non-alarming.\n"
    "5. End with next-steps encouraging patient to speak with their doctor.\n"
    "6. Output ONLY valid JSON:\n"
    '{"summary":"...","key_points":["..."],"next_steps_suggestion":"..."}'
)

def patient_agent(soap, ddx):
    ddx_summary = '\n'.join(
        f"- {d['condition']} is being considered ({d['likelihood']} likelihood)"
        for d in ddx['diagnoses'][:3]
    ) or 'No specific conditions documented yet.'
    user = (
        f"Translate to plain patient-friendly language.\n\n"
        f"SYMPTOMS/FINDINGS: {soap['subjective']}\n{soap['objective']}\n\n"
        f"WORKING ASSESSMENT: {soap['assessment']}\n\n"
        f"POSSIBILITIES (not confirmed): {ddx_summary}"
    )
    raw = generate(make_messages(PATIENT_SYSTEM, user))
    parsed = parse_json(raw)
    return {
        'summary': parsed.get('summary', 'Please speak with your doctor.'),
        'key_points': parsed.get('key_points', [])[:5],
        'next_steps_suggestion': parsed.get('next_steps_suggestion', 'Please discuss with your doctor.'),
        'raw': raw
    }

print('Agent 4 (Patient) defined.')

## 5. Full Pipeline — Demo Case

**Case:** 45-year-old female, progressive dyspnoea, restrictive spirometry, HRCT honeycombing

In [None]:
DEMO_CASE = dict(
    clinical_notes=(
        "45-year-old female presenting with a 3-week history of progressive dyspnea on exertion, "
        "dry cough, and fatigue. No fever, no chest pain, no haemoptysis. Non-smoker. No recent "
        "travel. Works in a textile factory. On examination: RR 22/min, SpO2 91% on room air, "
        "bilateral fine inspiratory crackles at lung bases. No clubbing, no peripheral oedema."
    ),
    lab_results=(
        "CBC: WBC 8.2, Hgb 11.4 g/dL, Plt 310. LDH 280 U/L (ref <225). "
        "ESR 68 mm/hr. CRP 2.1 mg/dL. Spirometry: FVC 68% predicted, FEV1/FVC 0.81."
    ),
    radiology_text=(
        "HRCT Chest: Bilateral ground-glass opacities in the lower lobes with honeycombing "
        "and traction bronchiectasis. No pleural effusion. No lymphadenopathy."
    ),
    age=45,
    sex='female'
)
print('Demo case loaded.')

In [None]:
DISCLAIMER = (
    '⚠️  CLINICAL DECISION SUPPORT ONLY — AI-generated output must be reviewed '
    'by a qualified clinician. Not a diagnosis, prescription, or medical order.'
)

t_start = time.time()
print('=' * 60)

print('[1/4] SOAP Structuring Agent ...')
t0 = time.time(); soap = soap_agent(**DEMO_CASE)
print(f'  {time.time()-t0:.1f}s')

print('[2/4] Differential Diagnosis Agent ...')
t0 = time.time(); ddx = ddx_agent(soap, age=45, sex='female')
print(f'  {time.time()-t0:.1f}s — {len(ddx["diagnoses"])} entries')

print('[3/4] Guideline Retrieval Agent (RAG) ...')
t0 = time.time(); guidelines = guideline_agent(soap, ddx)
print(f'  {time.time()-t0:.1f}s — {len(guidelines["recommendations"])} recs')

print('[4/4] Patient Explanation Agent ...')
t0 = time.time(); patient = patient_agent(soap, ddx)
print(f'  {time.time()-t0:.1f}s')

print(f'\nTotal pipeline: {time.time()-t_start:.1f}s')
print('=' * 60)

## 6. Results

In [None]:
print(DISCLAIMER); print()
print('━'*60); print('AGENT 1 — SOAP NOTE'); print('━'*60)
for k, v in [('SUBJECTIVE', soap['subjective']), ('OBJECTIVE', soap['objective']),
              ('ASSESSMENT', soap['assessment']), ('PLAN SUGGESTIONS', soap['plan_suggestions'])]:
    print(f'\n{k}:\n{v}')

In [None]:
print('━'*60); print('AGENT 2 — DIFFERENTIAL DIAGNOSIS'); print('━'*60)
for d in ddx['diagnoses']:
    print(f"\n{d['rank']}. {d['condition']} — {d['likelihood']}")
    print(f"   FOR    : {d['supporting_features']}")
    print(f"   AGAINST: {d['against_features']}")
print(f"\nReasoning: {ddx['reasoning_summary']}")

In [None]:
print('━'*60); print('AGENT 3 — GUIDELINE RECOMMENDATIONS (RAG)'); print('━'*60)
print(f"Sources: {', '.join(guidelines['retrieved_sources'])}\n")
for r in guidelines['recommendations']:
    print(f"[{r['category']}] ({r['confidence']}) {r['recommendation']}")
    print(f"  Source: {r['source']}\n")

In [None]:
print('━'*60); print('AGENT 4 — PATIENT EXPLANATION'); print('━'*60)
print(patient['summary'])
print('\nKey Points:')
for i, p in enumerate(patient['key_points'], 1): print(f'  {i}. {p}')
print(f"\nNext Steps: {patient['next_steps_suggestion']}")

## 7. Summary

| Metric | Value |
|---|---|
| Model | google/medgemma-1.5-4b-it |
| Model class | AutoModelForImageTextToText |
| Processor | AutoProcessor |
| Quantization | int8 CPU / bfloat16 GPU |
| Agents executed | 4 |
| RAG | Local FAISS index |
| Internet dependency | None at inference time |
| Diagnoses confirmed | 0 — by design |
| Prescriptions issued | 0 — by design |

---

> **Research and competition use only. Not clinically validated. Not approved for clinical use.**