In [15]:
import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from tqdm import tqdm
# print(torch.version)
# print(torch.version.cuda)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Lambda
import torchvision.models as models
import torch.nn as nn
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Load the model "create_model_from_pretrained"
By default, the model preprocessor uses 448 x 448 as the input size. To specify a different image size (e.g. 336 x 336), use the **force_img_size** argument.

You can specify a cuda device by using the **device** argument, or manually move the model to a device later using **model.to(device)**.

In [16]:
# configs
BATCH_SIZE = 1 # load each slide all tiles sequentially 
K_FOLDS_PATH = r"E:\Aamir Gulzar\dataset\splits\kfolds_IDARS_mini.csv"
DATA_PATH = r"E:\\Aamir Gulzar\\dataset\\Patches"
FEATURES_SAVE_DIR = r"E:/Aamir Gulzar/dataset/Baseline_FiveCrop_Features"
# saved model path
Model_PATH = r"E:\Aamir Gulzar\existing_approaches\Baseline_Fivecrop_4Folds\MSI_vs_MSS\fold1\best0\checkpoint_best_AUC.pth"

In [17]:
model = models.resnet34()
model.fc = nn.Identity()
checkpoint = torch.load(Model_PATH)
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()
model = model.to(device)
def preprocess_resnet(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Ensure consistent size
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))  # ImageNet normalization
    ])
    return transform(image)

### DataLoaders

In [18]:
class PatchLoader(Dataset):
    def __init__(self, label_file, data_path, transform=None, num_samples=None, img_resize=False, mode=2):     
        lib = pd.DataFrame(pd.read_csv(label_file, usecols=['WSI_Id', 'label_id'], keep_default_na=True))
        lib.dropna(inplace=True)
        if num_samples is not None:
            lib = lib.sample(n=num_samples)
        tar = lib['label_id'].values.tolist()
        allslides = lib['WSI_Id'].values.tolist()       
        slides = []
        tiles = []
        ntiles = []
        slideIDX = []
        targets = []
        j = 0
        for i, path in enumerate(allslides):
            t = []
            cpath = os.path.join(data_path, str(path))
            if not os.path.exists(cpath):
                # print('This slide does not exist: {}'.format(path))
                continue
            else:
                # print('This slide exists: {}'.format(path))
                # count = 0
                for f in os.listdir(cpath): 
                    if '.png' in f:
                        # count = count + 1
                        t.append(os.path.join(cpath, f))
                if len(t) > 0:
                    slides.append(path)
                    tiles.extend(t)
                    ntiles.append(len(t))
                    slideIDX.extend([j]*len(t))
                    targets.append(int(tar[i]))
                    j+=1
        print('Number of Slides: {}'.format(len(slides)))
        print('Number of tiles: {}'.format(len(tiles)))
        self.slides = slides
        self.slideIDX = slideIDX
        self.ntiles = ntiles
        self.tiles = tiles
        self.targets = targets
        self.transform = transform
        self.img_resize = img_resize
        self.mode = mode

    def maketraindata(self, idxs):
        self.t_data = [(self.slideIDX[x], self.tiles[x], self.targets[self.slideIDX[x]]) for x in idxs]

    def shuffletraindata(self):
        self.t_data = random.sample(self.t_data, len(self.t_data))

    def __getitem__(self, index):
        if self.mode == 1:# loads all tiles from each slide sequentially for train/validatoin set
            tile = self.tiles[index]
            img = Image.open(str(tile)).convert('RGB')
            slideIDX = self.slideIDX[index]
            target = self.targets[slideIDX]
            if self.img_resize== True:  
                img = img.resize((224, 224), Image.BILINEAR)
            if self.transform is not None:
                img = self.transform(img)
            return img, target
        elif self.mode == 2:  # used when a different trainset is prepared e.g. with given tile index    
            slideIDX, tile, target = self.t_data[index]
            img = Image.open(str(tile)).convert('RGB')
            if self.img_resize == True:
                img = img.resize((224, 224), Image.BILINEAR)
            if self.transform is not None:
                img = self.transform(img)
            return img, target

    def __len__(self):
        if self.mode == 1:
            length = len(self.tiles)
        elif self.mode == 2:
            length = len(self.t_data)
        else:
            length = 0
        # print(f"__len__ called, mode: {self.mode}, length: {length}")
        return length
    

In [19]:
from torch.utils.data import Sampler, DataLoader
mode = 1 # for sequentially data/patches loading we will use mode =1 and mode= 2 for random loading.
class SlideBatchSampler(Sampler):
    def __init__(self, ntiles):
        # ntiles contains the number of tiles per slide
        self.ntiles = ntiles
        self.indices = []
        start_idx = 0
        for num_tiles in ntiles:
            self.indices.append(list(range(start_idx, start_idx + num_tiles)))
            start_idx += num_tiles
    def __iter__(self):
        # Yield each set of indices for a single slide (batch contains all tiles for that slide)
        for batch in self.indices:
            yield batch
    def __len__(self):
        return len(self.indices)

transform = transforms.Compose([
    transforms.FiveCrop(224),  # this is a list of 5 crops
    Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))  # convert to tensor and stack
])

# Custom collate function to handle the batch of crops
def collate_fn(batch):
    images, labels = zip(*batch)
    images = torch.stack(images)
    batch_size, num_crops, c, h, w = images.size()
    images = images.view(-1, c, h, w)  # flatten the crops into individual images
    labels = torch.tensor(labels).repeat_interleave(num_crops)  # repeat labels for each crop
    return images, labels

def create_dataloader(label_file, data_path, transform, num_samples,shuffle, img_resize, mode):
    dataset = PatchLoader(label_file=label_file, data_path=data_path, transform=transform, num_samples=num_samples, img_resize=img_resize, mode=mode)
    num_tiles = len(dataset.slideIDX)
    dataset.maketraindata(np.arange(num_tiles))
    if shuffle:
        dataset.shuffletraindata()
        # Use SlideBatchSampler instead of a fixed batch size
    batch_sampler = SlideBatchSampler(dataset.ntiles)
    dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0, pin_memory=False)
    return dataloader

# Create DataLoaders
train_dataloader = create_dataloader(label_file="E:\Aamir Gulzar\dataset\splits\kfolds_IDARS_mini.csv",data_path="E:\\Aamir Gulzar\\dataset\\Patches",
                                     transform=transform,num_samples=7,shuffle=True,img_resize=False,mode=mode)
print(f"Length of train_loader: {len(train_dataloader)}")

Number of Slides: 7
Number of tiles: 5324
Length of train_loader: 7


  train_dataloader = create_dataloader(label_file="E:\Aamir Gulzar\dataset\splits\kfolds_IDARS_mini.csv",data_path="E:\\Aamir Gulzar\\dataset\\Patches",


### DataLoaders Testing

In [20]:
for batch_idx, (images, labels) in enumerate(train_dataloader):
    print(f"Batch {batch_idx+1}")
    print(f"Images shape: {images.shape}")
    print(f"Labels: {labels}")
    if batch_idx == 2:  # Only print a few batches to check if it's working
        break

Batch 1
Images shape: torch.Size([932, 5, 3, 224, 224])
Labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

### ROI Feature Extraction on FiveCrop Patches Using Averaging Approach at WSI Level
Extract patches of individual patch five crops then average and at the end average features of all patches of each WSI 

In [21]:
@torch.no_grad()
def extract_embeddings_with_chunks(model, preprocess, dataloader, chunk_size=1000):
    all_embeddings = []
    all_labels = []
    device = next(model.parameters()).device
    # print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (images, labels) in tqdm(enumerate(dataloader), total=len(dataloader)):
        batch_size, num_crops, channels, height, width = images.shape
        images = images.view(batch_size * num_crops, channels, height, width)
        images = images.permute(0, 2, 3, 1).cpu().numpy()  # Convert to (H, W, C) format
        images = (images * 255).astype(np.uint8)  # Convert to uint8

        # Preprocess images
        images = torch.stack([preprocess(Image.fromarray(image)) for image in images])
        images = images.to(device)

        wsi_embeddings = []
        num_patches = images.size(0)
        num_chunks = (num_patches + chunk_size - 1) // chunk_size

        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min(start_idx + chunk_size, num_patches)
            chunk = images[start_idx:end_idx]
            with torch.inference_mode():
                # Extract embeddings for the chunk
                chunk_embeddings = model(chunk).detach().cpu()
                chunk_embeddings = chunk_embeddings.view(-1, num_crops, chunk_embeddings.size(-1))  # Reshape to [chunk_size, num_crops, embedding_dim]
                chunk_embeddings = chunk_embeddings.mean(dim=1)  # Average over five-crop features
                wsi_embeddings.append(chunk_embeddings)
        # Concatenate embeddings for all chunks in this WSI
        wsi_embeddings = torch.cat(wsi_embeddings, dim=0)  # [num_patches, embedding_dim]
        # print(f'WSI embeddings shape: {wsi_embeddings.shape}')
        slide_embedding = wsi_embeddings.mean(dim=0)  # Average over all patches
        wsi_label = labels[0].item()  # Assuming all labels in the batch are the same
        all_embeddings.append(slide_embedding.numpy())
        all_labels.append(wsi_label)
        # print (f'Batch {batch_idx} done')
        # print(f'shape of all embeddings: {len(all_embeddings)}')

    # Stack the embeddings and labels
    asset_dict = {
        "embeddings": np.stack(all_embeddings).astype(np.float32),  # [num_slides, embedding_dim]
        "labels": np.array(all_labels),
    }

    return asset_dict

In [22]:
start = time.time()
train_features = extract_embeddings_with_chunks(model, preprocess_resnet, train_dataloader)
# convert these to torch
train_feats = torch.Tensor(train_features['embeddings'])
train_labels = torch.Tensor(train_features['labels']).type(torch.long)

elapsed = time.time() - start
print(f'Took {elapsed:.03f} seconds')
print(f'Train features shape {train_feats.shape} and Labels shape is {train_labels.shape}')

100%|██████████| 7/7 [02:12<00:00, 18.92s/it]


Took 132.856 seconds
Train features shape torch.Size([7, 512]) and Labels shape is torch.Size([7])


### Save Features

In [23]:
# now i wwanted to save these extracted features and nam each file with the slide name that can be get from train_dataloader.dataset.slides against each feature in train_feats
# i will use the same for val_feats and test_feats
# save the extracted features

# Define directories for train, validation, and test splits
train_dir = r"E:\Aamir Gulzar\dataset\\baseline_new_avg_features"

def save_features(features, dataloader, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    slides = dataloader.dataset.slides
    for i, slide in enumerate(slides):
        slide_feats = features[i]
        slide_name = slide.split('.')[0]
        # save as torch .pt file 
        save_path = os.path.join(save_dir, f'{slide_name}.pt')
        torch.save(slide_feats, save_path)
        # print(f'Saved features for slide {slide_name} to {save_path}')

save_features(train_feats, train_dataloader, train_dir)