In [1]:
import torch
import torchvision
import os
import sys
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from tqdm import tqdm
# print(torch.version)
# print(torch.version.cuda)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
from conch.open_clip_custom import create_model_from_pretrained, get_tokenizer, tokenize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# configs
BATCH_SIZE = 1 # load each slide all tiles sequentially 
K_FOLDS_PATH = r"E:\KSA Project\dataset\paip_data\labels\paip_few_samples.csv"
DATA_PATH = r"E:\KSA Project\dataset\paip_data\Patches"
FEATURES_SAVE_DIR = r"E:\KSA Project\dataset\paip_data\CONCH_FiveCrop_Features"


### Load the model "create_model_from_pretrained"
By default, the model preprocessor uses 448 x 448 as the input size. To specify a different image size (e.g. 336 x 336), use the **force_img_size** argument.

You can specify a cuda device by using the **device** argument, or manually move the model to a device later using **model.to(device)**.

In [2]:
model_cfg = 'conch_ViT-B-16'
checkpoint_path = './checkpoints/pytorch_model.bin'
model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path, device=device)
# model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path, force_img_size=224, device='cuda:2')
_ = model.eval()

### DataLoaders

In [4]:
sys.path.append("..")
from dataloader import PatchLoader, SlideBatchSampler
mode = 1 # for sequentially data/patches loading we will use mode =1 and mode= 2 for random loading.

transform = transforms.Compose([
    transforms.FiveCrop(224),  # this is a list of 5 crops
    Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))  # convert to tensor and stack
])
def create_dataloader(label_file, data_path, transform, num_samples, mode):
    # Create the dataset
    dataset = PatchLoader(label_file=label_file, data_path=data_path, transform=transform, num_samples=num_samples, mode=mode)
    # Ensure sequential data loading by disabling shuffle
    batch_sampler = SlideBatchSampler(dataset.ntiles)
    dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0, pin_memory=False)
    return dataloader

# Create DataLoaders
data_loader = create_dataloader(label_file=K_FOLDS_PATH,data_path=DATA_PATH,
                                     transform=transform,num_samples=None,mode=mode)
print(f"Length of data_loader: {len(data_loader)}")

Number of Slides: 4
Number of tiles: 5440
Length of data_loader: 4


In [22]:
# for batch_idx, (images, labels) in enumerate(data_loader):
#     print(f"Batch {batch_idx+1}")
#     print(f"Images shape: {images.shape}")
#     print(f"Labels: {labels}")
#     if batch_idx == 5:  # Only print a few batches to check if it's working
#         break

### Embed images 
The **.encode_image()** method encodes a batch of images into a batch of image embeddings. Note that this function applies the contrastive learning projection head to the image and performs l2-normalization before returning the embedding, which is used for computing the similarity scores such as between images and texts. 

### Five Crop Feature Extraction and Saving

In [23]:
from conch.open_clip_custom import get_tokenizer, tokenize

@torch.no_grad()
def extract_embeddings_patch_by_patch(model, preprocess, dataloader, save_dir):
    """
    Extract and save embeddings for each WSI, patch by patch, without averaging the five crops.
    Args:
    - model: The model used to extract embeddings.
    - preprocess: Preprocessing function to apply to the images.
    - dataloader: Dataloader providing WSI patches and labels.
    - save_dir: Directory where the extracted embeddings will be saved.
    
    Returns:
    - None: The function saves the extracted embeddings to disk.
    """ 
    device = next(model.parameters()).device
    print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (images, labels) in tqdm(enumerate(dataloader), total=len(dataloader)):
        wsi_name = dataloader.dataset.slides[batch_idx]
        # make a new directory for each WSI
        save_dir_wsi = os.path.join(save_dir, f'{wsi_name}')
        # Check if the WSI directory already exists, skip if processed
        if os.path.exists(save_dir_wsi):
            print(f"WSI {batch_idx+1} {wsi_name} already processed. Skipping...")
            continue
        os.makedirs(save_dir_wsi, exist_ok=True)
        batch_indices = dataloader.batch_sampler.indices[batch_idx]        
        for patch_idx, dataset_idx in enumerate(batch_indices):  # Loop through indices of patches for current WSI
            patch_name = os.path.splitext(os.path.basename(dataloader.dataset.tiles[dataset_idx]))[0]
            image = images[patch_idx]
            label = labels[patch_idx]
            # Reshape image to combine batch and fivecrop dimensions
            num_crops, channels, height, width = image.shape
            # Convert image to numpy and ensure it is in the correct format
            image = image.permute(0, 2, 3, 1).cpu().numpy()
            image = (image * 255).astype(np.uint8)
            # Preprocess image
            image = torch.stack([preprocess(Image.fromarray(im)) for im in image])
            # Move image to the same device as the model
            image = image.to(device)
            # Extract embeddings
            with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
                gpu_inf_mode_embed = model.encode_image(image).detach().cpu()
            # Save embeddings to disk
            save_path = os.path.join(save_dir_wsi, f'{patch_name}.pt')
            torch.save(gpu_inf_mode_embed, save_path)
            # also save in simple text format in text file
            # save_path_txt = os.path.join(save_dir_wsi, f'{wsi_name}_{i}.txt')
            # np.savetxt(save_path_txt, embeddings.cpu().numpy())
            # print the saved path
            # print(f"Embeddings saved to {save_path_txt}")
        

In [24]:
extract_embeddings_patch_by_patch(model, preprocess, data_loader, FEATURES_SAVE_DIR)

The size of input dataloader is 31


100%|██████████| 31/31 [24:51<00:00, 48.11s/it]  
