In [None]:
import os
# Settting the GPU ID to use
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import pickle

import scipy
import scipy.sparse

from tqdm import tqdm
from sklearn import metrics
from sklearn.model_selection import StratifiedKFold
from multiprocessing.pool import ThreadPool as Pool

from openslide_handler import OpenSlideHandler

In [None]:
def get_wsi_dimensions(wsi_path):
    '''
    Function to get the dimensions of the WSI
    Args:
    wsi_path: Path of the WSI
    Returns:
    openslide_obj.wsi_dimensions: Dimensions of the WSI
    '''
    openslide_obj = OpenSlideHandler(wsi_path)
    return openslide_obj.wsi_dimensions

def load_atlas_data(pkl_files_path):
    '''
    Loads the atlas data from the pkl files
    Args:
    pkl_files_path: Folder path of the pkl files
    returns:
    atlas_data: Dictionary of atlas data
    '''
    # Initializing the atlas data dictionary
    atlas_data = {
        'file_names': [],
        'load_configs': [],
        'image_regions': [],
        'embeddings': [],
        'masks': [],
        'measurements': [],
        'selected' : []
    }

    # Loading the atlas data from the pkl files
    for pickel_file in tqdm(os.listdir(pkl_files_path), total=len(os.listdir(pkl_files_path))):
        # Checking if the file is a pkl file
        if pickel_file.endswith('.pkl'):
            with open(os.path.join(pkl_files_path, pickel_file), 'rb') as f:
                patch_data = pickle.load(f)

            # Appending the data to the atlas data dictionary
            atlas_data['file_names'].append([patch_data['file_name']] * len(patch_data['embeddings']))
            atlas_data['load_configs'].append(patch_data['load_configs'])
            atlas_data['image_regions'].append(patch_data['image_regions'])
            atlas_data['embeddings'].append(patch_data['embeddings'])
            atlas_data['masks'].append(patch_data['mask'])
            atlas_data['measurements'].append(patch_data['measurements'])
            atlas_data['selected'].append(np.full(len(patch_data['embeddings']), False))
    
    # Concatenating the data    
    atlas_data['file_names'] = np.concatenate(atlas_data['file_names'])
    atlas_data['load_configs'] = np.concatenate(atlas_data['load_configs'])
    atlas_data['image_regions'] = np.concatenate(atlas_data['image_regions'])
    atlas_data['embeddings'] = np.vstack(atlas_data['embeddings'])
    atlas_data['masks'] = np.concatenate(atlas_data['masks'])
    atlas_data['measurements'] = np.concatenate(atlas_data['measurements'])
    atlas_data['selected'] = np.concatenate(atlas_data['selected'])

    # Reforming the measurements and removing the last column which is the ROI and is not needed
    atlas_data['measurements'] = np.array([list(x.values()) for x in atlas_data['measurements']])[:, :-1]
            
    return atlas_data

def predict(new_cases, known_cases, known_measurements, vote_method='mean', simil_method='euclidean', top_n=5):
    '''
    Function to predict the measurements of the new cases
    Args:
    new_cases: Embeddings of the new cases
    known_cases: Embeddings of the known cases
    known_measurements: Measurements of the known cases
    vote_method: Method to use for voting the measurements
    simil_method: Method to use for calculating the similarity between the new and known cases
    top_n: Number of top cases to use for voting
    Returns:
    vote_results: Predicted measurements of the new cases
    top_n_indices: Indices of the top cases used for voting
    top_n_similarities: Similarities of the top cases used for voting
    '''
    # Calculating the distances between the new and known cases
    # C is the number of new cases
    # M is the number of known cases
    # N is the embedding dimension
    distances = scipy.spatial.distance.cdist(
        new_cases, # C x N
        known_cases, # M x N
        'euclidean'
    ) # C x M

    # Asserting the dimensions of the distances matches the known measurements
    assert distances.shape[1] == known_measurements.shape[0]
    
    # Adding a small offset to the distances to avoid division by zero
    zero_offset = 1e-6
    
    # Functiions for voting
    vote_fns = {
        'mean': lambda x, _: np.mean(x, axis=1), # Voting using the mean
        'median': lambda x, _: np.median(x, axis=1), # VOting using the median
        'max': lambda x, _: np.max(x, axis=1), # Voting using the max
        'weighted_mean_distance': (lambda x, y: 
            np.average(x, axis=1, weights=1/(np.repeat(y, x.shape[2], axis=1).reshape(y.shape[0], y.shape[1], -1)+zero_offset))
        ), # Voting using the weighted mean of the distances
        'weighted_mean_cosine': (lambda x, y:
            np.average(x, axis=1, weights=np.repeat(y, x.shape[2], axis=1).reshape(y.shape[0], y.shape[1], -1))
        ), # Voting using the weighted mean of the cosine similarities
        'squread_weighted_mean': (lambda x, y:
            np.average(x, axis=1, weights=1/((np.repeat(y, x.shape[2], axis=1).reshape(y.shape[0], y.shape[1], -1)+zero_offset)**2))
        ), # Voting using the squared weighted mean of the distances
    }
    
    # Sorting the distances and getting the top_n indices
    if simil_method == 'euclidean':       
        similar_indices = np.argsort(distances, axis=1)
    elif simil_method == 'cosine':
        similar_indices = np.argsort(distances, axis=1)[:, ::-1]
    
    # Getting the top_n measurements, indices and distances
    top_n_measurements = np.array(known_measurements)[similar_indices][:, :top_n]
    top_n_indices = similar_indices[:, :top_n]
    top_n_distances = np.sort(distances, axis=1)[:, :top_n]
    
    # Returning the vote results, top_n indices and top_n distances
    return vote_fns[vote_method](top_n_measurements, top_n_distances), top_n_indices, top_n_distances

def cross_validation(known_cases, known_measurements, vote_method, simil_method, top_n, n_splits=5):
    '''
    Function to perform cross validation testing
    Args:
    known_cases: Embeddings of the known cases
    known_measurements: Measurements of the known cases
    vote_method: Method to use for voting the measurements
    simil_method: Method to use for calculating the similarity between the new and known cases
    top_n: Number of top cases to use for voting
    n_splits: Number of splits to use for cross validation
    Returns:
    y_true: True measurements of the known cases
    y_pred: Predicted measurements of the known cases
    '''    
    # Creating the stratified k-fold
    stratified_k_fold = StratifiedKFold(n_splits=n_splits, shuffle=True)
    # Getting the labels of the known cases for stratified k-fold
    case_labels = np.argmax(known_measurements, axis=1)

    # Creating the lists to store the true and predicted measurements
    y_true = []
    y_pred = []

    # Performing the cross validation
    for fold_index, (train_index, test_index) in enumerate(stratified_k_fold.split(known_cases, case_labels)):
        # Getting the train and test cases and measurements
        _train_cases = known_cases[train_index]
        _train_measurements = known_measurements[train_index]
        
        # Getting the test cases and measurements
        _test_cases = known_cases[test_index]
        _test_measurements = known_measurements[test_index]
        
        # Asserting the dimensions of the train and test cases and measurements
        assert _train_cases.shape[0] == _train_measurements.shape[0]
        
        # Printing the fold information
        print(f'Fold {fold_index+1}/{n_splits} - Train: {_train_cases.shape[0]} - Test: {_test_cases.shape[0]}')
        
        # Getting the predictions
        y_true.append(_test_measurements)
        y_pred.append(predict(_test_cases, _train_cases, _train_measurements, vote_method, simil_method, top_n)[0])
    
    # Concatenating the true and predicted measurements                       
    y_true = np.concatenate(y_true, axis=0)
    y_pred = np.concatenate(y_pred, axis=0)
    
    return y_true, y_pred

def leave_one_out(known_cases, known_measurements, vote_method, simil_method, top_n):
    '''
    Function to perform leave one out testing
    Args:
    known_cases: Embeddings of the known cases
    known_measurements: Measurements of the known cases
    vote_method: Method to use for voting the measurements
    simil_method: Method to use for calculating the similarity between the new and known cases
    top_n: Number of top cases to use for voting
    Returns:
    y_true: True measurements of the known cases
    y_pred: Predicted measurements of the known cases
    '''
    # Getting the number of cases
    number_of_cases = len(known_cases)
    
    # Copying the known cases and measurements to preserve the original
    original_known_cases = known_cases.copy()
    original_known_measurements = known_measurements.copy()
    
    # Defining the local job function for multiprocessing
    def _local_job(out_index):
        # Getting the test case and removing it from the known cases
        _test_case = np.expand_dims(original_known_cases[out_index, :], axis=0)
        # Removing the test case from the known cases and measurements
        _known_cases = np.vstack([original_known_cases[:out_index], original_known_cases[out_index + 1:]])
        _known_measurements = np.vstack([original_known_measurements[:out_index], original_known_measurements[out_index + 1:]])
        
        # Asserting the dimensions of the known cases and measurements
        assert len(_known_cases) == len(_known_measurements)
        
        # Getting the true and predicted measurements
        _true = original_known_measurements[out_index, :]
        _pred, _, _ = predict(_test_case, _known_cases, _known_measurements, vote_method, simil_method, top_n)
        
        # Returning the true and predicted measurements          
        return _true, _pred
    
    # Performing the leave one out testing using multiprocessing
    with Pool(64) as pool:
        tested = list(tqdm(
            pool.imap(_local_job, range(number_of_cases)),
            total=number_of_cases,
            desc='Testing'
            ))
    
    # Getting the true and predicted measurements
    y_true = []
    y_pred = []
    for _true, _pred in tested:
        y_true.append(_true)
        y_pred.append(_pred)
    
    return np.array(y_true), np.array(y_pred)

In [1]:
# Setting the network, patch size and magnification
network = 'kimianet'
patch_size = 128
magnification = 40

# Setting the labels
annot_labels = ['steatosis', 'ballooning',
                'mallory', 'inflammation', 'fibrosis', 'roi']

# Setting the path to the atlas
atlas_path = f'/mayo_atlas/home/m276983/atlases/liver_estimate_atlases/{network}_{patch_size}_{magnification}X'

In [None]:
# Loading the atlas data
atlas_data = load_atlas_data(atlas_path)

In [None]:
# Checking the shape of each array in atlas_data
for key, value in atlas_data.items():
    print(f'{key}: {value.shape}')

In [None]:
# Selecting the cases with measurements above a threshold
selection_threshold = 0.0
sample_number = 10000

# resetting the selected array
atlas_data['selected'] = np.full(atlas_data['measurements'].shape[0], False)

# Selecting the cases with measurements above the threshold in each label
for l_index, label in enumerate(annot_labels[:-1]):
    # Finding the indices of the cases with measurements above the threshold
    case_indices = np.where(atlas_data['measurements'][:, l_index] > selection_threshold)[0]
    label_case_number = len(case_indices)
    # Checking if the number of cases is greater than the sample number
    if label_case_number > sample_number:
        selected_indices = np.random.choice(case_indices, sample_number, replace=False)
    else:
        selected_indices = case_indices
    # Setting the selected indices to True
    atlas_data['selected'][selected_indices] = True
    # Printing the number of cases with measurements above the threshold   
    print(f'selecting {len(selected_indices)} of {label_case_number} cases for {label}')
    
# adding a number of cases with no finding
normal_indices = np.where(~atlas_data['measurements'].any(axis=1))[0]
normal_case_number = len(normal_indices)
selected_normal_indices = np.random.choice(normal_indices, sample_number, replace=False)
print(f'selecting {len(selected_normal_indices)} of {normal_case_number} cases for no finding')
atlas_data['selected'][selected_normal_indices] = True

print(f'Number of selected cases in total: {atlas_data["selected"].sum()}')

In [None]:
# Checking the shape of each array in atlas_data
for key, value in atlas_data.items():
    print(f'{key}: {value[atlas_data["selected"]].shape}')

In [None]:
# Setting the similarity and voting methods
vote_method = 'weighted_mean_distance'
similarity_method = 'euclidean'
top_n = 100
n_splits = 5 

# Testing using leave one out
# y_true, y_pred = leave_one_out(
#     atlas_data['embeddings'][atlas_data['selected']],
#     atlas_data['measurements'][atlas_data['selected']],
#     vote_method=vote_method,
#     simil_method=similarity_method,
#     top_n=top_n
#     )

# Testing using cross validation
y_true, y_pred = cross_validation(
    atlas_data['embeddings'][atlas_data['selected']],
    atlas_data['measurements'][atlas_data['selected']],
    vote_method=vote_method,
    simil_method=similarity_method,
    top_n=top_n,
    n_splits=n_splits
)

# Checking the shape of y_true and y_pred
print(y_true.shape, y_pred.shape)

In [None]:
# function to calculate the mean absolute error
def mean_absolute_error(y_true, y_pred):
    return np.mean(np.abs(y_true - y_pred), axis=0)

# function to calculate the mean squared error
def mean_squared_error(y_true, y_pred):
    return np.mean(np.square(y_true - y_pred), axis=0)

# function to calculate the standard deviation of the absolute error
def std_absolute_error(y_true, y_pred):
    return np.std(np.abs(y_true - y_pred), axis=0)

# Printing the mean absolute error, standard deviation of absolute error and r2 score for each label
for l_index, label in enumerate(annot_labels[:-1]):
    mse = mean_squared_error(y_true[:, l_index], y_pred[:, l_index])
    mae = mean_absolute_error(y_true[:, l_index], y_pred[:, l_index])
    stdae = std_absolute_error(y_true[:, l_index], y_pred[:, l_index])
    r_2_score = metrics.r2_score(y_true[:, l_index], y_pred[:, l_index])
    print(
        f'{label:<15}: Mean Absolute Error: {mae:.3f}, Standard Deviation of Absolute Error: {stdae:.5f}, r2 score: {r_2_score:.3f}'
        )

In [None]:
# Creating some example cases
n_samples = 5

# Selecting the indices of the example cases
sample_indices = np.random.choice(range(sum(atlas_data['selected'])), n_samples, replace=False)

# Predicting the measurements of the example cases
sample_y_pred, top_indices, top_distances = predict(
    atlas_data['embeddings'][atlas_data['selected']][sample_indices], 
    np.delete(atlas_data['embeddings'][atlas_data['selected']], sample_indices, axis=0),
    np.delete(atlas_data['measurements'][atlas_data['selected']], sample_indices, axis=0),
    vote_method,
    similarity_method,
    top_n
    )

# Selecting the true measurements of the example cases
sample_y_true = atlas_data['measurements'][atlas_data['selected']][sample_indices]

print(sample_y_true.shape, sample_y_pred.shape)

# Plotting the example cases
for i_index, sample_index in enumerate(sample_indices):
    _file_name = atlas_data['file_names'][atlas_data['selected']][sample_index]
    _load_config = atlas_data['load_configs'][atlas_data['selected']][sample_index]
    _image_region = atlas_data['image_regions'][atlas_data['selected']][sample_index]
    _masks = atlas_data['masks'][atlas_data['selected']][sample_index]
    _measurement = atlas_data['measurements'][atlas_data['selected']][sample_index]
    _y_true = sample_y_true[i_index]
    _y_pred = sample_y_pred[i_index]
    _top_5_indices = top_indices[i_index][:5]
    _top_5_distances = top_distances[i_index][:5]

    fig, axs = plt.subplots(2, 5, figsize=(20, 10))

    # Remove grids from subplots
    for ax in axs.flatten():
        ax.axis('off')

    axs[0, 0].imshow(_image_region)
    axs[0, 0].set_title('image_region')

    for l_index, (label, measure) in enumerate(zip(annot_labels[:-1], _measurement[:-1])):
        axs[0, l_index + 1].imshow(_masks[label], cmap='gray', vmin=0, vmax=1)
        axs[0, l_index + 1].set_title(
            f'{label}: {measure:.2f}\n'
            f'true: {_y_true[l_index]:.2f}\n'
            f'pred: {_y_pred[l_index]:.2f}'

            )
        
    # Selecting the similar patches by removing the sample indices from the selected indices
    simil_patches = np.delete(atlas_data['image_regions'][atlas_data['selected']], sample_indices, axis=0)[_top_5_indices]
    simil_file_names = np.delete(atlas_data['file_names'][atlas_data['selected']], sample_indices, axis=0)[_top_5_indices]
    simil_load_configs = np.delete(atlas_data['load_configs'][atlas_data['selected']], sample_indices, axis=0)[_top_5_indices]
    simil_measurements = np.delete(atlas_data['measurements'][atlas_data['selected']], sample_indices, axis=0)[_top_5_indices]
    
    for l_index, (patch, measurements, distance, file_name, load_config) in enumerate(zip(
        simil_patches,
        simil_measurements,
        _top_5_distances,
        simil_file_names,
        simil_load_configs)):
        axs[1, l_index].imshow(patch)
        axs[1, l_index].set_title(
            f'file: {file_name}\n'
            + f'location: {load_config["location"]}\n'
            + '\n'.join([f"{label}: {measure:.2f}" for label, measure in zip(annot_labels[:-1], measurements[:-1])])
            + f'\ndistance: {distance:.2f}'
            )

    # Adjust subplot spacing
    plt.subplots_adjust(wspace=0.1)
    plt.subplots_adjust(hspace=0.5)

    fig.suptitle(
        f'Image region from {_file_name}.svs with load config {_load_config}')
    
    # saving the figure with the name reflecting patch size, network and magnification
    plt.savefig(f'sample_patch_{patch_size}_{network}_{magnification}_{i_index}.png', bbox_inches = 'tight')

    # Show the plot
    plt.show()