# HistoLab — MedGemma Demo on Kaggle GPU

Runs the fine-tuned MedGemma-4B histopathology classifier on a **free Kaggle T4 GPU**  
and exposes it via a public Gradio link (valid ~72 hours).

### Before running:
1. Enable **GPU** in *Session options → Accelerator → GPU T4 x1*
2. Enable **Internet** in *Session options → Internet → On*
3. Add two **Kaggle Secrets** (the lock icon on the left sidebar):
   - `HF_TOKEN` — your HuggingFace token (needs read access to `google/medgemma-4b-it`)
   - `ADAPTER_REPO_ID` — e.g. `karadi97/medgemma-histolab-5k`
4. Add the **demo image dataset**: click *Add data* (➕) → search `histolab-medgemma-demo-samples` → Add  
   *(enables the BACH / CRC / PCAM quick-load buttons in the Gradio UI)*
5. **Run All** and copy the `gradio.live` URL from the last cell output.

In [None]:
# Verify GPU is available
import subprocess
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'],
                        capture_output=True, text=True)
print('GPU:', result.stdout.strip() or 'NOT FOUND — enable GPU in Session options!')

In [None]:
# Load secrets from Kaggle Secrets (add them via the lock icon in the left sidebar)
import os
from kaggle_secrets import UserSecretsClient

secrets = UserSecretsClient()
os.environ['HF_TOKEN'] = secrets.get_secret('HF_TOKEN')
os.environ['ADAPTER_REPO_ID'] = secrets.get_secret('ADAPTER_REPO_ID')

print('HF_TOKEN set:        ', bool(os.environ.get('HF_TOKEN')))
print('ADAPTER_REPO_ID set: ', os.environ.get('ADAPTER_REPO_ID'))

In [None]:
# Install dependencies (takes ~2-3 min on first run)
!pip install -q \
    'transformers>=4.49.0' \
    'accelerate>=0.25.0' \
    'bitsandbytes>=0.41.0' \
    'peft>=0.7.0' \
    'gradio>=5.0.0,<7.0.0' \
    'huggingface_hub>=0.19.0' \
    'Pillow>=10.0.0'
print('Dependencies installed.')

# Show versions for debugging
import transformers, peft, torch
print(f'  transformers: {transformers.__version__}')
print(f'  peft:         {peft.__version__}')
print(f'  torch:        {torch.__version__}')
print(f'  CUDA:         {torch.version.cuda}')
if torch.cuda.is_available():
    print(f'  GPU:          {torch.cuda.get_device_name(0)}')
    print(f'  bf16 native:  {torch.cuda.is_bf16_supported()}')

In [None]:
# Clone the repo (or pull latest if already cloned)
import os

REPO_URL = 'https://github.com/karaditya/medgemma-histolab.git'
REPO_DIR = '/kaggle/working/medgemma-histolab'

if not os.path.exists(REPO_DIR):
    !git clone --depth 1 {REPO_URL} {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull --ff-only origin main

os.chdir(REPO_DIR)
print('Working directory:', os.getcwd())


In [None]:
# Install the local histolab package
!pip install -q -e . --no-deps
print('histolab package installed.')

In [None]:
# Link the demo image dataset into the path the app expects
# Auto-detects the dataset from /kaggle/input/ (handles all mount formats)
import pathlib, os

data_dir = pathlib.Path('/kaggle/working/medgemma-histolab/data/datasets')
data_dir.mkdir(parents=True, exist_ok=True)

# Search /kaggle/input/ recursively for a folder containing crc/bach/pcam
# Kaggle mounts datasets at varying depths:
#   /kaggle/input/<slug>/                          (classic)
#   /kaggle/input/datasets/<user>/<slug>/          (new API)
kaggle_input = pathlib.Path('/kaggle/input')
dataset_root = None
target_dirs = {'crc', 'bach', 'pcam'}

for root, dirs, files in os.walk(str(kaggle_input)):
    root_path = pathlib.Path(root)
    if target_dirs & set(dirs):
        dataset_root = root_path
        break
    # Don't recurse deeper than 4 levels
    if len(root_path.relative_to(kaggle_input).parts) >= 4:
        dirs.clear()

if dataset_root is None:
    print('WARNING: No demo dataset found under /kaggle/input/')
    print('  Add it via: Add data (+) -> search histolab-medgemma-demo-samples -> Add')
    # Debug: show what IS mounted
    for p in sorted(kaggle_input.rglob('*')):
        if p.is_dir() and len(p.relative_to(kaggle_input).parts) <= 3:
            print(f'    {p}')
else:
    print(f'Found dataset at: {dataset_root}')
    for ds in ['crc', 'bach', 'pcam']:
        src = dataset_root / ds
        dst = data_dir / ds
        if dst.exists() or dst.is_symlink():
            print(f'  already linked: {ds}')
        elif src.exists():
            dst.symlink_to(src)
            n_imgs = sum(1 for _ in src.rglob('*') if _.is_file())
            print(f'  linked {ds}: {n_imgs} files')
        else:
            print(f'  {ds}: not in dataset (skipping)')


In [None]:
# Pre-load the model so errors show here (not hidden behind Gradio UI)
# This is the slow step: ~3-5 min on first run (downloads base model + merges adapter)
import sys, gc, shutil, torch
sys.path.insert(0, REPO_DIR)

# Force-clear any stale merged model cache from a previous run
# (ensures the merge runs fresh with the latest code/adapter)
adapter_merged = os.path.join(REPO_DIR, 'adapter', 'merged_model')
if os.path.exists(adapter_merged):
    print(f'Clearing stale merged model cache: {adapter_merged}')
    shutil.rmtree(adapter_merged)

import app as histolab_app

print('Loading model...')
try:
    histolab_app.wrapper.load()
    print(f'Model loaded successfully!')
    if torch.cuda.is_available():
        alloc = torch.cuda.memory_allocated() / 1024**3
        total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        print(f'GPU memory: {alloc:.1f} / {total:.1f} GiB')

    # Verify chat template is present
    proc = histolab_app.wrapper.processor
    has_template = getattr(proc.tokenizer, 'chat_template', None) is not None
    print(f'Chat template loaded: {has_template}')
    if not has_template:
        print('WARNING: No chat template! Model will likely produce empty output.')

    # Quick test inference to verify model works
    print('\n--- Test inference ---')
    from PIL import Image
    test_img = Image.new('RGB', (224, 224), color=(180, 120, 160))
    test_result = histolab_app.wrapper.analyze_patch(
        image=test_img,
        prompt='What type of tissue is this? Answer briefly.',
        max_new_tokens=32
    )
    raw = test_result.get('raw_response', '')
    print(f'Test output ({len(raw)} chars): {repr(raw[:200])}')
    if not raw.strip():
        print('WARNING: Model produced empty output! Check adapter merge and chat template.')
    else:
        print('Model is generating text correctly.')

except Exception as e:
    print(f'ERROR loading model: {e}')
    import traceback; traceback.print_exc()

In [None]:
# Launch Gradio (model already loaded above)
histolab_app.app_instance.model_loaded = True

histolab_app.demo.launch(
    share=True,
    show_api=False,
    quiet=False,
)
