In [None]:
import os
import pickle
import pandas as pd
from preprocessing import WholeSlideImage, construct_dataset
from clustering_analysis import visualise_clusters, significative_cluster_example, patch_closest_to_centroid, save_patch_example, SLIDES_PATH, ROIS_PATH, PATHOLOGIST_ANNOTATIONS_PATH, PATCHES_CLUSTERING_PATH

## Construct dataset and features from chosen slide

In [None]:
dataset_path = "./dataset.csv"
df, dataset = construct_dataset(dataset_path)

In [None]:
slide_name = "L210794"
features_resnet = dataset[dataset.slide_data[dataset.slide_data.slide_id==slide_name].index[0]][0]

In [None]:
with open("./Results/clustering/kmeans_model_cohort_1.pickle", "rb") as fp:
    kmeans = pickle.load(fp)
with open("./Results/clustering/features_matrix_cohort_1.pickle", "rb") as input_file:
    features_matrix_dict = pickle.load(input_file)

In [None]:
features_resnet_normalized = (features_resnet.numpy()-features_matrix_dict["mean"])/features_matrix_dict["std"]

## Visualize clustering

In [None]:
labels = kmeans.predict(features_resnet_normalized)
img, mask, ROIs = visualise_clusters(slide_name, labels, coords_filtered = None, label = None)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
n_clusters = 8
colors = plt.cm.Set1(np.linspace(0, 1, n_clusters))[:, :-1]
import matplotlib
matplotlib.colors.ListedColormap(colors)

In [None]:
img.crop(ROIs[0][1:]//8)

## Save significative clusters as red shaded patches on whole slide images

In [None]:
sign_clusters = [6]
downscale_lvl = 3
n_clusters = 8
n_patch = 5

In [None]:
for sign_cluster in sign_clusters:
    output_path = f"./Results/clustering/slides_cluster_significant/cluster_{sign_cluster}"
    os.makedirs(output_path, exist_ok=True)
    for slide_name in df[df.cohort == "COHORT 1"].slide_id.unique():
        WSI_object = WholeSlideImage(os.path.join(SLIDES_PATH, f"{slide_name}.mrxs"), pd.read_csv(ROIS_PATH), PATHOLOGIST_ANNOTATIONS_PATH)
        features_resnet = dataset[dataset.slide_data[dataset.slide_data.slide_id==slide_name].index[0]][0]
        features_resnet_normalized = (features_resnet.numpy()-features_matrix_dict["mean"])/features_matrix_dict["std"]
        labels = kmeans.predict(features_resnet_normalized)
        _, mask, ROIs = visualise_clusters(slide_name, labels, label = sign_cluster, downscale_lvl=downscale_lvl)
        for roi_index in range(len(ROIs)):
            significative_cluster_example(WSI_object, mask, roi_index, downscale_lvl).save(os.path.join(output_path, f"{slide_name}_roi{roi_index}.png"))

## Save snapchot of patches closest to each cluster centroids for all slides

In [None]:
save_dir = "./Results/clustering/patch_closest_to_centroids"
os.makedirs(save_dir, exist_ok=True)
for slide_name in df[df.cohort == "COHORT 1"].slide_id.unique():
        WSI_object = WholeSlideImage(os.path.join(SLIDES_PATH, f"{slide_name}.mrxs"), pd.read_csv(ROIS_PATH), PATHOLOGIST_ANNOTATIONS_PATH)
        features_resnet = dataset[dataset.slide_data[dataset.slide_data.slide_id==slide_name].index[0]][0]
        features_resnet_normalized = (features_resnet.numpy()-features_matrix_dict["mean"])/features_matrix_dict["std"]
        labels = kmeans.predict(features_resnet_normalized)
        img_clusters, _, _ = visualise_clusters(slide_name, labels)
        img_clusters.save(os.path.join(save_dir, f"{slide_name}_clusters.png"))
        patchs_closest_imgs, labels_in_slide = patch_closest_to_centroid(labels, WSI_object, kmeans, features_resnet_normalized, n_patch=n_patch)
        save_patch_example(save_dir, slide_name, patchs_closest_imgs, labels_in_slide, n_clusters=n_clusters, n_patch=n_patch)

## Construct UMAP representations

In [None]:
import umap
import h5py
import seaborn as sns

In [None]:
nuclei_csv = pd.read_csv("./Results/nuclei/nuclei_features.csv")

In [None]:
M = []
cohort_1 = df.slide_id[df.cohort == "COHORT 1"]
for slide_id in cohort_1:
    file = h5py.File(os.path.join(PATCHES_CLUSTERING_PATH, 'patches', slide_id + ".h5"), 'r')
    dset = file['coords']
    coords = dset[:]
    slide_idx = df[df.slide_id == slide_id].index[0]
    feats = np.array(dataset[slide_idx][0])
    feats = (feats - features_matrix_dict["mean"]) / features_matrix_dict["std"]
    labels_clustering = kmeans.predict(feats)
    for cluster in np.unique(kmeans.labels_):
        feats_in_cluster = feats[labels_clustering == cluster]
        M.append(feats_in_cluster)
M = np.concatenate(M)

In [None]:
to_sample = int(M.shape[0]*0.05)
np.random.seed(1)
sampled_index = np.random.randint(M.shape[0], size=to_sample)
M_sample = M[sampled_index, :]

In [None]:
mapper = umap.UMAP(min_dist=0.2,n_neighbors=20, random_state=1).fit(M_sample)

In [None]:
sns.set_theme()
plt.figure(figsize=(16,16))
sns.scatterplot(x=mapper.embedding_[:,0],y=mapper.embedding_[:,1], hue=kmeans.predict(M_sample), palette=list(plt.cm.Set1(np.linspace(0, 1, 7+1))[:, :-1]))
plt.legend(title="Cluster id:", prop={'size': 15}, title_fontsize=15)
plt.axis('off')
plt.savefig("./Results/figures_paper/umap_clusters.png")

In [None]:
pre_plot = plt.scatter(x=mapper.embedding_[:,0],y=mapper.embedding_[:,1], c=nuclei_csv.loc[sampled_index, :].dab_max_mean, cmap="inferno")
plt.close()
fig = plt.figure(figsize=(16,16))
plot = sns.scatterplot(x=mapper.embedding_[:,0],y=mapper.embedding_[:,1], hue=nuclei_csv.loc[sampled_index, :].dab_max_mean, palette="inferno", legend=None)
plt.axis('off')
cb = fig.colorbar(pre_plot, ax=plot, cax = fig.add_axes([0.14, 0.74, 0.05, 0.15]),)#0.14, 0.9
cb.ax.tick_params(labelsize=15)
cb.outline.set_color('white')
cb.outline.set_linewidth(2)
cb.ax.set_title("Maximum DAB intensity (optical density):",size=15)
plt.savefig("./Results/figures_paper/umap_dab_intensity.png")

In [None]:
fig = plt.figure(figsize=(16,16))
nuclei_csv_copy = nuclei_csv.copy()
nuclei_csv_copy.loc[nuclei_csv_copy.n_nuclei>3,"n_nuclei"] = 4

ax = sns.scatterplot(x=mapper.embedding_[:,0],y=mapper.embedding_[:,1], hue=nuclei_csv_copy.loc[sampled_index, :].n_nuclei, palette="inferno")
ax.axis("off")
h, l = ax.get_legend_handles_labels()
ax.legend(handles=h, labels=["0","1","2","3","> or = 4"], title="Number of nuclei:", prop={'size': 15}, title_fontsize=15)
plt.savefig("./Results/figures_paper/umap_density.png")