In [1]:
import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from tqdm import tqdm
# print(torch.version)
# print(torch.version.cuda)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Lambda
import torchvision.models as models
import torch.nn as nn
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Global Config Settings

In [7]:
# configs
BATCH_SIZE = 1 # load each slide all tiles sequentially 
K_FOLDS_PATH = r'E:\\Aamir Gulzar\\dataset\\paip_data\\labels\\validation_data_MSI.csv'
DATA_PATH = r"E:\\Aamir Gulzar\\dataset\\paip_data\\Patches"
FEATURES_SAVE_DIR = r"E:/Aamir Gulzar/dataset/paip_data/Baseline_FiveCrop_Features"
# saved model path
Model_PATH = r"E:\Aamir Gulzar\existing_approaches\Baseline_Fivecrop_4Folds\MSI_vs_MSS\fold1\best0\checkpoint_best_AUC.pth"

### Model

In [8]:
model = models.resnet34()
model.fc = nn.Identity()
checkpoint = torch.load(Model_PATH)
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()
model = model.to(device)
def preprocess_resnet(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Ensure consistent size
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))  # ImageNet normalization
    ])
    return transform(image)

### Data Loader

In [9]:
from dataloader import PatchLoader, SlideBatchSampler
mode = 1 # for sequentially data/patches loading we will use mode =1 and mode= 2 for random loading.

transform = transforms.Compose([
    transforms.FiveCrop(224),  # this is a list of 5 crops
    Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))  # convert to tensor and stack
])
def create_dataloader(label_file, data_path, transform, num_samples, mode):
    # Create the dataset
    dataset = PatchLoader(label_file=label_file, data_path=data_path, transform=transform, num_samples=num_samples, mode=mode)
    # Ensure sequential data loading by disabling shuffle
    batch_sampler = SlideBatchSampler(dataset.ntiles)
    dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0, pin_memory=False)
    return dataloader

# Create DataLoaders
data_loader = create_dataloader(label_file=K_FOLDS_PATH,data_path=DATA_PATH,
                                     transform=transform,num_samples=None,mode=mode)
print(f"Length of data_loader: {len(data_loader)}")

Number of Slides: 31
Number of tiles: 29119
Length of data_loader: 31


In [10]:
# for batch_idx, (images, labels) in enumerate(data_loader):
#     print(f"Batch {batch_idx+1}")
#     print(f"Images shape: {images.shape}")
#     print(f"Labels: {labels}")
#     if batch_idx == 5:  # Only print a few batches to check if it's working
#         break

### FiveCrop Feature Extraction and Save

In [11]:
@torch.no_grad()
def extract_embeddings_patch_by_patch(model, dataloader, save_dir):
    """
    Extract and save embeddings for each WSI, patch by patch, without averaging the five crops.
    Args:
    - model: The model used to extract embeddings.
    - dataloader: Dataloader providing WSI patches and labels.
    - save_dir: Directory where the extracted embeddings will be saved.
    Returns:
    - None: The function saves the extracted embeddings to disk.
    """
    device = next(model.parameters()).device
    print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (images, labels) in tqdm(enumerate(dataloader), total=len(dataloader)):
        wsi_name = dataloader.dataset.slides[batch_idx]
        print(f"Processing WSI {batch_idx+1} {wsi_name}")
        # make a new directory for each WSI
        save_dir_wsi = os.path.join(save_dir, f'{wsi_name}')
        # Check if the WSI directory already exists, skip if processed
        if os.path.exists(save_dir_wsi):
            print(f"WSI {batch_idx+1} {wsi_name} already processed. Skipping...")
            continue
        os.makedirs(save_dir_wsi, exist_ok=True)
        batch_indices = dataloader.batch_sampler.indices[batch_idx]        
        for patch_idx, dataset_idx in enumerate(batch_indices):  # Loop through indices of patches for current WSI
            patch_name = os.path.splitext(os.path.basename(dataloader.dataset.tiles[dataset_idx]))[0]
            image = images[patch_idx]
            label = labels[patch_idx]
            image = image.permute(0, 2, 3, 1).cpu().numpy()  # Convert to (num_crops, H, W, C)
            image = (image * 255).astype(np.uint8)  # Convert back to uint8 format
            # Apply Normalization on Each Crop (Matching Virchow2 & Conch)
            image = torch.stack([preprocess_resnet(Image.fromarray(im)) for im in image])
            # Move image to the same device as the model
            image = image.to(device)
            # Extract embeddings
            with torch.inference_mode():
                # Extract embeddings for the chunk
                features = model(image)
            # Save embeddings to disk
            save_path = os.path.join(save_dir_wsi, f'{patch_name}.pt')
            # print(f"Saving features for patch {i + 1} to {save_path}")
            torch.save(features, save_path)
            # also save in simple text format in text file
            # save_path_txt = os.path.join(save_dir_wsi, f'{wsi_name}_{i}.txt')
            # np.savetxt(save_path_txt, features.cpu().numpy() , delimiter=",")

In [12]:
extract_embeddings_patch_by_patch(model, data_loader, FEATURES_SAVE_DIR)

The size of input dataloader is 31


  0%|          | 0/31 [00:00<?, ?it/s]

Processing WSI 1 validation_data_01_nonMSIH


  3%|▎         | 1/31 [00:23<11:30, 23.01s/it]

Processing WSI 2 validation_data_02_MSIH


  6%|▋         | 2/31 [01:11<18:16, 37.80s/it]

Processing WSI 3 validation_data_03_nonMSIH


 10%|▉         | 3/31 [01:29<13:24, 28.73s/it]

Processing WSI 4 validation_data_04_MSIH


 13%|█▎        | 4/31 [02:05<14:22, 31.93s/it]

Processing WSI 5 validation_data_05_nonMSIH


 16%|█▌        | 5/31 [02:56<16:45, 38.65s/it]

Processing WSI 6 validation_data_06_nonMSIH


 19%|█▉        | 6/31 [03:12<12:53, 30.94s/it]

Processing WSI 7 validation_data_07_nonMSIH


 23%|██▎       | 7/31 [03:29<10:30, 26.25s/it]

Processing WSI 8 validation_data_08_nonMSIH


 26%|██▌       | 8/31 [04:03<11:03, 28.87s/it]

Processing WSI 9 validation_data_09_nonMSIH


 29%|██▉       | 9/31 [04:38<11:18, 30.86s/it]

Processing WSI 10 validation_data_10_nonMSIH


 32%|███▏      | 10/31 [04:58<09:34, 27.34s/it]

Processing WSI 11 validation_data_11_MSIH


 35%|███▌      | 11/31 [05:28<09:27, 28.36s/it]

Processing WSI 12 validation_data_12_nonMSIH


 39%|███▊      | 12/31 [05:34<06:46, 21.39s/it]

Processing WSI 13 validation_data_13_nonMSIH


 42%|████▏     | 13/31 [06:10<07:47, 25.99s/it]

Processing WSI 14 validation_data_14_nonMSIH


 45%|████▌     | 14/31 [06:18<05:46, 20.38s/it]

Processing WSI 15 validation_data_15_MSIH


 48%|████▊     | 15/31 [06:32<04:57, 18.56s/it]

Processing WSI 16 validation_data_16_nonMSIH


 52%|█████▏    | 16/31 [06:36<03:32, 14.14s/it]

Processing WSI 17 validation_data_17_nonMSIH


 55%|█████▍    | 17/31 [07:07<04:27, 19.09s/it]

Processing WSI 18 validation_data_18_nonMSIH


 58%|█████▊    | 18/31 [07:14<03:22, 15.55s/it]

Processing WSI 19 validation_data_19_nonMSIH


 61%|██████▏   | 19/31 [07:26<02:53, 14.44s/it]

Processing WSI 20 validation_data_20_nonMSIH


 65%|██████▍   | 20/31 [07:39<02:34, 14.07s/it]

Processing WSI 21 validation_data_21_nonMSIH


 68%|██████▊   | 21/31 [07:50<02:12, 13.24s/it]

Processing WSI 22 validation_data_22_nonMSIH


 71%|███████   | 22/31 [07:57<01:41, 11.29s/it]

Processing WSI 23 validation_data_23_nonMSIH


 74%|███████▍  | 23/31 [08:06<01:24, 10.57s/it]

Processing WSI 24 validation_data_24_nonMSIH


 77%|███████▋  | 24/31 [08:39<02:00, 17.26s/it]

Processing WSI 25 validation_data_25_nonMSIH


 81%|████████  | 25/31 [09:00<01:50, 18.43s/it]

Processing WSI 26 validation_data_26_nonMSIH


 84%|████████▍ | 26/31 [09:18<01:31, 18.20s/it]

Processing WSI 27 validation_data_27_MSIH


 87%|████████▋ | 27/31 [09:48<01:27, 21.95s/it]

Processing WSI 28 validation_data_28_MSIH


 90%|█████████ | 28/31 [10:04<01:00, 20.15s/it]

Processing WSI 29 validation_data_29_MSIH


 94%|█████████▎| 29/31 [10:09<00:31, 15.62s/it]

Processing WSI 30 validation_data_30_nonMSIH


 97%|█████████▋| 30/31 [10:35<00:18, 18.60s/it]

Processing WSI 31 validation_data_31_nonMSIH


100%|██████████| 31/31 [10:50<00:00, 20.98s/it]


## Reload saved Features (same like averaging appraoch)

In [9]:
import os
import torch

def load_features_and_labels(save_dir):
    wsi_feature_list = []
    label_list = []

    # Loop through each WSI folder
    for wsi_folder in os.listdir(save_dir):
        wsi_folder_path = os.path.join(save_dir, wsi_folder)
        print(f'wsi_folder_path{wsi_folder_path}')

        if os.path.isdir(wsi_folder_path):  # Ensure it is a directory
            patch_features_list = []
            # Loop through each patch file (.pt) inside the WSI folder
            for file_name in os.listdir(wsi_folder_path):
                if file_name.endswith('.pt'):
                    file_path = os.path.join(wsi_folder_path, file_name)
                    print(f'patch file {file_path}')
                    # Load the five-crop features for the patch
                    patch_features = torch.load(file_path)  # Loaded as [[],[],[],[],[]]
                    # patch_features = torch.stack(patch_features)  # Convert list of tensors to tensor [5, feature_dim]
                    print(f'loaded patch feature vector shape{patch_features.shape}')
                    print(f'loaded patch feature vector shape{patch_features}')
                    # Average the five-crop features to get one feature for the patch
                    avg_patch_feature = patch_features.mean(dim=0)  # [feature_dim]
                    print(f'loaded patch feature shape after mean {avg_patch_feature.shape}')
                    patch_features_list.append(avg_patch_feature)
            # After all patches are processed, average them to get one feature for the WSI
            wsi_features = torch.stack(patch_features_list).mean(dim=0)  # [feature_dim]
            # Append WSI feature to the list
            wsi_feature_list.append(wsi_features)
            # Determine label based on WSI folder name
            if '_nonMSI' in wsi_folder:
                label_list.append(0)
            elif '_MSI' in wsi_folder:
                label_list.append(1)
    # Stack all WSI features and labels
    features = torch.stack(wsi_feature_list)
    labels = torch.tensor(label_list)
    return features, labels

# Define directories for train, validation, and test splits
save_dir = FEATURES_SAVE_DIR
# Load features and labels for each split
saved_features, saved_labels = load_features_and_labels(save_dir)

wsi_folder_pathE:\\KSA Project\\dataset\\testing\\CAIMAN_features\TCGA-3L-AA1B_nonMSIH
patch file E:\\KSA Project\\dataset\\testing\\CAIMAN_features\TCGA-3L-AA1B_nonMSIH\TCGA-3L-AA1B_nonMSIH_0.pt
loaded patch feature vector shapetorch.Size([5, 512])
loaded patch feature vector shapetensor([[10.3492,  7.8456, 15.9951,  ...,  3.5675,  0.1409, 22.2074],
        [10.4047,  8.2960, 17.0138,  ...,  3.4913,  0.1405, 23.3252],
        [11.1867,  8.2246, 15.1310,  ...,  3.7105,  0.1650, 22.6287],
        [11.5462,  8.3224, 15.3200,  ...,  3.7407,  0.1912, 23.1283],
        [10.4556,  8.1382, 17.7604,  ...,  3.3586,  0.1968, 24.5821]],
       device='cuda:0')
loaded patch feature shape after mean torch.Size([512])
patch file E:\\KSA Project\\dataset\\testing\\CAIMAN_features\TCGA-3L-AA1B_nonMSIH\TCGA-3L-AA1B_nonMSIH_1.pt
loaded patch feature vector shapetorch.Size([5, 512])
loaded patch feature vector shapetensor([[10.7634,  8.5093, 17.8861,  ...,  3.6829,  0.1401, 24.4079],
        [11.1720,  8

In [None]:
loaded patch feature vector shapetorch.Size([5, 512])
loaded patch feature vector shapetensor([[10.3492,  7.8456, 15.9951,  ...,  3.5675,  0.1409, 22.2074],
        [10.4047,  8.2960, 17.0138,  ...,  3.4913,  0.1405, 23.3252],
        [11.1867,  8.2246, 15.1310,  ...,  3.7105,  0.1650, 22.6287],
        [11.5462,  8.3224, 15.3200,  ...,  3.7407,  0.1912, 23.1283],
        [10.4556,  8.1382, 17.7604,  ...,  3.3586,  0.1968, 24.5821]],
       device='cuda:0')
loaded patch feature shape after mean torch.Size([512])
patch file E:\\KSA Project\\dataset\\testing\\CAIMAN_features\TCGA-3L-AA1B_nonMSIH\TCGA-3L-AA1B_nonMSIH_1.pt
loaded patch feature vector shapetorch.Size([5, 512])
loaded patch feature vector shapetensor([[10.7634,  8.5093, 17.8861,  ...,  3.6829,  0.1401, 24.4079],
        [11.1720,  8.6508, 18.1954,  ...,  3.7514,  0.1358, 24.8268],
        [11.7423,  8.5194, 15.5370,  ...,  3.7580,  0.1996, 23.3296],
        [11.7490,  8.6051, 15.7384,  ...,  3.8491,  0.1444, 23.7329],
        [10.2253,  8.1069, 17.6576,  ...,  3.2173,  0.1775, 24.2737]],
       device='cuda:0')
loaded patch feature shape after mean torch.Size([512])
patch file E:\\KSA Project\\dataset\\testing\\CAIMAN_features\TCGA-3L-AA1B_nonMSIH\TCGA-3L-AA1B_nonMSIH_2.pt
loaded patch feature vector shapetorch.Size([5, 512])
loaded patch feature vector shapetensor([[12.3652,  9.4916, 19.3000,  ...,  4.1201,  0.1584, 27.0940],
        [12.3414,  9.4774, 19.2851,  ...,  4.1033,  0.1580, 27.0423],
        [12.3530,  9.4830, 19.2985,  ...,  4.1212,  0.1567, 27.0924],
        [12.4026,  9.4249, 19.2395,  ...,  4.0286,  0.1538, 27.0516],

In [10]:
print(f'Train features shape {saved_features.shape} and Labels shape is {saved_labels.shape}')
# Print unique labels in each set to verify correctness
print(f'Unique labels in set: {torch.unique(saved_labels)}')

Train features shape torch.Size([5, 512]) and Labels shape is torch.Size([5])
Unique labels in set: tensor([0, 1])
