In [1]:
import argparse
import time
import os
from os import makedirs
from os.path import join, basename, splitext
import pickle
import glob
import operator
import copy
import math
from collections import Counter, defaultdict
import requests
import json
import io
import datetime
import sys
import yaml

import h5py
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import openslide
from matplotlib import pyplot as plt

# Functions

In [2]:
def unpickle_object(file_path):
    with open(file_path, 'rb') as file:
        return pickle.load(file)


def query(results, site, topK_mMV):
    query_results = []
    for test_slide in results.keys():
        test_slide_result = results[test_slide]['results']
        
        # Filter out complete failure case (i.e.,
        # All mosaics fail to retrieve a patch that meet the criteria)
        ttlen = 0
        for tt in test_slide_result:
            ttlen += len(tt)
        if ttlen == 0:
            continue

        bag_result = []
        bag_summary = []
        len_info = []
        label_count_summary = {}
        weight = calculate_weights(site)
        for idx, bag in enumerate(test_slide_result):
            if site == "organ":
                ent, label_cnt, dist = Uncertainty_Cal(bag, weight, is_organ=True)
            else:
                ent, label_cnt, dist = Uncertainty_Cal(bag, weight, is_organ=False)

            if ent is not None:
                label_count_summary[idx] = label_cnt
                bag_summary.append((idx, ent, dist, len(bag)))
                len_info.append(len(bag))

        bag_summary_dirty = copy.deepcopy(bag_summary)
        bag_summary, hamming_thrsh = Clean(len_info, bag_summary)
        bag_removed = Filtered_BY_Prediction(bag_summary, label_count_summary)

        # Process to calculate the final ret slide
        ret_final = []
        visited = {}
        for b in bag_summary:
            bag_index = b[0]
            uncertainty = b[1]
            res = results[test_slide]['results'][bag_index]
            for r in res:
                if uncertainty == 0:
                    if r['slide_name'] not in visited:
                        if site == "organ":
                            ret_final.append((r['slide_name'], r['hamming_dist'], r['site'], uncertainty, bag_index))
                        else:
                            ret_final.append((r['slide_name'], r['hamming_dist'], r['diagnosis'], uncertainty, bag_index))
                        visited[r['slide_name']] = 1
                else:
                    if (r['hamming_dist'] <= hamming_thrsh) and (r['slide_name'] not in visited):
                        if site == "organ":
                            ret_final.append((r['slide_name'], r['hamming_dist'], r['site'], uncertainty, bag_index))
                        else:
                            ret_final.append((r['slide_name'], r['hamming_dist'], r['diagnosis'], uncertainty, bag_index))
                        visited[r['slide_name']] = 1

        ret_final_tmp = [(e[1], e[2], e[3], e[-1]) for e in sorted(ret_final, key=lambda x: (x[3], x[1]))
                         if e[-1] not in bag_removed]
        ret_final = [(e[0], e[1], e[2]) for e in sorted(ret_final, key=lambda x: (x[3], x[1]))
                     if e[-1] not in bag_removed][0:topK_mMV]

        query_results.append((test_slide, ret_final))
    return query_results

def calculate_weights(site):
    if site == "organ":
        factor = 30
        # Count the number of slide in each diagnosis (organ)
        latent_all = join(DATA_DIR, "PATCHES", "*", "*", "*", "patches", "*")
        type_of_organ = [basename(e) for e in glob.glob(join(DATA_DIR, "PATCHES", "*"))]
        total_slide = {k: 0 for k in type_of_organ}
        for latent_path in glob.glob(latent_all):
            anatomic_site = latent_path.split("/")[-5]
            total_slide[anatomic_site] += 1
    else:
        factor = 10
        # Count the number of slide in each site (organ)
        latent_all = join(DATA_DIR, "PATCHES", site, "*", "*", "patches", "*")
        type_of_diagnosis = [basename(e) for e in glob.glob(join(DATA_DIR, "PATCHES", site, "*"))]
        total_slide = {k: 0 for k in type_of_diagnosis}
        for latent_path in glob.glob(latent_all):
            diagnosis = latent_path.split("/")[-4]
            total_slide[diagnosis] += 1
    
    # Using the inverse count as a weight for each diagnosis
    sum_inv = 0
    for v in total_slide.values():
        sum_inv += (1./v)

    # Set a parameter k  to make the weight sum to k (k = 10, here)
    norm_fact = factor / sum_inv
    weight = {k: norm_fact * 1./v for k, v in total_slide.items()}
    return weight

def Uncertainty_Cal(bag, weight, is_organ=False):
    """
    Implementation of Weighted-Uncertainty-Cal in the paper.
    Input:
        bag (list): A list of dictionary which contain the searhc results for each mosaic
    Output:
        ent (float): The entropy of the mosaic retrieval results
        label_count (dict): The diagnois and the corresponding weight for each mosaic
        hamming_dist (list): A list of hamming distance between the input mosaic and the result
    """
    if len(bag) >= 1:
        label = []
        hamming_dist = []
        label_count = defaultdict(float)
        for bres in bag:
            if is_organ:
                label.append(bres['site'])
            else:
                label.append(bres['diagnosis'])
            hamming_dist.append(bres['hamming_dist'])

        # Counting the diagnoiss by weigted count
        # If the count is less than 1, round to 1
        for lb_idx, lb in enumerate(label):
            label_count[lb] += (1. / (lb_idx + 1)) * weight[lb]
        for k, v in label_count.items():
            if v < 1.0:
                v = 1.0
            else:
                label_count[k] = v

        # Normalizing the count to [0,1] for entropy calculation
        total = 0
        ent = 0
        for v in label_count.values():
            total += v
        for k in label_count.keys():
            label_count[k] = label_count[k] / total
        for v in label_count.values():
            ent += (-v * np.log2(v))
        return ent, label_count, hamming_dist
    else:
        return None, None, None

def Clean(len_info, bag_summary):
    """
    Implementation of Clean in the paper
    Input:
        len_info (list): The length of retrieval results for each mosaic
        bag_summary (list): A list that contains the positional index of mosaic,
        entropy, the hamming distance list, and the length of retrieval results
    Output:
        bag_summary (list): The same format as input one but without low quality result
        (i.e, result with large hamming distance)
        top5_hamming_distance (float): The mean of average hamming distance in top 5
        retrival results of all mosaics
    """
    LOW_FREQ_THRSH = 3
    LOW_PRECENT_THRSH = 5
    HIGH_PERCENT_THRSH = 95
    len_info = [b[-1] for b in bag_summary]
    if len(set(len_info)) <= LOW_FREQ_THRSH:
        pass
    else:
        bag_summary = [b for b in bag_summary if b[-1]
                       > np.percentile(len_info, LOW_PRECENT_THRSH)
                       and b[-1] < np.percentile(len_info, HIGH_PERCENT_THRSH)]

    # Remove the mosaic if its top5 mean hammign distance is bigger than average
    top5_hamming_dist = np.mean([np.mean(b[2][0:5]) for b in bag_summary])

    bag_summary = sorted(bag_summary, key=lambda x: (x[1]))  # sort by certainty
    bag_summary = [b for b in bag_summary if np.mean(b[2][0:5]) <= top5_hamming_dist]
    return bag_summary, top5_hamming_dist

def Filtered_BY_Prediction(bag_summary, label_count_summary):
    """
    Implementation of Filtered_By_Prediction in the paper
    Input:
        bag_summary (list): The same as the output from Clean
        label_count_summary (dict): The dictionary storing the diagnosis occurrence 
        of the retrieval result in each mosaic
    Output:
        bag_removed: The index (positional) of moaic that should not be considered 
        among the top5
    """
    voting_board = defaultdict(float)
    for b in bag_summary[0:5]:
        bag_index = b[0]
        for k, v in label_count_summary[bag_index].items():
            voting_board[k] += v
    final_vote_candidates = sorted(voting_board.items(), key=lambda x: -x[1])
    fv_pointer = 0
    while True:
        final_vote = final_vote_candidates[fv_pointer][0]
        bag_removed = {}
        for b in bag_summary[0:5]:
            bag_index = b[0]
            max_vote = max(label_count_summary[bag_index].items(), key=operator.itemgetter(1))[0]
            if max_vote != final_vote:
                bag_removed[bag_index] = 1
        if len(bag_removed) != len(bag_summary[0:5]):
            break
        else:
            fv_pointer += 1
    return bag_removed

# Inputs

In [3]:
diagnoses_dict = {
    "Brain Lower Grade Glioma": "LGG",
    "Glioblastoma Multiforme": "GBM",
    "Breast Invasive Carcinoma": "BRCA",
    "Lung Adenocarcinoma": "LUAD",
    "Lung Squamous Cell Carcinoma": "LUSC",
    "Colon Adenocarcinoma": "COAD",
    "Liver Hepatocellular Carcinoma": "LIHC",
    "Cholangiocarcinoma": "CHOL",
}

sites_dict = {
    "Brain": "brain",
    "Breast": "breast",
    "Bronchus and lung": "lung",
    "Colon": "colon",
    "Liver and intrahepatic bile ducts": "liver",
}

In [4]:
DATA_DIR = "FEATURES/DATABASE/"

DATASETS = ["brain", "breast", "breast", "colon", "liver", "lung"]
EXPERIMENTS = ["UCLA", "READER_STUDY", "BRCA_HER2", "BRCA_TRASTUZUMAB", "ABLATION_BRCA_TRASTUZUMAB", "GBM_MICROSCOPE_CPTAC", "GBM_MICROSCOPE_UPENN"]
EXTENSIONS = ["svs", "svs", "svs", "svs", "svs", "svs", "ndpi"]

metadata = pd.read_csv("FEATURES/DATABASE/sampled_metadata.csv")
metadata = metadata.set_index('file_name')

# Site Retrieval

In [6]:
dataset = "organ"

for experiment, extension in zip(EXPERIMENTS, EXTENSIONS):

    k = 10
    
    with open(f"FEATURES/TEST_DATA/{experiment}/query_slides.yaml", 'r') as f:
        QUERY_SLIDES = yaml.safe_load(f)
    with open(f"FEATURES/TEST_DATA/{experiment}/query_subtypes.yaml", 'r') as f:
        QUERY_SUBTYPES = yaml.safe_load(f)
    
    save_dir = join("TEST_DATA_RESULTS", experiment)
    makedirs(save_dir, exist_ok=True)
    
    results_path = join("FEATURES/TEST_DATA/", experiment, "Results", dataset, "results.pkl")
    results = unpickle_object(results_path)
    query_results = query(results, dataset, k)
    
    records = []
    for returned in query_results:
        query_name, result = returned
        
        temp = []
        if experiment == "UCLA":
            query_name = f"{query_name}.{extension}"
            query_site = QUERY_SLIDES[query_name]
            query_diagnosis = QUERY_SUBTYPES[query_name]
        else:
            query_name = f"{query_name}.{extension}"
            query_site = QUERY_SLIDES[query_name]
            query_diagnosis = QUERY_SUBTYPES[query_name]
    
        temp.extend([query_name, query_site, query_diagnosis])
    
        for r in result:
            result_name = f"{r[0]}.svs"
            result_site = sites_dict[metadata.loc[result_name, "primary_site"]]
            result_diagnosis = diagnoses_dict[metadata.loc[result_name, "project_name"]]
            result_distance = r[1]
            
            temp.extend([result_name, result_site, result_diagnosis, result_distance])
    
        records.append(temp)
    
    columns = ["query_name", "query_site", "query_diagnosis"]
    [columns.extend([f"ret_{i}_name", f"ret_{i}_site", f"ret_{i}_diagnosis", f"ret_{i}_dist"]) for i in range(1, k + 1)]
    
    for record in records:
        if len(record) < len(columns):
            record.extend([None] * (len(columns) - len(record)))
            
    df = pd.DataFrame.from_records(records, columns=columns)
    save_path = join(save_dir, f"site.csv")
    df.to_csv(save_path, index=False)

# Sub Type Retrieval

In [7]:
for experiment, extension in zip(EXPERIMENTS, EXTENSIONS):
    k = 5
    
    with open(f"FEATURES/TEST_DATA/{experiment}/query_slides.yaml", 'r') as f:
        QUERY_SLIDES = yaml.safe_load(f)
    with open(f"FEATURES/TEST_DATA/{experiment}/query_subtypes.yaml", 'r') as f:
        QUERY_SUBTYPES = yaml.safe_load(f)

    save_dir = join("TEST_DATA_RESULTS", experiment)
    makedirs(save_dir, exist_ok=True)

    records = []
    for slide, site in QUERY_SLIDES.items():
        print(f"processing {site}_{slide} started ...")
        results_path = join("FEATURES/TEST_DATA/", experiment, "Results", site, "results.pkl")
        results = unpickle_object(results_path)
        query_results = query(results, site, k)

        temp = []
        for query_name, result in query_results:
            query_name = f"{query_name}.{extension}"
            
            if query_name == slide:
                query_site = QUERY_SLIDES[query_name]
                query_diagnosis = QUERY_SUBTYPES[query_name]
                
                temp.extend([query_name, query_site, query_diagnosis])

                for r in result:
                    result_name = f"{r[0]}.svs"
                    result_site = sites_dict[metadata.loc[result_name, "primary_site"]]
                    result_diagnosis = diagnoses_dict[metadata.loc[result_name, "project_name"]]
                    result_distance = r[1]
                    
                    temp.extend([result_name, result_site, result_diagnosis, result_distance])

        records.append(temp)

    columns = ["query_name", "query_site", "query_diagnosis"]
    [columns.extend([f"ret_{i}_name", f"ret_{i}_site", f"ret_{i}_diagnosis", f"ret_{i}_dist"]) for i in range(1, k + 1)]
    
    for record in records:
        if len(record) < len(columns):
            record.extend([None] * (len(columns) - len(record)))
            
    df = pd.DataFrame.from_records(records, columns=columns)
    save_path = join(save_dir, f"sub_type.csv")
    df.to_csv(save_path, index=False)

processing brain_7316UP-19.ndpi started ...
processing brain_7316UP-99.ndpi started ...
processing brain_7316UP-213.ndpi started ...
processing brain_7316UP-1302.ndpi started ...
processing brain_7316UP-426.ndpi started ...
processing brain_7316UP-435.ndpi started ...
processing brain_7316UP-743.ndpi started ...
processing brain_7316UP-1206.ndpi started ...
processing brain_7316UP-613.ndpi started ...
processing brain_7316UP-2726.ndpi started ...
processing brain_7316UP-1405.ndpi started ...
processing brain_7316UP-393.ndpi started ...
processing brain_7316UP-3301.ndpi started ...
processing brain_7316UP-2058.ndpi started ...
processing brain_7316UP-498.ndpi started ...
processing brain_7316UP-961.ndpi started ...
processing brain_7316UP-895.ndpi started ...
processing brain_7316UP-1883.ndpi started ...
processing brain_7316UP-302.ndpi started ...
processing brain_7316UP-1135.ndpi started ...
processing brain_7316UP-1108.ndpi started ...
processing brain_7316UP-243.ndpi started ...
pro