In this appraoch we averaged all the pacthes features of each WSI and save WSI level averaged feature vector. 

In [None]:
import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from tqdm import tqdm
# print(torch.version)
# print(torch.version.cuda)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

### Config Files

In [2]:
# configs
BATCH_SIZE = 1 # load each slide all tiles sequentially 
K_FOLDS_PATH = r"E:\KSA Project\\dataset\\splits\kfolds_IDARS.csv"
DATA_PATH = r"E:\\KSA Project\\dataset\\testing\\Patches"
FEATURES_SAVE_DIR = r"E:\\KSA Project\\dataset\\testing\\virchow2_features"
# torch.tensor([1.2, 3.4]).device

In [3]:
# huggingface login 
# from huggingface_hub import login
# login()
# need to specify MLP layer and activation function for proper init
model = timm.create_model("hf-hub:paige-ai/Virchow2", pretrained=True, mlp_layer=SwiGLUPacked, act_layer=torch.nn.SiLU)
model_transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

_ = model.eval()
model = model.to(device)

### Data Loader

In [None]:
from dataloader import PatchLoader, SlideBatchSampler
mode = 1 # for sequentially data/patches loading we will use mode =1 and mode= 2 for random loading.
transform = transforms.Compose([
    transforms.FiveCrop(224),  # this is a list of 5 crops
    Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))  # convert to tensor and stack
])

def create_dataloader(label_file, data_path, transform, num_samples, mode):
    # Create the dataset
    dataset = PatchLoader(label_file=label_file, data_path=data_path, transform=transform, num_samples=num_samples, mode=mode)
    # Ensure sequential data loading by disabling shuffle
    batch_sampler = SlideBatchSampler(dataset.ntiles)
    dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0, pin_memory=False)
    return dataloader


# Create DataLoaders
data_loader = create_dataloader(label_file=K_FOLDS_PATH,data_path=DATA_PATH,
                                     transform=transform,num_samples=None,mode=mode)
print(f"Length of data_loader: {len(data_loader)}")


#### DataLoader Testing

In [None]:
# how can i print and view before dataloader input dataset details using 
for batch_idx, (images, labels) in enumerate(data_loader):
    print(f"Batch {batch_idx+1}")
    print(f"Images shape: {images.shape}")
    print(f"Labels: {labels}")
    if batch_idx == 5:  # Only print a few batches to check if it's working
        break

### ROI Feature Extraction on FiveCrop Patches Using Averaging Approach at WSI Level
Extract patches of individual patch five crops then average and at the end average features of all patches of each WSI 

In [5]:
# @torch.no_grad()
def extract_embeddings_with_chunks(model, transforms, dataloader, chunk_size=1000):
    all_embeddings = []
    all_labels = []
    device = next(model.parameters()).device
    print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (images, labels) in tqdm(enumerate(dataloader), total=len(dataloader)):
        batch_size, num_crops, channels, height, width = images.shape
        images = images.view(batch_size * num_crops, channels, height, width)
        images = images.permute(0, 2, 3, 1).cpu().numpy()  # Convert to (H, W, C) format
        images = (images * 255).astype(np.uint8)  # Convert to uint8

        # Preprocess images using the provided transforms
        images = torch.stack([transforms(Image.fromarray(image)) for image in images])
        images = images.to(device)

        wsi_embeddings = []
        num_patches = images.size(0)
        num_chunks = (num_patches + chunk_size - 1) // chunk_size

        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min(start_idx + chunk_size, num_patches)
            chunk = images[start_idx:end_idx]
            with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
                # Extract embeddings for the chunk
                output = model(chunk).detach().cpu()
                # class_token = output[:, 0]  # Extract class token
                patch_tokens = output[:, 5:]  # Extract patch tokens, ignoring the first 4 register tokens
                # chunk_embeddings = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1)  # Concatenate class token and average pool of patch tokens
                chunk_embeddings = patch_tokens.mean(dim=1)  # Average pool of patch tokens
                # Reshape to [chunk_size, num_crops, embedding_dim]
                chunk_embeddings = chunk_embeddings.view(-1, num_crops, chunk_embeddings.size(-1))
                # Mean across the 5 crops
                chunk_embeddings = chunk_embeddings.mean(dim=1)  # [chunk_size, embedding_dim]
                # chunk_embeddings = chunk_embeddings.to(torch.float16)
                wsi_embeddings.append(chunk_embeddings.cpu())

        # Concatenate embeddings for all chunks in this WSI
        wsi_embeddings = torch.cat(wsi_embeddings, dim=0)  # [num_patches, embedding_dim]
        # Mean across all patches in the WSI
        slide_embedding = wsi_embeddings.mean(dim=0)  # [embedding_dim]
        # Take one label for the WSI
        wsi_label = labels[0].item()

        all_embeddings.append(slide_embedding.numpy())
        all_labels.append(wsi_label)

    # Stack the embeddings and labels
    asset_dict = {
        "embeddings": np.stack(all_embeddings).astype(np.float32),  # [num_slides, embedding_dim]
        "labels": np.array(all_labels),
    }

    return asset_dict

In [None]:
start = time.time()
features = extract_embeddings_with_chunks(model, model_transforms, data_loader)
# convert these to torch
feats = torch.Tensor(features['embeddings'])
labels = torch.Tensor(features['labels']).type(torch.long)

elapsed = time.time() - start
print(f'Took {elapsed:.03f} seconds')
print(f'Train features shape {feats.shape} and Labels shape is {labels.shape}')

## Save Feature

In [7]:
def save_features(features, dataloader, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    slides = dataloader.dataset.slides
    for i, slide in enumerate(slides):
        slide_feats = features[i]
        slide_name = slide.split('.')[0]
        # save as torch .pt file 
        save_path = os.path.join(save_dir, f'{slide_name}.pt')
        torch.save(slide_feats, save_path)
        # print(f'Saved features for slide {slide_name} to {save_path}')

save_features(feats, data_loader, FEATURES_SAVE_DIR)