In [19]:
import pickle
import yaml
import os
from os import makedirs
from os.path import join, splitext, dirname

import torch
from torch.utils.data import Dataset
import numpy as np
import pandas as pd

from ssl_encoder_training import PairCenterDataset
from utils.model.base_model import AttenHashEncoder
from utils.evaluate import Evaluator
from utils.retrieval_utils import generate_incidence
from utils.evaluate import hamming_retrieval

# Functions

In [2]:
def fine_tune(raw, model_dir):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    encoder = AttenHashEncoder(feature_in, feature_out, depth)
    model_path = os.path.join(model_dir, 'model_best.pth')
    encoder.load_state_dict(torch.load(model_path))
    encoder = encoder.to(device)

    with torch.no_grad():
        raw = torch.from_numpy(raw).to(device)
        h, w = encoder(raw, no_pooling=True, weight=True)
    return h, w

In [3]:
def hyedge_similarity(inc, alpha, beta):
    sh = generate_sh(inc)
    sv = generate_sv(inc)
    ss = inc + inc.T
    sh, sv, ss = normalize(sh), normalize(sv), normalize(ss)

    simi = sh + alpha * sv + beta * ss
    simi_tops, slide_top_idx = torch.topk(torch.from_numpy(simi), simi.shape[0], dim=1, largest=True)

    return slide_top_idx, simi_tops

def generate_sh(inc):
    return inc.dot(np.transpose(inc))

def generate_sv(inc):
    return np.transpose(inc).dot(inc)

def normalize(mat):
    row_sums = mat.sum(axis=1)
    new_matrix = mat / row_sums[:, np.newaxis]
    return new_matrix

In [4]:
class HyperGraph():
    def __init__(self, experiment):
        self.experiment = experiment
        self.result_dict = {}
        self.weight = None
        

    def reset(self):
        self.__init__()
        

    def add_patches(self, patches, paths):
        assert len(patches.shape) == 3
        assert patches.shape[0] == len(paths)
        for i in range(patches.shape[0]):
            self.add_patch(patches[i], paths[i])

    
    def add_patch(self, patch, path):
        assert len(patch.shape) == 2
        for j in range(patch.shape[0]):
            self.result_dict[path + '@' + str(j)] = patch[j]

    
    def add_weight(self, weight):
        weight = weight.reshape(-1)
        assert len(self.result_dict) == weight.shape[0]
        self.weight = weight

    
    def get_results(self, K=10, alpha=1, beta=1):
        key_list, top_idx, dis_mat = hamming_retrieval(hypergraph.result_dict)
        inc, list_slide_id = generate_incidence(key_list, top_idx, num_cluster, K, hypergraph.weight)
        slide_top_idx, simi_tops =  hyedge_similarity(inc, alpha, beta)
        return list_slide_id, slide_top_idx, simi_tops


    def save_metrics(self, list_slide_id, slide_top_idx, simi_tops, extension="svs", site_retrieval=False):
        save_dir = join("TEST_DATA_RESULTS", self.experiment)
        makedirs(save_dir, exist_ok=True)

        k = 10 if site_retrieval else 5

        records = []
        for id, slide_path in enumerate(list_slide_id):
            if slide_path.startswith("TEST_DATA"):
                temp = []
                top_slide_ids = slide_top_idx[id]
                top_slide_sims = simi_tops[id]

                query_name = f"{slide_path.split('/')[-1]}.{extension}"
                query_site = QUERY_SLIDES[query_name]
                query_diagnosis = QUERY_SUBTYPES[query_name]

                temp.extend([query_name, query_site, query_diagnosis])

                counter = 0
                for r, sim in zip(top_slide_ids, top_slide_sims):
                    rr = list_slide_id[r]
                    if not rr.startswith("TEST_DATA"):
                        result_name = f"{rr.split('/')[-1]}.svs"
                        result_site = rr.split('/')[-3]
                        result_diagnosis = rr.split('/')[-2]
                        result_similarity = sim.item()
                        temp.extend([result_name, result_site, result_diagnosis, result_similarity])
                        counter = counter + 1
                        if counter == k:
                            break

                records.append(temp)

        columns = ["query_name", "query_site", "query_diagnosis"]
        [columns.extend([f"ret_{i}_name", f"ret_{i}_site", f"ret_{i}_diagnosis", f"ret_{i}_dist"]) for i in range(1, k + 1)]
        df = pd.DataFrame.from_records(records, columns=columns)
        if site_retrieval:
            # save_path = join(save_dir, f"site_{slide_path.split('/')[-1]}.csv")
            save_path = join(save_dir, f"site.csv")
        else:
            # save_path = join(save_dir, f"sub_type_{slide_path.split('/')[-1]}.csv")
            save_path = join(save_dir, f"sub_type.csv")
        df.to_csv(save_path, index=False)
        

# Inputs

In [5]:
num_cluster = 20
feature_in = 512
feature_out = 1024
depth = 1

# MODEL_DIR = "/home/mxn2498/projects/new_search_comp/hshr/FEATURES/DATABASE/model/ssl_att"
MODEL_DIR = "checkpoints"
RESULT_DIR = "/home/mxn2498/projects/new_search_comp/hshr/FEATURES/DATABASE"
TMP = "/home/mxn2498/projects/new_search_comp/hshr/FEATURES/DATABASE/TEMP"

DATASETS = ["brain", "breast", "colon", "liver", "lung"]

EXPERIMENTS = {
    "brain": ["LGG", "GBM"],
    "breast": ["BRCA"],
    "lung": ["LUAD", "LUSC"],
    "colon": ["COAD"],
    "liver": ["LIHC", "CHOL"],
}

# Site Retrieval

In [6]:
experiments = ["UCLA", "READER_STUDY", "BRCA_HER2", "BRCA_TRASTUZUMAB", "GBM_MICROSCOPE_CPTAC", "GBM_MICROSCOPE_UPENN"]
extensions = ["svs", "svs", "svs", "svs", "svs", "ndpi"]

TEST_RESULT_DIR = "/home/mxn2498/projects/new_search_comp/hshr/FEATURES"
TEST_TMP = "/home/mxn2498/projects/new_search_comp/hshr/FEATURES/TEST_DATA/TEMP"

In [7]:
for experiment, extension in zip(experiments, extensions):
    with open(f"FEATURES/TEST_DATA/{experiment}/query_slides.yaml", 'r') as f:
        QUERY_SLIDES = yaml.safe_load(f)
    with open(f"FEATURES/TEST_DATA/{experiment}/query_subtypes.yaml", 'r') as f:
        QUERY_SUBTYPES = yaml.safe_load(f)

    cfs = []
    cf_paths = []
    for name in DATASETS:
        valid_dataset = PairCenterDataset(RESULT_DIR, TMP, False, EXPERIMENTS[name])
        for c1, _, path in valid_dataset.centers:
            cfs.append(c1)
            cf_paths.append(path)
            
    valid_dataset = PairCenterDataset(TEST_RESULT_DIR, TEST_TMP, False, [experiment])
    for c1, _, path in valid_dataset.centers:
        cfs.append(c1)
        cf_paths.append(path)
    cfs = np.array(cfs)
    
    hypergraph = HyperGraph(experiment)
    
    h, w = fine_tune(cfs, MODEL_DIR)
    hypergraph.add_patches(h.cpu().detach().numpy(), cf_paths)
    hypergraph.add_weight(w.cpu().detach().numpy())
    
    list_slide_id, slide_top_idx, simi_tops = hypergraph.get_results(K=10, alpha=1, beta=1)
    hypergraph.save_metrics(list_slide_id, slide_top_idx, simi_tops, extension, site_retrieval=True)

# Subtype Retrieval

In [8]:
experiments = ["BRCA_HER2", "BRCA_TRASTUZUMAB", "GBM_MICROSCOPE_CPTAC", "GBM_MICROSCOPE_UPENN"]
extensions = ["svs", "svs", "svs", "ndpi"]
sites = ["breast", "breast", "brain", "brain"]

TEST_RESULT_DIR = "/home/mxn2498/projects/new_search_comp/hshr/FEATURES"
TEST_TMP = "/home/mxn2498/projects/new_search_comp/hshr/FEATURES/TEST_DATA/TEMP"

In [9]:
for experiment, extension, site in zip(experiments, extensions, sites):
    with open(f"FEATURES/TEST_DATA/{experiment}/query_slides.yaml", 'r') as f:
        QUERY_SLIDES = yaml.safe_load(f)
    with open(f"FEATURES/TEST_DATA/{experiment}/query_subtypes.yaml", 'r') as f:
        QUERY_SUBTYPES = yaml.safe_load(f)  
    
    cfs = []
    cf_paths = []
    name = site
    valid_dataset = PairCenterDataset(RESULT_DIR, TMP, False, EXPERIMENTS[name])
    for c1, _, path in valid_dataset.centers:
        cfs.append(c1)
        cf_paths.append(path)
    valid_dataset = PairCenterDataset(TEST_RESULT_DIR, TEST_TMP, False, [experiment])
    for c1, _, path in valid_dataset.centers:
        cfs.append(c1)
        cf_paths.append(path)
    cfs = np.array(cfs)
    
    hypergraph = HyperGraph(experiment)
    
    h, w = fine_tune(cfs, MODEL_DIR)
    hypergraph.add_patches(h.cpu().detach().numpy(), cf_paths)
    hypergraph.add_weight(w.cpu().detach().numpy())
    
    list_slide_id, slide_top_idx, simi_tops = hypergraph.get_results(K=10, alpha=1, beta=1)
    hypergraph.save_metrics(list_slide_id, slide_top_idx, simi_tops, extension, site_retrieval=False)

In [26]:
experiments = ["UCLA", "READER_STUDY"]
extensions = ["svs", "svs"]

TEST_RESULT_DIR = "/home/mxn2498/projects/new_search_comp/hshr/FEATURES/TEST_DATA"
TEST_TMP = "/home/mxn2498/projects/new_search_comp/hshr/FEATURES/TEST_DATA/TEMP"

In [27]:
for experiment, extension, site in zip(experiments, extensions, sites):
    with open(f"FEATURES/TEST_DATA/{experiment}/query_slides.yaml", 'r') as f:
        QUERY_SLIDES = yaml.safe_load(f)
    with open(f"FEATURES/TEST_DATA/{experiment}/query_subtypes.yaml", 'r') as f:
        QUERY_SUBTYPES = yaml.safe_load(f)  

    for slide, site in QUERY_SLIDES.items():
        cfs = []
        cf_paths = []
        name = site
        valid_dataset = PairCenterDataset(RESULT_DIR, TMP, False, EXPERIMENTS[name])
        for c1, _, path in valid_dataset.centers:
            cfs.append(c1)
            cf_paths.append(path)

        f0 = join(TEST_RESULT_DIR, experiment, splitext(slide)[0], "clu_0.npy")
        f1 = join(TEST_RESULT_DIR, experiment, splitext(slide)[0], "clu_1.npy") 
        c0 = np.load(f0)
        c1 = np.load(f1)
        path = f"TEST_DATA/{experiment}/{splitext(slide)[0]}"
        cfs.append(c0)
        cf_paths.append(path)
        cfs = np.array(cfs)

        new_exp = f"{experiment}_{splitext(slide)[0]}"
        hypergraph = HyperGraph(new_exp)
        
        h, w = fine_tune(cfs, MODEL_DIR)
        hypergraph.add_patches(h.cpu().detach().numpy(), cf_paths)
        hypergraph.add_weight(w.cpu().detach().numpy())
        
        list_slide_id, slide_top_idx, simi_tops = hypergraph.get_results(K=10, alpha=1, beta=1)
        hypergraph.save_metrics(list_slide_id, slide_top_idx, simi_tops, extension, site_retrieval=False)