In [12]:
import torch
import torchvision
import os
import sys
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from tqdm import tqdm
# print(torch.version)
# print(torch.version.cuda)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
from conch.open_clip_custom import create_model_from_pretrained, get_tokenizer, tokenize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
# configs
BATCH_SIZE = 1 # load each slide all tiles sequentially 
K_FOLDS_PATH = r"E:\KSA Project\\dataset\\splits\kfolds_IDARS.csv"
DATA_PATH = r"E:\\KSA Project\\dataset\\testing\\Patches"
FEATURES_SAVE_DIR = r"E:\\KSA Project\\dataset\\testing\\conch_features"
# torch.tensor([1.2, 3.4]).device

### Load the model "create_model_from_pretrained"
By default, the model preprocessor uses 448 x 448 as the input size. To specify a different image size (e.g. 336 x 336), use the **force_img_size** argument.

You can specify a cuda device by using the **device** argument, or manually move the model to a device later using **model.to(device)**.

In [3]:
model_cfg = 'conch_ViT-B-16'
checkpoint_path = './checkpoints/pytorch_model.bin'
model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path, device=device)
# model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path, force_img_size=224, device='cuda:2')
_ = model.eval()

### DataLoaders

In [None]:
from dataloader import PatchLoader, SlideBatchSampler
mode = 1 # for sequentially data/patches loading we will use mode =1 and mode= 2 for random loading.
transform = transforms.Compose([
    transforms.FiveCrop(224),  # this is a list of 5 crops
    Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))  # convert to tensor and stack
])

def create_dataloader(label_file, data_path, transform, num_samples, mode):
    dataset = PatchLoader(label_file=label_file, data_path=data_path, transform=transform, num_samples=num_samples, mode=mode)
    batch_sampler = SlideBatchSampler(dataset.ntiles)
    dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0, pin_memory=False)
    return dataloader

# Create DataLoaders
data_loader = create_dataloader(label_file=K_FOLDS_PATH,data_path=DATA_PATH,
                                     transform=transform,num_samples=None,mode=mode)
print(f"Length of data_loader: {len(data_loader)}")

### DataLoaders Testing

In [7]:
for batch_idx, (images, labels) in enumerate(data_loader):
    print(f"Batch {batch_idx+1}")
    print(f"Images shape: {images.shape}")
    print(f"Labels: {labels}")
    if batch_idx == 2:  # Only print a few batches to check if it's working
        break

Batch 1
Images shape: torch.Size([5, 3, 224, 224])
Labels: tensor([0, 0, 0, 0, 0])
Batch 2
Images shape: torch.Size([5, 3, 224, 224])
Labels: tensor([0, 0, 0, 0, 0, 0])
Batch 3
Images shape: torch.Size([5, 3, 224, 224])
Labels: tensor([1, 1, 1])


### ROI Feature Extraction on FiveCrop Patches Using Averaging Approach at WSI Level
Extract patches of individual patch five crops then average and at the end average features of all patches of each WSI 

In [7]:
@torch.no_grad()
def extract_embeddings_with_chunks(model, preprocess, dataloader, chunk_size=1000):
    all_embeddings = []
    all_labels = []
    device = next(model.parameters()).device
    # print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (images, labels) in tqdm(enumerate(dataloader), total=len(dataloader)):
        batch_size, num_crops, channels, height, width = images.shape
        images = images.view(batch_size * num_crops, channels, height, width)
        images = images.permute(0, 2, 3, 1).cpu().numpy()  # Convert to (H, W, C) format
        images = (images * 255).astype(np.uint8)  # Convert to uint8

        # Preprocess images
        images = torch.stack([preprocess(Image.fromarray(image)) for image in images])
        images = images.to(device)

        wsi_embeddings = []
        num_patches = images.size(0)
        num_chunks = (num_patches + chunk_size - 1) // chunk_size

        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min(start_idx + chunk_size, num_patches)
            chunk = images[start_idx:end_idx]
            with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
                # Extract embeddings for the chunk
                chunk_embeddings = model.encode_image(chunk).detach().cpu()
                chunk_embeddings = chunk_embeddings.view(-1, num_crops, chunk_embeddings.size(-1))  # Reshape to [chunk_size, num_crops, embedding_dim]
                chunk_embeddings = chunk_embeddings.mean(dim=1)  # Average over five-crop features
                wsi_embeddings.append(chunk_embeddings)
        # Concatenate embeddings for all chunks in this WSI
        wsi_embeddings = torch.cat(wsi_embeddings, dim=0)  # [num_patches, embedding_dim]
        # print(f'WSI embeddings shape: {wsi_embeddings.shape}')
        slide_embedding = wsi_embeddings.mean(dim=0)  # Average over all patches
        wsi_label = labels[0].item()  # Assuming all labels in the batch are the same
        all_embeddings.append(slide_embedding.numpy())
        all_labels.append(wsi_label)
        # print (f'Batch {batch_idx} done')
        # print(f'shape of all embeddings: {len(all_embeddings)}')

    # Stack the embeddings and labels
    asset_dict = {
        "embeddings": np.stack(all_embeddings).astype(np.float32),  # [num_slides, embedding_dim]
        "labels": np.array(all_labels),
    }

    return asset_dict

In [None]:
start = time.time()
features = extract_embeddings_with_chunks(model, preprocess, data_loader)

# convert these to torch
feats = torch.Tensor(features['embeddings'])
labels = torch.Tensor(features['labels']).type(torch.long)

elapsed = time.time() - start
print(f'Took {elapsed:.03f} seconds')
print(f'Train features shape {feats.shape} and Labels shape is {labels.shape}')

### Save Features

In [None]:
def save_features(features, dataloader, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    slides = dataloader.dataset.slides
    for i, slide in enumerate(slides):
        slide_feats = features[i]
        slide_name = slide.split('.')[0]
        # save as torch .pt file 
        save_path = os.path.join(save_dir, f'{slide_name}.pt')
        torch.save(slide_feats, save_path)
        # print(f'Saved features for slide {slide_name} to {save_path}')

save_features(feats, data_loader, FEATURES_SAVE_DIR)