# Prov-GigaPath Inference (Tile encoder → Slide encoder)

This notebook follows the recommended slide-level inference pipeline:

1. **Tile the whole slide** into `N` image tiles (and keep each tile's `(x, y)` coordinate).
2. **Encode each tile** into an embedding with the **tile encoder**.
3. **Feed tile embeddings + coordinates** into the **slide encoder** to obtain a **slide-level representation**.

> You can also use the tile encoder alone for tile-level tasks (classification, retrieval, etc.).


In [1]:
# (Optional) Install dependencies
# If you're in a managed environment, you may already have these installed.
# !pip install -U timm torch torchvision pillow gigapath

import os
import torch
import timm
from PIL import Image
from torchvision import transforms

print("torch:", torch.__version__)
print("timm:", timm.__version__)


  from .autonotebook import tqdm as notebook_tqdm


torch: 2.2.1+cu121
timm: 1.0.3


In [2]:
# Helpers

def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

device = get_device()
device


device(type='cuda')

## 1) Load the tile encoder + preprocessing transform

The tile encoder takes a **224×224** image after standard ImageNet normalization.


In [3]:
# Load tile encoder (Hugging Face Hub via timm)
tile_encoder = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True)
tile_encoder = tile_encoder.to(device).eval()

# Preprocessing transform (as in the model card)
transform = transforms.Compose(
    [
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ]
)

tile_encoder


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1536, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1536, out_features=4608, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1536, out_features=1536, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (mlp): GluMlp(
        (fc1): Linear(in_features=1536, out_features=8192, bias=True)
        (act): SiLU()
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
    

## 2) Tile-level inference (single image)

Update `img_path` to any tile image (PNG/JPG) you want to encode.


In [4]:
img_path = "images/prov_normal_000_1.png"  # TODO: change to your tile image path

assert os.path.exists(img_path), f"Not found: {img_path}"

sample_input = transform(Image.open(img_path).convert("RGB")).unsqueeze(0).to(device)

with torch.no_grad():
    tile_emb = tile_encoder(sample_input).squeeze(0)  # shape: (D,)

tile_emb.shape, tile_emb.dtype, tile_emb.device


(torch.Size([1536]), torch.float32, device(type='cuda', index=0))

## 3) Tile-level inference (batch over a folder)

If you already tiled a whole slide into a directory of tile images, you can encode them in batches.

**You must also have coordinates** for each tile. The common pattern is:
- filename encodes coordinates (e.g., `x1234_y5678.png`), **or**
- you keep a separate CSV/JSON mapping `filename → (x, y)`.

This cell gives a template you can adapt.


In [5]:
from pathlib import Path
from typing import List, Tuple, Dict
import re

def parse_xy_from_name(p: Path) -> Tuple[int, int]:
    """Example parser: expects filenames like .../x1234_y5678.png
    Change this to match *your* tiler naming convention.
    """
    m = re.search(r"x(-?\d+)_y(-?\d+)", p.stem)
    if m is None:
        raise ValueError(f"Cannot parse x/y from filename: {p.name}")
    return int(m.group(1)), int(m.group(2))

def load_tiles_and_coords(tile_dir: str, exts={'.png', '.jpg', '.jpeg'}) -> Tuple[List[Path], torch.Tensor]:
    tile_dir = Path(tile_dir)
    paths = sorted([p for p in tile_dir.rglob('*') if p.suffix.lower() in exts])
    if len(paths) == 0:
        raise FileNotFoundError(f"No images found under: {tile_dir}")
    coords = [parse_xy_from_name(p) for p in paths]  # list of (x, y)
    coords = torch.tensor(coords, dtype=torch.long)  # (N, 2)
    return paths, coords

@torch.no_grad()
def encode_tiles(paths: List[Path], batch_size: int = 64) -> torch.Tensor:
    """Returns tile embeddings: (N, D)"""
    embs = []
    for i in range(0, len(paths), batch_size):
        batch_paths = paths[i:i+batch_size]
        imgs = []
        for p in batch_paths:
            img = Image.open(p).convert('RGB')
            imgs.append(transform(img))
        x = torch.stack(imgs, dim=0).to(device)  # (B, 3, 224, 224)
        y = tile_encoder(x)  # (B, D) or model-dependent
        embs.append(y.detach().cpu())
    return torch.cat(embs, dim=0)

# ---- Example usage (uncomment and edit) ----
# tile_dir = "tiles/slide_001"  # TODO: set your tiled slide folder
# paths, coordinates = load_tiles_and_coords(tile_dir)
# tile_embed = encode_tiles(paths, batch_size=64)  # (N, D) on CPU
# coordinates.shape, tile_embed.shape


## 4) Load the slide encoder

The slide encoder consumes:
- `tile_embed`: `(N, D)` float embeddings
- `coordinates`: `(N, 2)` integer coordinates (typically pixel coords in the WSI plane)

It outputs a **slide-level embedding** (often `(D_slide,)` or `(1, D_slide)` depending on the implementation).


In [None]:
from gigapath.slide_encoder import create_model

# Create slide encoder
# Signature from your snippet:
# gigapath.slide_encoder.create_model(<hub_id>, <arch_name>, <in_dim>)
slide_encoder = create_model(
    "hf_hub:prov-gigapath/prov-gigapath",
    "gigapath_slide_enc12l768d",
    1536,
)

slide_encoder = slide_encoder.to(device).eval()
slide_encoder


dilated_ratio:  [1, 2, 4, 8, 16]
segment_length:  [1024, 5792, 32768, 185363, 1048576]
Number of trainable LongNet parameters:  85148160
Global Pooling: False
[92m Successfully Loaded Pretrained GigaPath model from hf_hub:prov-gigapath/prov-gigapath [00m


LongNetViT(
  (patch_embed): PatchEmbed(
    (proj): Linear(in_features=1536, out_features=768, bias=True)
    (norm): Identity()
  )
  (encoder): LongNetEncoder(
    (dropout_module): Dropout(p=0.25, inplace=False)
    (layers): ModuleList(
      (0): LongNetEncoderLayer(
        (self_attn): DilatedAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (inner_attn_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout_module): Dropout(p=0.0, inplace=False)
        )
        (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout_module): Dropout(p=0.25, inplace=False)
        (drop_path): DropPath(p=0.0)
        (ffn): FeedForwardNetwork(
          (activation_dropout

## 5) Slide-level inference (example)

Below is a minimal, end-to-end example using a **single tile** embedding + a dummy coordinate.
For real slides, use the folder/batch encoder above to produce `tile_embed` and `coordinates`.


In [8]:
# If you didn't run the folder/batch section, we can build a tiny example from the single-tile output.
# In real usage, you should have:
#   tile_embed: (N, D) float tensor
#   coordinates: (N, 2) long/int tensor

if 'tile_emb' in globals():
    tile_embed = tile_emb.unsqueeze(0).detach().cpu()  # (1, D) on CPU
    coordinates = torch.tensor([[0, 0]], dtype=torch.long)  # (1, 2)
else:
    raise RuntimeError("Run the single-tile inference cell first to create tile_emb.")

# Move inputs to device expected by slide encoder
tile_embed_dev = tile_embed.to(device)
coordinates_dev = coordinates.to(device)

with torch.no_grad():
    slide_emb = slide_encoder(tile_embed_dev, coordinates_dev).squeeze(0)

slide_emb.shape, slide_emb.dtype, slide_emb.device


ValueError: not enough values to unpack (expected 3, got 2)

## Notes & tips

- **Coordinates matter**: the slide encoder uses them to model spatial relationships. Make sure your coordinate system is consistent across tiles.
- **Batching**: tile encoding is the expensive part; encode tiles in batches on GPU and keep embeddings on CPU if memory is tight.
- **Precision**: you can try `torch.autocast('cuda')` for faster tile encoding if the model supports it.
- **Saving**: store `tile_embed` and `coordinates` (e.g., `.pt` or `.npz`) so you can re-run slide encoder quickly.
