In [6]:
import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
from tqdm import tqdm
import numpy as np
import time
import random
# print(torch.version)
# print(torch.version.cuda)
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from uni import get_encoder
from uni.downstream.eval_patch_features.linear_probe import eval_linear_probe
from uni.downstream.eval_patch_features.fewshot import eval_knn, eval_fewshot
from uni.downstream.eval_patch_features.protonet import ProtoNet, prototype_topk_vote
from uni.downstream.eval_patch_features.metrics import get_eval_metrics, print_metrics
from uni.downstream.utils import concat_images
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# configs
BATCH_SIZE = 1 # load each slide all tiles sequentially 
K_FOLDS_PATH = r"E:\KSA Project\\dataset\\splits\kfolds_IDARS.csv"
DATA_PATH = r"E:\\KSA Project\\dataset\\testing\\Patches"
FEATURES_SAVE_DIR = r"E:\\KSA Project\\dataset\\testing\\uni_features"
# torch.tensor([1.2, 3.4]).device

### Downloading UNI weights + Creating Model

The function `get_encoder` performs the commands above, downloading in the checkpoint in the `./assets/ckpts/` relative path of this GitHub repository.

In [None]:
from uni import get_encoder
model, transform = get_encoder(enc_name='uni', device=device)

### Data Loaders

In [4]:
from dataloader import PatchLoader, SlideBatchSampler
mode = 1 # for sequentially data/patches loading we will use mode =1 and mode= 2 for random loading.
transform = transforms.Compose([
    transforms.FiveCrop(224),  # this is a list of 5 crops
    Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))  # convert to tensor and stack
])

def create_dataloader(label_file, data_path, transform, num_samples, mode):
    dataset = PatchLoader(label_file=label_file, data_path=data_path, transform=transform, num_samples=num_samples, mode=mode)
    batch_sampler = SlideBatchSampler(dataset.ntiles)
    dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0, pin_memory=False)
    return dataloader

# Create DataLoaders
data_loader = create_dataloader(label_file=K_FOLDS_PATH,data_path=DATA_PATH,
                                     transform=transform,num_samples=None,mode=mode)
print(f"Length of data_loader: {len(data_loader)}")

Number of Slides: 290
Number of tiles: 176575
Length of train_loader: 290
Number of Slides: 58
Number of tiles: 36278
Length of val_loader: 58
Number of Slides: 61
Number of tiles: 31773
Length of test_loader: 61


In [35]:
for batch_idx, (images, labels) in enumerate(data_loader):
    print(f"Batch {batch_idx+1}")
    print(f"Images shape: {images.shape}")
    print(f"Labels: {labels}")
    if batch_idx == 2:  # Only print a few batches to check if it's working
        break

Batch 1
Images shape: torch.Size([50, 5, 3, 224, 224])
Labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0])
Batch 2
Images shape: torch.Size([1156, 5, 3, 224, 224])
Labels: tensor([0, 0, 0,  ..., 0, 0, 0])
Batch 3
Images shape: torch.Size([1645, 5, 3, 224, 224])
Labels: tensor([0, 0, 0,  ..., 0, 0, 0])


### DataLoader testing

In [13]:

def print_unique_class_representation(loader):
    def get_unique_classes(loader):
        all_labels = []
        for _, labels in loader:
            all_labels.extend(labels.tolist())
        unique_labels, counts = torch.unique(torch.tensor(all_labels), return_counts=True)
        return unique_labels, counts

    unique_labels, counts = get_unique_classes(loader)
    print("\nTest Loader Unique Labels and Counts:")
    print(f"Unique labels: {unique_labels}")
    print(f"Counts: {counts}")

# Example usage:
print_unique_class_representation(data_loader)
# for batch_idx, (images, labels) in enumerate(data_loader):
#     print(f"Batch {batch_idx}: {len(images)} images")
#     print(f"Images shape: {images.shape}")
#     unique_labels, counts = torch.unique(labels, return_counts=True)
#     print(f"Unique labels: {unique_labels}")
#     print(f"Counts: {counts}")
#     if batch_idx == 2:  # Only print a few batches to check if it's working
#         break


Test Loader Unique Labels and Counts:
Unique labels: tensor([0, 1])
Counts: tensor([26839,  4934])


### ROI Feature Extraction on FiveCrop Patches Using Averaging Approach

In [5]:
@torch.no_grad()
def extract_wsi_features_with_chunks(model, dataloader, chunk_size=1000):
    """
    Extract features for each WSI by averaging the patch features for each slide, with dynamic chunking to avoid memory issues.
    
    Args:
    model: The feature extractor model.
    dataloader: DataLoader with batches of patches for each slide.
    chunk_size: Number of patches to process at a time to avoid OOM.
    
    Returns:
    asset_dict: Dictionary containing aggregated WSI features and labels.
    """
    all_embeddings = []  # Initialize as empty list
    all_labels = []  # Initialize as empty list
    device = next(model.parameters()).device
    print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (batch, target) in tqdm(enumerate(dataloader), total=len(dataloader)):
        num_patches, num_crops, c, h, w = batch.size()  # [num_patches, 5, 3, 224, 224]
        
        # Flatten the crops into individual images
        batch = batch.view(-1, c, h, w)
        target = target.repeat_interleave(num_crops)

        # Initialize empty list to store embeddings for this batch (WSI)
        wsi_embeddings = []

        # Determine number of patches in the current batch
        num_patches = batch.size(0)
        # print(f'No of patches in this batch including five crop is: {num_patches}')
        # Calculate number of chunks needed
        num_chunks = (num_patches + chunk_size - 1) // chunk_size
        
        # Process the batch in chunks
        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min(start_idx + chunk_size, num_patches)
            chunk = batch[start_idx:end_idx].to(device)  # Take a chunk of patches
            # print(f'The selected chunk shape is: {chunk.shape}')
            
            with torch.inference_mode():
                chunk_embeddings = model(chunk).detach().cpu()  # Extract features for the chunk
                # print(f'The shape of chunk embeddings from model : {chunk_embeddings.shape}')
                chunk_embeddings = chunk_embeddings.view(-1, num_crops, chunk_embeddings.size(-1))  # Reshape to [chunk_size, num_crops, embedding_dim]
                # print(f'After Reshape the shape of chunk embeddings: {chunk_embeddings.shape}')
                # Mean across the 5 crops
                chunk_embeddings = chunk_embeddings.mean(dim=1)  # [chunk_size, embedding_dim]
                wsi_embeddings.append(chunk_embeddings)  # Store the chunk's embeddings
        
        # Concatenate embeddings for all chunks in this WSI
        wsi_embeddings = torch.cat(wsi_embeddings, dim=0)  # [num_patches, embedding_dim]
        # print(f'Shape of concatenated WSI embeddings: {wsi_embeddings.shape}')
        # Mean across all patches in the WSI
        slide_embedding = wsi_embeddings.mean(dim=0)  # [embedding_dim]
        # print(f'Shape of WSI after averaging {slide_embedding.shape}')
        # Take one label for the WSI
        wsi_label = target[0].item()

        all_embeddings.append(slide_embedding.numpy())
        all_labels.append(wsi_label)

    # Stack the embeddings and labels
    asset_dict = {
        "embeddings": np.stack(all_embeddings).astype(np.float32),  # [num_slides, embedding_dim]
        "labels": np.array(all_labels),
    }

    return asset_dict

In [None]:

# get path to example data
# extract patch features from the train and test datasets (returns dictionary of embeddings and labels)
start = time.time()

train_features = extract_wsi_features_with_chunks(model,data_loader)

# convert these to torch
feats = torch.Tensor(train_features['embeddings'])
labels = torch.Tensor(train_features['labels']).type(torch.long)

elapsed = time.time() - start
print(f'Took {elapsed:.03f} seconds')
print(f'Features shape {feats.shape} and Labels shape is {labels.shape}')

The size of input dataloader is 290


  x = F.scaled_dot_product_attention(
100%|██████████| 290/290 [1:53:09<00:00, 23.41s/it]  


The size of input dataloader is 58


100%|██████████| 58/58 [24:07<00:00, 24.95s/it] 


The size of input dataloader is 61


100%|██████████| 61/61 [21:44<00:00, 21.39s/it]

Took 9541.786 seconds
Train features shape torch.Size([290, 1024]) and Labels shape is torch.Size([290])
Valid features shape torch.Size([58, 1024]) and Labels shape is torch.Size([58])
Test features shape torch.Size([61, 1024]) and Labels shape is torch.Size([61])





## Save Feature

In [53]:

def save_features(features, dataloader, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    slides = dataloader.dataset.slides
    for i, slide in enumerate(slides):
        slide_feats = features[i]
        slide_name = slide.split('.')[0]
        # save as torch .pt file 
        save_path = os.path.join(save_dir, f'{slide_name}.pt')
        torch.save(slide_feats, save_path)
        # print(f'Saved features for slide {slide_name} to {save_path}')

save_features(feats, data_loader, FEATURES_SAVE_DIR)