In [None]:
import logging

FORMAT = '%(asctime)s %(name)s %(funcName)s %(message)s'
log_level = logging.WARNING
logging.basicConfig(format=FORMAT, datefmt='%H:%M:%S',
                    level=log_level)

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys
import h5py
import numpy as np
import subprocess
import pandas as pd
import matplotlib.pyplot as plt
import csv

bnsi_path = '/scicore/home/nimwegen/degroo0000/Bonsai-data-representation'
sys.path.append(bnsi_path)
from bonsai_scout.bonsai_scout_helpers import Bonvis_figure, Bonvis_settings, Bonvis_metadata

#### List information on datasets

In [None]:
SINGLE_DATASET = False
if not SINGLE_DATASET:
    dataset_ids = ['simulated_datasets/simulated_binary_10_gens_samplingnoise_seed_2462',
                   'simulated_datasets/simulated_pseudobulk_based_ncells_1024_seed_1231',
    'simulated_datasets/simulated_binary_10_gens_samplingnoise_unbalanced_seed_2462',
    'simulated_datasets/simulated_binary_10_gens_samplingnoise_randomtimes_seed_2462',
    'simulated_datasets/simulated_binary_10_gens_samplingnoise_realcovariance_seed_2462',
    'simulated_datasets/simulated_binary_10_gens_samplingnoise_randomtimes_unbalanced_seed_2462',
    'simulated_datasets/simulated_binary_10_gens_samplingnoise_randomtimes_unbalanced_realcovariance_seed_2462']
else:
    dataset_ids = ['simulated_datasets/simulated_binary_10_gens_samplingnoise_randomtimes_unbalanced_realcovariance_seed_2462']
    # dataset_ids = ['simulated_datasets/simulated_pseudobulk_based_ncells_1024_seed_1231']
    
input_data_folders = [os.path.join('/scicore/home/nimwegen/GROUP/Projects/bonsai_runs/paper_figures_datasets', dataset_id, 'UMI_counts') for dataset_id in dataset_ids]
bonsai_results_folders = [os.path.join('/scicore/home/nimwegen/GROUP/Projects/bonsai_runs/paper_figures_datasets', dataset_id, 'bonsai') 
                          for dataset_id in dataset_ids]
bonsai_results_folders = [os.path.join('/scicore/home/nimwegen/degroo0000/Bonsai-data-representation/slurm_runs_pipeline/output', dataset_id, 'bonsai') 
                          for dataset_id in dataset_ids]

if not SINGLE_DATASET:
    dataset_descr_lst = ['Perfect binary', 'Pseudobulk', 'Unbalanced (Unb)', 'Random branch lengths (Rbl)', 'Real covariance (Reco)', 'Rbl + Unb', 'Rbl + Unb + Reco']
else:
    dataset_descr_lst = ['Rbl + Unb + Reco']
    # dataset_descr_lst = ['Pseudobulk']

## Create Bonsai visualization of dataset

In [None]:
%%capture  
# The above %%capture is used for not showing the tree-visualizations that are created.
bonvis_metadata_lst = []
bonvis_settings_lst = []
bonvis_data_hdf_lst = []
bonvis_fig_lst = []
for ind_dataset, bonsai_results in enumerate(bonsai_results_folders):
    # Load metadata, settings and data
    data_path = os.path.join(bonsai_results, 'bonsai_vis_data.hdf')
    settings_path = os.path.join(bonsai_results, 'bonsai_vis_settings.json')
    bonvis_metadata = Bonvis_metadata(data_path)
    bonvis_settings = Bonvis_settings(load_settings_path=settings_path)
    bonvis_data_hdf = h5py.File(data_path, 'r')
    bonvis_fig = Bonvis_figure(bonvis_data_hdf, bonvis_metadata, bonvis_data_path=data_path,
                           bonvis_settings=bonvis_settings)
    bonvis_fig.create_figure(figsize=(6, 6))

    bonvis_metadata_lst.append(bonvis_metadata)
    bonvis_settings_lst.append(bonvis_settings)
    bonvis_data_hdf_lst.append(bonvis_data_hdf)
    bonvis_fig_lst.append(bonvis_fig)

In [None]:
%%capture  
# It is possible to create circular layout
for ind_dataset, bonvis_fig in enumerate(bonvis_fig_lst):
    bonvis_fig.update_figure(ly_type='ly_eq_angle')
    bonvis_fig.create_figure(figsize=(6, 6))

In [None]:
# Get a list of possible celltype-annotations
# for ind_dataset, bonvis_fig in enumerate(bonvis_fig_lst):
#     celltype_info = bonvis_fig.bonvis_settings.celltype_info
#     print(dataset_descr_lst[ind_dataset])
#     print(celltype_info.annot_alts)
#     print('\n')

In [None]:
# Here, we set the desired celltype-annotation for every dataset
node_style_lst = []
for ind_dataset, dataset_descr in enumerate(dataset_descr_lst):
    node_style = 'Celltype3' if (dataset_descr != "Pseudobulk") else 'Pseudobulk'
    node_style_lst.append(node_style)

In [None]:
# Visualize the tree in the equal-daylight layout, with the correct celltype-annotation
for ind_dataset, bonvis_fig in enumerate(bonvis_fig_lst):
    node_style = node_style_lst[ind_dataset]
    geometry = 'flat' if (dataset_descr_lst[ind_dataset] == 'Pseudobulk') else 'hyperbolic'
    zoom = 1 if (dataset_descr_lst[ind_dataset] == 'Pseudobulk') else 1
    bonvis_fig.update_figure(ly_type='ly_eq_angle', geometry=geometry, node_style=node_style, zoom=zoom);

## Create Bonsai visualization of ground truth dataset

In [None]:
if not SINGLE_DATASET:
    dataset_ids_gt = ['simulated_datasets/simulated_binary_10_gens_samplingnoise_seed_2462',
    'simulated_datasets/simulated_binary_10_gens_samplingnoise_unbalanced_seed_2462',
    'simulated_datasets/simulated_binary_10_gens_samplingnoise_randomtimes_seed_2462',
    'simulated_datasets/simulated_binary_10_gens_samplingnoise_realcovariance_seed_2462',
    'simulated_datasets/simulated_binary_10_gens_samplingnoise_randomtimes_unbalanced_seed_2462',
    'simulated_datasets/simulated_binary_10_gens_samplingnoise_randomtimes_unbalanced_realcovariance_seed_2462']

    dataset_descr_lst_gt = ['Perfect binary', 'Unbalanced (Unb)', 'Random branch lengths (Rbl)', 'Real covariance (Reco)', 'Rbl + Unb', 'Rbl + Unb + Reco']
else:
    dataset_ids_gt = ['simulated_datasets/simulated_binary_10_gens_samplingnoise_randomtimes_unbalanced_realcovariance_seed_2462']
    # dataset_ids_gt = []
    
    dataset_descr_lst_gt = ['Rbl + Unb + Reco']
    # dataset_descr_lst_gt = []

input_data_folders_gt = [os.path.join('/scicore/home/nimwegen/GROUP/Projects/bonsai_runs/paper_figures_datasets', id, 'UMI_counts') for id in dataset_ids_gt]
bonsai_results_folders_gt = [os.path.join('/scicore/home/nimwegen/GROUP/Projects/bonsai_runs/paper_figures_datasets', id, 'UMI_counts', 'true_tree') for id in dataset_ids_gt]



In [None]:
%%capture  
# The above %%capture is used for not showing the tree-visualizations that are created.
bonvis_metadata_lst_gt = []
bonvis_settings_lst_gt = []
bonvis_data_hdf_lst_gt = []
bonvis_fig_lst_gt = []
for ind_dataset, bonsai_results in enumerate(bonsai_results_folders_gt):
    # Load metadata, settings and data
    data_path = os.path.join(bonsai_results, 'bonsai_vis_data.hdf')
    print(data_path)
    settings_path = os.path.join(bonsai_results, 'bonsai_vis_settings.json')
    bonvis_metadata = Bonvis_metadata(data_path)
    bonvis_settings = Bonvis_settings(load_settings_path=settings_path)
    bonvis_data_hdf = h5py.File(data_path, 'r')
    bonvis_fig = Bonvis_figure(bonvis_data_hdf, bonvis_metadata, bonvis_data_path=data_path,
                           bonvis_settings=bonvis_settings)
    bonvis_fig.create_figure(figsize=(6, 6))

    bonvis_metadata_lst_gt.append(bonvis_metadata)
    bonvis_settings_lst_gt.append(bonvis_settings)
    bonvis_data_hdf_lst_gt.append(bonvis_data_hdf)
    bonvis_fig_lst_gt.append(bonvis_fig)

In [None]:
# Here, we set the desired celltype-annotation for every dataset
node_style_lst_gt = []
for ind_dataset, dataset_descr in enumerate(dataset_descr_lst_gt):
    node_style = 'Celltype3' if (dataset_descr != "Pseudobulk") else 'Pseudobulk'
    node_style_lst_gt.append(node_style)

In [None]:
# Visualize the tree in the equal-daylight layout, with the correct celltype-annotation
for ind_dataset_gt, bonvis_fig in enumerate(bonvis_fig_lst_gt):
    ind_dataset = dataset_descr_lst.index(dataset_descr_lst_gt[ind_dataset_gt])
    node_style = node_style_lst_gt[ind_dataset_gt]
    bonvis_fig.update_figure(ly_type='ly_eq_angle', geometry='hyperbolic', node_style=node_style);

## Create PCA and UMAP plot of all datasets

In [None]:
from paper_figure_scripts_and_notebooks.simulating_datasets.analyzing_simulated_datasets.knn_recall_helpers import do_pca, do_logp1, fit_umap, fit_phate, get_pdists_on_tree, Dataset, \
    compare_nearest_neighbours_to_truth, compare_pdists_to_truth, get_pdists_on_tree, compare_pdists_to_truth_per_cell, compare_nearest_neighbours_to_truth
from bonsai.bonsai_helpers import find_latest_tree_folder
from scipy.spatial import distance
from bonsai.bonsai_dataprocessing import SCData, get_bonsai_euclidean_distances
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
from matplotlib import colormaps
plt.set_loglevel(level='warning')

In [None]:
umi_counts_lst = []
cell_ids_lst = []
for ind_dataset, input_data_folder in enumerate(input_data_folders):    
    print(ind_dataset)
    # Load data; extract UMI-counts and cell-IDs
    umi_counts_df = pd.read_csv(os.path.join(input_data_folder, 'Gene_table.txt'), header=0,
                                        index_col=0, sep='\t')
    cell_ids = list(umi_counts_df.columns)
    umi_counts = umi_counts_df.values
    del umi_counts_df
    
    cell_ids_lst.append(cell_ids)
    umi_counts_lst.append(umi_counts)

In [None]:
pca_projected_lst = []
PCA_COMPS = [2, 10]
logp1_lst = []

for ind_dataset, input_data_folder in enumerate(input_data_folders):
    print(ind_dataset)
    umi_counts = umi_counts_lst[ind_dataset]
    # Perform logp1
    logp1 = do_logp1(umi_counts)

    # Perform PCA to 2 components for visualization and to 10 components for subsequent UMAP
    pca_projected = do_pca(logp1, n_comps_list=PCA_COMPS)
    # del logp1
    logp1_lst.append(logp1)
    pca_projected_lst.append(pca_projected)

In [None]:
umap_projected_lst = []
for ind_dataset, input_data_folder in enumerate(input_data_folders):   
    print(ind_dataset)
    pca_projected = pca_projected_lst[ind_dataset]
    # Perform UMAP on the larger-number-of-components PCA
    umap_projected = {}
    for n_comps, pca_proj in pca_projected.items():
        if n_comps == 2:
            continue
        umap_projected[n_comps] = fit_umap(pca_proj, random_state=None, n_neighbors=15, min_dist=0.1,
                                           n_components=2,
                                           metric='euclidean',
                                           make_plot=False, title='')
        umap_projected_lst.append(umap_projected)

In [None]:
phate_projected_lst = []
for ind_dataset, input_data_folder in enumerate(input_data_folders):   
    print(ind_dataset)
    # Perform PHATE on the logp1-transformed data
    logp1 = logp1_lst[ind_dataset]
    # pca_proj = pca_projected_lst[ind_dataset][10]
    phate_projected = {}
    phate_projected['all'] = fit_phate(logp1)
    phate_projected_lst.append(phate_projected)

## Create pairwise distance plots

In [None]:
# Read in ground truth squared pairwise distances (divided by the number of dimensions)
true_dists_lst = []
# selected_gene_inds_lst = []
for ind_dataset, input_data_folder in enumerate(input_data_folders):
    print(ind_dataset)
    delta_gc_true = pd.read_csv(os.path.join(input_data_folder, 'delta_true.txt'), header=None,
                                index_col=None, sep='\t').values
    
    num_dims = delta_gc_true.shape[0]
    true_dists = distance.pdist(delta_gc_true.T, metric='sqeuclidean')/num_dims
    true_dists_lst.append(true_dists)

In [None]:
INCLUDE_SANITY=True
if INCLUDE_SANITY:
    sanity_lst = []
    for ind_dataset, input_data_folder in enumerate(input_data_folders):
        print(ind_dataset)
        dataset_id = dataset_ids[ind_dataset]
        sanity_path = os.path.join('/scicore/home/nimwegen/GROUP/Projects/bonsai_runs/paper_figures_datasets', dataset_id, 'Sanity')
        sanity_dists = pd.read_csv(os.path.join(sanity_path, 'cell_cell_distance_with_errorbar_avzscore_gt_1.txt'), header=None,
                                   index_col=None, sep='\t').values.flatten()
#         delta_gc_sanity = pd.read_csv(os.path.join(sanity_path, 'delta.txt'), header=None,
#                                     index_col=None, sep='\t').values
#         num_dims = delta_gc_sanity.shape[0]
#         sanity_dists = distance.pdist(delta_gc_sanity.T, metric='sqeuclidean')/num_dims
        sanity_lst.append(sanity_dists)

In [None]:
# Calculate distances on tree
bonsai_dists_lst = []
for ind_dataset, bonsai_results in enumerate(bonsai_results_folders):
    print(ind_dataset)
    cell_ids = cell_ids_lst[ind_dataset]
    tree_results = os.path.join(bonsai_results, find_latest_tree_folder(bonsai_results))
    bonsai_dists = get_pdists_on_tree(os.path.join(tree_results, 'tree.nwk'), cell_ids)
    bonsai_dists_lst.append(bonsai_dists)

In [None]:
pca_dists_lst = []
for ind_dataset, bonsai_results in enumerate(bonsai_results_folders):
    print(ind_dataset)
    pca_projected = pca_projected_lst[ind_dataset]
    for n_comps, pca_proj in pca_projected.items():
        if n_comps != 2:
            continue
        pca_dists = distance.pdist(pca_proj.T, metric='sqeuclidean') / 2
        pca_dists_lst.append(pca_dists)

In [None]:
umap_dists_lst = []
for ind_dataset, bonsai_results in enumerate(bonsai_results_folders):
    print(ind_dataset)
    umap_projected = umap_projected_lst[ind_dataset]
    for n_comps, umap_proj in umap_projected.items():
        umap_dists = distance.pdist(umap_proj.T, metric='sqeuclidean') / 2
        umap_dists_lst.append(umap_dists)

In [None]:
phate_dists_lst = []
for ind_dataset, bonsai_results in enumerate(bonsai_results_folders):
    print(ind_dataset)
    phate_projected = phate_projected_lst[ind_dataset]
    phate_proj = phate_projected['all']
    phate_dists = distance.pdist(phate_proj.T, metric='sqeuclidean') / 2
    phate_dists_lst.append(phate_dists)

In [None]:
# Initialize some list for storing information about the different tools
tool_objcts_lst = []

## Plot everything in big figure without the box-plots

In [None]:
figs_lst = []
axs_lst = []
# ncols = 5 if INCLUDE_SANITY else 4
ncols = 5
for ind_dataset, dataset in enumerate(dataset_descr_lst):
    fig, axs = plt.subplots(nrows=2, ncols=ncols, figsize=(12, 6))
    # Make the 2nd row share y-axis with the first subplot in the 2nd row
    for i in range(1, ncols):  # columns 1 to 3
        axs[1, i].sharey(axs[1, 0])
    figs_lst.append(fig)
    axs_lst.append(axs)
    for ax in axs.flatten():
        ax.set_box_aspect(1)
    # ax.axis('off')
    plt.tight_layout()
    plt.subplots_adjust(left=0, right=1.0, bottom=0.12, top=0.88)
#     plt.tight_layout()
    fig.suptitle(dataset_descr_lst[ind_dataset], fontsize=20)

In [None]:
# Create Bonsai visualization
# Visualize the tree in the equal-daylight layout, with the correct celltype-annotation
for ind_dataset, bonvis_fig in enumerate(bonvis_fig_lst):
    dataset_fig = figs_lst[ind_dataset]
    dataset_axs = axs_lst[ind_dataset]
    bonvis_fig.bonvis_settings.transf_info.ax_lims = np.array([-1.01, 1.01, -1.01, 1.01])
    bonvis_fig.update_figure(geometry='flat')
    bonvis_fig.create_figure(figsize=(6, 6), fig=dataset_fig, ax=dataset_axs[0, 1])
    dataset_descr = dataset_descr_lst[ind_dataset]
    if dataset_descr not in dataset_descr_lst_gt:
        dataset_axs[0,0].axis('off')

In [None]:
# Create Bonsai visualization for ground truth
# Visualize the tree in the equal-daylight layout, with the correct celltype-annotation
for ind_dataset_gt, bonvis_fig in enumerate(bonvis_fig_lst_gt):
    ind_dataset = dataset_descr_lst.index(dataset_descr_lst_gt[ind_dataset_gt])
    dataset_fig = figs_lst[ind_dataset]
    dataset_axs = axs_lst[ind_dataset]
    bonvis_fig.bonvis_settings.transf_info.ax_lims = np.array([-1.01, 1.01, -1.01, 1.01])
    bonvis_fig.update_figure(geometry='flat')
    bonvis_fig.create_figure(figsize=(6, 6), fig=dataset_fig, ax=dataset_axs[0,0])

In [None]:
annotation_folders = [os.path.join('/scicore/home/nimwegen/GROUP/Projects/bonsai_runs/paper_figures_datasets', dataset_id, 'annotation') for dataset_id in dataset_ids]

# Make figure for 2D-PCA and 2D-UMAP
for ind_dataset, input_data_folder in enumerate(input_data_folders):
    print(ind_dataset)
    bonvis_fig = bonvis_fig_lst[ind_dataset]
    cats_to_color = bonvis_fig.bonvis_settings.node_style['annot_info'].annot_to_color
    
    pca_projected = pca_projected_lst[ind_dataset]
    umap_projected = umap_projected_lst[ind_dataset]
    phate_projected = phate_projected_lst[ind_dataset]
    data_descr = dataset_descr_lst[ind_dataset]
    cell_ids = cell_ids_lst[ind_dataset]
    
    # Read in annotation to get color information for the UMAP
    scData = SCData(onlyObject=True, dataset=dataset_ids[ind_dataset])
    scData.metadata.nCells = len(cell_ids)
    scData.metadata.cellIds = cell_ids
    print(os.path.join(input_data_folder, 'annotation'))
    annotation_df, feature_matrices = scData.get_annotations(annotation_folders[ind_dataset])
    special_annotations = {'Pseudobulk': 'annot_pseudobulk'}
    annotation_label = special_annotations[data_descr] if data_descr in special_annotations else 'annot_Celltype3'
    annotation_to_be_used = annotation_df[annotation_label]
    cats = np.sort(np.unique(annotation_to_be_used))

    # celltype_colors = colormaps.get_cmap('tab20')
    # cats_to_color = {cat: celltype_colors(ind) for ind, cat in enumerate(cats)}
    colors = [cats_to_color[cat] for cat in annotation_to_be_used]

    fig = figs_lst[ind_dataset]
    ax = axs_lst[ind_dataset][0, 2]
    ax.set_box_aspect(1)
#     ax.axis('off')
    plt.subplots_adjust(left=0, right=1.0, bottom=0, top=1.0)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    
    for n_comps, pca_proj in pca_projected.items():
        if n_comps in [2]:
            ax.scatter(pca_proj[0, :], pca_proj[1, :], s=10, c=colors)

    ax = axs_lst[ind_dataset][0, 3]
    ax.set_box_aspect(1)
    plt.subplots_adjust(left=0, right=1.0, bottom=0, top=1.0)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)

    for n_comps, umap_proj in umap_projected.items():
        if n_comps in [10, 50]:
            ax.scatter(umap_proj[0, :], umap_proj[1, :], s=10, c=colors)

    ax = axs_lst[ind_dataset][0, 4]
    ax.set_box_aspect(1)
    plt.subplots_adjust(left=0, right=1.0, bottom=0, top=1.0)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)

    for n_comps, phate_proj in phate_projected.items():
        ax.scatter(phate_proj[0, :], phate_proj[1, :], s=10, c=colors)

In [None]:
fig, axs = plt.subplots(ncols=4)
tools_lst = ['bonsai', 'pca', 'umap', 'phate']
bonsai_results = bonsai_results_folders[0]
ind_dataset = 0
true_dists = true_dists_lst[ind_dataset]
dists_lst = [bonsai_dists_lst, pca_dists_lst, umap_dists_lst, phate_dists_lst]
for ind_tool, tool in enumerate(tools_lst):
    ax = axs[ind_tool]
    dists = dists_lst[ind_tool][0]
    ax.scatter(true_dists, dists, s=2)
# Create histogram of correlations figures

# RECALCULATE = False
# RECALCULATE = RECALCULATE or not len(tool_objcts_lst)
# if RECALCULATE:
#     tool_objcts_lst = []
    
# if INCLUDE_SANITY:
#     tools_lst.append('sanity')
# for ind_dataset, bonsai_results in enumerate(bonsai_results_folders):
#     print("\n\nTreating dataset {}\n".format(dataset_descr_lst[ind_dataset]))
#     # if dataset_descr_lst[ind_dataset] not in ['Perfect binary', 'Real covariance (Reco)', 'Rbl + Unb + Reco']:
#     #     continue
#     true_dists = true_dists_lst[ind_dataset]
#     tools_dists_dict = {'bonsai': bonsai_dists_lst[ind_dataset], 'pca': pca_dists_lst[ind_dataset], 'umap': umap_dists_lst[ind_dataset]}

#     if INCLUDE_SANITY:
#         tools_dists_dict['sanity'] =  sanity_lst[ind_dataset]
#     if RECALCULATE:
#         tool_objcts = []
#         true_objct = Dataset(distances=true_dists_lst[ind_dataset], data_type='delta_true', data_id='delta_true', color_types=tools_lst)
#         true_objct.true_dataset_ranks = None
#         for ind_tool, tool in enumerate(tools_lst):
#             data_id = tool + dataset_descr_lst[ind_dataset]
#             tool_objcts.append(
#                 Dataset(distances=tools_dists_dict[tool], data_type=tool, data_id=data_id, color_types=tools_lst))
#     else:
#         tool_objcts = tool_objcts_lst[ind_dataset]
#     for ind_tool, tool_objct in enumerate(tool_objcts):
#         # if ind_tool != 0:
#         #     continue
# #         fig, ax = plt.subplots(figsize=(3,3))
#         fig = figs_lst[ind_dataset]
#         ax = axs_lst[ind_dataset][1, ind_tool]
#         ax.set_box_aspect(1)
#         n_neighbours_list = compare_nearest_neighbours_to_truth([true_objct, tool_objct], make_fig=False, max_neighbours=600, ax=None,
    #                                              only_powers_of_2=True,
    #                                              title='')
    #     avg_rel_diffs, R_vals = compare_pdists_to_truth_per_cell([true_objct, tool_objct], make_fig=True, axs=ax, set_lims=False, return_Rvals=True, XLABEL=False, YLABEL=False, flip_axes=True, first_title=' ', loglog_corr=False)
    # if RECALCULATE:
    #     tool_objcts_lst.append(tool_objcts)

In [None]:
# Create histogram of correlations figures
tools_lst = ['sanity', 'bonsai', 'pca', 'umap', 'phate']

RECALCULATE = True
RECALCULATE = RECALCULATE or not len(tool_objcts_lst)
if RECALCULATE:
    tool_objcts_lst = []
    
# if INCLUDE_SANITY:
#     tools_lst.append('sanity')
for ind_dataset, bonsai_results in enumerate(bonsai_results_folders):
    print("\n\nTreating dataset {}\n".format(dataset_descr_lst[ind_dataset]))
    # if dataset_descr_lst[ind_dataset] not in ['Perfect binary', 'Real covariance (Reco)', 'Rbl + Unb + Reco']:
    #     continue
    true_dists = true_dists_lst[ind_dataset]
    tools_dists_dict = {'bonsai': bonsai_dists_lst[ind_dataset], 'pca': pca_dists_lst[ind_dataset], 'umap': umap_dists_lst[ind_dataset], 'phate': phate_dists_lst[ind_dataset]}

    if INCLUDE_SANITY:
        tools_dists_dict['sanity'] =  sanity_lst[ind_dataset]
    if RECALCULATE:
        tool_objcts = []
        true_objct = Dataset(distances=true_dists_lst[ind_dataset], data_type='delta_true', data_id='delta_true', color_types=tools_lst)
        true_objct.true_dataset_ranks = None
        for ind_tool, tool in enumerate(tools_lst):
            data_id = tool + dataset_descr_lst[ind_dataset]
            tool_objcts.append(
                Dataset(distances=tools_dists_dict[tool], data_type=tool, data_id=data_id, color_types=tools_lst))
    else:
        tool_objcts = tool_objcts_lst[ind_dataset]
    for ind_tool, tool_objct in enumerate(tool_objcts):
        # if ind_tool != 0:
        #     continue
#         fig, ax = plt.subplots(figsize=(3,3))
        fig = figs_lst[ind_dataset]
        ax = axs_lst[ind_dataset][1, ind_tool]
        ax.set_box_aspect(1)
        n_neighbours_list = compare_nearest_neighbours_to_truth([true_objct, tool_objct], make_fig=False, max_neighbours=600, ax=None,
                                                 only_powers_of_2=True,
                                                 title='')
        avg_rel_diffs, R_vals = compare_pdists_to_truth_per_cell([true_objct, tool_objct], make_fig=True, axs=ax, set_lims=False, return_Rvals=True, XLABEL=False, YLABEL=False, flip_axes=True, first_title=' ', loglog_corr=False)
    if RECALCULATE:
        tool_objcts_lst.append(tool_objcts)

In [None]:
nrows = 2
ncols = int(np.ceil(len(dataset_ids)/nrows))
fig_nn, axs_nn = plt.subplots(figsize=(3*ncols,3*nrows), nrows=nrows, ncols=ncols)
marker = '*-' if len(n_neighbours_list) < 40 else '-'

if len(dataset_ids) == 1:
    axs_nn = np.array([axs_nn])
axs_nn = axs_nn.flatten()
for ax in axs_nn:
    ax.axis('off')

for ind_dataset, bonsai_results in enumerate(bonsai_results_folders):
    ax = axs_nn[ind_dataset]
    ax.axis('on')
    dataset_descr = dataset_descr_lst[ind_dataset]
    for ind_tool, tool in enumerate(tools_lst):
        if (not INCLUDE_SANITY) and (tool == 'sanity'):
            continue
        tool_objct = tool_objcts_lst[ind_dataset][ind_tool]
        cf_nn = tool_objct.correct_fractions_of_neighbours
        
        if tool != 'bonsai':
            ax.plot(n_neighbours_list, tool_objct.correct_fractions_of_neighbours, marker,
                    c=tool_objct.data_type_color,
                    label=tool, zorder=0)
        else:
            ax.plot(n_neighbours_list, tool_objct.correct_fractions_of_neighbours, marker,
                    c=tool_objct.data_type_color, linewidth=3, label=tool, zorder=1)
    
    # box = ax.get_position()
    # ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    if ind_dataset == 0:
        ax.legend(loc='lower right')
        ax.set_ylabel('Fraction of correct\nnearest neighbours')
    ax.set_xlabel('Number of nearest neighbours')
    ax.set_xscale('log')
    ax.set_ylim(0, 1)
    ax.xaxis.set_major_formatter(ScalarFormatter())
    ax.ticklabel_format(style='plain', axis='x')
    # ticks = ax.get_xticklabels()
    # ax.set_xticks(ax.get_xticks(), ax.get_xticklabels())
    ax.set_title(dataset_descr)
plt.tight_layout()

In [None]:
# Store boxplots in a figure.
fig_nn.savefig('/scicore/home/nimwegen/degroo0000/sc_datasets/simulated_datasets/knn_figure.png', dpi=300)
fig_nn.savefig('/scicore/home/nimwegen/degroo0000/sc_datasets/simulated_datasets/knn_figure.svg')

In [None]:
# Create boxplots in separate figure as well, now grouped by dataset
boxprops = dict(linewidth=0.05)
flierprops = dict(markersize=2, markeredgewidth=0.5)
medianprops = dict(color='black', linewidth=1)
nrows = 2
ncols = int(np.ceil(len(dataset_ids)/nrows))
fig_bp, axs_bp = plt.subplots(figsize=(3*ncols,3*nrows), nrows=nrows, ncols=ncols)

if len(dataset_ids) == 1:
    axs_bp = np.array([axs_bp])
axs_bp = axs_bp.flatten()
for ax in axs_bp:
    ax.axis('off')

for ind_dataset, bonsai_results in enumerate(bonsai_results_folders):
    ax = axs_bp[ind_dataset]
    ax.axis('on')
    dataset_descr = dataset_descr_lst[ind_dataset]
    pearsonRs = []
    tool_names = []
    for ind_tool, tool in enumerate(tools_lst):
        if (not INCLUDE_SANITY) and (tool == 'sanity'):
            continue
        tool_objct = tool_objcts_lst[ind_dataset][ind_tool]
        tool_names.append(tool)
        pearsonRs.append(tool_objct.pearsonRs)
    ax = axs_bp[ind_dataset]
    bplot = ax.boxplot(pearsonRs, whis=(5, 95), labels=tool_names, patch_artist=True, 
                           flierprops=flierprops, medianprops=medianprops, boxprops=boxprops)
    # fill with colors
    for ind_patch, patch in enumerate(bplot['boxes']):
        patch.set_facecolor(color=tool_objcts_lst[ind_dataset][ind_patch].data_type_color)
    ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right')
    ax.set_ylim(-0.05,1.05)
    ax.set_title(dataset_descr)
    if ind_dataset == 0:
        ax.set_ylabel('Pearson R')
plt.tight_layout()

In [None]:
# Store boxplots in a figure.
fig_bp.savefig('/scicore/home/nimwegen/degroo0000/sc_datasets/simulated_datasets/box_plots_figure.png', dpi=300)
fig_bp.savefig('/scicore/home/nimwegen/degroo0000/sc_datasets/simulated_datasets/box_plots_figure.svg')

In [None]:
# Change some titles and labels
for ind_fig, fig_obj in enumerate(figs_lst):
    axs = axs_lst[ind_fig]
#     axs[2,0].axis('off')
    if dataset_descr_lst[ind_fig] in dataset_descr_lst_gt:
        axs[0,0].set_title('Ground truth', fontsize=16)
    axs[0,1].set_title('Bonsai', fontsize=16)
    axs[0,2].set_title('PCA', fontsize=16)
    axs[0,3].set_title('UMAP', fontsize=16)
    axs[0,4].set_title('PHATE', fontsize=16)
    axs[1,0].set_title('Sanity', fontsize=16)
    axs[1,1].set_title('Bonsai', fontsize=16)
    axs[1,2].set_title('PCA', fontsize=16)
    axs[1,3].set_title('UMAP', fontsize=16)
    axs[1,4].set_title('PHATE', fontsize=16)
    # if INCLUDE_SANITY:
    #     axs[0,4].set_title('Sanity', fontsize=16)
    #     axs[0,4].axis('off')
    for ind_ax, ax in enumerate(axs[1,:]):
        ax.set_xlabel("Pearson R", fontsize=12)
    
    # Turn everything off except for the y-label for the top left figure.
    ax = axs[0,0]
    axs[0,0].axis('on')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    ax.tick_params(bottom=False, labelbottom=False)
    ax.tick_params(top=False, right=False)
    ax.tick_params(left=False, labelleft=False)

    # axs[0,0].set_ylabel('Visualization', fontsize=16)
    axs[1,0].set_ylabel('Number of cells', fontsize=12)
#     axs[1,1].set_ylabel('True squared\ndistances')
#     axs[1,1].set_xlabel('Inferred squared\ndistances')
#     axs[2,1].set_ylabel('R-squared')
    # axs[1,0].axis('off')
    # axs[1,0].text(0.5, 0.5, "Pearson R-values\nfor correlation of\ninferred and true\ndistances between\neach cell and all\nothers.",
    #                 horizontalalignment='center', verticalalignment='center', transform=axs[1,0].transAxes, fontsize=16)
    fig_obj.subplots_adjust(bottom=0.05, top=0.88, left=0.1, right=.9)
    fig_obj.text(0.03, 0.9, 'A)', fontsize=14, fontweight='bold', va='top', ha='left')
    fig_obj.text(0.03, 0.45, 'B)', fontsize=14, fontweight='bold', va='top', ha='left')


In [None]:
# Save figures
for ind_fig, fig_obj in enumerate(figs_lst):
    dataset_descr_label = dataset_descr_lst[ind_fig].replace(' ','_')
    fig_obj.savefig(os.path.join(input_data_folders[ind_fig], '{}_overview_figure_wo_boxplots.png'.format(dataset_descr_label)), dpi=300)
    fig_obj.savefig(os.path.join(input_data_folders[ind_fig], '{}_overview_figure_wo_boxplots.svg'.format(dataset_descr_label)))
    print(os.path.join(input_data_folders[ind_fig], '{}_overview_figure_wo_boxplots.svg'.format(dataset_descr_label)))

### Show figures

In [None]:
ind=0
figs_lst[ind]

In [None]:
ind=1
figs_lst[ind]

In [None]:
ind=2
figs_lst[ind]

In [None]:
ind=3
figs_lst[ind]

In [None]:
ind=4
figs_lst[ind]

In [None]:
ind=5
figs_lst[ind]

In [None]:
ind=6
figs_lst[ind]

## Make big panel-figure with all visualizations

In [None]:
%%capture
ncols = 5
nrows = len(dataset_descr_lst)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 21))
if nrows == 1:
    axs = axs[None, :]
# for ax in axs.flatten():
#     ax.set_box_aspect(1)

for ind_dataset, dataset_descr in enumerate(dataset_descr_lst):
    axs[ind_dataset, 0].set_ylabel(dataset_descr)

axs[0,0].set_title('Ground truth', fontsize=16)
axs[0,1].set_title('Bonsai', fontsize=16)
axs[0,2].set_title('PCA', fontsize=16)
axs[0,3].set_title('UMAP', fontsize=16)    
axs[0,4].set_title('PHATE', fontsize=16)  
    # ax.axis('off')
# plt.tight_layout()
    # plt.subplots_adjust(left=0, right=1.0, bottom=0.12, top=0.88)
#     plt.tight_layout()
    # fig.suptitle(dataset_descr_lst[ind_dataset], fontsize=20)

In [None]:
# Create Bonsai visualization
# Visualize the tree in the equal-daylight layout, with the correct celltype-annotation
for ind_dataset, bonvis_fig in enumerate(bonvis_fig_lst):
    dataset_fig = fig
    dataset_ax = axs[ind_dataset, 1]
    bonvis_fig.update_figure()
    bonvis_fig.create_figure(figsize=(6, 6), fig=dataset_fig, ax=dataset_ax)
    dataset_descr = dataset_descr_lst[ind_dataset]
    plt.subplots_adjust(left=-0.10, right=1.1, bottom=-0.1, top=1.1)

    # if dataset_descr not in dataset_descr_lst_gt:
    #     ax = axs[ind_dataset, 0]
    #     ax.spines['top'].set_visible(False)
    #     ax.spines['right'].set_visible(False)
    #     ax.spines['bottom'].set_visible(False)
    #     ax.spines['left'].set_visible(False)
    
    #     ax.tick_params(bottom=False, labelbottom=False)
    #     ax.tick_params(top=False, right=False)
    #     ax.tick_params(left=False, labelleft=False)

In [None]:
# Create Bonsai visualization for ground truth
# Visualize the tree in the equal-daylight layout, with the correct celltype-annotation
for ind_dataset_gt, bonvis_fig in enumerate(bonvis_fig_lst_gt):
    ind_dataset = dataset_descr_lst.index(dataset_descr_lst_gt[ind_dataset_gt])
    dataset_fig = fig
    ax = axs[ind_dataset, 0]
    bonvis_fig.update_figure()
    bonvis_fig.create_figure(figsize=(6, 6), fig=dataset_fig, ax=ax)
    plt.subplots_adjust(left=0, right=1.0, bottom=0, top=1.0)

In [None]:
for ind_dataset, dataset_descr in enumerate(dataset_descr_lst):
    ax = axs[ind_dataset, 0]
    ax.axis('on')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    ax.tick_params(bottom=False, labelbottom=False)
    ax.tick_params(top=False, right=False)
    ax.tick_params(left=False, labelleft=False)
    ax.set_ylabel(dataset_descr, fontsize=12)

In [None]:
# Make figure for 2D-PCA and 2D-UMAP
for ind_dataset, input_data_folder in enumerate(input_data_folders):
    print(ind_dataset)
    bonvis_fig = bonvis_fig_lst[ind_dataset]
    cats_to_color = bonvis_fig.bonvis_settings.node_style['annot_info'].annot_to_color
    
    pca_projected = pca_projected_lst[ind_dataset]
    umap_projected = umap_projected_lst[ind_dataset]
    phate_projected = phate_projected_lst[ind_dataset]
    data_descr = dataset_descr_lst[ind_dataset]
    cell_ids = cell_ids_lst[ind_dataset]
    
    # Read in annotation to get color information for the UMAP
    scData = SCData(onlyObject=True, dataset=dataset_ids[ind_dataset])
    scData.metadata.nCells = len(cell_ids)
    scData.metadata.cellIds = cell_ids
    annotation_df, feature_matrices = scData.get_annotations(annotation_folders[ind_dataset])
    special_annotations = {'Pseudobulk': 'annot_pseudobulk'}
    annotation_label = special_annotations[data_descr] if data_descr in special_annotations else 'annot_Celltype3'
    annotation_to_be_used = annotation_df[annotation_label]
    cats = np.sort(np.unique(annotation_to_be_used))

    # celltype_colors = colormaps.get_cmap('tab20')
    # cats_to_color = {cat: celltype_colors(ind) for ind, cat in enumerate(cats)}
    colors = [cats_to_color[cat] for cat in annotation_to_be_used]

    ax = axs[ind_dataset, 2]
    ax.set_box_aspect(1)
#     ax.axis('off')
    plt.subplots_adjust(left=0, right=1.0, bottom=0, top=1.0)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    
    for n_comps, pca_proj in pca_projected.items():
        if n_comps in [2]:
            ax.scatter(pca_proj[0, :], pca_proj[1, :], s=10, c=colors)

    ax = axs[ind_dataset, 3]
    ax.set_box_aspect(1)
    plt.subplots_adjust(left=0, right=1.0, bottom=0, top=1.0)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)

    for n_comps, umap_proj in umap_projected.items():
        if n_comps in [10]:
            ax.scatter(umap_proj[0, :], umap_proj[1, :], s=10, c=colors)

    ax = axs[ind_dataset, 4]
    ax.set_box_aspect(1)
    plt.subplots_adjust(left=0, right=1.0, bottom=0, top=1.0)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)

    for n_comps, phate_proj in phate_projected.items():
        ax.scatter(phate_proj[0, :], phate_proj[1, :], s=10, c=colors)

In [None]:
fig.savefig('/scicore/home/nimwegen/degroo0000/sc_datasets/simulated_datasets/all_visualizations.png', dpi=300)
fig.savefig('/scicore/home/nimwegen/degroo0000/sc_datasets/simulated_datasets/all_visualizations.svg')

In [None]:
fig