In [1]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import pickle
import openslide as ops
from collections import Counter
from sklearn.cluster import KMeans

In [2]:
def cluster_embeddings(embeddings, n_clusters):
    all_embeddings = np.concatenate([data['embedding_vector'] for data in embeddings], axis=0)
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(all_embeddings)
    labels = kmeans.labels_
    for i, data in enumerate(embeddings):
        data["label"] = labels[i]
    return embeddings

In [3]:
def plot_patches(num_patches_per_cluster, dataset, title):
    unique_labels = set([data['label'] for data in dataset])
    num_labels = len(unique_labels)

    selected_data = []
    for label in unique_labels:
        data = [data for data in dataset if data['label'] == label]
        random.shuffle(data)
        selected_data.extend(data[:num_patches_per_cluster])

    print("Number of selected patches: ", len(selected_data))

    num_labels = len(unique_labels)
    fig, axs = plt.subplots(num_labels, num_patches_per_cluster, figsize=(20, 8), dpi=200)
    fig.suptitle(title, fontsize=16)

    for i, data in enumerate(selected_data):
        title = f"Label: {data['label']}\n {data['slide_name']}_{data['x']}_{data['y']}"
        axs[i // num_patches_per_cluster, i % num_patches_per_cluster].set_title(title).set_size(8)
        slide_path = os.path.join("./wsi", f"{data['slide_name']}")
        region = (data['x'], data['y'])
        size = data["patch_size"]
        resize = data["resize"]
        patch_image = ops.open_slide(slide_path).read_region(region, 0, size).convert("RGB").resize(resize)
        axs[i // num_patches_per_cluster, i % num_patches_per_cluster].imshow(patch_image)
        axs[i // num_patches_per_cluster, i % num_patches_per_cluster] 

    plt.tight_layout()
    plt.show()

    labels = [data['label'] for data in dataset]
    counter = Counter(labels)
    print(counter)
    print("-" * 50)
    

In [4]:
def plot_selected_patches(embeddings, num_patches, title):
    fig, axs = plt.subplots(1, num_patches, figsize=(20, 3), dpi=200)
    fig.suptitle(title, fontsize=16)
    random.shuffle(embeddings)

    for i, data in enumerate(embeddings[:num_patches]):
        title = f"Label: {data['label']}\n {data['slide_name']}_{data['x']}_{data['y']}"
        axs[i].set_title(title).set_size(8)
        slide_path = os.path.join("./wsi", f"{data['slide_name']}")
        region = (data['x'], data['y'])
        size = data["patch_size"]
        resize = data["resize"]
        patch_image = ops.open_slide(slide_path).read_region(region, 0, size).convert("RGB").resize(resize)
        axs[i].imshow(patch_image)
        
    plt.tight_layout()
    plt.show()

In [None]:
NUM_CLUSTERS = 5
embed_dataset = os.listdir('./embeddings')

if not os.path.exists('./clustered_embeddings'):
    os.makedirs('./clustered_embeddings')

for embeddings_name in embed_dataset:
    dataset = pickle.load(open(f'./embeddings/{embeddings_name}', 'rb'))
    all_embeddings = np.concatenate([data['embedding_vector'] for data in dataset], axis=0)
    print("Embeddings: ", embeddings_name)
    print("Shape of embeddings: ", all_embeddings.shape)

    clustered_embeddings = cluster_embeddings(dataset, NUM_CLUSTERS)
    pickle.dump(clustered_embeddings, open(f"./clustered_embeddings/clustered_{embeddings_name}", "wb"))

In [None]:
for embeddings_name in embed_dataset:
    dataset = pickle.load(open(f'./clustered_embeddings/clustered_{embeddings_name}', 'rb'))
    plot_patches(num_patches_per_cluster=10, dataset=dataset, title=embeddings_name)

In [7]:
if not os.path.exists('./selected_embeddings'):
    os.makedirs('./selected_embeddings')

def select_patches(dataset, label_id, embeddings_name):
    selected_data = [data for data in dataset if data['label'] == label_id]
    pickle.dump(selected_data, open(f"./selected_embeddings/selected_{embeddings_name}", "wb")) 

In [None]:
label_ids = [[2,3], [0,2]]
for i,embeddings_name in enumerate(embed_dataset):
    dataset = pickle.load(open(f'./clustered_embeddings/clustered_{embeddings_name}', 'rb'))
    for label in label_ids[i]:
        select_patches(dataset, label, embeddings_name)
        print(f"Selected patches for {embeddings_name} saved successfully.")
        print("-" * 50)

In [None]:
selected_dataset = os.listdir('./selected_embeddings')

for selected_embeddings_name in selected_dataset:
    dataset = pickle.load(open(f'./selected_embeddings/{selected_embeddings_name}', 'rb'))
    print("Length of selected dataset: ", len(dataset))
    plot_selected_patches(dataset, num_patches=10, title=selected_embeddings_name)

In [10]:
merged_dataset = []

for selected_embeddings_name in selected_dataset:
    dataset = pickle.load(open(f'./selected_embeddings/{selected_embeddings_name}', 'rb'))
    merged_dataset.extend(dataset)

if not os.path.exists('./merged_embeddings'):
    os.makedirs('./merged_embeddings')

pickle.dump(merged_dataset, open(f"./merged_embeddings/merged_dataset.pkl", "wb"))

In [None]:
print("Merged dataset saved successfully.")
print("Merged dataset size: ", len(merged_dataset))