In this appraoch we extracted and saved patch level (Five-crops) features of each WSI. 

In [1]:
import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from tqdm import tqdm
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
# configs
BATCH_SIZE = 1 # load each slide all tiles sequentially 
K_FOLDS_PATH = r'E:\\Aamir Gulzar\\dataset\\paip_data\\labels\\validation_data_MSI.csv'
DATA_PATH = r"E:\\Aamir Gulzar\\dataset\\paip_data\\Patches"
FEATURES_SAVE_DIR = r"E:/Aamir Gulzar/dataset/paip_data/Virchow2_FiveCrop_Features"


In [3]:
# need to specify MLP layer and activation function for proper init
model = timm.create_model("hf-hub:paige-ai/Virchow2", pretrained=True, mlp_layer=SwiGLUPacked, act_layer=torch.nn.SiLU)
model_transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

_ = model.eval()
model = model.to(device)

### Data Loader

In [4]:
from dataloader import PatchLoader, SlideBatchSampler
mode = 1 # for sequentially data/patches loading we will use mode =1 and mode= 2 for random loading.

transform = transforms.Compose([
    transforms.FiveCrop(224),  # this is a list of 5 crops
    Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))  # convert to tensor and stack
])
def create_dataloader(label_file, data_path, transform, num_samples, mode):
    # Create the dataset
    dataset = PatchLoader(label_file=label_file, data_path=data_path, transform=transform, num_samples=num_samples, mode=mode)
    # Ensure sequential data loading by disabling shuffle
    batch_sampler = SlideBatchSampler(dataset.ntiles)
    dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0, pin_memory=False)
    return dataloader

# Create DataLoaders
data_loader = create_dataloader(label_file=K_FOLDS_PATH,data_path=DATA_PATH,
                                     transform=transform,num_samples=None,mode=mode)
print(f"Length of data_loader: {len(data_loader)}")

Number of Slides: 31
Number of tiles: 29119
Length of data_loader: 31


In [5]:
# for batch_idx, (images, labels) in enumerate(data_loader):
#     print(f"Batch {batch_idx+1}")
#     print(f"Images shape: {images.shape}")
#     print(f"Labels: {labels}")
#     if batch_idx == 5:  # Only print a few batches to check if it's working
#         break

### FiveCrop Feature Extraction and Save

In [6]:
@torch.no_grad()
def extract_embeddings_patch_by_patch(model, preprocess, dataloader, save_dir):
    """
    Extract and save embeddings for each WSI, patch by patch, without averaging the five crops.
    Args:
    - model: The model used to extract embeddings.
    - preprocess: Preprocessing function to apply to the images.
    - dataloader: Dataloader providing WSI patches and labels.
    - save_dir: Directory where the extracted embeddings will be saved.
    Returns:
    - None: The function saves the extracted embeddings to disk.
    """
    device = next(model.parameters()).device
    print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (images, labels) in tqdm(enumerate(dataloader), total=len(dataloader)):
        wsi_name = dataloader.dataset.slides[batch_idx]
        print(f"Processing WSI {batch_idx+1} {wsi_name}")
        # make a new directory for each WSI
        save_dir_wsi = os.path.join(save_dir, f'{wsi_name}')
        # Check if the WSI directory already exists, skip if processed
        if os.path.exists(save_dir_wsi):
            print(f"WSI {batch_idx+1} {wsi_name} already processed. Skipping...")
            continue
        os.makedirs(save_dir_wsi, exist_ok=True)
        batch_indices = dataloader.batch_sampler.indices[batch_idx]        
        for patch_idx, dataset_idx in enumerate(batch_indices):  # Loop through indices of patches for current WSI
            patch_name = os.path.splitext(os.path.basename(dataloader.dataset.tiles[dataset_idx]))[0]
            image = images[patch_idx]
            label = labels[patch_idx]
            # Reshape image to combine batch and fivecrop dimensions
            num_crops, channels, height, width = image.shape
            # Convert image to numpy and ensure it is in the correct format
            image = image.permute(0, 2, 3, 1).cpu().numpy()
            image = (image * 255).astype(np.uint8)
            # Preprocess image
            image = torch.stack([preprocess(Image.fromarray(im)) for im in image])
            # print the preprocessed image 
            # Move image to the same device as the model
            image = image.to(device)
            # Extract embeddings
            with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
                # Extract embeddings for the chunk
                output = model(image).detach().cpu()
                # class_token = output[:, 0]  # Extract class token
                patch_tokens = output[:, 5:]  # Extract patch tokens, ignoring the first 4 register tokens
                # chunk_embeddings = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1)  # Concatenate class token and average pool of patch tokens
                embeddings = patch_tokens.mean(dim=1)  # Average pool of patch tokens            # print shape of the patch embeddings
            # Save embeddings to disk
            save_path = os.path.join(save_dir_wsi, f'{patch_name}.pt')
            torch.save(embeddings, save_path)
            # also save in simple text format in text file
            # save_path_txt = os.path.join(save_dir_wsi, f'{wsi_name}_{i}.txt')
            # np.savetxt(save_path_txt, embeddings.cpu().numpy())
            # print the saved path
            # print(f"Embeddings saved to {save_path}")

In [None]:
extract_embeddings_patch_by_patch(model, model_transforms, data_loader, FEATURES_SAVE_DIR)