In [None]:
from os import makedirs
from os.path import join, basename, splitext
from statistics import mode, mean
from collections import Counter
import random
import glob
import json
import pickle
import datetime
import sys
import json
import yaml

import numpy as np
from numpy.linalg import norm
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import openslide
from matplotlib import pyplot as plt
import h5py
from PIL import Image
from tqdm.notebook import tqdm

sys.path.append("../../")
from model import ccl_model

# Functions

In [None]:
def cosine_sim(a, b):
    return np.dot(a, b)/(norm(a) * norm(b))

In [None]:
def calculate_weights(site):
    if site == "organ":
        factor = 30
        # Count the number of slide in each diagnosis (organ)
        latent_all = join(PATCHES_DIR, "*", "*", "patches", "*")
        type_of_organ = list(sites_diagnoses_dict.keys())
        total_slide = {k: 0 for k in type_of_organ}
        for latent_path in glob.glob(latent_all):
            anatomic_site = latent_path.split("/")[-4]
            total_slide[anatomic_site] += 1
    else:
        factor = 10
        # Count the number of slide in each site (organ)
        latent_all = join(PATCHES_DIR, site, "*", "patches", "*")
        type_of_diagnosis = sites_diagnoses_dict[site]
        total_slide = {k: 0 for k in type_of_diagnosis}
        for latent_path in glob.glob(latent_all):
            diagnosis = latent_path.split("/")[-3]
            total_slide[diagnosis] += 1
    
    # Using the inverse count as a weight for each diagnosis
    sum_inv = 0
    for v in total_slide.values():
        sum_inv += (1./v)

    # Set a parameter k  to make the weight sum to k (k = 10, here)
    norm_fact = factor / sum_inv
    weight = {k: norm_fact * 1./v for k, v in total_slide.items()}
    return weight

In [None]:
def wsi_query(mosaics, test_mosaics, metadata, site, weight, cosine_threshold, results_dir, temp_results_dir):
    Bags = {}    # Dictionary to store each Bag for each query WSI
    Entropies = {}    # Dictionary to store entropies for each patch in each Bag for each query WSI
    Etas = {}    # Dictionary to store eta thresholds for each patch in each Bag for each query WSI
    Results = {}    # Dictionary to store top-N similar WSIs to query WSI
    for fname in test_mosaics.index.unique():
        WSI = test_mosaics.loc[fname]["features"].tolist()
        k = len(WSI)
        Bag = {}
        Entropy = {}
        for patch_idx, patch_feature in enumerate(WSI):
            # Retreiving similar patches (creating Bag)
            if site == "organ":
                bag = [(idx, cosine_sim(patch_feature, row["features"])) for idx, row in mosaics.iterrows() if cosine_sim(patch_feature, row["features"]) >= cosine_threshold]
            else:
                site_mosaics = mosaics.loc[list(metadata.loc[mosaics.loc[:, "file_name"], "primary_site"].apply(lambda x: sites_dict[x]) == site)].copy()
                bag = [(idx, cosine_sim(patch_feature, row["features"])) for idx, row in site_mosaics.iterrows() if cosine_sim(patch_feature, row["features"]) >= cosine_threshold]
            Bag[patch_idx] = sorted(bag, key=lambda x: x[1], reverse=True)
            t = len(Bag[patch_idx])

            # Calculating entropy for each query patch in the Bag
            entropy = 0
            if site == "organ":
                u = set([sites_dict[metadata.loc[mosaics.loc[idx, "file_name"], "primary_site"]] for (idx, _) in Bag[patch_idx]])
                for organ in u:
                    num, denum = 0, 0
                    for (idx, sim) in Bag[patch_idx]:
                        bag_organ = sites_dict[metadata.loc[mosaics.loc[idx, "file_name"], "primary_site"]]
                        num += ((organ==bag_organ) * 1) * ((sim + 1) / 2) * weight[bag_organ]
                        denum += ((sim + 1) / 2) * weight[bag_organ]
                    p = num / denum
                    entropy -= p * np.log(p)
            else:
                u = set([diagnoses_dict[metadata.loc[site_mosaics.loc[idx, "file_name"], "project_name"]] for (idx, _) in Bag[patch_idx]])
                for diagnosis in u:
                    num, denum = 0, 0
                    for (idx, sim) in Bag[patch_idx]:
                        bag_diagnosis = diagnoses_dict[metadata.loc[site_mosaics.loc[idx, "file_name"], "project_name"]]
                        num += ((diagnosis==bag_diagnosis) * 1) * ((sim + 1) / 2) * weight[bag_diagnosis]
                        denum += ((sim + 1) / 2) * weight[bag_diagnosis]
                    p = num / denum
                    entropy -= p * np.log(p)
            Entropy[patch_idx] = entropy
            
        # Sorting Bag members in terms of descending entropy
        Bag = dict(sorted(Bag.items(), key=lambda x: Entropy[x[0]], reverse=True))

        # Calculating eta threshold for each query patch in the Bag
        eta_threshold = 0
        for patch_idx in range(len(WSI)):
            eta = np.mean([x[1] for x in Bag[patch_idx][:5]]) if len(Bag[patch_idx]) else 0
            # eta = 0 if np.isnan(eta) else eta
            eta_threshold += eta 
        eta_threshold = eta_threshold / k

        # Removing query patches in the Bag with small eta (eta < eta_threshold) 
        ids = []
        for idx, bag in Bag.items():
            eta = np.mean([x[1] for x in bag[:5]]) if len(bag) else 0
            # eta = 0 if np.isnan(eta) else eta
            if eta < eta_threshold:
                ids.append(idx)
        for idx in ids:
            del Bag[idx]

        # Majority voting for retrieving the results
        WSIRet = {}
        for idx, bag in Bag.items():
            if site == "organ":
                matches = [sites_dict[metadata.loc[mosaics.loc[b[0], "file_name"], "primary_site"]] for b in bag[:5]]
                slides = [mosaics.loc[b[0], "slide_path"] for b in bag[:5]]
            else:
                matches = [diagnoses_dict[metadata.loc[site_mosaics.loc[b[0], "file_name"], "project_name"]] for b in bag[:5]]
                slides = [site_mosaics.loc[b[0], "slide_path"] for b in bag[:5]]
            sims = [b[1] for b in bag[:5]]
            # Using slide path as the key
            slide_path = slides[matches.index(mode(matches))]
            if slide_path not in WSIRet:
                WSIRet[slide_path] = (slide_path, sims[matches.index(mode(matches))], mean(sims))
        WSIRet = list(WSIRet.values())

        with open(join(temp_results_dir, f"{splitext(fname)[0]}_bag.pkl"), "wb") as f:
            pickle.dump(Bag, f)
        with open(join(temp_results_dir, f"{splitext(fname)[0]}_entropy.pkl"), "wb") as f:
            pickle.dump(Entropy, f)
        with open(join(temp_results_dir, f"{splitext(fname)[0]}_WSIRet.pkl"), "wb") as f:
            pickle.dump(WSIRet, f)

        Bags[fname] = Bag
        Entropies[fname] = Entropy
        Etas[fname] = eta
        Results[fname] = WSIRet

    with open(join(results_dir, f"Bags.pkl"), "wb") as f:
        pickle.dump(Bags, f)
    with open(join(results_dir, f"Entropies.pkl"), "wb") as f:
        pickle.dump(Entropies, f)
    with open(join(results_dir, f"Etas.pkl"), "wb") as f:
        pickle.dump(Etas, f)
    with open(join(results_dir, f"Results.pkl"), "wb") as f:
        pickle.dump(Results, f)
        
    return Results, Bags, Entropies, Etas

In [None]:
def show_results(test_slide, ret_final, experiment, site_retrieval):
    if site_retrieval:
        save_dir = join("../../TEST_DATA_RESULTS", experiment, "search_result_images", "wsi_site_retrieval", splitext(test_slide)[0])
    else:
        save_dir = join("../../TEST_DATA_RESULTS", experiment, "search_result_images", "wsi_vertical", splitext(test_slide)[0])
    makedirs(save_dir, exist_ok=True)
    
    query_path = join(TEST_SLIDES_DIR, test_slide)
    query_slide = openslide.open_slide(query_path)
    query_thumbnail = query_slide.get_thumbnail((300, 300))
    query_slide.close()

    fig = plt.figure(figsize=(20, 10))  # adjust as necessary

    plt.subplot(1, len(ret_final) + 2, 1)
    plt.imshow(query_thumbnail)
    plt.axis('off')  # to hide the x and y axis
    plt.title(f"Query\n{splitext(test_slide)[0]}")

    # Plot black line
    plt.subplot(1, len(ret_final) + 2, 2)
    plt.plot([0, 0], [0, 1], color='black', transform=plt.gca().transAxes, linewidth=2.0)
    plt.axis('off')

    links = []
    for i, result in enumerate(ret_final, 2):
        path, sim, mean_sim = result
        file_name = basename(path)
        site = sites_dict[metadata.loc[file_name, "primary_site"]]
        diagnosis = diagnoses_dict[metadata.loc[file_name, "project_name"]]
        slide = openslide.open_slide(path)
        thumbnail = slide.get_thumbnail((300, 300))
        slide.close()

        plt.subplot(1, len(ret_final) + 2, i)
        plt.imshow(thumbnail)
        plt.axis('off')  # to hide the x and y axis
        plt.title(f'Result {i-2}: {site} - {diagnosis}')  # add caption
        plt.title(f'Result {i-1}: {site} - {diagnosis}\nTop Mean Similarity: {mean_sim:.2f}')

        # returning GDC link to slides
        links.append(VIEW_URL + metadata.loc[file_name, "id"])

        # saving to file
        thumbnail.save(join(save_dir, f"result_{i-1}.png"))

    plt.tight_layout()
    plt.show()

    fig.savefig(join(save_dir, "all.png"), bbox_inches='tight')
    fig.savefig(join(save_dir, "all.eps"), format='eps', bbox_inches='tight')
    
    return links

In [None]:
def get_features(slide, patch):
    class roi_dataset(Dataset):
        def __init__(self, slide, patch, transforms):
            super().__init__()
            self.patch = patch
            self.slide = slide
            self.transforms = transforms

        def __len__(self):
            return 1

        def __getitem__(self, idx):
            path = join(TEST_PATCHES_DIR, splitext(self.slide)[0], self.patch)
            patch_region = Image.open(path)            
            patch_region = self.transforms(patch_region)
            return patch_region
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    trnsforms = transforms.Compose([
        transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    
    dataset = roi_dataset(slide, patch, trnsforms)
    database_loader = DataLoader(dataset, batch_size=1024, shuffle=False)

    ccl = ccl_model(checkpoint_path="../../checkpoints/best_ckpt.pth").to(device)
    ccl.eval()
    
    with torch.no_grad():
        for batch in database_loader:
            batch = batch.to(device)
            features = ccl(batch)

    return features.cpu().numpy()

In [None]:
def patch_query(test_slide, patch, mosaics, metadata, site, cosine_threshold):
    _, coord1, coord2 = splitext(patch)[0].split("_")
    patch_feature = get_features(test_slide, patch)
    
    site_mosaics = mosaics.loc[list(metadata.loc[mosaics.loc[:, "file_name"], "primary_site"].apply(lambda x: sites_dict[x]) == site)].copy()
    
    bag = [(idx, cosine_sim(patch_feature, row["features"]), row["patch_level"], row["patch_size"], row["coord1"], row["coord2"]) for idx, row in site_mosaics.iterrows() if cosine_sim(patch_feature, row["features"]) >= cosine_threshold]
    Bag = sorted(bag, key=lambda x: x[1], reverse=True)
    
    slides = [site_mosaics.loc[b[0], "slide_path"] for b in Bag]
    meta = [b[1:] for b in Bag]
    Results = [(slide, sim, level, patch_size, coord1, coord2) for slide, (sim, level, patch_size, coord1, coord2) in zip(slides, meta)]
    return Results, Bag

In [None]:
def show_results_patch(test_slide, patch, results, experiment):
    save_dir = join("../../TEST_DATA_RESULTS", experiment, "search_result_images", "patch_retrieval", splitext(test_slide)[0], splitext(patch)[0])
    makedirs(save_dir, exist_ok=True)
    
    query_path = join(TEST_PATCHES_DIR, splitext(test_slide)[0], patch)
    query_thumbnail = Image.open(query_path)
    
    fig = plt.figure(figsize=(20, 10), dpi=1200)  # adjust as necessary

    plt.subplot(1, len(results) + 2, 1)
    plt.imshow(np.array(query_thumbnail))
    plt.axis('off')  # to hide the x and y axis
    plt.title(f"Query\n{test_slide}")

    # Plot black line
    plt.subplot(1, len(results) + 2, 2)
    plt.plot([0, 0], [0, 1], color='black', transform=plt.gca().transAxes, linewidth=2.0)
    plt.axis('off')

    links = []
    for i, result in enumerate(results, 2):
        slide_path, sim, level, patch_size, coord1, coord2 = result
        file_name = basename(slide_path)
        slide = openslide.OpenSlide(slide_path)
        region = slide.read_region((coord1, coord2), level, (patch_size, patch_size)).convert("RGB")
        slide.close()

        plt.subplot(1, len(results) + 2, i)
        plt.imshow(np.array(region))
        plt.axis('off')  # to hide the x and y axis
        site = sites_dict[metadata.loc[file_name, "primary_site"]]
        diagnosis = diagnoses_dict[metadata.loc[file_name, "project_name"]]
        plt.title(f'Result {i-1}: {site} - {diagnosis}\nCosine Similarity: {float(sim): .2f}')  # add caption

        # returning GDC link to slides
        links.append(VIEW_URL + metadata.loc[file_name, "id"])

        # saving to file
        region.save(join(save_dir, f"result_{i-1}.png"))

    plt.tight_layout()
    plt.show()
    
    fig.savefig(join(save_dir, "all.png"),  dpi='figure', bbox_inches='tight')
    fig.savefig(join(save_dir, "all.eps"), format='eps', dpi='figure', bbox_inches='tight')
    fig.savefig(join(save_dir, "all.pdf"), format='pdf', dpi='figure', bbox_inches='tight')
    
    return links

# Inputs

In [None]:
diagnoses_dict = {
    "Brain Lower Grade Glioma": "LGG",
    "Glioblastoma Multiforme": "GBM",
    "Breast Invasive Carcinoma": "BRCA",
    "Lung Adenocarcinoma": "LUAD",
    "Lung Squamous Cell Carcinoma": "LUSC",
    "Colon Adenocarcinoma": "COAD",
    "Liver Hepatocellular Carcinoma": "LIHC",
    "Cholangiocarcinoma": "CHOL",
}

sites_dict = {
    "Brain": "brain",
    "Breast": "breast",
    "Bronchus and lung": "lung",
    "Colon": "colon",
    "Liver and intrahepatic bile ducts": "liver",
}

sites_diagnoses_dict = {
    "brain": ["LGG", "GBM"],
    "breast": ["BRCA"],
    "lung": ["LUAD", "LUSC"],
    "colon": ["COAD"],
    "liver": ["LIHC", "CHOL"],
}

In [None]:
VIEW_URL = "https://portal.gdc.cancer.gov/files/"
PATCHES_DIR = "../../FEATURES/DATABASE"
TEST_SLIDES_DIR = "/raid/nejm_ai/TEST_DATA/GBM_MICROSCOPE_UPENN"
TEST_PATCHES_DIR = "../../FEATURES/TEST_DATA/GBM_MICROSCOPE_UPENN/visualized_patches"
RESULTS_DIR = "../../FEATURES/TEST_DATA/GBM_MICROSCOPE_UPENN/results"
query_slides_path = "../../FEATURES/TEST_DATA/GBM_MICROSCOPE_UPENN/query_slides.yaml"

experiment = "GBM_MICROSCOPE_UPENN"

In [None]:
with open(query_slides_path, 'r') as f:
    query_slides = yaml.safe_load(f)

In [None]:
metadata_path = "../../FEATURES/DATABASE/sampled_metadata.csv"
mosaics_path = "../../FEATURES/DATABASE/mosaics.h5"
test_mosaics_path = "../../FEATURES/TEST_DATA/GBM_MICROSCOPE_UPENN/mosaics.json"

In [None]:
metadata = pd.read_csv(metadata_path)
metadata = metadata.set_index('file_name')

In [None]:
mosaics = pd.read_hdf(mosaics_path, 'df')
test_mosaics = pd.read_json(test_mosaics_path)
test_mosaics["features"] = test_mosaics["features"].apply(lambda lst: torch.tensor(lst))
test_mosaics["file_name"] = test_mosaics.apply(lambda row: basename(row["slide_path"]), axis=1)
test_mosaics = test_mosaics.set_index(['file_name'], inplace=False)

# WSI Retrieval

## WSI Organ Search

In [None]:
results_dir = join(RESULTS_DIR, "organ")
temp_results_dir = join(results_dir, "temp")

In [None]:
with open(join(results_dir, "Results.pkl"), "rb") as file:
    Results = pickle.load(file)

In [None]:
N = 5    # N for top-N results
for test_slide, ret_final in Results.items():
    links = show_results(test_slide, ret_final[:N], experiment=experiment, site_retrieval=True)
    print(links)

## WSI Sub-Type Search

In [None]:
for query, site in query_slides.items():
    results_dir = join(RESULTS_DIR, f"{splitext(query)[0]}_{site}")
    with open(join(results_dir, "Results.pkl"), "rb") as file:
        Results = pickle.load(file)

    N = 5    # N for top-N results
    for test_slide, ret_final in Results.items():
        links = show_results(test_slide, ret_final[:N], experiment=experiment, site_retrieval=False)
        print(links)