In [None]:
import os
from os import listdir, makedirs
from os.path import join, dirname, basename, splitext
import datetime
import ast
import yaml

import openslide
import pickle
import numpy as np
import pandas as pd
import bitarray
import glob
from bitarray import util as butil
from skimage import io
from matplotlib import pyplot as plt
from PIL import Image

# Functions and Classes

In [None]:
class BoB:
    def __init__(self, barcodes, file_name, site, diagnosis):
        self.barcodes = [bitarray.bitarray(b.tolist()) for b in barcodes]
        self.file_name = file_name
        self.site = site
        self.diagnosis = diagnosis
        
    def select_subset(self, n=3):
        idx = np.arange(len(self.barcodes))
        np.random.shuffle(idx)
        idx = idx[:n]
        return BoB(barcodes=[self.barcodes[i] for i in idx])
    
    def distance(self, bob):
        total_dist = []
        for feat in self.barcodes:
            distances = [butil.count_xor(feat, b) for b in bob.barcodes]
            total_dist.append(np.min(distances))
        retval = np.median(total_dist)
        return retval

In [None]:
def load_mosaic(mosaic_path):
    df = pd.read_hdf(mosaic_path, 'df')
    # Convert back to original forms
    df['loc'] = df['loc'].apply(ast.literal_eval)
    df['wsi_loc'] = df['wsi_loc'].apply(ast.literal_eval)
    df['rgb_histogram'] = df['rgb_histogram'].apply(lambda x: np.array(ast.literal_eval(x)))
    return df.to_dict('records')

## WSI Retrieval

In [None]:
def prepare_features_horizontal(data, test_data, extension):
    features_dict = {}
    for fname, feature in data.items():
        file_name = fname.split("_patch")[0] + ".svs"
        features_dict.setdefault(file_name, []).extend([feature])
    
    test_features_dict = {}
    for fname, feature in test_data.items():
        file_name = f'{fname.split("_patch")[0]}.{extension}'
        test_features_dict.setdefault(file_name, []).extend([feature])

    return features_dict, test_features_dict

In [None]:
def prepare_features_vertical(data, test_data, test_slide_name, site, extension):
    features_dict = {}
    for fname, feature in data.items():
        file_name = fname.split("_patch")[0] + ".svs"
        if sites_dict[metadata.loc[file_name, "primary_site"]] == site:
            features_dict.setdefault(file_name, []).extend([feature])
    
    test_features_dict = {}
    for fname, feature in test_data.items():
        file_name = f'{fname.split("_patch")[0]}.{extension}'
        if file_name == test_slide_name:
            test_features_dict.setdefault(file_name, []).extend([feature])

    return features_dict, test_features_dict

In [None]:
def prepare_bobs(features_dict, test_features_dict, experiment):
    BoBs = []
    for file_name, feature_queue in features_dict.items():
        barcodes = (np.diff(np.array(feature_queue), n=1, axis=1) < 0) * 1
        if file_name in metadata.index:
            site = sites_dict[metadata.loc[file_name, "primary_site"]]
            diagnosis = diagnoses_dict[metadata.loc[file_name, 'project_name']]
        else:
            site = None
            diagnosis = None
        BoBs.append(BoB(barcodes, file_name, site, diagnosis))
    
    test_BoBs = []
    for file_name, feature_queue in test_features_dict.items():
        barcodes = (np.diff(np.array(feature_queue), n=1, axis=1) < 0) * 1
        site = experiment
        diagnosis = experiment
        test_BoBs.append(BoB(barcodes, file_name, site, diagnosis))

    return BoBs, test_BoBs

In [None]:
def horizontal_search(data, test_data, extension, experiment, k):
    features_dict, test_features_dict = prepare_features_horizontal(data, test_data, extension)
    BoBs, test_BoBs = prepare_bobs(features_dict, test_features_dict, experiment)

    distances = {}
    for test_bob in test_BoBs:
        for bob in BoBs:
            distances.setdefault(test_bob, []).append((test_bob.distance(bob), bob))

    results = []
    for key, values in distances.items():
        values.sort(key=lambda x: x[0])
        top_k_results = values[:k]
        results.append([key, top_k_results])

    return results

In [None]:
def vertical_search(data, test_data, test_slide_name, extension, experiment, site, k):
    features_dict, test_features_dict = prepare_features_vertical(data, test_data, test_slide_name, site, extension)
    BoBs, test_BoBs = prepare_bobs(features_dict, test_features_dict, experiment)

    distances = []
    test_bob = test_BoBs[0]
    for bob in BoBs:
        distances.append((test_bob.distance(bob), bob))

    distances.sort(key=lambda x: x[0])
    top_k_results = distances[:k]

    return [test_bob, top_k_results]

In [None]:
def show_results(query, results, experiment, site_retrieval):
    query_path = join("/raid/nejm_ai/TEST_DATA/", experiment, query.file_name)

    if site_retrieval:
        save_dir = join("../../TEST_DATA_RESULTS", experiment, "search_result_images", "wsi_site_retrieval", splitext(query.file_name)[0])
    else:
        save_dir = join("../../TEST_DATA_RESULTS", experiment, "search_result_images", "wsi_vertical", splitext(query.file_name)[0])
    makedirs(save_dir, exist_ok=True)
    
    query_slide = openslide.open_slide(query_path)
    query_thumbnail = query_slide.get_thumbnail((300, 300))
    query_slide.close()

    fig = plt.figure(figsize=(20, 10))  # adjust as necessary

    plt.subplot(1, len(results) + 2, 1)
    plt.imshow(query_thumbnail)
    plt.axis('off')  # to hide the x and y axis
    plt.title(f"Query\n{query.file_name}")

    # Plot black line
    plt.subplot(1, len(results) + 2, 2)
    plt.plot([0, 0], [0, 1], color='black', transform=plt.gca().transAxes, linewidth=2.0)
    plt.axis('off')

    links = []
    for i, result in enumerate(results, 2):
        distance, bob = result

        site = sites_dict[metadata.loc[bob.file_name, "primary_site"]]
        diagnosis = diagnoses_dict[metadata.loc[bob.file_name, "project_name"]]
        
        path = join("/raid/nejm_ai/DATABASE", site, diagnosis, bob.file_name)
        slide = openslide.open_slide(path)
        thumbnail = slide.get_thumbnail((300, 300))
        slide.close()

        plt.subplot(1, len(results) + 2, i)
        plt.imshow(thumbnail)
        plt.axis('off')  # to hide the x and y axis
        plt.title(f'Result {i-1}: {site} - {diagnosis}\nDistance: {distance}')  # add caption

        # returning GDC link to slides
        links.append(VIEW_URL + metadata.loc[bob.file_name, "id"])

        # saving to file
        thumbnail.save(join(save_dir, f"result_{i-1}.png"))

        

    plt.tight_layout()
    plt.show()
    
    fig.savefig(join(save_dir, "all.png"), bbox_inches='tight')
    fig.savefig(join(save_dir, "all.eps"), format='eps', bbox_inches='tight')
    
    return links

## Patch Retrieval

In [None]:
def prepare_features_patch(data, test_data, test_patch_path, site, extension):
    features_dict = {}
    for fname, feature in data.items():
        base_name = fname.split("_patch")[0]
        file_name = base_name + ".svs"
        patch_name = "patch" + fname.split("_patch")[-1]
        if sites_dict[metadata.loc[file_name, "primary_site"]] == site:
            features_dict.setdefault(join(base_name, patch_name), []).extend([feature])
    
    test_features_dict = {}
    for fname, feature in test_data.items():
        base_name = fname.split("_patch")[0]
        file_name = f"{base_name}.{extension}"
        patch_name = "patch" + fname.split("_patch")[-1]
        if file_name == basename(dirname(test_patch_path)) and patch_name == basename(test_patch_path):
            test_features_dict.setdefault(join(base_name, patch_name), []).extend([feature])

    return features_dict, test_features_dict

In [None]:
def prepare_bobs_patch(features_dict, test_features_dict, experiment):
    BoBs = []
    for file_name, feature_queue in features_dict.items():
        barcodes = (np.diff(np.array(feature_queue), n=1, axis=1) < 0) * 1
        slide_name = file_name.split("/")[0] + ".svs"
        if slide_name in metadata.index:
            site = sites_dict[metadata.loc[slide_name, "primary_site"]]
            diagnosis = diagnoses_dict[metadata.loc[slide_name, 'project_name']]
        else:
            site = None
            diagnosis = None
        BoBs.append(BoB(barcodes, file_name, site, diagnosis))
    
    test_BoBs = []
    for file_name, feature_queue in test_features_dict.items():
        barcodes = (np.diff(np.array(feature_queue), n=1, axis=1) < 0) * 1
        site = experiment
        diagnosis = experiment
        test_BoBs.append(BoB(barcodes, file_name, site, diagnosis))

    return BoBs, test_BoBs

In [None]:
def vertical_search_patch(data, test_data, test_patch_path, site, experiment, extension, k):
    features_dict, test_features_dict = prepare_features_patch(data, test_data, test_patch_path, site, extension)
    BoBs, test_BoBs = prepare_bobs_patch(features_dict, test_features_dict, experiment)

    distances = []
    test_bob = test_BoBs[0]
    for bob in BoBs:
        distances.append((test_bob.distance(bob), bob))

    distances.sort(key=lambda x: x[0])
    top_k_results = distances[:k]

    return [test_bob, top_k_results]

In [None]:
def show_results_patch(query, results, experiment):
    query_path = join(TEST_QUERY_PATCHES_DIR, query.file_name)

    save_dir = join("../../TEST_DATA_RESULTS", experiment, "search_result_images", "patch_retrieval", splitext(query.file_name)[0])
    makedirs(save_dir, exist_ok=True)
    
    query_thumbnail = Image.open(query_path)
    
    fig = plt.figure(figsize=(20, 10), dpi=1200)  # adjust as necessary

    plt.subplot(1, len(results) + 2, 1)
    plt.imshow(np.array(query_thumbnail))
    plt.axis('off')  # to hide the x and y axis
    plt.title(f"Query\n{query.file_name}")

    # Plot black line
    plt.subplot(1, len(results) + 2, 2)
    plt.plot([0, 0], [0, 1], color='black', transform=plt.gca().transAxes, linewidth=2.0)
    plt.axis('off')

    links = []
    for i, result in enumerate(results, 2):
        distance, bob = result
        site = bob.site
        diagnosis = bob.diagnosis
        slide_path, patch_str = bob.file_name.split("/")
        x, y = splitext(patch_str)[0].split("_")[1:]
        mosaic_path = join(PATCHES_DIR, site, diagnosis, slide_path, "mosaic.h5")
        mosaic = load_mosaic(mosaic_path)
        patch = [mosaic[i] for i in range(len(mosaic)) if mosaic[i]['loc'] == [int(x), int(y)]][0]
        slide_path = join("/raid/nejm_ai/DATABASE", site, diagnosis, slide_path + ".svs")
        slide = openslide.open_slide(slide_path)
        try:
            objective_power = int(slide.properties['openslide.objective-power'])
        except KeyError:
            objective_power = 20
        h, w = patch['wsi_loc']
        patch_size_20x = int((objective_power/20.)*1000)
        patch_region = slide.read_region((w, h), 0, (patch_size_20x, patch_size_20x))
        if objective_power == 40:
            new_size = (patch_size_20x // 2, patch_size_20x // 2)
            patch_region = patch_region.resize(new_size)

        plt.subplot(1, len(results) + 2, i)
        plt.imshow(patch_region.convert('RGB'))
        plt.axis('off')  # to hide the x and y axis
        plt.title(f'Result {i-1}: {site} - {diagnosis}\nDistance: {distance}')  # add caption

        # returning GDC link to slides
        id = metadata.loc[bob.file_name.split("/")[0] + ".svs", "id"]
        links.append(VIEW_URL + id)

        # saving to file
        patch_region.save(join(save_dir, f"result_{i-1}.png"))

    plt.tight_layout()
    plt.show()
    
    fig.savefig(join(save_dir, "all.png"),  dpi='figure', bbox_inches='tight')
    fig.savefig(join(save_dir, "all.eps"), format='eps', dpi='figure', bbox_inches='tight')
    fig.savefig(join(save_dir, "all.pdf"), format='pdf', dpi='figure', bbox_inches='tight')
    
    return links

## Metrics

In [None]:
def get_raw_metrics(results, experiment, site=False):
    save_dir = join("../../TEST_DATA_RESULTS", experiment, "raw_metrics")
    makedirs(save_dir, exist_ok=True)
    
    records = []
    for result in results:
        query, returned = result
        temp = []
        temp.extend([query.file_name, query_slides[query.file_name], true_diagnosis[query.file_name]])
        for r in returned:
            temp.extend([r[1].file_name, r[1].site, r[1].diagnosis, r[0]])
        records.append(temp)
        
    columns = ["query_name", "query_site", "query_diagnosis"]
    if site:
        [columns.extend([f"ret_{i}_name", f"ret_{i}_site", f"ret_{i}_diagnosis", f"ret_{i}_dist"]) for i in range(1, 11)]
        df = pd.DataFrame.from_records(records, columns=columns)
        df.to_csv(join(save_dir, "site.csv"), index=False)
    else:
        [columns.extend([f"ret_{i}_name", f"ret_{i}_site", f"ret_{i}_diagnosis", f"ret_{i}_dist"]) for i in range(1, 6)]
        df = pd.DataFrame.from_records(records, columns=columns)
        df.to_csv(join(save_dir, "sub_type.csv"), index=False)

# Inputs

In [None]:
with open('../../FEATURES/DATABASE/features.pkl', 'rb') as f:
    data = pickle.load(f)

with open('../../FEATURES/TEST_DATA/BRCA_HER2/features.pkl', 'rb') as f:
    test_data = pickle.load(f)

# Global variables

In [None]:
VIEW_URL = "https://portal.gdc.cancer.gov/files/"
PATCHES_DIR = "../../FEATURES/DATABASE/PATCHES" 
TEST_PATCHES_DIR = "../../FEATURES/TEST_DATA/BRCA_HER2/PATCHES/"
TEST_QUERY_PATCHES_DIR = "../../FEATURES/TEST_DATA/BRCA_HER2/QUERY_PATCHES/"

experiment = "BRCA_HER2"
extension = "svs"

In [None]:
metadata = pd.read_csv("../../FEATURES/DATABASE/sampled_metadata.csv")
metadata = metadata.set_index('file_name')

In [None]:
diagnoses_dict = {
    "Brain Lower Grade Glioma": "LGG",
    "Glioblastoma Multiforme": "GBM",
    "Breast Invasive Carcinoma": "BRCA",
    "Lung Adenocarcinoma": "LUAD",
    "Lung Squamous Cell Carcinoma": "LUSC",
    "Colon Adenocarcinoma": "COAD",
    "Liver Hepatocellular Carcinoma": "LIHC",
    "Cholangiocarcinoma": "CHOL",
}

sites_dict = {
    "Brain": "brain",
    "Breast": "breast",
    "Bronchus and lung": "lung",
    "Colon": "colon",
    "Liver and intrahepatic bile ducts": "liver",
}

sites_diagnoses_dict = {
    "brain": ["LGG", "GBM"],
    "breast": ["BRCA"],
    "lung": ["LUAD", "LUSC"],
    "colon": ["COAD"],
    "liver": ["LIHC", "CHOL"],
}

In [None]:
with open(f"../../FEATURES/TEST_DATA/{experiment}/query_slides.yaml", 'r') as f:
    query_slides = yaml.safe_load(f)

true_diagnosis = dict()
for key in query_slides.keys():
    true_diagnosis[key] = "BRCA"

# WSI Retrieval
## Site Retrieval

In [None]:
# Returning top k in a median of min distance sense
k = 5
site_retrieval = True

results = horizontal_search(data, test_data, extension, experiment, k)
for (key, top_k_results) in results:
    links = show_results(key, top_k_results, experiment, site_retrieval)
    print(links)

## Sub-Type Retrieval

In [None]:
k = 5
site_retrieval = False

for query, site in query_slides.items():
    key, top_k_results = vertical_search(data, test_data, query, extension, experiment, site, k)
    links = show_results(key, top_k_results, experiment, site_retrieval)
    print(links)

# Metrics

In [None]:
k = 10
site_retrieval = True
results = horizontal_search(data, test_data, extension, experiment, k)

get_raw_metrics(results, experiment, site=site_retrieval)

In [None]:
k = 5
site_retrieval = False

results = []
for query, site in query_slides.items():
    key, top_k_results = vertical_search(data, test_data, query, extension, experiment, site, k)
    results.append([key, top_k_results])

get_raw_metrics(results, experiment, site=site_retrieval)