In [1]:
import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from tqdm import tqdm
# print(torch.version)
# print(torch.version.cuda)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Lambda
import torchvision.models as models
import torch.nn as nn
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Global Config Settings

In [2]:
# configs
BATCH_SIZE = 1 # load each slide all tiles sequentially 
K_FOLDS_PATH = r'E:\\Aamir Gulzar\\dataset\\paip_data\\labels\\traning_data_MSI.csv'
DATA_PATH = r"E:\\Aamir Gulzar\\dataset\\paip_data\\Patches"
FEATURES_SAVE_DIR = r"E:/Aamir Gulzar/dataset/paip_data/CAIMAN_FiveCrop_Features"
# saved model path
Model_PATH = r"E:\Aamir Gulzar\existing_approaches\CAIMAN_Fivecrop_4Folds\MSI_vs_MSS_T50R50\fold1\best0\checkpoint_best_AUC.pth"

### Model

In [3]:
# Suppose you have a model definition like:
class CNN(nn.Module):
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        self.model_resnet = models.resnet34(weights='DEFAULT')
        num_ftrs = self.model_resnet.fc.in_features
        self.model_resnet.fc = nn.Identity()
        self.classifier = nn.Linear(num_ftrs, num_classes)
        self.conf = nn.Linear(num_ftrs, 1)

    def forward(self, x):
        # No 5-crop logic here: x => [N, 3, H, W]
        features = self.model_resnet(x)         # => [N, num_ftrs]
        logit = self.classifier(features)       # => [N, num_classes]
        conf = self.conf(features)              # => [N, 1]
        return features  # or (logit, conf), depending on what you need

# 1. Create the same architecture
num_classes = 2  # or whatever number of classes you had
model = CNN(num_classes=num_classes)
checkpoint = torch.load(Model_PATH)
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()
model = model.to(device)

def preprocess_resnet(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Ensure consistent size
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))  # ImageNet normalization
    ])
    return transform(image)

### Data Loader

In [4]:
from dataloader import PatchLoader, SlideBatchSampler
mode = 1 # for sequentially data/patches loading we will use mode =1 and mode= 2 for random loading.

transform = transforms.Compose([
    transforms.FiveCrop(224),  # this is a list of 5 crops
    Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))  # convert to tensor and stack
])
def create_dataloader(label_file, data_path, transform, num_samples, mode):
    # Create the dataset
    dataset = PatchLoader(label_file=label_file, data_path=data_path, transform=transform, num_samples=num_samples, mode=mode)
    # Ensure sequential data loading by disabling shuffle
    batch_sampler = SlideBatchSampler(dataset.ntiles)
    dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0, pin_memory=False)
    return dataloader

# Create DataLoaders
data_loader = create_dataloader(label_file=K_FOLDS_PATH,data_path=DATA_PATH,
                                     transform=transform,num_samples=None,mode=mode)
print(f"Length of data_loader: {len(data_loader)}")

Number of Slides: 47
Number of tiles: 40630
Length of data_loader: 47


In [26]:
# for batch_idx, (images, labels) in enumerate(data_loader):
#     print(f"Batch {batch_idx+1}")
#     print(f"Images shape: {images.shape}")
#     if batch_idx == 1:  # Only print a few batches to check if it's working
#         break

### FiveCrop Feature Extraction and Save

In [5]:
@torch.no_grad()
def extract_embeddings_patch_by_patch(model, dataloader, save_dir):
    """
    Extract and save embeddings for each WSI, patch by patch, without averaging the five crops.
    Args:
    - model: The model used to extract embeddings.
    - dataloader: Dataloader providing WSI patches and labels.
    - save_dir: Directory where the extracted embeddings will be saved.
    Returns:
    - None: The function saves the extracted embeddings to disk.
    """
    device = next(model.parameters()).device
    print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (images, labels) in tqdm(enumerate(dataloader), total=len(dataloader)):
        wsi_name = dataloader.dataset.slides[batch_idx]
        print(f"Processing WSI {batch_idx+1} {wsi_name}")
        # make a new directory for each WSI
        save_dir_wsi = os.path.join(save_dir, f'{wsi_name}')
        # Check if the WSI directory already exists, skip if processed
        if os.path.exists(save_dir_wsi):
            print(f"WSI {batch_idx+1} {wsi_name} already processed. Skipping...")
            continue
        os.makedirs(save_dir_wsi, exist_ok=True)
        for i in range(len(images)):
            patch_name = os.path.splitext(os.path.basename(dataloader.dataset.tiles[i]))[0]
            image = images[i]
            label = labels[i]
            num_crops, channels, height, width = image.shape
            image = image.permute(0, 2, 3, 1).cpu().numpy()  # Convert to (num_crops, H, W, C)
            image = (image * 255).astype(np.uint8)  # Convert back to uint8 format
            # Apply Normalization on Each Crop (Matching Virchow2 & Conch)
            image = torch.stack([preprocess_resnet(Image.fromarray(im)) for im in image])
            # Move image to the same device as the model
            image = image.to(device)
            # Extract embeddings
            with torch.inference_mode():
                # Extract embeddings for the chunk
                features = model(image)
            # Save embeddings to disk
            save_path = os.path.join(save_dir_wsi, f'{patch_name}.pt')
            # print(f"Saving features for patch {i + 1} to {save_path}")
            torch.save(features, save_path)
            # also save in simple text format in text file
            # save_path_txt = os.path.join(save_dir_wsi, f'{wsi_name}_{i}.txt')
            # np.savetxt(save_path_txt, features.cpu().numpy() , delimiter=",")

In [6]:
extract_embeddings_patch_by_patch(model, data_loader, FEATURES_SAVE_DIR)

The size of input dataloader is 47


  0%|          | 0/47 [00:00<?, ?it/s]

Processing WSI 1 training_data_01_MSIH


  2%|▏         | 1/47 [00:43<33:20, 43.50s/it]

Processing WSI 2 training_data_02_nonMSIH


  4%|▍         | 2/47 [01:46<41:20, 55.13s/it]

Processing WSI 3 training_data_03_nonMSIH


  6%|▋         | 3/47 [02:28<36:05, 49.22s/it]

Processing WSI 4 training_data_04_nonMSIH


  9%|▊         | 4/47 [02:37<23:43, 33.10s/it]

Processing WSI 5 training_data_05_MSIH


 11%|█         | 5/47 [03:54<34:11, 48.83s/it]

Processing WSI 6 training_data_06_MSIH


 13%|█▎        | 6/47 [04:20<28:12, 41.28s/it]

Processing WSI 7 training_data_07_nonMSIH


 15%|█▍        | 7/47 [04:42<23:13, 34.84s/it]

Processing WSI 8 training_data_08_nonMSIH


 17%|█▋        | 8/47 [04:59<19:04, 29.35s/it]

Processing WSI 9 training_data_09_nonMSIH


 19%|█▉        | 9/47 [05:49<22:38, 35.76s/it]

Processing WSI 10 training_data_10_nonMSIH


 21%|██▏       | 10/47 [06:41<25:03, 40.64s/it]

Processing WSI 11 training_data_11_nonMSIH


 23%|██▎       | 11/47 [06:54<19:24, 32.35s/it]

Processing WSI 12 training_data_12_MSIH


 26%|██▌       | 12/47 [07:33<19:54, 34.14s/it]

Processing WSI 13 training_data_13_nonMSIH


 28%|██▊       | 13/47 [08:08<19:36, 34.60s/it]

Processing WSI 14 training_data_14_nonMSIH


 30%|██▉       | 14/47 [08:27<16:22, 29.78s/it]

Processing WSI 15 training_data_15_nonMSIH


 32%|███▏      | 15/47 [09:21<19:43, 36.99s/it]

Processing WSI 16 training_data_16_nonMSIH


 34%|███▍      | 16/47 [09:42<16:37, 32.18s/it]

Processing WSI 17 training_data_17_nonMSIH


 36%|███▌      | 17/47 [10:02<14:19, 28.67s/it]

Processing WSI 18 training_data_18_nonMSIH


 38%|███▊      | 18/47 [10:18<12:02, 24.93s/it]

Processing WSI 19 training_data_19_nonMSIH


 40%|████      | 19/47 [11:00<14:01, 30.05s/it]

Processing WSI 20 training_data_20_MSIH


 43%|████▎     | 20/47 [11:29<13:17, 29.54s/it]

Processing WSI 21 training_data_21_nonMSIH


 45%|████▍     | 21/47 [11:57<12:37, 29.14s/it]

Processing WSI 22 training_data_22_nonMSIH


 47%|████▋     | 22/47 [12:36<13:19, 32.00s/it]

Processing WSI 23 training_data_23_nonMSIH


 49%|████▉     | 23/47 [12:50<10:40, 26.67s/it]

Processing WSI 24 training_data_24_MSIH


 51%|█████     | 24/47 [13:44<13:23, 34.95s/it]

Processing WSI 25 training_data_25_nonMSIH


 53%|█████▎    | 25/47 [14:23<13:13, 36.08s/it]

Processing WSI 26 training_data_26_nonMSIH


 55%|█████▌    | 26/47 [15:41<17:00, 48.60s/it]

Processing WSI 27 training_data_27_nonMSIH


 57%|█████▋    | 27/47 [16:30<16:17, 48.89s/it]

Processing WSI 28 training_data_28_nonMSIH


 57%|█████▋    | 27/47 [17:29<12:57, 38.86s/it]


KeyboardInterrupt: 