In [1]:
import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
from PIL import Image

In [2]:
from huggingface_hub import login

login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
import os 
import glob
png_paths = sorted(glob.glob(os.path.join("clean_tiles", "*.png")))
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
model = timm.create_model("hf-hub:paige-ai/Virchow2", pretrained=True, mlp_layer=SwiGLUPacked, act_layer=torch.nn.SiLU)
model = model.eval()

In [6]:
transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

In [7]:
import os
import glob
png_paths = sorted(glob.glob(os.path.join("clean_tiles", "*.png")))

In [None]:
import os
import glob
import torch
from PIL import Image
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import umap
from sklearn.cluster import DBSCAN
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
import pandas as pd
import ace_tools as tools

# Load Virchow2 model
model = timm.create_model("hf-hub:paige-ai/Virchow2", pretrained=True,
                          mlp_layer=SwiGLUPacked, act_layer=torch.nn.SiLU)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Get transform
transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

# Get tile paths
tile_paths = sorted(glob.glob("clean_tiles/*.png"))

# Extract per-tile embeddings
tile_embeddings = []
tile_coords = []
slide_ids = []

with torch.no_grad():
    for path in tqdm(tile_paths, desc="Extracting embeddings"):
        image = Image.open(path).convert("RGB")
        tensor = transforms(image).unsqueeze(0).to(device)
        output = model(tensor)

        class_token = output[:, 0]
        patch_tokens = output[:, 5:]
        embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1)

        tile_embeddings.append(embedding.squeeze(0).cpu())

        # Extract slide ID and coordinates
        filename = os.path.basename(path)
        parts = filename.replace(".png", "").split("_")[-2:]
        #slide_id = parts[0]
        x = int(parts[0])  # e.g., x0 -> 0
        y = int(parts[1])  # e.g., y0 -> 0

        #slide_ids.append(slide_id)
        tile_coords.append((x, y))

# Convert to tensor
X = torch.stack(tile_embeddings)
X_np = X.numpy()

# UMAP reduction
umap_2d = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, metric='cosine', random_state=42)
X_umap = umap_2d.fit_transform(X_np)

# DBSCAN clustering
clustering = DBSCAN(eps=0.3, min_samples=5).fit(X_umap)
cluster_labels = clustering.labels_

# Prepare per-slide pooled embeddings
slide_to_tiles = {}
for emb, sid in zip(X, slide_ids):
    slide_to_tiles.setdefault(sid, []).append(emb)

slide_pooled = {sid: torch.stack(tiles).mean(0) for sid, tiles in slide_to_tiles.items()}

# Visualization heatmaps for each slide
heatmap_data = []
for sid in set(slide_ids):
    coords = [(x, y, label) for (s, (x, y), label) in zip(slide_ids, tile_coords, cluster_labels) if s == sid]
    if not coords:
        continue
    xs, ys, lbls = zip(*coords)
    max_x, max_y = max(xs) + 1, max(ys) + 1
    grid = np.full((max_y, max_x), -1)
    for x, y, lbl in coords:
        grid[y, x] = lbl

    # Save heatmap image to disk (in-memory visualization)
    plt.figure(figsize=(6, 6))
    plt.imshow(grid, cmap="tab10", interpolation="nearest")
    plt.title(f"Cluster heatmap: {sid}")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(f"{sid}_heatmap.png")
    plt.close()

    heatmap_data.append({
        "slide_id": sid,
        "heatmap_path": f"{sid}_heatmap.png"
    })

# Output table with slide IDs and heatmap paths
df = pd.DataFrame(heatmap_data)
tools.display_dataframe_to_user(name="Slide Heatmaps", dataframe=df)


In [8]:
import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
device ="cpu"
def extract_deep_features_as_npy(tile_paths, output_dir="tile_npy_embeddings"):
    os.makedirs(output_dir, exist_ok=True)
    tile_coords = []

    with torch.no_grad():
        for path in tqdm(tile_paths, desc="Extracting & saving tile embeddings"):
            image = Image.open(path).convert("RGB")
            tensor = transforms(image).unsqueeze(0).to(device)
            output = model(tensor)

            class_token = output[:, 0]
            patch_tokens = output[:, 5:]
            embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1)

            # Extract filename and coordinates
            filename = os.path.basename(path).replace(".png", "")
            parts = filename.split("_")[-2:]
            x = int(parts[0])
            y = int(parts[1])
            tile_coords.append((x, y))

            # Save as individual .npy file
            save_path = os.path.join(output_dir, f"{filename}.npy")
            np.save(save_path, embedding.squeeze(0).cpu().numpy())

    print(f"Saved {len(tile_paths)} tile embeddings to '{output_dir}/'")


    

In [9]:
extract_deep_features_as_npy(png_paths,"deep_features")

Extracting & saving tile embeddings: 100%|██████████████████████████████████████| 15297/15297 [3:26:47<00:00,  1.23it/s]

Saved 15297 tile embeddings to 'deep_features/'





In [1]:
import os
import numpy as np

def compute_slide_embedding(tile_dir="tile_npy_embeddings", pooling="mean", output_path="slide_embedding.npy"):
    tile_embeddings = []

    for fname in os.listdir(tile_dir):
        if fname.endswith(".npy"):
            emb = np.load(os.path.join(tile_dir, fname))
            tile_embeddings.append(emb)

    tile_embeddings = np.stack(tile_embeddings)

    # Pool across all tiles
    if pooling == "mean":
        slide_embedding = tile_embeddings.mean(axis=0)
    elif pooling == "max":
        slide_embedding = tile_embeddings.max(axis=0)
    elif pooling == "median":
        slide_embedding = np.median(tile_embeddings, axis=0)
    else:
        raise ValueError(f"Unsupported pooling method: {pooling}")

    # Save the result
    np.save(output_path, slide_embedding)
    print(f"Saved whole-slide embedding to: {output_path}")

    return slide_embedding


In [7]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import umap
from PIL import Image

def umap_reducer(x: np.ndarray, dims: int = 3, nns: int = 10) -> np.ndarray:
    reducer = umap.UMAP(
        n_neighbors=nns,
        n_components=dims,
        metric="manhattan",
        spread=0.5,
        random_state=2,
    )
    reduced = reducer.fit_transform(x)
    reduced -= reduced.min(axis=0)
    reduced /= reduced.max(axis=0)
    return reduced

def scatter_tile_umap(tile_dir, output_path, thumbnail_path=None):
    tile_paths = sorted(glob.glob(os.path.join(tile_dir, "tile_*.npy")))
    embeddings, coords = [], []

    for path in tile_paths:
        base = os.path.basename(path).replace(".npy", "")
        _, _, x, y = base.split("_")
        x, y = int(x), int(y)
        coords.append((x, y))
        embeddings.append(np.load(path))

    coords = np.array(coords)
    embeddings = np.stack(embeddings)
    reduced_colors = umap_reducer(embeddings)  # shape: (N, 3)

    # --- Plot ---
    plt.figure(figsize=(10, 10))
    if thumbnail_path:
        thumbnail = Image.open(thumbnail_path)
        plt.imshow(thumbnail)
    plt.scatter(coords[:, 0], coords[:, 1], c=reduced_colors, s=5, alpha=0.6)
    plt.gca().invert_yaxis()  # optional: match image convention
    plt.axis("off")
    plt.title("UMAP Feature Embedding (Scatter View)")
    plt.savefig(output_path, bbox_inches="tight", dpi=300)
    plt.close()

    print(f"[✓] Saved scatter UMAP view to: {output_path}")


In [8]:
scatter_tile_umap(
    tile_dir="deep_features",
    output_path="tile_umap_scatter.png"
)


  warn(


[✓] Saved scatter UMAP view to: tile_umap_scatter.png
