In [1]:
import h5py
import os
import timm
import torch
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
from torch.utils.data import DataLoader, Dataset

In [2]:
# for multi-gpu
print(os.environ["CUDA_VISIBLE_DEVICES"])
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

0,1


In [None]:
model = timm.create_model(
    "hf-hub:paige-ai/Virchow2",
    pretrained=True,
    mlp_layer=SwiGLUPacked,
    act_layer=torch.nn.SiLU
)
model.eval().cuda()

config = resolve_data_config(model.pretrained_cfg, model=model)
transform = create_transform(**config)

In [4]:
class TileFolderDataset(Dataset):
    def __init__(self, folder):
        self.paths = sorted([
            os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".png")
        ])
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        return self.transform(img), self.paths[idx]

In [6]:
# ───────────── Embedding + Saving ─────────────
def extract_and_save(tile_folder, h5_output_path, batch_size=96):
    dataset = TileFolderDataset(tile_folder)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=20, pin_memory=True, prefetch_factor=2, persistent_workers=True)

    all_embeddings = []
    all_coords = []

    for batch_imgs, batch_paths in tqdm(dataloader, desc=os.path.basename(tile_folder)):
        batch_imgs = batch_imgs.cuda()
        with torch.no_grad():
            out = model(batch_imgs)  # shape: (B, 261, 1280)

        cls = out[:, 0]
        patch_tokens = out[:, 5:]  # skip register tokens
        mean = patch_tokens.mean(dim=1)
        embedding = torch.cat([cls, mean], dim=-1)  # (B, 2560)
        all_embeddings.append(embedding.cpu())

        # Extract x, y from filename (e.g., TCGA-XX_L1_1232_2048.png)
        for path in batch_paths:
            base = os.path.splitext(os.path.basename(path))[0]
            try:
                x, y = map(int, base.split("_")[-2:])
            except:
                x, y = 0, 0
            all_coords.append((x, y))

    all_embeddings = torch.cat(all_embeddings, dim=0)     # (N, 2560)
    all_coords = torch.tensor(all_coords)                 # (N, 2)

    with h5py.File(h5_output_path, "w") as f:
        f.create_dataset("features", data=all_embeddings.numpy())
        f.create_dataset("coords", data=all_coords.numpy())

    print(f"✅ Saved {all_embeddings.shape[0]} embeddings to {h5_output_path}")

In [7]:
%pwd

'/orcd/data/edboyden/002/ezh/uni'

In [None]:
tile_root_dir = "virchow_tiles_gpu0"               # root directory containing subfolders for each WSI
output_dir = "virchow_features"                   # where to save .h5 files
os.makedirs(output_dir, exist_ok=True)

for slide_folder in sorted(os.listdir(tile_root_dir)):
    slide_path = os.path.join(tile_root_dir, slide_folder)
    if not os.path.isdir(slide_path):
        continue  # skip files

    h5_output_path = os.path.join(output_dir, f"{slide_folder}.h5")

    if os.path.exists(h5_output_path):
        print(f"✅ Skipping {slide_folder}, already exists.")
        continue

    try:
        extract_and_save(slide_path, h5_output_path, batch_size=196)
    except Exception as e:
        print(f"❌ Failed to process {slide_folder}: {e}")

✅ Skipping TCGA-CK-4952, already exists.
✅ Skipping TCGA-CK-5912, already exists.
✅ Skipping TCGA-CK-5913, already exists.
✅ Skipping TCGA-CK-5914, already exists.
✅ Skipping TCGA-CK-5915, already exists.
✅ Skipping TCGA-CK-5916, already exists.
✅ Skipping TCGA-CK-6746, already exists.
✅ Skipping TCGA-CK-6747, already exists.
✅ Skipping TCGA-CK-6748, already exists.
✅ Skipping TCGA-CK-6751, already exists.
✅ Skipping TCGA-CL-4957, already exists.
✅ Skipping TCGA-CL-5917, already exists.


TCGA-CL-5918: 100%|████████████████████████████████| 44/44 [02:55<00:00,  4.00s/it]


✅ Saved 8487 embeddings to virchow_features/TCGA-CL-5918.h5


TCGA-CM-4743: 100%|████████████████████████████████| 58/58 [03:53<00:00,  4.03s/it]


✅ Saved 11286 embeddings to virchow_features/TCGA-CM-4743.h5


TCGA-CM-4744: 100%|████████████████████████████████| 68/68 [04:32<00:00,  4.01s/it]


✅ Saved 13230 embeddings to virchow_features/TCGA-CM-4744.h5


TCGA-CM-4746: 100%|████████████████████████████████| 59/59 [03:56<00:00,  4.02s/it]


✅ Saved 11484 embeddings to virchow_features/TCGA-CM-4746.h5


TCGA-CM-4747: 100%|████████████████████████████████| 67/67 [04:27<00:00,  4.00s/it]


✅ Saved 12998 embeddings to virchow_features/TCGA-CM-4747.h5


TCGA-CM-4748: 100%|████████████████████████████████| 29/29 [01:57<00:00,  4.06s/it]


✅ Saved 5658 embeddings to virchow_features/TCGA-CM-4748.h5


TCGA-CM-4750: 100%|████████████████████████████████| 64/64 [04:14<00:00,  3.98s/it]


✅ Saved 12350 embeddings to virchow_features/TCGA-CM-4750.h5


TCGA-CM-4751: 100%|████████████████████████████████| 43/43 [02:53<00:00,  4.04s/it]


✅ Saved 8400 embeddings to virchow_features/TCGA-CM-4751.h5


TCGA-CM-4752: 100%|████████████████████████████████| 31/31 [02:04<00:00,  4.01s/it]


✅ Saved 5995 embeddings to virchow_features/TCGA-CM-4752.h5


TCGA-CM-5341: 100%|████████████████████████████████| 44/44 [02:56<00:00,  4.00s/it]


✅ Saved 8554 embeddings to virchow_features/TCGA-CM-5341.h5


TCGA-CM-5344: 100%|████████████████████████████████| 56/56 [03:42<00:00,  3.97s/it]


✅ Saved 10810 embeddings to virchow_features/TCGA-CM-5344.h5


TCGA-CM-5348: 100%|████████████████████████████████| 60/60 [04:01<00:00,  4.02s/it]


✅ Saved 11718 embeddings to virchow_features/TCGA-CM-5348.h5


TCGA-CM-5349: 100%|████████████████████████████████| 56/56 [03:44<00:00,  4.02s/it]


✅ Saved 10925 embeddings to virchow_features/TCGA-CM-5349.h5


TCGA-CM-5860: 100%|████████████████████████████████| 96/96 [06:23<00:00,  4.00s/it]


✅ Saved 18715 embeddings to virchow_features/TCGA-CM-5860.h5


TCGA-CM-5861: 100%|████████████████████████████████| 91/91 [06:02<00:00,  3.99s/it]


✅ Saved 17670 embeddings to virchow_features/TCGA-CM-5861.h5


TCGA-CM-5862: 100%|████████████████████████████████| 53/53 [03:29<00:00,  3.95s/it]


✅ Saved 10208 embeddings to virchow_features/TCGA-CM-5862.h5


TCGA-CM-5863: 100%|████████████████████████████████| 62/62 [04:08<00:00,  4.01s/it]


✅ Saved 12065 embeddings to virchow_features/TCGA-CM-5863.h5


TCGA-CM-5864: 100%|████████████████████████████████| 82/82 [05:29<00:00,  4.01s/it]


✅ Saved 16008 embeddings to virchow_features/TCGA-CM-5864.h5


TCGA-CM-5868: 100%|████████████████████████████████| 92/92 [06:07<00:00,  4.00s/it]


✅ Saved 17955 embeddings to virchow_features/TCGA-CM-5868.h5


TCGA-CM-6161: 100%|████████████████████████████████| 59/59 [03:56<00:00,  4.01s/it]


✅ Saved 11500 embeddings to virchow_features/TCGA-CM-6161.h5


TCGA-CM-6162: 100%|████████████████████████████████| 91/91 [06:04<00:00,  4.01s/it]


✅ Saved 17765 embeddings to virchow_features/TCGA-CM-6162.h5


TCGA-CM-6163: 100%|████████████████████████████████| 94/94 [06:12<00:00,  3.96s/it]


✅ Saved 18240 embeddings to virchow_features/TCGA-CM-6163.h5


TCGA-CM-6164:   0%|                                         | 0/53 [00:00<?, ?it/s]