In [9]:
import torch
import torchvision
import os
import sys
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from tqdm import tqdm
# print(torch.version)
# print(torch.version.cuda)
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from uni import get_encoder
from uni.downstream.eval_patch_features.linear_probe import eval_linear_probe
from uni.downstream.eval_patch_features.fewshot import eval_knn, eval_fewshot
from uni.downstream.eval_patch_features.protonet import ProtoNet, prototype_topk_vote
from uni.downstream.eval_patch_features.metrics import get_eval_metrics, print_metrics
from uni.downstream.utils import concat_images
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
# configs
BATCH_SIZE = 1 # load each slide all tiles sequentially 
K_FOLDS_PATH = r'E:\\Aamir Gulzar\\dataset\\paip_data\\labels\\validation_data_MSI.csv'
DATA_PATH = r"E:\\Aamir Gulzar\\dataset\\paip_data\\Patches"
FEATURES_SAVE_DIR = r"E:/Aamir Gulzar/dataset/paip_data/UNI_FiveCrop_Features"


### Downloading UNI weights + Creating Model

The function `get_encoder` performs the commands above, downloading in the checkpoint in the `./assets/ckpts/` relative path of this GitHub repository.

In [None]:
from uni import get_encoder
model, transform = get_encoder(enc_name='uni', device=device)

### Data Loaders

In [12]:
# import dataloader from one step back directory there is a fine named dataloader.py
sys.path.append("..")
from dataloader import PatchLoader, SlideBatchSampler
from torchvision import transforms

mode = 1 # for sequentially data/patches loading we will use mode =1 and mode= 2 for random loading.

# transform = transforms.Compose([
#     transforms.FiveCrop(224),  # this is a list of 5 crops
#     Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))  # convert to tensor and stack
# ])

simple_transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL to Tensor
])
def create_dataloader(label_file, data_path, transform, num_samples, mode):
    # Create the dataset
    dataset = PatchLoader(label_file=label_file, data_path=data_path, transform=transform, num_samples=num_samples, mode=mode)
    # Ensure sequential data loading by disabling shuffle
    batch_sampler = SlideBatchSampler(dataset.ntiles)
    dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0, pin_memory=False)
    return dataloader

# Create DataLoaders
data_loader = create_dataloader(label_file=K_FOLDS_PATH,data_path=DATA_PATH,
                                     transform=simple_transform,num_samples=None,mode=mode)
print(f"Length of data_loader: {len(data_loader)}")

Number of Slides: 31
Number of tiles: 29119
Length of data_loader: 31


In [13]:
# for batch_idx, (images, labels) in enumerate(data_loader):
#     print(f"Batch {batch_idx+1}")
#     print(f"Images shape: {images.shape}")
#     print(f"Labels: {labels}")
#     if batch_idx == 5:  # Only print a few batches to check if it's working
#         break

### DataLoader testing

In [14]:

def print_unique_class_representation(dataloader):
    def get_unique_classes(loader):
        all_labels = []
        for _, labels in loader:
            all_labels.extend(labels.tolist())
        unique_labels, counts = torch.unique(torch.tensor(all_labels), return_counts=True)
        return unique_labels, counts
    unique_labels, counts = get_unique_classes(dataloader)

    print("\nTest Loader Unique Labels and Counts:")
    print(f"Unique labels: {unique_labels}")
    print(f"Counts: {counts}")

# Example usage:
# print_unique_class_representation(data_loader)
# for batch_idx, (images, labels) in enumerate(data_loader):
#     print(f"Batch {batch_idx}: {len(images)} images")
#     print(f"Images shape: {images.shape}")
#     unique_labels, counts = torch.unique(labels, return_counts=True)
#     print(f"Unique labels: {unique_labels}")
#     print(f"Counts: {counts}")
#     if batch_idx == 2:  # Only print a few batches to check if it's working
#         break

### ROI Feature Extraction on FiveCrop Patches Level and Save

In [15]:
import torch
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt

def extract_embeddings_patch_by_patch(model, dataloader, save_dir):
    device = next(model.parameters()).device
    print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (images, labels) in tqdm(enumerate(dataloader), total=len(dataloader)):
        wsi_name = dataloader.dataset.slides[batch_idx]

        # Make a new directory for each WSI
        save_dir_wsi = os.path.join(save_dir, f'{wsi_name}')
        if os.path.exists(save_dir_wsi):
            print(f"WSI {batch_idx+1} {wsi_name} already processed. Skipping...")
            continue
        os.makedirs(save_dir_wsi, exist_ok=True)

        batch_indices = dataloader.batch_sampler.indices[batch_idx]        

        for patch_idx, dataset_idx in enumerate(batch_indices):  # Loop through patches
            patch_name = os.path.splitext(os.path.basename(dataloader.dataset.tiles[dataset_idx]))[0]
            image = images[patch_idx]  # Extract patch
            label = labels[patch_idx]

            # Convert to PIL before applying FiveCrop
            image_pil = to_pil_image(image)

            # Apply FiveCrop manually before passing to model
            fivecrop_transform = transforms.FiveCrop(224)
            cropped_images = fivecrop_transform(image_pil)  # List of 5 crops as PIL images
            # List to store features of all five crops
            crop_features = []

            for i, crop in enumerate(cropped_images):  # Process each crop separately
                crop_transformed = transform(crop)  # Apply the UNI model’s transformation                
                crop_transformed = crop_transformed.unsqueeze(0).to(device)  # Add batch dimension

                with torch.inference_mode():
                    embedding = model(crop_transformed).detach().cpu().squeeze(0)  # Ensure 1D tensor
                    crop_features.append(embedding)  # Store each crop’s features
            
            # Stack all five crops into a single tensor (5, feature_dim)
            stacked_features = torch.stack(crop_features)

            # Save the tensor in the correct format
            save_path = os.path.join(save_dir_wsi, f'{patch_name}.pt')
            torch.save(stacked_features, save_path)

In [16]:
extract_embeddings_patch_by_patch(model, data_loader, FEATURES_SAVE_DIR)

The size of input dataloader is 31


100%|██████████| 31/31 [34:58<00:00, 67.68s/it]  


In [13]:
import torch

# Define the path to the .pt file
file_path = r"E:\Aamir Gulzar\dataset\paip_data\UNI_FiveCrop_Features_old\validation_data_01_nonMSIH\validation_data_01_nonMSIH_x0_y1536_3.pt"

try:
    data = torch.load(file_path)
    shape = torch.tensor(data).shape  # Get shape
    print("Shape of the data:", shape)
    print("Values:\n", data)  # Print values
except Exception as e:
    print("Error loading file:", e)


Shape of the data: torch.Size([5, 1024])
Values:
 tensor([[-0.2399, -0.6515, -0.5076,  ...,  0.3224,  1.9748,  0.5773],
        [-1.0118, -0.7169, -1.3044,  ...,  0.7268,  1.6948, -0.1991],
        [-0.8957, -1.6086, -0.7708,  ...,  0.4210,  1.4487,  0.2510],
        [-0.9476,  0.0703, -1.5778,  ...,  0.1153,  0.1450, -1.3170],
        [-0.4141, -1.7474, -1.4875,  ...,  0.9523,  1.3589,  0.9226]])


  shape = torch.tensor(data).shape  # Get shape


In [None]:
@torch.no_grad()
def extract_embeddings_patch_by_patch_old(model, dataloader, save_dir):
    """
    Extract and save embeddings for each WSI, patch by patch, without averaging the five crops.
    Args:
    - model: The model used to extract embeddings.
    - preprocess: Preprocessing function to apply to the images.
    - dataloader: Dataloader providing WSI patches and labels.
    - save_dir: Directory where the extracted embeddings will be saved.
    Returns:
    - None: The function saves the extracted embeddings to disk.
    """
    device = next(model.parameters()).device
    print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (images, labels) in tqdm(enumerate(dataloader), total=len(dataloader)):
        wsi_name = dataloader.dataset.slides[batch_idx]
        # make a new directory for each WSI
        save_dir_wsi = os.path.join(save_dir, f'{wsi_name}')
        # Check if the WSI directory already exists, skip if processed
        if os.path.exists(save_dir_wsi):
            print(f"WSI {batch_idx+1} {wsi_name} already processed. Skipping...")
            continue
        os.makedirs(save_dir_wsi, exist_ok=True)
        batch_indices = dataloader.batch_sampler.indices[batch_idx]        
        for patch_idx, dataset_idx in enumerate(batch_indices):  # Loop through indices of patches for current WSI
            patch_name = os.path.splitext(os.path.basename(dataloader.dataset.tiles[dataset_idx]))[0]
            image = images[patch_idx]
            label = labels[patch_idx]
            # Reshape image to combine batch and fivecrop dimensions
            num_crops, channels, height, width = image.shape
            image = image.view(num_crops, channels, height, width).to(device)
            # print shape of the patch after reshaping
            with torch.inference_mode():
                embeddings = model(image).detach().cpu()  # Extract features for the image
            # Save embeddings to disk
            save_path = os.path.join(save_dir_wsi, f'{patch_name}.pt')
            torch.save(embeddings, save_path)
            # also save in simple text format in text file
            # save_path_txt = os.path.join(save_dir_wsi, f'{wsi_name}_{i}.txt')
            # np.savetxt(save_path_txt, embeddings.cpu().numpy())
            # print the saved path
            # print(f"Embeddings saved to {save_path_txt}")