# MedGemma Spatial Transcriptomics Analysis

**Competition**: Google – MedGemma AI Impact Challenge  
**Repository**: https://github.com/harshameghadri/medgemma-spatial

## Pipeline
1. Load Visium spatial transcriptomics data (h5ad)
2. QC filtering + preprocessing
3. Spatial Leiden clustering
4. Tissue-agnostic cell type annotation (CellTypist immune model + z-score marker panels)
5. Spatial statistics (Moran's I, entropy, neighborhood enrichment)
6. Clinical report generation using MedGemma 4B

**Privacy**: All processing is local — no patient data leaves the machine.

In [None]:
# ── Setup & dependencies ────────────────────────────────────────────────────
import subprocess, sys, os, warnings, json, time
from pathlib import Path
warnings.filterwarnings('ignore')

def pip_install(pkg):
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])

for pkg in ['celltypist', 'squidpy']:
    try:
        __import__(pkg)
    except ImportError:
        print(f'Installing {pkg}...')
        pip_install(pkg)

print('Dependencies ready')

In [None]:
# ── HuggingFace token (required for MedGemma; falls back to demo mode) ──────
HF_TOKEN = None
try:
    from kaggle_secrets import UserSecretsClient
    HF_TOKEN = UserSecretsClient().get_secret('HF_TOKEN')
    os.environ['HF_TOKEN'] = HF_TOKEN
    print('HF_TOKEN loaded from Kaggle secrets')
except Exception:
    HF_TOKEN = os.environ.get('HF_TOKEN')
    if HF_TOKEN:
        print('HF_TOKEN loaded from environment')
    else:
        print('No HF_TOKEN — demo mode will be used for report generation')

DEMO_MODE = HF_TOKEN is None

In [None]:
# ── Clone repo and import src/ modules ────────────────────────────────────────
import importlib
import scanpy as sc
import numpy as np
import pandas as pd

REPO_ROOT = Path('/kaggle/working/medgemma-spatial')

# Determine where src/ lives: local or Kaggle
_src_found_locally = any(
    (c / 'src' / 'streamlit_adapter.py').exists()
    for c in [Path('.'), Path('..')]
)

if _src_found_locally:
    for candidate in [Path('.'), Path('..')]:
        if (candidate / 'src' / 'streamlit_adapter.py').exists():
            REPO_ROOT = candidate.resolve()
            break
    print(f'Local src/ found at {REPO_ROOT}')
else:
    # On Kaggle — always clone fresh
    if not REPO_ROOT.exists():
        print('Cloning repo to /kaggle/working/medgemma-spatial...')
        subprocess.run(
            ['git', 'clone', '--depth=1',
             'https://github.com/harshameghadri/medgemma-spatial.git',
             str(REPO_ROOT)],
            check=True
        )
        print('Clone complete')
    else:
        print(f'Repo already at {REPO_ROOT}')

# Insert correct root at front of sys.path
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# Purge any stale src.* cache from a previous failed import attempt
for _mod in list(sys.modules.keys()):
    if _mod == 'src' or _mod.startswith('src.'):
        del sys.modules[_mod]
importlib.invalidate_caches()

from src.streamlit_adapter import annotate_spatial_regions, calculate_spatial_heterogeneity
from src.report_generation.prompt_builder import generate_medgemma_prompt, evaluate_report_quality
print(f'src/ modules loaded from {REPO_ROOT}')
print(f'scanpy {sc.__version__}')

In [None]:
# ── Download CellTypist models (1-3 MB each, auto-cached) ────────────────────
import celltypist
celltypist.models.download_models(force_update=False)
print('CellTypist models ready')

In [None]:
# ── Load Visium h5ad ──────────────────────────────────────────────────────────
DATA_PATHS = [
    Path('/kaggle/input/medgemma-spatial-data/annotated_visium.h5ad'),
    Path('/kaggle/input/medgemma-spatial-data/visium_breast_cancer.h5ad'),
    REPO_ROOT / 'outputs' / 'annotated_visium.h5ad',
    Path('../outputs/annotated_visium.h5ad'),
    Path('outputs/annotated_visium.h5ad'),
]

h5ad_path = next((p for p in DATA_PATHS if p.exists()), None)

if h5ad_path is None:
    raise FileNotFoundError(
        'No h5ad file found. Add the medgemma-spatial-data dataset to this notebook.'
    )

print(f'Loading: {h5ad_path}')
t0 = time.time()
adata = sc.read_h5ad(h5ad_path)
print(f'Loaded: {adata.n_obs:,} spots × {adata.n_vars:,} genes ({time.time()-t0:.1f}s)')
print(f'Obs columns: {list(adata.obs.columns)}')

In [None]:
# ── QC filter ──────────────────────────────────────────────────────────────────
# Filter spots with very low counts (critical for Visium HD 008µm bins)
if 'total_counts' not in adata.obs.columns:
    sc.pp.calculate_qc_metrics(adata, inplace=True)

n_before = adata.n_obs
adata = adata[adata.obs['total_counts'] >= 200].copy()
print(f'QC (>=200 counts): {n_before:,} → {adata.n_obs:,} spots')

# Subsample for Kaggle memory limits
if adata.n_obs > 15000:
    sc.pp.subsample(adata, n_obs=15000, random_state=42)
    print(f'Subsampled to {adata.n_obs:,} spots')

In [None]:
# ── Tissue-agnostic spatial annotation ────────────────────────────────────────
# annotate_spatial_regions runs:
#   1. Preprocessing (normalize, HVG, scale) if not already done
#   2. Spatial Leiden clustering → spatial_region column
#   3. Z-score marker scoring for 13 compartment types
#   4. CellTypist (Immune_All_High model) on immune-enriched spots
#   5. Merges results → cell_type column
#
# tissue='Unknown' = tissue-blind operation (privacy-preserving)

print('Running spatial annotation pipeline...')
t0 = time.time()

adata_ann, annotation_metrics = annotate_spatial_regions(
    adata,
    resolution=0.5,
    use_markers=True,
    tissue='Unknown'
)

print(f'\nAnnotation complete ({time.time()-t0:.1f}s)')
print(f'  Leiden clusters : {annotation_metrics["n_clusters"]}')
print(f'  Cell types found: {annotation_metrics["n_cell_types"]}')
print(f'  Mean confidence : {annotation_metrics.get("mean_confidence", 0):.1%}')
print('\nTop cell types:')
for ct, n in sorted(annotation_metrics['cell_type_counts'].items(), key=lambda x: -x[1])[:10]:
    print(f'  {ct}: {n:,} ({n/adata_ann.n_obs*100:.1f}%)')

In [None]:
# ── Spatial statistics ─────────────────────────────────────────────────────────
print('Computing spatial statistics...')
t0 = time.time()

spatial_metrics = calculate_spatial_heterogeneity(adata_ann)

print(f'Spatial stats ({time.time()-t0:.1f}s):')
print(f"  Moran's I (mean): {spatial_metrics.get('morans_i', {}).get('mean', 0):.3f}")
print(f"  Spatial entropy  : {spatial_metrics.get('spatial_entropy', {}).get('mean', 0):.3f}")
print(f"  Enriched pairs   : {spatial_metrics.get('neighborhood_enrichment', {}).get('n_enriched_pairs', 0)}")

In [None]:
# ── Visualization ──────────────────────────────────────────────────────────────
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

if 'spatial' in adata_ann.obsm:
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    sc.pl.spatial(adata_ann, color='leiden', ax=axes[0], show=False,
                  title='Spatial: Leiden Clusters')
    sc.pl.spatial(adata_ann, color='cell_type', ax=axes[1], show=False,
                  title='Spatial: Cell Type Annotation')
    plt.tight_layout()
    plt.savefig('spatial_overview.png', dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved: spatial_overview.png')
else:
    # UMAP fallback when no spatial coordinates
    if 'X_umap' not in adata_ann.obsm:
        sc.tl.umap(adata_ann)
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    sc.pl.umap(adata_ann, color='leiden', ax=axes[0], show=False, title='UMAP: Clusters')
    sc.pl.umap(adata_ann, color='cell_type', ax=axes[1], show=False,
               title='UMAP: Cell Types', legend_fontsize=8)
    plt.tight_layout()
    plt.savefig('spatial_overview.png', dpi=150, bbox_inches='tight')
    plt.show()
    print('No spatial coords — UMAP visualization used')

In [None]:
# ── Build features dict for MedGemma ──────────────────────────────────────────
features = {
    'annotation': {
        'n_spots': int(adata_ann.n_obs),
        'n_clusters': int(annotation_metrics['n_clusters']),
        'n_cell_types': int(annotation_metrics['n_cell_types']),
        'mean_confidence': float(annotation_metrics.get('mean_confidence', 0.8)),
        'cell_type_counts': {k: int(v) for k, v in annotation_metrics['cell_type_counts'].items()},
    },
    'spatial_heterogeneity': {
        'morans_i_mean': float(spatial_metrics.get('morans_i', {}).get('mean', 0)),
        'entropy_mean': float(spatial_metrics.get('spatial_entropy', {}).get('mean', 0)),
        'n_enriched_pairs': int(spatial_metrics.get('neighborhood_enrichment', {}).get('n_enriched_pairs', 0)),
    },
    'uncertainty': {
        'mean_prediction_entropy': float(spatial_metrics.get('spatial_entropy', {}).get('mean', 0)),
    }
}

with open('spatial_features.json', 'w') as f:
    json.dump(features, f, indent=2)

print('Features JSON saved: spatial_features.json')
print(json.dumps(features, indent=2))

In [None]:
# ── Generate MedGemma clinical report ──────────────────────────────────────────
prompt = generate_medgemma_prompt(features)
print(f'Prompt length: {len(prompt)} chars')

if DEMO_MODE:
    print('=== DEMO MODE (no HF_TOKEN) — prompt preview: ===')
    print(prompt[:500] + '...')
    clinical_report = (
        f'[DEMO] This tissue sample shows {annotation_metrics["n_cell_types"]} cell populations '
        f'across {annotation_metrics["n_clusters"]} Leiden clusters. '
        f'Spatial analysis reveals Moran\'s I = {features["spatial_heterogeneity"]["morans_i_mean"]:.3f}, '
        'indicating structured gene expression. '
        'To generate a full MedGemma report, set HF_TOKEN via Kaggle Secrets.'
    )
else:
    print('Loading MedGemma 4B...')
    try:
        import torch
        from transformers import AutoTokenizer, AutoModelForCausalLM

        model_id = 'google/medgemma-4b-it'
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map='auto',
            token=HF_TOKEN
        )

        inputs = tokenizer(prompt, return_tensors='pt').to(device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=400,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )

        n_input_tokens = inputs['input_ids'].shape[1]
        clinical_report = tokenizer.decode(outputs[0][n_input_tokens:], skip_special_tokens=True).strip()
        print(f'Report: {len(clinical_report.split())} words')

    except Exception as e:
        print(f'MedGemma failed: {e}')
        clinical_report = f'[ERROR] {e}'

print('\n=== CLINICAL REPORT ===')
print(clinical_report)

with open('clinical_report.txt', 'w') as f:
    f.write(clinical_report)

In [None]:
# ── Report quality check ───────────────────────────────────────────────────────
if not DEMO_MODE:
    quality = evaluate_report_quality(clinical_report, features)
    print('Report quality:')
    print(f'  Word count      : {quality["word_count"]}')
    print(f'  Has raw numbers : {quality["has_raw_numbers"]}')
    print(f'  Has interpretation: {quality["has_interpretation"]}')
    print(f'  Parroting risk  : {quality["parroting_risk"]}')

In [None]:
# ── Pipeline summary ───────────────────────────────────────────────────────────
print('=' * 60)
print('MEDGEMMA SPATIAL PATHOLOGY — PIPELINE SUMMARY')
print('=' * 60)
print(f'  Spots analyzed   : {adata_ann.n_obs:,}')
print(f'  Leiden clusters  : {annotation_metrics["n_clusters"]}')
print(f'  Cell types       : {annotation_metrics["n_cell_types"]}')
print(f'  Annotation model : CellTypist (Immune_All_High) + z-score markers')
print(f'  Moran\'s I        : {features["spatial_heterogeneity"]["morans_i_mean"]:.3f}')
print(f'  Spatial entropy  : {features["spatial_heterogeneity"]["entropy_mean"]:.3f}')
print(f'  Report mode      : {"DEMO" if DEMO_MODE else "MedGemma 4B"}')
print('=' * 60)
print('Outputs saved:')
print('  spatial_features.json — machine-readable spatial metrics')
print('  clinical_report.txt   — pathology report')
print('  spatial_overview.png  — tissue visualization')
print('GitHub: github.com/harshameghadri/medgemma-spatial')