In [1]:
import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from tqdm import tqdm
# print(torch.version)
# print(torch.version.cuda)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Lambda
import torchvision.models as models
import torch.nn as nn
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Global Config Settings

In [2]:
# configs
BATCH_SIZE = 1 # load each slide all tiles sequentially 
K_FOLDS_PATH = r'E:\\Aamir Gulzar\\dataset\\paip_data\\labels\\traning_data_MSI.csv'
DATA_PATH = r"E:\\Aamir Gulzar\\dataset\\paip_data\\Patches"
FEATURES_SAVE_DIR = r"E:/Aamir Gulzar/dataset/paip_data/IDARS_FiveCrop_Features"
# saved model path
Model_PATH = r"E:\Aamir Gulzar\existing_approaches\IDaRS_Fivecrop_4Folds\MSI_vs_MSS_T1R10\fold1\best0\checkpoint_best_AUC.pth"

### Model

In [3]:
model = models.resnet34()
model.fc = nn.Identity()
checkpoint = torch.load(Model_PATH)
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()
model = model.to(device)
def preprocess_resnet(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Ensure consistent size
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))  # ImageNet normalization
    ])
    return transform(image)

### Data Loader

In [4]:
from dataloader import PatchLoader, SlideBatchSampler
mode = 1 # for sequentially data/patches loading we will use mode =1 and mode= 2 for random loading.

transform = transforms.Compose([
    transforms.FiveCrop(224),  # this is a list of 5 crops
    Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))  # convert to tensor and stack
])
def create_dataloader(label_file, data_path, transform, num_samples, mode):
    # Create the dataset
    dataset = PatchLoader(label_file=label_file, data_path=data_path, transform=transform, num_samples=num_samples, mode=mode)
    # Ensure sequential data loading by disabling shuffle
    batch_sampler = SlideBatchSampler(dataset.ntiles)
    dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0, pin_memory=False)
    return dataloader

# Create DataLoaders
data_loader = create_dataloader(label_file=K_FOLDS_PATH,data_path=DATA_PATH,
                                     transform=transform,num_samples=None,mode=mode)
print(f"Length of data_loader: {len(data_loader)}")

Number of Slides: 47
Number of tiles: 40630
Length of data_loader: 47


In [5]:
# for batch_idx, (images, labels) in enumerate(data_loader):
#     print(f"Batch {batch_idx+1}")
#     print(f"Images shape: {images.shape}")
#     if batch_idx == 1:  # Only print a few batches to check if it's working
#         break

### FiveCrop Feature Extraction and Save

In [6]:
@torch.no_grad()
def extract_embeddings_patch_by_patch(model, dataloader, save_dir):
    """
    Extract and save embeddings for each WSI, patch by patch, without averaging the five crops.
    Args:
    - model: The model used to extract embeddings.
    - dataloader: Dataloader providing WSI patches and labels.
    - save_dir: Directory where the extracted embeddings will be saved.
    Returns:
    - None: The function saves the extracted embeddings to disk.
    """
    device = next(model.parameters()).device
    print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (images, labels) in tqdm(enumerate(dataloader), total=len(dataloader)):
        wsi_name = dataloader.dataset.slides[batch_idx]
        print(f"Processing WSI {batch_idx+1} {wsi_name}")
        # make a new directory for each WSI
        save_dir_wsi = os.path.join(save_dir, f'{wsi_name}')
        # Check if the WSI directory already exists, skip if processed
        if os.path.exists(save_dir_wsi):
            print(f"WSI {batch_idx+1} {wsi_name} already processed. Skipping...")
            continue
        os.makedirs(save_dir_wsi, exist_ok=True)
        for i in range(len(images)):
            patch_name = os.path.splitext(os.path.basename(dataloader.dataset.tiles[i]))[0]
            image = images[i]
            label = labels[i]
            num_crops, channels, height, width = image.shape
            image = image.permute(0, 2, 3, 1).cpu().numpy()  # Convert to (num_crops, H, W, C)
            image = (image * 255).astype(np.uint8)  # Convert back to uint8 format
            # Apply Normalization on Each Crop (Matching Virchow2 & Conch)
            image = torch.stack([preprocess_resnet(Image.fromarray(im)) for im in image])
            # Move image to the same device as the model
            image = image.to(device)
            # Extract embeddings
            with torch.inference_mode():
                # Extract embeddings for the chunk
                features = model(image)
            # Save embeddings to disk
            save_path = os.path.join(save_dir_wsi, f'{patch_name}.pt')
            # print(f"Saving features for patch {i + 1} to {save_path}")
            torch.save(features, save_path)
            # also save in simple text format in text file
            # save_path_txt = os.path.join(save_dir_wsi, f'{wsi_name}_{i}.txt')
            # np.savetxt(save_path_txt, features.cpu().numpy() , delimiter=",")

In [7]:
extract_embeddings_patch_by_patch(model, data_loader, FEATURES_SAVE_DIR)

The size of input dataloader is 47


  0%|          | 0/47 [00:00<?, ?it/s]

Processing WSI 1 training_data_01_MSIH


  2%|▏         | 1/47 [00:23<17:54, 23.35s/it]

Processing WSI 2 training_data_02_nonMSIH


  4%|▍         | 2/47 [01:03<25:04, 33.42s/it]

Processing WSI 3 training_data_03_nonMSIH


  6%|▋         | 3/47 [01:26<21:04, 28.73s/it]

Processing WSI 4 training_data_04_nonMSIH


  9%|▊         | 4/47 [01:32<13:55, 19.43s/it]

Processing WSI 5 training_data_05_MSIH


 11%|█         | 5/47 [02:10<18:23, 26.27s/it]

Processing WSI 6 training_data_06_MSIH


 13%|█▎        | 6/47 [02:25<15:22, 22.50s/it]

Processing WSI 7 training_data_07_nonMSIH


 15%|█▍        | 7/47 [02:35<12:10, 18.26s/it]

Processing WSI 8 training_data_08_nonMSIH


 17%|█▋        | 8/47 [02:45<10:08, 15.61s/it]

Processing WSI 9 training_data_09_nonMSIH


 19%|█▉        | 9/47 [03:12<12:11, 19.25s/it]

Processing WSI 10 training_data_10_nonMSIH


 21%|██▏       | 10/47 [03:58<17:03, 27.67s/it]

Processing WSI 11 training_data_11_nonMSIH


 23%|██▎       | 11/47 [04:11<13:44, 22.90s/it]

Processing WSI 12 training_data_12_MSIH


 26%|██▌       | 12/47 [04:53<16:54, 29.00s/it]

Processing WSI 13 training_data_13_nonMSIH


 28%|██▊       | 13/47 [05:26<17:04, 30.13s/it]

Processing WSI 14 training_data_14_nonMSIH


 30%|██▉       | 14/47 [05:49<15:24, 28.02s/it]

Processing WSI 15 training_data_15_nonMSIH


 32%|███▏      | 15/47 [06:31<17:08, 32.15s/it]

Processing WSI 16 training_data_16_nonMSIH


 34%|███▍      | 16/47 [06:56<15:27, 29.92s/it]

Processing WSI 17 training_data_17_nonMSIH


 36%|███▌      | 17/47 [07:16<13:29, 26.98s/it]

Processing WSI 18 training_data_18_nonMSIH


 38%|███▊      | 18/47 [07:37<12:13, 25.30s/it]

Processing WSI 19 training_data_19_nonMSIH


 40%|████      | 19/47 [08:11<13:00, 27.89s/it]

Processing WSI 20 training_data_20_MSIH


 43%|████▎     | 20/47 [08:36<12:07, 26.94s/it]

Processing WSI 21 training_data_21_nonMSIH


 45%|████▍     | 21/47 [09:07<12:11, 28.14s/it]

Processing WSI 22 training_data_22_nonMSIH


 47%|████▋     | 22/47 [09:42<12:33, 30.16s/it]

Processing WSI 23 training_data_23_nonMSIH


 49%|████▉     | 23/47 [09:57<10:17, 25.71s/it]

Processing WSI 24 training_data_24_MSIH


 51%|█████     | 24/47 [10:50<13:01, 33.99s/it]

Processing WSI 25 training_data_25_nonMSIH


 53%|█████▎    | 25/47 [11:27<12:46, 34.84s/it]

Processing WSI 26 training_data_26_nonMSIH


 55%|█████▌    | 26/47 [12:43<16:28, 47.07s/it]

Processing WSI 27 training_data_27_nonMSIH


 57%|█████▋    | 27/47 [13:30<15:43, 47.19s/it]

Processing WSI 28 training_data_28_nonMSIH


 60%|█████▉    | 28/47 [14:41<17:11, 54.30s/it]

Processing WSI 29 training_data_29_nonMSIH


 62%|██████▏   | 29/47 [14:52<12:20, 41.15s/it]

Processing WSI 30 training_data_30_MSIH


 64%|██████▍   | 30/47 [15:52<13:17, 46.89s/it]

Processing WSI 31 training_data_31_nonMSIH


 66%|██████▌   | 31/47 [16:33<12:02, 45.18s/it]

Processing WSI 32 training_data_32_MSIH


 68%|██████▊   | 32/47 [16:58<09:46, 39.12s/it]

Processing WSI 33 training_data_33_nonMSIH


 70%|███████   | 33/47 [17:09<07:08, 30.57s/it]

Processing WSI 34 training_data_34_MSIH


 72%|███████▏  | 34/47 [17:18<05:12, 24.05s/it]

Processing WSI 35 training_data_35_nonMSIH


 74%|███████▍  | 35/47 [17:52<05:26, 27.21s/it]

Processing WSI 36 training_data_36_nonMSIH


 77%|███████▋  | 36/47 [18:02<04:00, 21.84s/it]

Processing WSI 37 training_data_37_nonMSIH


 79%|███████▊  | 37/47 [19:17<06:18, 37.83s/it]

Processing WSI 38 training_data_38_nonMSIH


 81%|████████  | 38/47 [19:58<05:49, 38.79s/it]

Processing WSI 39 training_data_39_nonMSIH


 83%|████████▎ | 39/47 [20:26<04:45, 35.67s/it]

Processing WSI 40 training_data_40_nonMSIH


 85%|████████▌ | 40/47 [20:58<04:01, 34.53s/it]

Processing WSI 41 training_data_41_nonMSIH


 87%|████████▋ | 41/47 [21:02<02:32, 25.37s/it]

Processing WSI 42 training_data_42_MSIH


 89%|████████▉ | 42/47 [21:24<02:02, 24.47s/it]

Processing WSI 43 training_data_43_nonMSIH


 91%|█████████▏| 43/47 [21:40<01:27, 21.80s/it]

Processing WSI 44 training_data_44_MSIH


 94%|█████████▎| 44/47 [21:53<00:57, 19.27s/it]

Processing WSI 45 training_data_45_nonMSIH


 96%|█████████▌| 45/47 [21:57<00:28, 14.50s/it]

Processing WSI 46 training_data_46_nonMSIH


 98%|█████████▊| 46/47 [22:09<00:13, 13.73s/it]

Processing WSI 47 training_data_47_MSIH


100%|██████████| 47/47 [22:26<00:00, 28.64s/it]
