# CANVAS — Quickstart Inference (API)
This notebook demonstrates pure inference on patch-level images via the importable API.
You’ll load local user weights, run folder-level predictions, and (optionally) apply stain color normalization.

## 1. Environment Check

Verify Python, PyTorch (CUDA availability), and timm versions to ensure the runtime matches expected dependencies.

In [3]:
# Quick environment check (safe to run multiple times)
import sys, torch
import timm
print("Python:", sys.version)
print("PyTorch:", torch.__version__, "CUDA:", torch.cuda.is_available())
print("timm:", timm.__version__)

Python: 3.10.18 (main, Jun  5 2025, 13:14:17) [GCC 11.2.0]
PyTorch: 2.0.1+cu117 CUDA: True
timm: 0.4.12


## 2. Configure Paths

Set the local fine-tuned checkpoint path (reference weights), the MUSK source (HF Hub or local .safetensors), and the demo image folder (PNG/JPG/TIF).

In [7]:
from pathlib import Path

# >>> EDIT THESE <<<
WEIGHTS_PATH = Path("Habitat_prediction/reference_weight.pth")   # local path to your user weights
MUSK_SOURCE  = "hf_hub:xiangjx/musk"         # or a local path to MUSK model.safetensors
DEMO_DIR     = Path("Demo_data")             # folder with PNG/JPG/TIF patches

print("Weights exists:", WEIGHTS_PATH.exists(), WEIGHTS_PATH)
print("Demo dir exists:", DEMO_DIR.exists(), DEMO_DIR)

Weights exists: True Habitat_prediction/reference_weight.pth
Demo dir exists: True Demo_data


## 3. Import API & Load Model

Construct the backbone and load your local weights.

In [16]:
from Habitat_prediction.api import load_model, predict_folder

model, device = load_model(
    #weights=str(WEIGHTS_PATH),    # optional: local weights only
    musk_source=MUSK_SOURCE,      # hf_hub or local path to MUSK backbone
)
print("Loaded model on:", device)

Load ckpt from /home/zli1893/.cache/huggingface/hub/models--xiangjx--musk/snapshots/de1ffed28608c197d2903f6fa42b491a3fbf0fb8/model.safetensors
Loaded model on: cuda


## 4. Inference

Run inference on all images under the demo folder. The output includes per-class probabilities.

In [19]:
import pandas as pd
from pathlib import Path

df = predict_folder(
    model, device, DEMO_DIR,
    img_size=384, batch_size=64,
)

OUT_DIR = Path("Demo_data"); OUT_DIR.mkdir(exist_ok=True)
out_csv = OUT_DIR / "output.csv"
df.to_csv(out_csv, index=False)
print("Saved")

Saved


## 5. Color Normalization (optional)

In [None]:
# Requires: pip install tiatoolbox umap-learn scikit-image opencv-python-headless
try:
    import tiatoolbox
    ref_img = Path('ref.png')
    if ref_img is None:
        raise RuntimeError("No reference image found under DEMO_DIR.")
    df_tia = predict_folder(
        model, device, DEMO_DIR,
        img_size=384, batch_size=64,
        color_norm=True,
        reference_image=str(ref_img),
        color_norm_method="Vahadane",
        color_norm_backend="tiatoolbox",
    )
    from pathlib import Path
    OUT_DIR = Path("outputs"); OUT_DIR.mkdir(exist_ok=True)
    out_csv_tia = OUT_DIR / "output.csv"
    df_tia.to_csv(out_csv_tia, index=False)
    print("Saved")
except Exception as e:
    print("[Info] tiatoolbox not available or failed to initialize:", e)