In [1]:
import h5py
import os
import timm
import torch
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
from torch.utils.data import DataLoader, Dataset

In [2]:
# for multi-gpu
print(os.environ["CUDA_VISIBLE_DEVICES"])
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

0,1


In [3]:
model = timm.create_model(
    "hf-hub:paige-ai/Virchow2",
    pretrained=True,
    mlp_layer=SwiGLUPacked,
    act_layer=torch.nn.SiLU
)
model.eval().cuda()

config = resolve_data_config(model.pretrained_cfg, model=model)
transform = create_transform(**config)

In [4]:
class TileFolderDataset(Dataset):
    def __init__(self, folder):
        self.paths = sorted([
            os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".png")
        ])
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        return self.transform(img), self.paths[idx]

In [5]:
# ───────────── Embedding + Saving ─────────────
def extract_and_save(tile_folder, h5_output_path, batch_size=96):
    dataset = TileFolderDataset(tile_folder)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=20, pin_memory=True, prefetch_factor=2, persistent_workers=True)

    all_embeddings = []
    all_coords = []

    for batch_imgs, batch_paths in tqdm(dataloader, desc=os.path.basename(tile_folder)):
        batch_imgs = batch_imgs.cuda()
        with torch.no_grad():
            out = model(batch_imgs)  # shape: (B, 261, 1280)

        cls = out[:, 0]
        patch_tokens = out[:, 5:]  # skip register tokens
        mean = patch_tokens.mean(dim=1)
        embedding = torch.cat([cls, mean], dim=-1)  # (B, 2560)
        all_embeddings.append(embedding.cpu())

        # Extract x, y from filename (e.g., TCGA-XX_L1_1232_2048.png)
        for path in batch_paths:
            base = os.path.splitext(os.path.basename(path))[0]
            try:
                x, y = map(int, base.split("_")[-2:])
            except:
                x, y = 0, 0
            all_coords.append((x, y))

    all_embeddings = torch.cat(all_embeddings, dim=0)     # (N, 2560)
    all_coords = torch.tensor(all_coords)                 # (N, 2)

    with h5py.File(h5_output_path, "w") as f:
        f.create_dataset("features", data=all_embeddings.numpy())
        f.create_dataset("coords", data=all_coords.numpy())

    print(f"✅ Saved {all_embeddings.shape[0]} embeddings to {h5_output_path}")

In [6]:
%pwd

'/orcd/data/edboyden/002/ezh/uni'

In [None]:
tile_root_dir = "virchow_tiles_gpu0"               # root directory containing subfolders for each WSI
output_dir = "virchow_features"                   # where to save .h5 files
os.makedirs(output_dir, exist_ok=True)

for slide_folder in sorted(os.listdir(tile_root_dir)):
    slide_path = os.path.join(tile_root_dir, slide_folder)
    if not os.path.isdir(slide_path):
        continue  # skip files

    h5_output_path = os.path.join(output_dir, f"{slide_folder}.h5")

    if os.path.exists(h5_output_path):
        print(f"✅ Skipping {slide_folder}, already exists.")
        continue

    try:
        extract_and_save(slide_path, h5_output_path, batch_size=196)
    except Exception as e:
        print(f"❌ Failed to process {slide_folder}: {e}")

.ipynb_checkpoints: 0it [00:00, ?it/s]


❌ Failed to process .ipynb_checkpoints: torch.cat(): expected a non-empty list of Tensors
✅ Skipping TCGA-CK-4952, already exists.
✅ Skipping TCGA-CK-5912, already exists.
✅ Skipping TCGA-CK-5913, already exists.
✅ Skipping TCGA-CK-5914, already exists.
✅ Skipping TCGA-CK-5915, already exists.
✅ Skipping TCGA-CK-5916, already exists.
✅ Skipping TCGA-CK-6746, already exists.
✅ Skipping TCGA-CK-6747, already exists.
✅ Skipping TCGA-CK-6748, already exists.
✅ Skipping TCGA-CK-6751, already exists.
✅ Skipping TCGA-CL-4957, already exists.
✅ Skipping TCGA-CL-5917, already exists.
✅ Skipping TCGA-CL-5918, already exists.
✅ Skipping TCGA-CM-4743, already exists.
✅ Skipping TCGA-CM-4744, already exists.
✅ Skipping TCGA-CM-4746, already exists.
✅ Skipping TCGA-CM-4747, already exists.
✅ Skipping TCGA-CM-4748, already exists.
✅ Skipping TCGA-CM-4750, already exists.
✅ Skipping TCGA-CM-4751, already exists.
✅ Skipping TCGA-CM-4752, already exists.
✅ Skipping TCGA-CM-5341, already exists.
✅ Skippi

TCGA-D5-6539: 100%|█████████████████████████████████████████████████████| 96/96 [06:27<00:00,  4.03s/it]


✅ Saved 18715 embeddings to virchow_features/TCGA-D5-6539.h5


TCGA-D5-6540: 100%|█████████████████████████████████████████████████████| 96/96 [06:25<00:00,  4.01s/it]


✅ Saved 18715 embeddings to virchow_features/TCGA-D5-6540.h5


TCGA-D5-6541: 100%|█████████████████████████████████████████████████████| 96/96 [06:25<00:00,  4.02s/it]


✅ Saved 18715 embeddings to virchow_features/TCGA-D5-6541.h5


TCGA-D5-6898: 100%|█████████████████████████████████████████████████████| 25/25 [01:41<00:00,  4.05s/it]


✅ Saved 4832 embeddings to virchow_features/TCGA-D5-6898.h5


TCGA-D5-6920: 100%|█████████████████████████████████████████████████████| 19/19 [01:16<00:00,  4.05s/it]


✅ Saved 3630 embeddings to virchow_features/TCGA-D5-6920.h5


TCGA-D5-6922: 100%|█████████████████████████████████████████████████████| 25/25 [01:39<00:00,  3.97s/it]


✅ Saved 4710 embeddings to virchow_features/TCGA-D5-6922.h5


TCGA-D5-6923: 100%|█████████████████████████████████████████████████████| 21/21 [01:25<00:00,  4.05s/it]


✅ Saved 4050 embeddings to virchow_features/TCGA-D5-6923.h5


TCGA-D5-6924: 100%|█████████████████████████████████████████████████████| 28/28 [01:54<00:00,  4.07s/it]


✅ Saved 5460 embeddings to virchow_features/TCGA-D5-6924.h5


TCGA-D5-6926: 100%|█████████████████████████████████████████████████████| 27/27 [01:48<00:00,  4.03s/it]


✅ Saved 5184 embeddings to virchow_features/TCGA-D5-6926.h5


TCGA-D5-6927: 100%|█████████████████████████████████████████████████████| 25/25 [01:40<00:00,  4.02s/it]


✅ Saved 4788 embeddings to virchow_features/TCGA-D5-6927.h5


TCGA-D5-6928: 100%|███████████████████████████████████████████████████████| 9/9 [00:36<00:00,  4.01s/it]


✅ Saved 1680 embeddings to virchow_features/TCGA-D5-6928.h5


TCGA-D5-6929:  22%|███████████▊                                         | 17/76 [01:11<03:56,  4.01s/it]