### Selection of Tissue Region Top patches and their Caption Ranking using Conch

In [1]:
import os
import numpy as np
import pandas as pd
import torch
from typing import List, Tuple
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
# set max PIL image size.02
Image.MAX_IMAGE_PIXELS = 1000000000
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans
from torchvision import transforms
from conch.open_clip_custom import create_model_from_pretrained, get_tokenizer, tokenize
import warnings
warnings.filterwarnings("ignore")

In [5]:
# read csv file of WSI
wsi_df = pd.read_csv(r"E:\KSA Project\dataset\paip_data\labels\paip_reviewed_slides.csv")
# set the WSI_Id column and select first 10 values and convert it into list
wsi_id_list = wsi_df["WSI_Id"].values[:4].tolist()
feature_dir = r"E:\KSA Project\dataset\paip_data\CONCH_FiveCrop_Features"  # Feature directory
image_dir = r"E:\KSA Project\dataset\paip_data\patches"  # Patch images directory
output_dir = r"E:\KSA Project\dataset\paip_data\output\prompts11_Reviewed_Patches"  # CSV save path
# create output directory if not exists
os.makedirs(output_dir, exist_ok=True)

#### Check the minimum and maximum no of patches

In [None]:
# print(wsi_df["label"].value_counts())
# # Define the path
# path = image_dir
# # Get all the directories in the specified path
# directories = [os.path.join(path, d) for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
# # Initialize a list to store file counts
# file_counts = []
# # Loop through each directory and count the files
# for dir in directories:
#     file_count = len([f for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f))])
#     file_counts.append(file_count)
# print(f"Minimum File Count: {min(file_counts)}")
# print(f"Maximum File Count: {max(file_counts)}")
# print(f"Average File Count: { sum(file_counts) / len(file_counts):.2f}")


In [None]:
# Load model and tokenizer for captions
model, preprocess = create_model_from_pretrained(
    model_cfg='conch_ViT-B-16', checkpoint_path='./checkpoints/pytorch_model.bin'
)
_ = model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

tokenizer = get_tokenizer()
prompts = [
    "Adipose",
    "Debris",
    "Lymphocytes",
    "Plasma cells",
    "Mucin",
    "Smooth Muscle",
    "Normal Mucosa",
    "Stroma",
    "Connective tissue",
    "Adenoma",
    "Tumor"
]

# Encode prompts properly
with torch.inference_mode():
    tokenized_prompts = tokenize(texts=prompts, tokenizer=tokenizer).to(device)
    text_embeddings = model.encode_text(tokenized_prompts)  # Ensure it gives [7, 512] shape

### TopPatches Clusters

In [None]:

def select_top_patches_from_clusters(patch_array, patch_filenames, num_clusters=2, num_patches_per_cluster=1):
    """
    Performs clustering and selects top representative patches from **both clusters**.
    Args:
        patch_array (np.array): Array of patch feature vectors.
        patch_filenames (list): List of patch filenames corresponding to the feature vectors.
        num_clusters (int): Number of clusters.
        num_patches_per_cluster (int): Number of patches to pick **per cluster**.

    Returns:
        selected_patch_files (list): Filenames of selected patches.
        selected_patch_features (list): Feature vectors of selected patches.
    """
    # Perform clustering
    clustering_model = KMeans(n_clusters=num_clusters, random_state=42)
    clustering_model.fit(patch_array)
    cluster_labels = clustering_model.labels_
    cluster_centroids = clustering_model.cluster_centers_

    selected_patch_files = []
    selected_patch_features = []

    for cluster_idx in range(num_clusters):
        cluster_indices = np.where(cluster_labels == cluster_idx)[0]
        if len(cluster_indices) == 0:
            continue  # Skip empty clusters

        # Compute distances from centroid
        cluster_patches = patch_array[cluster_indices]
        cluster_patch_files = [patch_filenames[i] for i in cluster_indices]
        distances = cdist(cluster_patches, [cluster_centroids[cluster_idx]], metric='euclidean').flatten()

        # Sort by proximity to centroid
        sorted_indices = np.argsort(distances)
        selected_indices = sorted_indices[:num_patches_per_cluster]  # Pick **top patches** per cluster

        # Store selected patch filenames & features
        selected_patch_files.extend([cluster_patch_files[i] for i in selected_indices])
        selected_patch_features.extend([patch_array[cluster_indices[i]] for i in selected_indices])

    return selected_patch_files, selected_patch_features


def display_top_patches_with_captions_and_save_csv(
    wsi_id_list, feature_dir, image_dir, wsi_df, output_dir, num_clusters=2
):
    """
    Displays the top selected patches for given WSIs with ranked caption annotations.
    Saves the results in a CSV file.

    Args:
        wsi_id_list (list): List of WSI Identifiers.
        feature_dir (str): Directory containing feature .pt files.
        image_dir (str): Directory containing patch images.
        wsi_df (pd.DataFrame): DataFrame containing WSI_Id and corresponding label.
        output_dir (str): Path to save the results.
        num_clusters (int): Number of clusters.

    Returns:
        None (Displays images and saves CSV)
    """
    results = []

    for wsi_id in wsi_id_list:
        print(f"Processing WSI: {wsi_id}")

        # Load WSI label
        wsi_label = wsi_df.loc[wsi_df["WSI_Id"] == wsi_id, "label"].values
        if len(wsi_label) > 0:
            wsi_label = wsi_label[0]
        else:
            print(f"⚠️ No label found for {wsi_id}, skipping...")
            continue

        # Adjust number of selected patches based on label
        if wsi_label == "MSIH":
            num_patches_per_cluster = 3  # Pick more patches for MSI
        else:
            num_patches_per_cluster = 3  # Pick fewer patches for nonMSI

        # Load patch feature vectors
        feature_dir_path = os.path.join(feature_dir, wsi_id)
        feature_files = [f for f in os.listdir(feature_dir_path) if f.endswith('.pt')]
        patch_features = []
        for feature_file in feature_files:
            feature_path = os.path.join(feature_dir_path, feature_file)
            feature_data = torch.load(feature_path)
            if feature_data.ndim == 2 and feature_data.shape[0] == 5:  # FiveCrop averaging
                feature_data = feature_data.mean(dim=0)
            patch_features.append(feature_data)

        patch_array = torch.stack(patch_features).cpu().numpy()
        if patch_array.ndim == 3:
            patch_array = patch_array.reshape(patch_array.shape[0], -1)

        # Get patch filenames
        wsi_image_dir = os.path.join(image_dir, wsi_id)
        patch_filenames = sorted([f for f in os.listdir(wsi_image_dir) if f.endswith('.png')])

        if len(patch_filenames) != len(patch_array):
            print(f"⚠️ Mismatch: {len(patch_filenames)} images vs {len(patch_array)} feature vectors. Skipping WSI.")
            continue

        # Select patches from **both clusters**
        selected_patch_files, selected_patch_features = select_top_patches_from_clusters(
            patch_array, patch_filenames, num_clusters, num_patches_per_cluster
        )

        # Compute captions and save results
        ranked_captions = {}
        for patch_file, patch_feature in zip(selected_patch_files, selected_patch_features):
            patch_tensor = torch.tensor(patch_feature).unsqueeze(0).to(device)  # Convert feature to tensor
            with torch.inference_mode():
                sim_scores = (patch_tensor @ text_embeddings.T).squeeze(0)  # Compute similarity with prompts

            ranked_scores, ranked_idx = torch.sort(sim_scores, descending=True)
            best_caption = prompts[ranked_idx[0]]  # Best-matching caption
            ranked_captions[patch_file] = best_caption

            # Store results in CSV format
            row = [wsi_id, wsi_label, patch_file] + [sim_scores[idx].item() for idx in range(len(prompts))] + [
                best_caption
            ]
            results.append(row)
        # continue
        # Display selected patches with captions
        total_patches = len(selected_patch_files)
        fig, axes = plt.subplots(2, num_patches_per_cluster, figsize=(15, 7))  # 2-row grid layout
        plt.subplots_adjust(hspace=0.3)  # Adjust spacing between rows

        for i, patch_file in enumerate(selected_patch_files):
            row = 0 if i < num_patches_per_cluster else 1  # First 4 patches in row 1, next 4 in row 2
            col = i % num_patches_per_cluster  # 4 columns

            patch_path = os.path.join(image_dir, wsi_id, patch_file)
            img = Image.open(patch_path)

            # Display image
            axes[row, col].imshow(img)
            axes[row, col].axis('off')

            # Add **Patch Name** on Top
            axes[row, col].set_title(patch_file, fontsize=6, color='black', pad=10)

            # Add **Best Caption** Below (Truncate long captions)
            caption = ranked_captions.get(patch_file, "No Caption")
            truncated_caption = (caption[:50] + "...") if len(caption) > 50 else caption  # Shorten long captions
            axes[row, col].text(0.5, -0.1, truncated_caption, ha='center', va='top', fontsize=8, color='blue', wrap=True,
                                transform=axes[row, col].transAxes)

        plt.suptitle(f"WSI: {wsi_id} | Top {num_patches_per_cluster} Patches Per Cluster", fontsize=10, fontweight='bold')

        # Save as high-quality PNG
        os.makedirs(output_dir, exist_ok=True)
        save_path = os.path.join(output_dir, f"{wsi_id}.png")
        plt.savefig(save_path, format='png', dpi=300, bbox_inches='tight')
        print(f"✅ Saved visualization: {save_path}")

        plt.show()

    # Save the results in a CSV file
    columns = ["WSI_Id", "Label", "Patch_ID"] + [f"Prompt_{i+1}" for i in range(len(prompts))] + ["Best_Caption"]
    df = pd.DataFrame(results, columns=columns)
    output_csv = os.path.join(output_dir, "top_patch_results_top1_2cluster.csv")
    df.to_csv(output_csv, index=False)
    print(f"✅ Results saved to: {output_csv}")

# Display top 2 patches per cluster for each WSI with captions
display_top_patches_with_captions_and_save_csv(wsi_id_list, feature_dir, image_dir, wsi_df, output_dir, num_clusters=2)


### Select topPatches and Plot FiveCrops 

In [None]:
# =======================================================================
# 1) CLUSTERING HELPER (unchanged, just returns top patches)
# =======================================================================
def select_top_patches_from_clusters(
    patch_array,       # shape [N, D], one vector per patch
    patch_filenames,
    num_clusters=2,
    num_patches_per_cluster=1
):
    """
    Performs clustering and selects top representative patches from each cluster.
    By default, picks the top 1 (closest to centroid) per cluster.
    """
    kmeans = KMeans(n_clusters=num_clusters, random_state=42)
    kmeans.fit(patch_array)
    cluster_labels = kmeans.labels_
    centroids = kmeans.cluster_centers_

    selected_patch_files = []
    selected_patch_indices = []

    for cluster_idx in range(num_clusters):
        cluster_indices = np.where(cluster_labels == cluster_idx)[0]
        if len(cluster_indices) == 0:
            continue

        # Distances of patches in this cluster to the centroid
        cluster_vectors = patch_array[cluster_indices]
        distances = cdist(cluster_vectors, [centroids[cluster_idx]], metric='euclidean').flatten()

        # Sort patches by distance, pick top N
        sorted_indices = np.argsort(distances)
        top_indices = sorted_indices[:num_patches_per_cluster]

        for ti in top_indices:
            patch_global_idx = cluster_indices[ti]
            selected_patch_files.append(patch_filenames[patch_global_idx])
            selected_patch_indices.append(patch_global_idx)

    return selected_patch_files, selected_patch_indices, cluster_labels

# =======================================================================
# 2) MAIN: AVERAGE FOR CLUSTERING, BUT STILL SHOW FIVE-CROP SUBFEATS
# =======================================================================
def display_top_patches_with_captions_and_save_csv(
    wsi_id_list,
    feature_dir,
    image_dir,
    wsi_df,
    output_dir,
    num_clusters=2
):
    """
    For each WSI:
      - Load five-crop features (each patch is [5, D]).
      - Average them to [D] for K-Means clustering.
      - Pick the top patch(es) per cluster.
      - For each selected patch, retrieve the original five-crop sub-features
        to generate sub-crop captions & show them side by side.
      - Save results to CSV.
      - Use a flexible subplot that accommodates any (num_clusters × topPatches).
    """
    os.makedirs(output_dir, exist_ok=True)
    results = []

    # If your five-crop was extracted at size=224:
    five_crop_transform = transforms.FiveCrop(size=224)

    for wsi_id in wsi_id_list:
        print(f"Processing WSI: {wsi_id}")

        # --- Lookup label (optional) ---
        row = wsi_df.loc[wsi_df["WSI_Id"] == wsi_id]
        if len(row) < 1:
            print(f"⚠️ No label found for WSI: {wsi_id}, skipping...")
            continue
        wsi_label = row.iloc[0]["label"]

        # Adjust how many top patches you want
        if wsi_label == "MSIH":
            num_patches_per_cluster = 1
        else:
            num_patches_per_cluster = 1

        # -------------------------------------------------------------------
        # A) LOAD FIVE-CROP FEATURES
        # -------------------------------------------------------------------
        feature_dir_path = os.path.join(feature_dir, wsi_id)
        if not os.path.isdir(feature_dir_path):
            print(f"Feature dir not found for {wsi_id}, skipping.")
            continue

        feature_files = [f for f in os.listdir(feature_dir_path) if f.endswith('.pt')]
        patch_filenames = []
        patch_features_5crop_list = []  # each entry is shape [5, D]

        for ffile in feature_files:
            path = os.path.join(feature_dir_path, ffile)
            feat_5crops = torch.load(path)  # shape [5, D]
            patch_features_5crop_list.append(feat_5crops)
            patch_filenames.append(ffile.replace('.pt', '.png'))

        if len(patch_features_5crop_list) == 0:
            print(f"No patch features for {wsi_id}, skipping.")
            continue

        # -------------------------------------------------------------------
        # B) AVERAGE sub-crops to get [D] per patch (for K-Means)
        # -------------------------------------------------------------------
        # We keep the original 5×D in patch_features_5crop_list for later,
        # but build a second array for clustering
        all_avg = []
        for pf in patch_features_5crop_list:
            avg_feat = pf.mean(dim=0)  # shape [D]
            all_avg.append(avg_feat.cpu().numpy())
        patch_array_for_clustering = np.stack(all_avg, axis=0)  # shape [N, D]

        # -------------------------------------------------------------------
        # C) K-Means, pick top patches
        # -------------------------------------------------------------------
        wsi_image_dir = os.path.join(image_dir, wsi_id)
        if not os.path.isdir(wsi_image_dir):
            print(f"Image dir not found for {wsi_id}, skipping.")
            continue

        selected_files, selected_indices, cluster_labels = select_top_patches_from_clusters(
            patch_array_for_clustering,
            patch_filenames,
            num_clusters=num_clusters,
            num_patches_per_cluster=num_patches_per_cluster
        )

        # We'll gather data for CSV here
        columns = (["WSI_Id", "Label", "Patch_ID", "Crop_Index"] +
                   [f"Prompt_{i+1}" for i in range(len(prompts))] +
                   ["Best_Caption"])
        cluster_rows = []

        # -------------------------------------------------------------------
        # D) Flexible subplot layout
        #    Each selected patch => one row
        #    6 columns => col0 for original patch, col1..5 for sub-crops
        # -------------------------------------------------------------------
        n_rows = len(selected_files)
        n_cols = 6

        fig, axes = plt.subplots(n_rows, n_cols, figsize=(3.5*n_cols, 3*n_rows))
        # If there's only 1 row, axes might be 1D => make it 2D for uniform indexing
        if n_rows == 1:
            axes = [axes]

        for row_idx, (patch_fname, patch_index) in enumerate(zip(selected_files, selected_indices)):

            # The original 5×D sub-features
            five_crop_feats = patch_features_5crop_list[patch_index]

            # Load the original patch image
            patch_path = os.path.join(wsi_image_dir, patch_fname)
            try:
                original_img = Image.open(patch_path).convert("RGB")
            except:
                print(f"Could not open {patch_path}, skipping row.")
                continue

            # Show original in col0
            axes[row_idx][0].imshow(original_img)
            axes[row_idx][0].axis('off')
            axes[row_idx][0].set_title(f"ClusterRow {row_idx}\n{patch_fname}", fontsize=7, color='red')

            # Generate sub-crops exactly as done during feature extraction
            subcrop_imgs = five_crop_transform(original_img)  # tuple of 5 PIL images

            # For each sub-crop, do similarity => caption
            for c_idx in range(5):
                sub_img = subcrop_imgs[c_idx]
                sub_feat = five_crop_feats[c_idx].unsqueeze(0).to(device)  # shape [1, D]

                # Compute similarity
                with torch.inference_mode():
                    sim_scores = (sub_feat @ text_embeddings.T).squeeze(0)
                ranked_scores, ranked_idx = torch.sort(sim_scores, descending=True)
                best_prompt_idx = ranked_idx[0].item()
                best_caption = prompts[best_prompt_idx]

                # Show sub-crop in col c_idx+1
                ax = axes[row_idx][c_idx + 1]
                ax.imshow(sub_img)
                ax.axis('off')
                # Title with crop index + truncated caption
                short_caption = best_caption[:35] + "..." if len(best_caption) > 35 else best_caption
                ax.set_title(
                    f"Crop {c_idx}, Prompt {best_prompt_idx+1}\n{short_caption}",
                    fontsize=7, color='blue'
                )

                # Save row for CSV
                row_data = [
                    wsi_id,                      # WSI_Id
                    wsi_label,                   # Label
                    patch_fname,                 # Patch_ID
                    c_idx                        # Crop index
                ]
                # Add each prompt's sim score
                row_data.extend(sim_scores.cpu().tolist())
                # Add best caption
                row_data.append(best_caption)
                cluster_rows.append(row_data)

        plt.suptitle(f"{wsi_id}: {num_clusters} cluster(s), top{num_patches_per_cluster} patch(es) each", 
                     fontsize=12, fontweight='bold')
        # Save figure
        out_fig = os.path.join(output_dir, f"{wsi_id}_topPatches.png")
        plt.savefig(out_fig, format='png', dpi=500, bbox_inches='tight')
        # plt.show()
        plt.close()
        # Convert cluster_rows to DataFrame, store for final CSV
        if len(cluster_rows) > 0:
            wsi_df_rows = pd.DataFrame(cluster_rows, columns=columns)
            results.append(wsi_df_rows)

    # =====================================================================
    # E) SAVE ALL WSIs' RESULTS TO CSV
    # =====================================================================
    if len(results) == 0:
        print("No results to save.")
        return

    final_df = pd.concat(results, ignore_index=True)
    out_csv = os.path.join(output_dir, "top_patch_fivecrop_results.csv")
    final_df.to_csv(out_csv, index=False)
    print(f"All results saved => {out_csv}")

display_top_patches_with_captions_and_save_csv(
    wsi_id_list=wsi_id_list,
    feature_dir= feature_dir,
    image_dir=  image_dir,
    wsi_df= wsi_df,
    output_dir= output_dir,
    num_clusters=2
)

Processing WSI: training_data_01_MSIH
Feature dir not found for training_data_01_MSIH, skipping.
Processing WSI: training_data_05_MSIH
All results saved => E:\KSA Project\dataset\paip_data\output\prompts11_Reviewed_Patches\top_patch_fivecrop_results.csv


### Patches  with Bar chart code

In [None]:
import os
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans

# =======================================================================
# 1) CLUSTERING HELPER (unchanged, just returns top patches)
# =======================================================================
def select_top_patches_from_clusters(
    patch_array,       # shape [N, D], one vector per patch
    patch_filenames,
    num_clusters=2,
    num_patches_per_cluster=1
):
    """
    Performs clustering and selects top representative patches from each cluster.
    By default, picks the top 1 (closest to centroid) per cluster.
    """
    kmeans = KMeans(n_clusters=num_clusters, random_state=42)
    kmeans.fit(patch_array)
    cluster_labels = kmeans.labels_
    centroids = kmeans.cluster_centers_

    selected_patch_files = []
    selected_patch_indices = []

    for cluster_idx in range(num_clusters):
        cluster_indices = np.where(cluster_labels == cluster_idx)[0]
        if len(cluster_indices) == 0:
            continue

        # Distances of patches in this cluster to the centroid
        cluster_vectors = patch_array[cluster_indices]
        distances = cdist(cluster_vectors, [centroids[cluster_idx]], metric='euclidean').flatten()

        # Sort patches by distance, pick top N
        sorted_indices = np.argsort(distances)
        top_indices = sorted_indices[:num_patches_per_cluster]

        for ti in top_indices:
            patch_global_idx = cluster_indices[ti]
            selected_patch_files.append(patch_filenames[patch_global_idx])
            selected_patch_indices.append(patch_global_idx)

    return selected_patch_files, selected_patch_indices, cluster_labels

# =======================================================================
# 2) MAIN: AVERAGE FOR CLUSTERING, BUT STILL SHOW FIVE-CROP SUBFEATS
# =======================================================================
def display_top_patches_with_captions_and_save_csv(
    wsi_id_list,
    feature_dir,
    image_dir,
    wsi_df,
    output_dir,
    num_clusters=2
):
    """
    For each WSI:
      - Load five-crop features (each patch is [5, D]).
      - Average them to [D] for K-Means clustering.
      - Pick the top patch(es) per cluster.
      - For each selected patch, retrieve the original five-crop sub-features
        to generate sub-crop captions & show them side by side.
      - Make a bar chart (column 6) of each sub-crop's best similarity score.
      - Save results to CSV.
      - Use a flexible subplot that accommodates any (num_clusters × topPatches).
    """
    os.makedirs(output_dir, exist_ok=True)
    results = []

    # If your five-crop was extracted at size=224:
    five_crop_transform = transforms.FiveCrop(size=224)

    for wsi_id in wsi_id_list:
        print(f"Processing WSI: {wsi_id}")

        # --- Lookup label (optional) ---
        row = wsi_df.loc[wsi_df["WSI_Id"] == wsi_id]
        if len(row) < 1:
            print(f"⚠️ No label found for WSI: {wsi_id}, skipping...")
            continue
        wsi_label = row.iloc[0]["label"]

        # Adjust how many top patches you want
        if wsi_label == "MSIH":
            num_patches_per_cluster = 1
        else:
            num_patches_per_cluster = 1
        wsi_folder = os.path.join(output_dir, wsi_id)
        os.makedirs(wsi_folder, exist_ok=True)
        # -------------------------------------------------------------------
        # A) LOAD FIVE-CROP FEATURES
        # -------------------------------------------------------------------
        feature_dir_path = os.path.join(feature_dir, wsi_id)
        if not os.path.isdir(feature_dir_path):
            print(f"Feature dir not found for {wsi_id}, skipping.")
            continue

        feature_files = [f for f in os.listdir(feature_dir_path) if f.endswith('.pt')]
        patch_filenames = []
        patch_features_5crop_list = []  # each entry is shape [5, D]

        for ffile in feature_files:
            path = os.path.join(feature_dir_path, ffile)
            feat_5crops = torch.load(path)  # shape [5, D]
            patch_features_5crop_list.append(feat_5crops)
            patch_filenames.append(ffile.replace('.pt', '.png'))

        if len(patch_features_5crop_list) == 0:
            print(f"No patch features for {wsi_id}, skipping.")
            continue

        # -------------------------------------------------------------------
        # B) AVERAGE sub-crops to get [D] per patch (for K-Means)
        # -------------------------------------------------------------------
        all_avg = []
        for pf in patch_features_5crop_list:
            avg_feat = pf.mean(dim=0)  # shape [D]
            all_avg.append(avg_feat.cpu().numpy())
        patch_array_for_clustering = np.stack(all_avg, axis=0)  # shape [N, D]

        # -------------------------------------------------------------------
        # C) K-Means, pick top patches
        # -------------------------------------------------------------------
        wsi_image_dir = os.path.join(image_dir, wsi_id)
        if not os.path.isdir(wsi_image_dir):
            print(f"Image dir not found for {wsi_id}, skipping.")
            continue

        selected_files, selected_indices, cluster_labels = select_top_patches_from_clusters(
            patch_array_for_clustering,
            patch_filenames,
            num_clusters=num_clusters,
            num_patches_per_cluster=num_patches_per_cluster
        )

        # We'll gather data for CSV here
        columns = (["WSI_Id", "Label", "Patch_ID", "Crop_Index"] +
                   [f"Prompt_{i+1}" for i in range(len(prompts))] +
                   ["Best_Caption"])
        cluster_rows = []

        # -------------------------------------------------------------------
        # D) Flexible subplot layout
        #    Each selected patch => one row
        #    Now 7 columns => col0 for original patch, col1..5 for sub-crops,
        #    col6 is the new bar chart of top similarity scores.
        # -------------------------------------------------------------------
        n_rows = len(selected_files)
        n_cols = 7  # increased from 6 to 7
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(3.5*n_cols, 3*n_rows), gridspec_kw={'wspace': 0.04, 'hspace': 0.09})       
        if n_rows == 1:
            axes = [axes]
        for row_idx, (patch_fname, patch_index) in enumerate(zip(selected_files, selected_indices)):
            # Create a subfolder for this patch
            patch_stem = patch_fname.replace('.png','')  
            patch_folder = os.path.join(wsi_folder, patch_stem)
            os.makedirs(patch_folder, exist_ok=True)
            # The original 5×D sub-features
            five_crop_feats = patch_features_5crop_list[patch_index]

            # Load the original patch image
            patch_path = os.path.join(wsi_image_dir, patch_fname)
            try:
                original_img = Image.open(patch_path).convert("RGB")
            except:
                print(f"Could not open {patch_path}, skipping row.")
                continue

            patch_out_path = os.path.join(patch_folder, patch_fname)
            original_img.save(patch_out_path)
            # Show original in col0
            axes[row_idx][0].imshow(original_img)
            axes[row_idx][0].axis('off')
            axes[row_idx][0].set_title(f"{patch_fname}", fontsize=7, color='red')

            # We'll store each sub-crop's best similarity score for the bar chart
            best_scores = []
            best_captions = []

            # Generate sub-crops exactly as done during feature extraction
            subcrop_imgs = five_crop_transform(original_img)  # tuple of 5 PIL images

            # For each sub-crop, do similarity => caption
            for c_idx in range(5):
                sub_img = subcrop_imgs[c_idx]
                sub_feat = five_crop_feats[c_idx].unsqueeze(0).to(device)  # shape [1, D]
                # save sub-crop image
                crop_fname = f"crop{c_idx}.png"
                crop_path = os.path.join(patch_folder, crop_fname)
                sub_img.save(crop_path)

                # Compute similarity
                with torch.inference_mode():
                    sim_scores = (sub_feat @ text_embeddings.T).squeeze(0)
                ranked_scores, ranked_idx = torch.sort(sim_scores, descending=True)
                best_prompt_idx = ranked_idx[0].item()
                best_caption = prompts[best_prompt_idx]
                best_score = ranked_scores[0].item()  # best sub-crop similarity

                # Save best score in a list
                best_scores.append(best_score)
                best_captions.append(best_caption)
                # Show sub-crop in col c_idx+1
                ax = axes[row_idx][c_idx + 1]
                ax.imshow(sub_img)
                ax.axis('off')
                # Title with crop index + truncated caption
                short_caption = best_caption[:35] + "..." if len(best_caption) > 35 else best_caption
                ax.set_title(f"Crop {c_idx}, {short_caption}",fontsize=7, color='blue')
                # Save row for CSV
                row_data = [
                    wsi_id,                      # WSI_Id
                    wsi_label,                   # Label
                    patch_fname,                 # Patch_ID
                    c_idx                        # Crop index
                ]
                # Add each prompt's sim score
                row_data.extend(sim_scores.cpu().tolist())
                # Add best caption
                row_data.append(best_caption)
                cluster_rows.append(row_data)

            # -------------------------------------------------------------------
            # E) Bar chart in col6 with the best similarity scores for each crop
            # -------------------------------------------------------------------
            ax_bar = axes[row_idx][6]
            ax_bar.bar(range(5), best_scores)
            for i, score in enumerate(best_scores):
                # Split best_captions[i] by whitespace
                words = best_captions[i].split()
                # If there are exactly 2 words, place them on separate lines:
                if len(words) == 2:
                    text_for_bar = f"{words[0]}\n{words[1]}"
                else:
                    text_for_bar = best_captions[i]
                ax_bar.text(i, score,text_for_bar, ha='center', va='bottom', fontsize=6)
            ax_bar.set_xticks(range(5))
            ax_bar.set_xticklabels(best_captions, fontsize=7, rotation=45)
            ax_bar.set_xticklabels([f"Crop{i}" for i in range(5)], fontsize=7, rotation=0)
            ax_bar.set_ylim([0, max(best_scores)*1.1 if best_scores else 1])
            ax_bar.tick_params(axis='y', labelsize=7)

        # plt.suptitle(f"{wsi_id}: {num_clusters} cluster(s), top{num_patches_per_cluster} patch(es) each", 
        #              fontsize=12, fontweight='bold')

        # Save figure
        out_fig = os.path.join(output_dir,wsi_id,f"{wsi_id}_topPatches.png")
        plt.savefig(out_fig, format='png', dpi=500, bbox_inches='tight')
        plt.close()

        # Convert cluster_rows to DataFrame, store for final CSV
        if len(cluster_rows) > 0:
            wsi_df_rows = pd.DataFrame(cluster_rows, columns=columns)
            results.append(wsi_df_rows)

    # =====================================================================
    # E) SAVE ALL WSIs' RESULTS TO CSV
    # =====================================================================
    if len(results) == 0:
        print("No results to save.")
        return

    final_df = pd.concat(results, ignore_index=True)
    # rename the column names and specially instead of prompt number write prompt name/text
    final_df.columns = ["WSI_Id", "Label", "Patch_ID", "Crop_Index"] + prompts + ["Best_Caption"]
    out_csv = os.path.join(output_dir, "top_patch_fivecrop_results.csv")
    final_df.to_csv(out_csv, index=False)
    print(f"All results saved => {out_csv}")


# Example call
display_top_patches_with_captions_and_save_csv(
    wsi_id_list=wsi_id_list,
    feature_dir=feature_dir,
    image_dir=image_dir,
    wsi_df=wsi_df,
    output_dir=output_dir,
    num_clusters=2
)

Processing WSI: training_data_01_MSIH
Feature dir not found for training_data_01_MSIH, skipping.
Processing WSI: training_data_05_MSIH
Processing WSI: training_data_06_MSIH
Feature dir not found for training_data_06_MSIH, skipping.
Processing WSI: training_data_20_MSIH
Feature dir not found for training_data_20_MSIH, skipping.
All results saved => E:\KSA Project\dataset\paip_data\output\prompts11_Reviewed_Patches\top_patch_fivecrop_results.csv


### Process All patches of all WSIs

In [None]:

import gc  # Garbage collection module

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def process_all_patches_with_captions_and_save_csv(
    wsi_id_list,
    feature_dir,
    image_dir,
    wsi_df,
    output_dir
):
    """
    Processes all patches from each WSI without selecting top patches from clusters.
    - Loads five-crop features for all patches.
    - Retrieves original five-crop sub-features.
    - Generates captions for all patches and sub-crops.
    - Saves results to CSV and creates per-patch visualizations.
    """
    os.makedirs(output_dir, exist_ok=True)
    five_crop_transform = transforms.FiveCrop(size=224)

    # Open the CSV file for writing
    out_csv = os.path.join(output_dir, "all_patches_fivecrop_results.csv")
    with open(out_csv, 'w') as csv_file:
        # Write the header
        columns = (["WSI_Id", "Label", "Patch_ID", "Crop_Index"] +
                   [f"Prompt_{i+1}" for i in range(len(prompts))] +
                   ["Best_Caption"])
        csv_file.write(','.join(columns) + '\n')

        for wsi_id in wsi_id_list:
            print(f"Processing WSI: {wsi_id}")

            # --- Lookup label (optional) ---
            row = wsi_df.loc[wsi_df["WSI_Id"] == wsi_id]
            if len(row) < 1:
                print(f"⚠️ No label found for WSI: {wsi_id}, skipping...")
                continue
            wsi_label = row.iloc[0]["label"]

            # Create output directory for this WSI
            wsi_output_dir = os.path.join(output_dir, wsi_id)
            os.makedirs(wsi_output_dir, exist_ok=True)

            # -------------------------------------------------------------------
            # A) LOAD FIVE-CROP FEATURES FOR ALL PATCHES
            # -------------------------------------------------------------------
            feature_dir_path = os.path.join(feature_dir, wsi_id)
            if not os.path.isdir(feature_dir_path):
                print(f"Feature dir not found for {wsi_id}, skipping.")
                continue

            feature_files = [f for f in os.listdir(feature_dir_path) if f.endswith('.pt')]
            patch_filenames = []
            patch_features_5crop_list = []

            for ffile in feature_files:
                path = os.path.join(feature_dir_path, ffile)
                feat_5crops = torch.load(path).to(device)  # Move features to GPU
                patch_features_5crop_list.append(feat_5crops)
                patch_filenames.append(ffile.replace('.pt', '.png'))

            if len(patch_features_5crop_list) == 0:
                print(f"No patch features for {wsi_id}, skipping.")
                continue

            # -------------------------------------------------------------------
            # B) PROCESS EACH PATCH AND SAVE IMMEDIATELY
            # -------------------------------------------------------------------
            wsi_image_dir = os.path.join(image_dir, wsi_id)
            if not os.path.isdir(wsi_image_dir):
                print(f"Image dir not found for {wsi_id}, skipping.")
                continue

            patch_rows = []

            for patch_fname, patch_features_5crop in zip(patch_filenames, patch_features_5crop_list):
                patch_path = os.path.join(wsi_image_dir, patch_fname)
                try:
                    original_img = Image.open(patch_path).convert("RGB")
                except:
                    print(f"Could not open {patch_path}, skipping row.")
                    continue

                # Generate sub-crops
                subcrop_imgs = five_crop_transform(original_img)  # tuple of 5 PIL images

                # Create figure for this patch
                fig, axes = plt.subplots(1, 6, figsize=(18, 3))  # 1 row, 6 columns

                # Show original patch in first column
                axes[0].imshow(original_img)
                axes[0].axis('off')
                axes[0].set_title(f"{patch_fname}", fontsize=9, color='red')

                # For each sub-crop, calculate similarity, caption, and save result
                for c_idx in range(5):
                    sub_img = subcrop_imgs[c_idx]
                    sub_feat = patch_features_5crop[c_idx].unsqueeze(0).to(device)  # shape [1, D]

                    # Compute similarity on GPU
                    with torch.inference_mode():
                        sim_scores = (sub_feat @ text_embeddings.T).squeeze(0)
                    ranked_scores, ranked_idx = torch.sort(sim_scores, descending=True)
                    best_prompt_idx = ranked_idx[0].item()
                    best_caption = prompts[best_prompt_idx]
                    best_score = ranked_scores[0].item()  # Extract best similarity score

                    # Show sub-crop
                    ax = axes[c_idx + 1]
                    ax.imshow(sub_img)
                    ax.axis('off')

                    # Print caption as title
                    short_caption = best_caption[:35] + "..." if len(best_caption) > 35 else best_caption
                    ax.set_title(f"Crop {c_idx}, {short_caption}", fontsize=8, color='blue')

                    # Draw similarity score on the image
                    pil_sub_img = sub_img.copy()
                    draw = ImageDraw.Draw(pil_sub_img)
                    font = ImageFont.load_default(10)
                    draw.text((5, 5), f"{best_score:.2f}", fill="blue", font=font)
                    ax.imshow(pil_sub_img)

                    # Save row for CSV
                    row_data = [
                        wsi_id,                      # WSI_Id
                        wsi_label,                   # Label
                        patch_fname,                 # Patch_ID
                        c_idx                        # Crop index
                    ]
                    # Add each prompt's sim score
                    row_data.extend(sim_scores.cpu().tolist())  # Move scores back to CPU for saving
                    # Add best caption
                    row_data.append(best_caption)
                    patch_rows.append(row_data)

                # Save figure
                patch_output_path = os.path.join(wsi_output_dir, f"{patch_fname.replace('.png', '_five_crops.png')}")
                plt.savefig(patch_output_path, format='png', dpi=500, bbox_inches='tight')
                plt.close()

            # Write the results of this WSI to the CSV file
            if len(patch_rows) > 0:
                for row in patch_rows:
                    csv_file.write(','.join(map(str, row)) + '\n')

            # Free up memory explicitly
            del patch_rows, patch_features_5crop_list, patch_filenames
            torch.cuda.empty_cache()  # Clear GPU memory
            gc.collect()  # Force garbage collection

    print(f"All results saved => {out_csv}")

# Call the updated function
process_all_patches_with_captions_and_save_csv(
    wsi_id_list=wsi_id_list,
    feature_dir=feature_dir,
    image_dir=image_dir,
    wsi_df=wsi_df,
    output_dir=output_dir
)

### Results Analysis

In [7]:
# read file from the output directory top_patch_results_top1_2cluster.csv
results_file = os.path.join(output_dir, "all_patches_fivecrop_results.csv")
# read the file
df = pd.read_csv(results_file)
df.head()
# print the count of MSIH and nonMSIH from the label column
print(df["Label"].value_counts())
# separate the msih and nonmsih data
msih = df[df["Label"] == "MSIH"]
nonmsih = df[df["Label"] == "nonMSIH"]
# now in both count the unique count of best caption column
print(msih["Best_Caption"].nunique())
print(nonmsih["Best_Caption"].nunique())
# now we will count the best caption for each label
print(f'Count of Patches against each promt in MSIH')
print(msih["Best_Caption"].value_counts())
print(f'Count of Patches against each promt in nonMSIH')
print(nonmsih["Best_Caption"].value_counts())

Label
MSIH       60
nonMSIH    45
Name: count, dtype: int64
4
5
Count of Patches against each promt in MSIH
Best_Caption
Adipose          33
Debris           19
Smooth Muscle     7
Mucin             1
Name: count, dtype: int64
Count of Patches against each promt in nonMSIH
Best_Caption
Debris               16
Smooth Muscle        14
Adipose              11
Lymphocytes           2
Connective tissue     2
Name: count, dtype: int64


### Change the format/shape of top patches output csv file from depthwise to row-wise for each WSI-ID

In [None]:
# load the input file from the output_path the name is top3_prompts_2cluster_top1_patch.csv
input_path = os.path.join(output_dir, "top3_prompts_2cluster_top1_patch.csv")
df = pd.read_csv(input_path)

# Function to extract unique captions while preserving their first-occurrence order
def unique_preserve_order(text_series):
    unique_caps = []
    # Iterate over each text in the series
    for text in text_series:
        # Split each cell by semicolon to get individual captions
        parts = text.split(";")
        for part in parts:
            cap = part.strip()
            if cap and cap not in unique_caps:
                unique_caps.append(cap)
    # Join back into a semicolon-separated string
    return "; ".join(unique_caps)

results = []

# Group the data by WSI_Id.
# It is assumed that for each WSI_Id the first 5 rows are the top patch's crops,
# and the next 5 rows are the bottom patch's crops.
for wsi_id, group in df.groupby("WSI_Id"):
    group = group.reset_index(drop=True)
    
    # Extract top patch (rows 0 to 4) and bottom patch (rows 5 to 9)
    top_crops = group.iloc[0:5]["Best_Captions"]
    bottom_crops = group.iloc[5:10]["Best_Captions"]
    
    # Get unique captions, preserving the order of first appearance.
    top_str = unique_preserve_order(top_crops)
    bottom_str = unique_preserve_order(bottom_crops)
    
    # Append the results. Here, the Agree, Missed, and Remarks columns are set as empty.
    results.append({
        "Final review 5 crop": wsi_id,
        "Top": top_str,
        "Agree": "",
        "Missed": "",
        "Remarks": "",
        "Bottom": bottom_str,
        "Agree_bottom": "",
        "Missed_bottom": ""
    })

# Create the final DataFrame
final_df = pd.DataFrame(results)

# Optional: save the output to a CSV file or print it.
output_file_path = os.path.join(output_dir, "reshaped_top3_prompts_2cluster_top1_patch.csv")

final_df.to_csv(output_file_path, index=False)