# Preamble

This notebook produces the figure for the single cell related study in the main text for illustrating applications of our method. 

In [None]:
from __future__ import division

# General parameters

In [None]:
prefix = "test"

cell_type_selection = "filter"

n_samples = 10
sampling_scheme = "double-replacement"
n_sampled_cells_per_celltype = 1000  # for top-k calculation

topk_target = "function"  # phenotype or function
topk_ratio = 0.001

n_max_cells_emb = 100  # for visualization 

n_threads = 64
n_jobs_topk = 64
n_jobs_fcs = 64

In [None]:
# # Parameters

# prefix = "pheno_v1"

# cell_type_selection = "filter"

# n_samples = 100
# sampling_scheme = "double-replacement"
# n_sampled_cells_per_celltype = 10000  # for top-k calculation
# topk_target = "phenotype"  # 0.01%
# topk_ratio = 0.0001  # 0.01%

# n_max_cells_emb = 1000  # for visualization 

In [None]:
# # Parameters

# prefix = "function300_v1"

# cell_type_selection = "filter"

# n_samples = 300
# sampling_scheme = "double-replacement"
# n_sampled_cells_per_celltype = 10000  # for top-k calculation
# topk_target = "function"  # 0.01%
# topk_ratio = 0.0001  # 0.01%

# n_max_cells_emb = 3000  # for visualization 

In [None]:
# Parameters

prefix = "function1000_v1"

cell_type_selection = "filter"

n_samples = 1000
sampling_scheme = "double-replacement"
n_sampled_cells_per_celltype = 10000  # for top-k calculation
topk_target = "function"  # 0.01%
topk_ratio = 0.0001  # 0.01%

n_max_cells_emb = 1000  # for visualization 

In [None]:
# # Parameters
# prefix = "rescue_v3"
# n_samples = 100
# sampling_scheme = "double-replacement"
# n_sampled_cells_per_celltype = 10000
# topk_target = "function"
# topk_ratio = 0.001
# cell_type_selection = "filter"
# n_max_cells_emb = 1000


In [None]:
# # Parameters

# prefix = "combined_v1"

# cell_type_selection = "filter"

# n_samples = 100
# sampling_scheme = "double-replacement"
# n_sampled_cells_per_celltype = 10000  # for top-k calculation
# topk_target = "combined"  # 0.01%
# topk_ratio = 0.0001  # 0.01%

# n_max_cells_emb = 1000  # for visualization 

In [None]:
verbose = 1

# Preamble

In [None]:
notebook_name = f"application___singlecell"\
                f"___parameters"\
                f"___prefix__{prefix}"\
                f"___cell_types__{cell_type_selection}"\
                f"___n_samples__{n_samples}"\
                f"___sampling_scheme__{sampling_scheme}"\
                f"___n_sampled_cells_per_celltype__{n_sampled_cells_per_celltype}"\
                f"___topk_target__{topk_target}"\
                f"___topk_ratio__{topk_ratio}"\
                f"___n_max_cells_emb__{n_max_cells_emb}"
print(notebook_name)

# Imports

In [None]:
%load_ext autoreload
%autoreload 2

# disable parallelization for BLAS and co.
from corals.threads import set_threads_for_external_libraries
set_threads_for_external_libraries(n_threads=n_threads)

# general
import re
import collections
import pickle
import warnings 
import joblib
import pathlib

# data
import numpy as np
import pandas as pd
import h5py

# ml / stats
import sklearn
import scipy.stats

# plotting
import matplotlib.pyplot as plt
import seaborn as sns

# init matplotlib defaults
import matplotlib
# matplotlib.rcParams['figure.facecolor'] = (1,1,1,1)


In [None]:
from matplotlib.collections import LineCollection
import sklearn.manifold

In [None]:
import corals.correlation.utils
import sklearn.impute
from coralsarticle.visualization import CurvedText

In [None]:
from coralsarticle.data.process.singlecell import load_cytof, prepare_cell_sampling, sample_cell_subgroups
import coralsarticle.data.process.singlecell

# Load results

In [None]:
path = pathlib.Path("../_out/" + notebook_name)
path.mkdir(parents=True, exist_ok=True)
path

In [None]:
# load essentials
print(path)
cell_types = pickle.load(open(path / "cell_types.pickle", "rb"))
subgroups = pickle.load(open(path / "subgroups.pickle", "rb"))

cells_phenotype_loaded = pickle.load(open(path / "cells_phenotype.pickle", "rb"))
cells_function_loaded = pickle.load(open(path / "cells_function.pickle", "rb"))

topk_loaded = pickle.load(open(path / "topk.pickle", "rb"))
topk_matrices_loaded = pickle.load(open(path / "topk_matrices.pickle", "rb"))

topk_stats_samples = pickle.load(open(path / "topk_stats_samples.pickle", "rb"))

cells_phenotype_emb_loaded = pickle.load(open(path / "cells_phenotype_emb.pickle", "rb"))
cells_phenotype_emb_idx_loaded = pickle.load(open(path / "cells_phenotype_emb_idx.pickle", "rb"))

In [None]:
# support legacy samples
idx_deque = -1
if isinstance(cells_phenotype_loaded, collections.deque):
    cells_phenotype = cells_phenotype_loaded[idx_deque]
    cells_function = cells_function_loaded[idx_deque]
    topk = topk_loaded[idx_deque]
    topk_matrices = topk_matrices_loaded[idx_deque]
    cells_phenotype_emb = cells_phenotype_emb_loaded[idx_deque]
    cells_phenotype_emb_idx = cells_phenotype_emb_idx_loaded[idx_deque]
else:
    # legacy samples
    cells_phenotype = cells_phenotype_loaded
    cells_function = cells_function_loaded
    topk = topk_loaded
    topk_matrices = topk_matrices_loaded
    cells_phenotype_emb = cells_phenotype_emb_loaded
    cells_phenotype_emb_idx = cells_phenotype_emb_idx_loaded

In [None]:
cell_type_order = coralsarticle.data.process.singlecell.CELL_TYPE_ORDERS[cell_type_selection]

In [None]:
# TODO: Should load this from files (it is stored as `topk_stats_bins.pickle`)
bins = np.concatenate([[-2], np.linspace(-1,1,201), [2]])

# Visualization

In [None]:
cell_keys = pd.read_excel("../data/raw/singlecell/cell_keys.xlsx")
rename = {r["Cell Key"]:r["Short Name"] for _, r in cell_keys.iterrows()}
cell_keys

In [None]:
print(cell_keys[["Short Name", "Long Name"]].to_latex(index=False))

In [None]:
# %matplotlib inline
# n = len(cell_type_order)
# fig, axes = plt.subplots(n, n, figsize=(4 * n, 4 * n))
# for i1, c1 in enumerate(cell_type_order):
#     for i2, c2 in enumerate(cell_type_order):
#         pair = (c1, c2)
#         ax = axes[i1, i2]
#         skip = False
#         for t in df.timepoint.unique():
#             v = df.loc[df.timepoint == t][pair]
#             if len(np.unique(v)) == 1:
#                 skip =True
        
#         if not skip:
#             sns.kdeplot(df[pair], hue=df.timepoint, ax=ax)

In [None]:
# %matplotlib inline
# n = len(cell_type_order)
# fig, axes = plt.subplots(n, n, figsize=(4 * n, 4 * n))
# for i1, c1 in enumerate(cell_type_order):
#     for i2, c2 in enumerate(cell_type_order):
        
#         pair = (c1, c2)
#         ax = axes[i1, i2]
#         skip = False
#         for t in df.timepoint.unique():
#             v = df.loc[df.timepoint == t][pair]
#             if len(np.unique(v)) == 1:
#                 skip =True
        
#         if not skip:
#             diff = df[df.timepoint == "T3"][pair].values - df[df.timepoint == "PP"][pair].values

#             q_left = np.quantile(diff, 0.05)
#             q_right = np.quantile(diff, 0.95)
# #             q_left = np.quantile(diff, 0.01)
# #             q_right = np.quantile(diff, 0.99)
            
#             if q_right < 0:
#                 color = "orange"
#             elif q_left > 0:
#                 color = "blue"
#             else:
#                 color = "grey"
            
#             sns.kdeplot(
#                 diff, 
#                 ax=ax, 
#                 color=color,
#                 linewidth=1 if q_left < 0 and q_right > 0 else 5)
#             ax.axvline(0, color="black", linewidth=3)
#             ax.axvline(q_left, linestyle="--", color=color)
#             ax.axvline(q_right, linestyle="--", color=color)
#             ax.set_title(f"{c1.replace('cells', '')} / {c2.replace('cells', '')}")
#             if pair in differential_cell_pairs_effect_size_map:
#                 ax.annotate(f"es={differential_cell_pairs_effect_size_map[pair]:.02}f", xy=(10,10), xycoords="axes points")

In [None]:
# def effect_size_original(x, y, quantile_left=0.05, quantile_right=0.95):
def effect_size_original(x, y, quantile_left=0.05, quantile_right=0.95):

    # effect size
    difference = (x - y)
#     average = np.mean(difference)
    average = np.median(difference)
    left = np.quantile(difference, quantile_left)
    right = np.quantile(difference, quantile_right)

    if right < 0:
        effect_size = average / (average - right)
    elif left > 0:
        effect_size = average / (average - left)
    else:
        effect_size = np.nan

    if (left > 0 and right > 0) or (left < 0 and right < 0):
        return effect_size
    else:
        return np.nan

In [None]:
def effect_size_cohen(x, y, threshold=0.8):
    # https://en.wikipedia.org/wiki/Effect_size#Cohen's_d
    es = (np.mean(x) - np.mean(y)) / np.sqrt((np.std(x)**2 + np.std(y)**2) / 2)
    return np.nan if threshold is not None and np.abs(es) < threshold else es

In [None]:
# def effect_size_cohen_sample(x, y, threshold=1.2, quantile_left=0.05, quantile_right=0.95):  
def effect_size_cohen_sample(x, y, threshold=1.2, quantile_left=0.05, quantile_right=0.95):  
    
    effect_sizes = []
            
    for i in range(100):
        idx_x = np.random.choice(np.arange(len(x)), len(x), replace=True)
        idx_y = np.random.choice(np.arange(len(y)), len(y), replace=True)
        effect_size = effect_size_cohen(x[idx_x], y[idx_y], threshold=None)
        effect_sizes.append(effect_size)

    es_mean = np.mean(effect_sizes)
    left = np.quantile(effect_sizes, quantile_left)  # two sided 0.025/0.975 threshold!? not sure
    right = np.quantile(effect_sizes, quantile_right)

#     plt.figure()
#     sns.kdeplot(effect_sizes)
#     plt.axvline(es_mean)
#     plt.axvline(left)
#     plt.axvline(right)
#     plt.show()
    
    if np.abs(es_mean) >= threshold:
        if (left > 0 and right > 0) or (left < 0 and right < 0):
            return es_mean
    return np.nan

In [None]:
# def effect_size_cohen_median(x, y, threshold=1.2):
def effect_size_cohen_median(x, y, threshold=1.2):
    # TODO: 
    # would be better to use PERCENTAGE BEND MIDVARIANCE probably: https://garstats.wordpress.com/2018/04/04/dbias/
    # though still not good probably: https://aakinshin.net/posts/nonparametric-effect-size2/
    # note: totally breaks down for non-normal it seems, with huuuuge median_abs_deviations
    effect_size = (np.median(x) - np.median(y)) / np.sqrt((scipy.stats.median_abs_deviation(x)**2 + scipy.stats.median_abs_deviation(y)**2) / 2)
    return np.nan if np.abs(effect_size) < threshold else effect_size

In [None]:
def effect_size_median_relative(x, y, threshold=0.5, quantile_left=0.05, quantile_right=0.95):
    effect_sizes = []
    effect_sizes_relative = []

    for i in range(1000):
        idx_x = np.random.choice(np.arange(len(x)), len(x), replace=True)
        idx_y = np.random.choice(np.arange(len(y)), len(y), replace=True)

        median_x = np.median(x[idx_x])
        median_y = np.median(y[idx_y])

        effect_size = median_x - median_y
        effect_sizes.append(effect_size)

        if median_x == 0 and median_y == 0:
            effect_size_relative = 0
        elif median_x == 0 or median_y == 0:
            effect_size_relative = np.sign(median_x - median_y) * 10
        else:
            effect_size_relative =  np.sign(median_x - median_y) * (max(median_x, median_y) / min(median_x, median_y) - 1)           
        effect_sizes_relative.append(effect_size_relative)

    es_mean = np.mean(effect_sizes)
    left = np.quantile(effect_sizes, quantile_left)
    right = np.quantile(effect_sizes, quantile_right)

    esr_median = np.median(effect_sizes_relative)

    if np.abs(esr_median) >= 0.5:
        if (left > 0 and right > 0) or (left < 0 and right < 0):
            return esr_median
    return np.nan

In [None]:
cliffs_delta_threshold = 0.622
def effect_size_cliffs_delta(x, y, p_threshold=0.05, effect_size_threshold=0.622):
    """
    With a large grain of salt, for interpretation we could convert Cohen's d thresholds defined by 
    
    > Sawilowsky, S (2009). "New effect size rules of thumb" (Wikipedia)
    
    to Cliff's delta by assuming underlying normal distributions:
    
    > Appropriate statistics for ordinal level data: Should we really be using t-test and cohen’s d for evaluating group differences on the NSSE and other surveys? 2006, Ramano et al.

    For this we can for example use `cohd2delta` from R library `orddom`.

    This results in:
      * Very small cohen=0.01, cliffs=0.077
      * Small      cohen=0.20, cliffs=0.147
      * Medium     cohen=0.50, cliffs=0.330
      * Large      cohen=0.80, cliffs=0.474
      * -          cohen=1.00, cliffs=0.554
      * Very large cohen=1.20, cliffs=0.622
      * Huge       cohen=2.00, cliffs=0.811
    """
    if x.size == 0 or y.size == 0:
        warnings.warn("No data given for `x` or `y`.")
        return np.nan
    u, p = scipy.stats.mannwhitneyu(x, y)
    es = cliffsDelta(x, y)
    
    if p_threshold is not None and p > p_threshold:
        return np.nan
    if effect_size_threshold is not None and abs(es) < effect_size_threshold:
        return np.nan
    return es
    
# the following code is based on: https://github.com/neilernst/cliffsDelta/blob/master/cliffsDelta.py

def cliffsDelta(lst1, lst2):

    m, n = len(lst1), len(lst2)
    lst2 = sorted(lst2)
    j = more = less = 0
    for repeats, x in runs(sorted(lst1)):
        while j <= (n - 1) and lst2[j] < x:
            j += 1
        more += j*repeats
        while j <= (n - 1) and lst2[j] == x:
            j += 1
        less += (n - j)*repeats
    d = (more - less) / (m*n)
    return d

def runs(lst):
    """Iterator, chunks repeated values"""
    for j, two in enumerate(lst):
        if j == 0:
            one, i = two, 0
        if one != two:
            yield j - i, one
            i = j
        one = two
    yield j - i + 1, two

In [None]:
# calculate statistically significant differences

sample = []
timepoint = []
matrices = []
for i_sample, stats in enumerate(topk_stats_samples):
    for k,v in stats.items():
        sample.append(i_sample)
        timepoint.append(k)
        matrices.append(v["counts"].reshape(1, -1))
#         matrices.append(v["means"].reshape(1, -1))
matrices = np.concatenate(matrices)
    
cc = [(c1,c2) for c1 in cell_types for c2 in cell_types]
df = pd.DataFrame(matrices, columns=cc)
df.insert(0, "sample", sample)
df.insert(1, "timepoint", timepoint)

differential_cell_pairs = []
differential_cell_pairs_p = []

differential_cell_pairs_effect_size = []
differential_cell_pairs_effect_size_stats = []

for c in cc:
    msk_x = df["timepoint"] == "T3"
    msk_y = df["timepoint"] == "PP"
    
    x = df.loc[msk_x,[c]].values.flatten()
    y = df.loc[msk_y,[c]].values.flatten()
    
    if not np.all((y - x) == 0):
#         print(c, x, y)

        # p-value
        w,p = scipy.stats.ranksums(x,y)
        if p < 0.05 / cell_type_order.size**2:
            differential_cell_pairs.append(c)
            differential_cell_pairs_p.append(p)

#         effect_size_mode = "original"
#         effect_size_mode = "cohen"
#         effect_size_mode = "cohen_median"
#         effect_size_mode = "cohen_sample"
#         effect_size_mode = "median_relative"
        effect_size_mode = "cliffs_delta"
            
        if effect_size_mode == "original":
            effect_size = effect_size_original(x, y)

        elif effect_size_mode == "median_relative":
            effect_size = effect_size_median_relative(x, y)
            
        elif effect_size_mode == "cohen":
            effect_size = effect_size_cohen(x, y)
                
        elif effect_size_mode == "cohen_sample":
            effect_size = effect_size_cohen_sample(x, y)
            
        elif effect_size_mode == "cohen_median":
            effect_size = effect_size_cohen_median(x, y)
            
        elif effect_size_mode == "cliffs_delta":
            effect_size = effect_size_cliffs_delta(x, y, p_threshold=0.05, effect_size_threshold=cliffs_delta_threshold)
            
        else:
            raise ValueError(f"Unknown effect size mode: {effect_size_mode}")
            
        if not np.isnan(effect_size):  
            differential_cell_pairs_effect_size.append(c)
            differential_cell_pairs_effect_size_stats.append(effect_size)

differential_cell_pairs_map = {
    k:p for k,p in zip(differential_cell_pairs, differential_cell_pairs_p)}

differential_cell_pairs_effect_size_map = {
    k:p for k,p in zip(differential_cell_pairs_effect_size, differential_cell_pairs_effect_size_stats) 
}

print(len(differential_cell_pairs_map))
print(len(differential_cell_pairs_effect_size_map))

In [None]:
("mDCs_noMDSC", "ncMCs_noMDSC") in cc

In [None]:
%matplotlib inline

n = len(cell_type_order)
fig, axes = plt.subplots(n, n, figsize=(4 * n, 4 * n))

for i1, c1 in enumerate(cell_type_order):
    for i2, c2 in enumerate(cell_type_order):
        
        pair = (c1, c2)
        ax = axes[i1, i2]
        skip = False
        for t in df.timepoint.unique():
            v = df.loc[df.timepoint == t][pair]
            if len(np.unique(v)) == 1:
                skip =True
        
        if not skip:
            sns.kdeplot(
                df[pair], 
                hue=df.timepoint, 
                ax=ax, 
                linewidth=5 if (c1,c2) in differential_cell_pairs_effect_size_map or (c2,c1) in differential_cell_pairs_effect_size_map else 1)

In [None]:
# test
# pair = ("mDCs_noMDSC", "ncMCs_noMDSC")
pair = ("Bcells", "CD16+CD56-NKcells")
# pair = list(differential_cell_pairs_effect_size_map.keys())[0]

v1 = df.loc[df.timepoint == "T3"][pair].values
v2 = df.loc[df.timepoint == "PP"][pair].values
es = effect_size_cohen_median(v1, v2)
es = effect_size_cliffs_delta(v1, v2)
print(
    es, 
    differential_cell_pairs_effect_size_map[pair] if pair in differential_cell_pairs_effect_size_map or (pair[1],pair[0]) in differential_cell_pairs_effect_size_map else "NA",
    scipy.stats.shapiro(v1)[1],
    scipy.stats.shapiro(v2)[1]
    )

sns.kdeplot(df[pair], hue=df.timepoint, linewidth=5 if (pair[1],pair[0]) in differential_cell_pairs_effect_size_map or pair in differential_cell_pairs_effect_size_map else 1)

In [None]:
# import corals.correlation.fast

# a = cells_function["T3"][pair[0]]
# b = cells_function["T3"][pair[1]]
# c1 = corals.correlation.fast.cor_matrix_symmetrical(a.transpose(), b.transpose())
# c1 = c1[c1 > 0.8]

# a = cells_function["PP"][pair[0]]
# b = cells_function["PP"][pair[1]]
# c2 = corals.correlation.fast.cor_matrix_symmetrical(a.transpose(), b.transpose())
# c2 = c2[c2 > 0.8]

# sns.histplot(x=np.concatenate([c1.flatten(), c2.flatten()]), hue=np.repeat(["T3", "PP"], (c1.size, c2.size)))
# # plt.xlim((-1,1)

In [None]:
cell_types_innate = cell_type_order[2:7]
cell_types_adaptive = cell_type_order[9:]
print(cell_type_order)
print(cell_types_innate)
print(cell_types_adaptive)

In [None]:
df

In [None]:
path_figures = pathlib.Path(f"../_out/figures/{notebook_name}")
path_figures.mkdir(parents=True, exist_ok=True)

In [None]:
contrast_colors = [sns.color_palette('deep')[1],sns.color_palette()[0]]

In [None]:
%matplotlib inline

# source: https://stackoverflow.com/questions/67188162/is-there-a-way-to-add-hatch-marks-on-a-seaborn-displot-using-a-kernal-density-es

bcells_label = cell_keys.set_index('Cell Key').loc['Bcells', 'Short Name']
nkcells_label = cell_keys.set_index('Cell Key').loc['CD16+CD56-NKcells', 'Short Name']
plots = [
        (f"innate / {bcells_label}", "Bcells", cell_types_innate),
        (f"innate / {nkcells_label}", "CD16+CD56-NKcells", cell_types_innate),
        (f"adaptive / {bcells_label}", "Bcells", cell_types_adaptive),
        (f"adaptive / {nkcells_label}", "CD16+CD56-NKcells", cell_types_adaptive)
]

fig, axes = plt.subplots(4, 1, figsize=(4.2, 2* 4), sharex=False, sharey=False, dpi=300)

for i, (name, source, targets) in enumerate(plots):

    ax = axes[i]

    v1 = np.array([df.loc[df.timepoint == "T3"][(source, t)].values for t in targets]).sum(axis=0)
    v2 = np.array([df.loc[df.timepoint == "PP"][(source, t)].values for t in targets]).sum(axis=0)

    # es = effect_size_cohen_median(v1, v2)
    es = effect_size_cliffs_delta(v1, v2, None, None)
    print(es)
    # print(
    #     es, 
    #     differential_cell_pairs_effect_size_map[pair] if pair in differential_cell_pairs_effect_size_map or (pair[1],pair[0]) in differential_cell_pairs_effect_size_map else "NA",
    #     scipy.stats.shapiro(v1)[1],
    #     scipy.stats.shapiro(v2)[1]
    #     )

    pal1 = ["grey", contrast_colors[0]]
    pal2 = ["grey", contrast_colors[1]]
    if es > 0:
        pal = pal2
    else:
        pal = pal1
    
    sns.kdeplot(
        x=v2, 
        fill=True, common_norm=False, palette=pal,
        alpha=.8, linewidth=1,
        color=pal[1],
        label="PP",
        ax=ax,
    )
    
    sns.kdeplot(
        x=v1, 
        fill=True, common_norm=False, palette=pal,
        alpha=.3, linewidth=1,
        color=pal[1],
        label="T3",
        ax=ax,
    )
    
    hatches = ['///', '']
    for collection, hatch in zip(ax.collections[::-1], hatches):
        collection.set_hatch(hatch)
    
    ax.legend(facecolor=(1,1,1,0.3), fancybox=False, edgecolor='None', loc="lower right", borderaxespad=0.08)
    
#     ax.set_axis_off()
#     ax.legend().remove()
    ax.set_ylabel("")
    ax.set_yticks([])
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    
    ax.set_xticklabels([f"{t / 1000:.0f}" for t in ax.get_xticks()])
    ax.annotate(name, (0,.075), xycoords="axes fraction", bbox=dict(facecolor='white', edgecolor='None', alpha=0.3))

    # sns.violinplot(
    #     y=np.concatenate([v1, v2]), 
    #     x=np.repeat(["T3", "PP"], (v1.size, v2.size)))
    if i + 1 == len(plots):
        ax.set_xlabel("number of top-k correlation ($10^3$)")

fig.savefig(path_figures / "innate_adaptive.pdf", bbox_inches="tight", facecolor='none') 

In [None]:
# # getting necessary libraries
# import numpy as np
# import pandas as pd
# import seaborn as sns
# import matplotlib.pyplot as plt
# sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

# # we generate a color palette with Seaborn.color_palette()
# pal = sns.color_palette(palette='coolwarm', n_colors=12)


# # in the sns.FacetGrid class, the 'hue' argument is the one that is the one that will be represented by colors with 'palette'
# g = sns.FacetGrid(
#     pd.DataFrame(dict(count=np.concatenate([v1, v2]), row=np.repeat(["T3", "PP"], (v1.size, v2.size)))), 
#     row="row", 
#     hue="row", 
#     aspect=5, 
#     height=3, 
#     palette=pal)

# # then we add the densities kdeplots for each month
# g.map(sns.kdeplot, 'count',
#       bw_adjust=1, clip_on=False,
#       fill=True, alpha=1, linewidth=1.5)

# # here we add a white line that represents the contour of each kdeplot
# g.map(sns.kdeplot, 'count', 
#       bw_adjust=1, clip_on=False, 
#       color="w", lw=2)

# # here we add a horizontal line for each plot
# g.map(plt.axhline, y=0,
#       lw=2, clip_on=False)

# # # we loop over the FacetGrid figure axes (g.axes.flat) and add the month as text with the right color
# # # notice how ax.lines[-1].get_color() enables you to access the last line's color in each matplotlib.Axes
# # for i, ax in enumerate(g.axes.flat):
# #     ax.text(-15, 0.02, month_dict[i+1],
# #             fontweight='bold', fontsize=15,
# #             color=ax.lines[-1].get_color())
    
# # we use matplotlib.Figure.subplots_adjust() function to get the subplots to overlap
# g.fig.subplots_adjust(hspace=-0.3)

# # # eventually we remove axes titles, yticks and spines
# g.set_titles("")
# g.set(yticks=[])
# g.despine(bottom=True, left=True)

# plt.setp(ax.get_xticklabels(), fontsize=15, fontweight='bold')
# # plt.xlabel('Temperature in degree Celsius', fontweight='bold', fontsize=15)
# # g.fig.suptitle('Daily average temperature in Seattle per month',
# #                ha='right',
# #                fontsize=20,
# #                fontweight=20)

# plt.show()

In [None]:
mean = df.groupby("timepoint").mean().drop(columns="sample")
mean_diff = mean.loc["T3",:] - mean.loc["PP",:]

In [None]:
mean_grouped = mean.copy()
mean_grouped.columns = pd.MultiIndex.from_tuples(mean.columns)
mean_grouped = mean_grouped.stack().groupby("timepoint").sum()
mean_grouped = mean_grouped.div(mean_grouped.max(axis=1).values, axis=0)

In [None]:
std = df.groupby("timepoint").std(ddof=0).drop(columns="sample")
std_diff = np.sqrt((std.loc["T3",:]**2 + std.loc["PP",:]**2) / 2)

In [None]:
effect_size_matrix = mean_diff.div(std_diff)
effect_size_matrix.index = pd.MultiIndex.from_tuples(effect_size_matrix.index)
effect_size_matrix= effect_size_matrix.unstack()

In [None]:
(np.abs(effect_size_matrix) >= 1.2).sum().sum()

In [None]:
i = 17; tp="T3"
df.iloc[(df["timepoint"] == "T3").values,i].hist(alpha=0.5)
df.iloc[(df["timepoint"] == "PP").values,i].hist(alpha=0.5)
print(scipy.stats.shapiro(df.iloc[(df["timepoint"] == "T3").values,i]))
print(scipy.stats.shapiro(df.iloc[(df["timepoint"] == "PP").values,i]))

In [None]:
msk = np.ones((len(cell_type_order), len(cell_type_order)))
for i1, c1 in enumerate(cell_type_order):
    for i2, c2 in enumerate(cell_type_order):
        if (c1, c2) in differential_cell_pairs_effect_size_map:
            msk[i1,i2] = 0

In [None]:
max_value = np.nanmax(np.abs(effect_size_matrix.values))
m = effect_size_matrix.loc[cell_type_order, cell_type_order]
sns.heatmap(
    m, 
    center=0, vmin=-max_value, vmax=max_value, 
    mask=msk)

In [None]:
max_value = np.nanmax(np.abs(effect_size_matrix.values))
m = effect_size_matrix.loc[cell_type_order, cell_type_order]
m[np.isnan(m)] = 0
sns.clustermap(
    m, 
    center=0, vmin=-max_value, vmax=max_value, 
    mask=msk)

In [None]:
%matplotlib inline
# pair = ("Bcells", "Bcells")
pair = ("Bcells", "CD16+CD56-NKcells")
sns.kdeplot(df[pair], hue=df.timepoint)

plt.figure()
diff = df[df.timepoint == "T3"][pair].values - df[df.timepoint == "PP"][pair].values
sns.kdeplot(diff, linewidth=1)
plt.axvline(0, color="black")
plt.axvline(np.quantile(diff, 0.01), linestyle="--", color="red")
plt.axvline(np.quantile(diff, 0.99), linestyle="--", color="red")
plt.axvline(np.quantile(diff, 0.05), linestyle="--")
plt.axvline(np.quantile(diff, 0.95), linestyle="--")
plt.axvline(np.quantile(diff, 0.10), linestyle="--", color="green")
plt.axvline(np.quantile(diff, 0.90), linestyle="--", color="green")

In [None]:
for c,v in differential_cell_pairs_effect_size_map.items():
    print(c, v, mean_diff[c], mean[c])
    print()

    n_bins = 10
    
    sample = -2
    i, j = np.argwhere(cell_types == c[0]).squeeze(), np.argwhere(cell_types == c[1]).squeeze()
    if np.sum(topk_stats_samples[-1]["T3"]["histograms"][i,j][-n_bins:]) == 0:
        i, j = j, i
    
    fig, axes = plt.subplots(1,2,figsize=(8,4), sharey=True)

    values_t3 = topk_stats_samples[sample]["T3"]["histograms"][i,j][-n_bins:]
    values_pp = topk_stats_samples[sample]["PP"]["histograms"][i,j][-n_bins:]
    
    if np.sign(sum(values_t3) - sum(values_pp)) == np.sign(mean_diff[c]):
        color = "blue"
    else:
        color = "red"

    ax = axes[0]
    ax.bar(bins[:-1][-n_bins:], values_t3, width=0.01, color=color)
    ax.set_title(f"T3: {values_t3.sum()}")

    ax = axes[1]
    ax.bar(bins[:-1][-n_bins:], values_pp, width=0.01, color=color)
    ax.set_title(f"PP: {values_pp.sum()}")
    fig.suptitle(f"{cell_types[i]} / {cell_types[j]}")

    plt.show()

In [None]:
cell_type_order

In [None]:
color_palette = sns.color_palette("Spectral", n_colors=32)
color_palette = sns.diverging_palette(145, 300, s=60, n=32)
neutral = [(.7,.7,.7)]
color_palette = neutral * 2 + color_palette[0:6] + neutral + color_palette[-11:]
sns.palplot(color_palette)

In [None]:
# color_palette = sns.color_palette("deep", n_colors=26)
# sns.palplot(color_palette)

In [None]:
# derive circle coordinates
import math

circle_coordinates = [(0,0)]
circle_coordinates = []

n = len(cell_type_order) - len(circle_coordinates)

for i in range(n):    
    x = math.cos(2 * math.pi * i / n)
    y = math.sin(2 * math.pi * i / n)
    circle_coordinates.append((x,y))
circle_coordinates = np.stack(circle_coordinates)

In [None]:
# derive cell coordinates
cell_coordinates = dict()
for i_cell_type, cell_type in enumerate(cell_type_order):

    cell_type_subgroup_sizes = [cells_phenotype_emb[s][cell_type].shape[0] for s in subgroups]
    cell_type_coordinates = np.concatenate([cells_phenotype_emb[s][cell_type] for s in subgroups])
    cell_type_coordinates = sklearn.preprocessing.MinMaxScaler(feature_range=(-1,1)).fit_transform(cell_type_coordinates) / 10
    cell_type_coordinates += circle_coordinates[i_cell_type]
    
    offset = 0
    for i_subgroup, s in enumerate(subgroups):
        end = offset + cell_type_subgroup_sizes[i_subgroup]
        cell_coordinates.setdefault(s, dict())[cell_type] = cell_type_coordinates[offset:end,:]
        offset = end

In [None]:
# transparent background
# matplotlib.rcParams['figure.facecolor'] = 'white'
# matplotlib.rcParams['figure.facecolor'] = 'none'

In [None]:
import sklearn.pipeline

def get_cell_colors(marker_category="function", marker=0, color_map="inferno"):
    
    if marker_category == "pheno":
        cells = cells_phenotype
    elif marker_category == "function":
        cells = cells_function
    else:
        raise ValueError(f"Unknown marker category: {marker_category}")
        
    if isinstance(color_map, str):
        color_map = matplotlib.cm.get_cmap(color_map)

    # calculate min/max
    marker_values = np.concatenate([
        cells[subgroup][cell_type][:,marker].flatten() 
        for subgroup in cells.keys()
        for cell_type in cells[subgroup].keys()])
    min_value, max_value = marker_values.min(), marker_values.max()
    norm = sklearn.pipeline.make_pipeline(
            sklearn.preprocessing.MinMaxScaler())\
        .fit(marker_values.reshape(-1,1))
        
    colors = {}
    for cell_type in cell_type_order:
        
#         # calculate min/max
#         marker_values = np.concatenate([
#             cells[subgroup][cell_type][:,marker].flatten() 
#             for subgroup in cells.keys()])
#         min_value, max_value = marker_values.min(), marker_values.max()
#         norm = sklearn.pipeline.make_pipeline(
#                 sklearn.preprocessing.MinMaxScaler())\
#             .fit(marker_values.reshape(-1,1))
# #         norm = matplotlib.colors.Normalize(vmin=min_value, vmax=max_value)
        
        for s in subgroups:
            marker_values = cells[s][cell_type][cells_phenotype_emb_idx[s][cell_type],marker].flatten()
            cell_colors = color_map(norm.transform(marker_values.reshape(-1,1)).flatten())
            colors.setdefault(s, {}).setdefault(cell_type, cell_colors)
            
    return colors

cell_colors = get_cell_colors()

In [None]:
from matplotlib.path import Path
import matplotlib

def plot_cells(subgroup, ax=None, title=None, cell_colors="grey"):

    # nodes
    for i_order, cell_type in enumerate(cell_type_order):

        emb = cell_coordinates[subgroup][cell_type]
        
        if isinstance(cell_colors, str):
            colors = cell_colors
            
        elif isinstance(cell_colors, list):
            colors = np.array([cell_colors[i_order]])
            
        elif isinstance(cell_colors, dict):
            colors = cell_colors[subgroup][cell_type_order[i_order]]
        
        else:
            raise ValueError()
            
#         ax.scatter(emb[:,0], emb[:,1], s=1, zorder=-100, color=colors, alpha=0.8, linewidths=0)
        ax.scatter(emb[:,0], emb[:,1], s=3, zorder=100, color=colors, alpha=0.99, linewidths=0)

    if title is not None:
        ax.set_title(title) 

In [None]:
def plot_circle(sizes, cell_colors=None, ax=None, divider_with=0.5, divider_thickness=None, thickness=5, radius=1):
    
    if divider_thickness is None:
        divider_thickness = thickness
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(6,6))
        ax.axis("off")
    
    sizes_freq = sizes / np.max(sizes)
    
    segment_angle = 360 / len(sizes)
    offset = 0#- segment_angle / 2
    for i,s in enumerate(sizes_freq):
        
        start = offset -  segment_angle * s / 2 
        end = offset
        
        if s > 0:
            a1 = matplotlib.patches.Arc((0,0), radius*2, radius*2, theta1=offset-divider_with*0.5,theta2=offset+divider_with*0.5, linewidth=thickness, color="dimgrey")
            ax.add_patch(a1)

            a2 = matplotlib.patches.Arc((0,0), radius*2, radius*2, theta1=start, theta2=end, linewidth=thickness,color=cell_colors[i])
            ax.add_patch(a2)
        
        offset+= segment_angle
        

In [None]:
def plot_circle_size(cell_colors=None, ax=None, divider_with=0.5, divider_thickness=None, thickness=5, radius=1, subgroup_include=None):
    
    sizes = np.array([
        np.sum([
            cells_phenotype[s][cell_type].shape[0] 
            for s in subgroups
            if (subgroup_include is None or subgroup_include(s))]) 
        for cell_type in cell_type_order])
     
    plot_circle(sizes, cell_colors=cell_colors, ax=ax, divider_with=divider_with, divider_thickness=divider_thickness, thickness=thickness, radius=radius)

In [None]:
def plot_circle_n_edges(subgroup, cell_colors=None, ax=None, divider_with=0.5, divider_thickness=None, thickness=5, radius=1):
    sizes= mean_grouped.loc[subgroup, cell_type_order].values
    plot_circle(sizes, cell_colors=cell_colors, ax=ax, divider_with=divider_with, divider_thickness=divider_thickness, thickness=thickness, radius=radius)

In [None]:
def plot_circle_intra_edges(subgroup, cell_colors=None, ax=None, divider_with=0.5, divider_thickness=None, thickness=5, radius=1):
    sizes = np.array([mean.loc[subgroup,:][(c,c)] if (c,c) in differential_cell_pairs_effect_size_map else 0 for c in cell_type_order])    
    print(sizes)
    plot_circle(sizes, cell_colors=cell_colors, ax=ax, divider_with=divider_with, divider_thickness=divider_thickness, thickness=thickness, radius=radius)

In [None]:
def plot_labels(subgroup, ax=None, title=None, cell_colors=None, divider_with=0.5, divider_thickness=None, thickness=5, radius=1, stagger=None, rename=None):
   
    if ax is None:
        fig, ax = plt.subplots(figsize=(24,24))
        ax.axis("off")
    

#     # nodes
    for i_order, cell_type in enumerate(cell_type_order):

        emb = cell_coordinates[subgroup][cell_type]
        
        # label coordinates
        label_coordinates = np.mean(emb, axis=0) * radius

#         ax.text(*label_coordinates, cell_type, zorder=101)

    sizes = np.array([1 for c in cell_type_order])
    sizes_freq = sizes / np.max(sizes)
    
    segment_angle = 2 * np.pi / len(sizes)
    offset = 0 * np.pi
    for i,s in enumerate(sizes_freq):
        
        x = np.cos(offset) * radius
        y =- np.sin(offset) * radius

#         ax.text(x,y, "blubb")
#         ax.scatter([x], [y], marker="x")
        
        text = cell_type_order[i] if rename is None else rename[cell_type_order[i]]
        
        xx = np.cos(np.linspace(offset - (len(text) / 2 * 0.025), offset + segment_angle * 2, 100)) * radius
        yy = -np.sin(np.linspace(offset - (len(text) / 2 * 0.025), offset + segment_angle * 2, 100)) * radius
#         xx = np.cos(np.linspace(offset, offset + segment_angle * 2, 100)) * radius
#         yy = -np.sin(np.linspace(offset, offset + segment_angle * 2, 100)) * radius
        
#         ax.scatter([x], [y])
    
        if stagger:
            stagger_freq, stagger_offset = stagger
            stagger_step = i % stagger_freq
            xx *= 1 + (stagger_step * stagger_offset)
            yy *= 1 + (stagger_step * stagger_offset) 
        
        text = CurvedText(
            x = xx,
            y = yy,
            text=text,
            va = 'bottom',
            ha="center",
            axes = ax, ##calls ax.add_artist in __init__
            color=cell_colors[2], #cell_colors[i],
#             weight="bold",
            fontsize=13
        )
        
#         ax.text(x,y, text, zorder=101, va="center", ha="center", rotation=-offset / (2*np.pi) * 360 + 270)
#         ax.scatter([0], [0])
#         ax.scatter([x], [y])
        
        offset-= segment_angle
        
        

In [None]:
def plot_summary_circle_intracell(ax=None, colors=["red", "blue"], divider_with=0.5, divider_thickness=None, thickness=5, radius=1, stagger=None):
    
    if divider_thickness is None:
        divider_thickness = thickness
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(6,6))
        ax.axis("off")
    
    effect_size = np.array([
        differential_cell_pairs_effect_size_map[(c,c)] 
        if (c,c) in differential_cell_pairs_effect_size_map else 0 
        for c in cell_type_order])
    
    if sum(effect_size) == 0: return
    
    max_effect_size = max(np.abs(effect_size))
    
#     ax.bar(np.arange(sizes_freq.size), sizes_freq)
    
    bar_sizes = np.abs(effect_size / max_effect_size)
    
    segment_angle = 360 / len(bar_sizes)
#     offset = - segment_angle / 2
    offset = 0
    for i, s in enumerate(bar_sizes):
        
        start = offset - segment_angle * s / 2
        end = offset
        
        weight = mean_diff[(cell_type_order[i], cell_type_order[i])] # for direction (more or less)
        edgecolors = colors[0] if weight < 0 else colors[1]
        
        # arc
        if stagger:
            stagger_freq, stagger_offset = stagger
            stagger_step = i % stagger_freq
            rr = radius * (1 + (stagger_step * stagger_offset))
        else:
            rr = radius
        
        a1 = matplotlib.patches.Arc((0,0), rr*2, rr*2, theta1=start,theta2=end, linewidth=thickness,color=edgecolors, alpha=0.7)
        ax.add_patch(a1)
        
#         extend = 0.3
#         ax.plot(
#             [circle_coordinates[i][0] * (1 + 0.12), circle_coordinates[i][0] * (1 + extend)], 
#             [circle_coordinates[i][1] * (1 + 0.12), circle_coordinates[i][1] * (1 + extend)])
#         ax.text(circle_coordinates[i][0] + (1 + 0.12), circle_coordinates[i][1] + (1 + 0.12), str(i))

#         if s > 0:   
#             diff = segment_angle * 1 / max_effect_size / 2
#             start_ref = offset - diff
#             end_ref = offset
#             a3 = matplotlib.patches.Arc((0,0), radius*2, radius*2, theta1=start_ref,theta2=end_ref, linewidth=thickness,color="white", alpha=0.8)
#             ax.add_patch(a3)
        
        offset+= segment_angle

In [None]:
from matplotlib.path import Path

def plot_summary(s, colors=["red", "blue"], edges_kwargs=None, cell_type_pair_include=None, ax=None, linewidth_scaling=None):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(6,6))
        ax.axis("off")
  
    # edges
    for i_order, cell_type_i in enumerate(cell_type_order):
        
        for j_order, cell_type_j in enumerate(cell_type_order):
            
            if (cell_type_pair_include is not None) and (not cell_type_pair_include(cell_type_i, cell_type_j)):
                continue
            
            i = np.argwhere(cell_types == cell_type_i).squeeze()
            j = np.argwhere(cell_types == cell_type_j).squeeze()
            
            if i_order < j_order:
#                 print(cell_type_i, cell_type_j)
#                 i_start = sum([cells_phenotype_emb[s][c].shape[0] for c in cell_types[:i]])
#                 i_end   = sum([cells_phenotype_emb[s][c].shape[0] for c in cell_types[:(i + 1)]])

#                 j_start = sum([cells_phenotype_emb[s][c].shape[0] for c in cell_types[:j]])
#                 j_end   = sum([cells_phenotype_emb[s][c].shape[0] for c in cell_types[:(j + 1)]])
#     #             print(i_start, i_end)
#     #             print(j_start, j_end)
    

#                 i_emb = cell_coordinates[s][cell_type_i]
#                 j_emb = cell_coordinates[s][cell_type_j]
                
#                 src = np.mean(i_emb, axis=0).reshape(1, -1)
#                 dst = np.mean(j_emb, axis=0).reshape(1, -1)
                
                src = circle_coordinates[[i_order]] * (1 - 0.15)
                dst = circle_coordinates[[j_order]] * (1 - 0.15)
                
#                 print(src, dst)

                verts = np.concatenate([src, np.zeros(src.shape), dst], axis=1).reshape(-1,2)
                codes = [
                    Path.MOVETO,
                    Path.CURVE3,
                    Path.MOVETO,
                ]
                paths = [Path([src[i], (0,0), dst[i]], codes) for i in range(src.shape[0])]
                if edges_kwargs is None:
                    def edges_kwargs(ct1, ct2):    
                        if (ct1, ct2) in differential_cell_pairs_effect_size_map:
                            p = differential_cell_pairs_effect_size_map[(ct1, ct2)]
                        else:
                            p = differential_cell_pairs_effect_size_map[(ct2, ct1)]
                            
#                         linewidth = (abs(p) - cliffs_delta_threshold) * 40 + 0.5
                        linewidth = (abs(p) - cliffs_delta_threshold) * 60 + 0.5
    
#                         p = differential_cell_pairs_map[(ct1, ct2)]
#                         linewidth = min(-np.log10(p),10)
                        
                        
                        weight = mean_diff[(ct1, ct2)] # for direction (more or less)
                        edgecolors = colors[0] if mean_diff[(ct1, ct2)] < 0 else colors[1]
                        linestyle = "-" if mean_diff[(ct1, ct2)] < 0 else (0, (1,0.5))
                        
                        return dict(linewidths=linewidth, facecolor='none', edgecolors=edgecolors, alpha=0.7, linestyle=linestyle)
                    
                c = matplotlib.collections.PathCollection(paths, **edges_kwargs(cell_type_i, cell_type_j))
                ax.add_collection(c)

    #             break
#         break

In [None]:
# color_palette = sns.color_palette("Spectral", n_colors=32)
# color_palette = sns.diverging_palette(145, 300, s=60, n=32)

# innate_colors =  color_palette[:6]
# adaptive_colors = color_palette[-11:]

# innate_colors =  list(reversed(sns.color_palette("ch:s=-.2,r=.6", n_colors=6))) 
# adaptive_colors = list(reversed(sns.cubehelix_palette(start=.5, rot=-.75, n_colors=11)))

# innate_colors =  list(reversed(sns.color_palette("ch:s=-.2,r=.6", n_colors=6)))[3:5] * 6
# adaptive_colors = list(reversed(sns.cubehelix_palette(start=.5, rot=-.75, n_colors=11)))[6:8] * 11
# innate_colors =  [sns.color_palette()[0]] * 6
# adaptive_colors = [sns.color_palette()[1]] * 11

innate_colors =  sns.diverging_palette(250, 30, l=75, center="dark", n=18)[:6]
adaptive_colors = list(reversed(sns.diverging_palette(250, 30, l=70, center="dark", n=32)[-11:]))

neutral = [(.6,.6,.6)]
neutral_text = [(.55,.55,.55)]
# neutral = [np.array([199,187,201]) / 255]
# neutral_text = [np.array([199/1.5,187/1.5,201/1.5]) / 255]
# neutral = [np.array([191,181,178]) / 255]
# neutral = [(0.2,0.7,0.7)]
# neutral = [(0.0,0.5,0.5)]
color_palette_labels = color_palette_cells = neutral * 2 + innate_colors + neutral + adaptive_colors
sns.palplot(color_palette_labels)

In [None]:
color_palette_cells = neutral * 2 + [(0.95,0.95,0.95)] * 6 + neutral + [(0.4,0.4,0.4)] * 11
sns.palplot(color_palette_cells)

In [None]:
color_palette_labels = neutral * 2 + [contrast_colors[1]] * 6 + neutral + [contrast_colors[0]] * 11
sns.palplot(color_palette_labels)

In [None]:
color_palette_labels = neutral * 2 + [(0.95,0.95,0.95)] * 6 + neutral + [(0.4,0.4,0.4)] * 11
sns.palplot(color_palette_labels)

In [None]:
color_palette_labels = neutral_text * 2 + [(0.2,0.2,0.2)] * 6 + neutral_text + [(0.75,0.75,0.75)] * 11
sns.palplot(color_palette_labels)

In [None]:
def plot_difference(cell_type_pair_include, edge_style, do_plot_size=True, linewidth_scaling_edges=None, linewidth_scaling_summary=None, ax=None, rename=None):
#     colors = ["pink", "gold"]
#     colors = [sns.color_palette("vlag", as_cmap=False,n_colors=8)[-1],sns.color_palette("vlag", as_cmap=False,n_colors=6)[1]]
#     colors = [color_palette[1],color_palette[0]]
    colors = contrast_colors
#     colors = [(*c, 0.5) for c in colors]
    
    plot_cells(s, title=None, cell_colors=color_palette_cells, ax=ax)
    plot_summary(s, colors=colors, ax=ax, cell_type_pair_include=cell_type_pair_include, linewidth_scaling=linewidth_scaling_summary)
    plot_summary_circle_intracell(ax=ax, colors=colors, radius=1+0.16)
#     plot_summary_circle_intracell(ax=ax, radius=1+0.17, stagger=(2,0.07))
    plot_labels("T3", ax=ax, radius=1+0.19, stagger=(2,0.05), rename=rename, cell_colors=color_palette_labels)

In [None]:
# color_palette = sns.color_palette("deep", n_colors=26)
# sns.palplot(color_palette)

In [None]:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt 
from matplotlib.lines import Line2D


patchList = []
data_key = Line2D(
    [0], [0], marker="o", color=contrast_colors[0], linestyle="-", alpha=0.7,
    linewidth=8,
    label="increasing", 
    markersize=0)
patchList.append(data_key)

data_key = Line2D(
    [0], [0], marker="o", color=contrast_colors[1], linestyle=(0, (1,0.5)), alpha=0.7,
    linewidth=8,
    label=f"decreasing", 
    markersize=0)
patchList.append(data_key)

fig, axes = plt.subplots(1,1, dpi=300, figsize=(2,2))
ax = axes
leg = ax.legend(handles=patchList, loc="upper left", frameon=False, title="relative number of top correlations", facecolor=(1,1,1,0.3), fancybox=False, edgecolor='None', borderaxespad=0.08)
leg._legend_box.align = "left"
leg.get_title().set_position((-5, 0))

ax.axis("off")
ax.set_facecolor("green")
fig.savefig('../_out/figures/legend_singlecell_change.pdf', bbox_inches='tight', facecolor='none')

In [None]:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt 
from matplotlib.lines import Line2D

patchList = []

data_key = Line2D(
    [0], [0], marker="o", color=(0.95,0.95,0.95), linestyle="-", alpha=0.7,
    linewidth=8,
    label="adaptive", 
    markersize=0)
patchList.append(data_key)

data_key = Line2D(
    [0], [0], marker="o", color=(0.1,0.1,0.1), linestyle="-", alpha=0.7,
    linewidth=8,
    label=f"innate", 
    markersize=0)
patchList.append(data_key)

fig, axes = plt.subplots(1,1, dpi=300, figsize=(2,2))
ax = axes
leg = ax.legend(handles=patchList, loc="upper left", frameon=False, title="cell types", facecolor=(1,1,1,0.3), fancybox=False, edgecolor='None', borderaxespad=0.08)
leg._legend_box.align = "left"
leg.get_title().set_position((-5, 0))

ax.axis("off")
fig.savefig('../_out/figures/legend_singlecell_cells.pdf', bbox_inches='tight', facecolor='none')

In [None]:
%matplotlib inline
edge_style = dict(facecolor='none', edgecolors="grey", alpha=0.5)
# linewidth_scaling_edges = lambda n_edges: max(min(1 / (n_edges / 20 + 1), 1), 0.01) # gotta play with this
# linewidth_scaling_summary = lambda p: max(3, np.log10(p) / np.log10(0.05) / 4)
cell_type_pair_include = lambda x,y: ((x,y) in differential_cell_pairs_effect_size_map.keys()) or ((y,x) in differential_cell_pairs_effect_size_map.keys())

fig, ax = plt.subplots(1,1,figsize=(13,13), dpi=300); ax.axis("off")
plot_difference(cell_type_pair_include, edge_style, ax=ax, rename=rename)

bbox = ax.get_window_extent()
bbox_data = bbox.transformed(ax.transData.inverted())
ax.update_datalim(bbox_data.corners())
ax.autoscale_view()

radius = 1
a1 = matplotlib.patches.Arc((0,0), radius*2, radius*2, 
                            theta1=28, theta2=135, 
#                             theta1=-8, theta2=135, 
                            linewidth=65, 
#                             color=contrast_colors[1], 
                            color=(0.1,0.1,0.1), 
                            alpha=0.7)
ax.add_patch(a1)
a1 = matplotlib.patches.Arc((0,0), radius*2, radius*2, 
                            theta1=153, theta2=350, 
#                             theta1=137, theta2=350, 
                            linewidth=65,
#                             color=contrast_colors[0], 
                            color=(0.9,0.9,0.9), 
                            alpha=0.7)
ax.add_patch(a1)

fig.savefig(path_figures / "cell_correlations_difference.pdf", bbox_inches="tight", facecolor='none')

## Edges

In [None]:
def plot_subgroup(subgroup, edge_style=None, cell_type_pair_include=None, ax=None, linewidth_scaling=None):
    #     cell_colors = color_palette
    intra_colors = ["silver" for _ in cell_type_order]
    
    linewidth_log = plot_edges(subgroup, ax=ax, edges_kwargs=edge_style, cell_type_pair_include=cell_type_pair_include, linewidth_scaling=linewidth_scaling)
    
    plot_cells(subgroup, title=None, cell_colors=color_palette_cells, ax=ax)
    
#     plot_circle_size("T3", cell_colors=color_palette, ax=ax, radius=1+0.17, subgroup_include=lambda s: s == "T3")
    plot_circle_intra_edges(subgroup, cell_colors=intra_colors, ax=ax, radius=1+0.17)
#     plot_circle_n_edges("T3", cell_colors=color_palette, ax=ax, radius=1+0.15)
    
    plot_labels("T3", ax=ax, radius=1+0.19, stagger=(2,0.05), rename=rename, cell_colors=color_palette_labels)

    return linewidth_log

In [None]:

def plot_edges(subgroup, ax=None, edges_kwargs=None, cell_type_pair_include=None, linewidth_scaling=None):
    
    # draw embeddings
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))
        ax.axis("off")
  
    linewidth_log = dict() 
    
    # edges
    for i_order, cell_type_i in enumerate(cell_type_order):
        
        for j_order, cell_type_j in enumerate(cell_type_order):
            
            if (cell_type_pair_include is not None) and (not cell_type_pair_include(cell_type_i, cell_type_j)):
                continue
                
            i = np.argwhere(cell_types == cell_type_i).squeeze()
            j = np.argwhere(cell_types == cell_type_j).squeeze()
            
            if i_order < j_order:
#                 print(cell_type_i, cell_type_j)
                i_start = sum([cells_phenotype[subgroup][c].shape[0] for c in cell_types[:i]])
                i_end   = sum([cells_phenotype[subgroup][c].shape[0] for c in cell_types[:(i + 1)]])

                j_start = sum([cells_phenotype[subgroup][c].shape[0] for c in cell_types[:j]])
                j_end   = sum([cells_phenotype[subgroup][c].shape[0] for c in cell_types[:(j + 1)]])
    #             print(i_start, i_end)
    #             print(j_start, j_end)

                mm = topk_matrices[subgroup]
    #             print(mm.nnz)
                m = mm[i_start:i_end, j_start:j_end] + mm[j_start:j_end,i_start:i_end].transpose()
#                 print(m.nnz)
                m = m[cells_phenotype_emb_idx[subgroup][cell_type_i],:][:,cells_phenotype_emb_idx[subgroup][cell_type_j]]
                m_coo = m.nonzero()

                i_emb = cell_coordinates[subgroup][cell_type_i]
                j_emb = cell_coordinates[subgroup][cell_type_j]
#                 print(cell_type_i, cell_type_j)
#                 print(j_emb.shape)
#                 print(m_coo[1].shape)

                src = i_emb[m_coo[0]]
                dst = j_emb[m_coo[1]]
        
#                 # sample edges
#                 if n_max_edges is not None:
#                     idx = np.random.choice(src.shape[0], min(src.shape[0], n_max_edges), replace=False)
#                     src = src[idx,:]
#                     dst = dst[idx,:]
                
#                 print(src.shape)
#                 print(dst.shape)

                verts = np.concatenate([src, np.zeros(src.shape), dst], axis=1).reshape(-1,2)
                codes = [
                    Path.MOVETO,
                    Path.CURVE3,
                    Path.MOVETO,
                ]
    #             paths = [Path([src[i], (src[i]+dst[i]) / 2 /2, dst[i]], codes) for i in range(src.shape[0])]
                paths = [Path([src[i], (0,0), dst[i]], codes) for i in range(src.shape[0])]
                print(cell_type_i, cell_type_j, len(paths))
                if linewidth_scaling is None:
                    linewidths = 0.01
                elif isinstance(linewidth_scaling, dict):
                    linewidths = linewidth_scaling[(cell_type_i, cell_type_j)]
                else:
                    linewidths = linewidth_scaling(src.shape[0])
                linewidth_log[(cell_type_i, cell_type_j)] = linewidths
                    
                if edges_kwargs is None:
                    edges_kwargs = dict(facecolor='none', edgecolors="grey", alpha=0.5)
                c = matplotlib.collections.PathCollection(paths, **edges_kwargs, linewidths=linewidths)
#                 c = matplotlib.collections.PathCollection(paths, linewidths=0.1, facecolor='none', edgecolors="grey", alpha=0.5)
                ax.add_collection(c)

    #             break
#         break
    print(linewidth_log)
    return linewidth_log

In [None]:
%matplotlib inline

#     cell_colors = get_cell_colors(marker_category=marker_category, marker=marker)
cell_colors = color_palette

# edge_style = dict(facecolor='none', edgecolors="black", alpha=0.1)
# linewidth_scaling_edges = lambda n_edges: max(min(1 / (n_edges / 20 + 1), 1)**2 * 3, 0.01) # gotta play with this

edge_style = dict(facecolor='none', edgecolors="black", alpha=0.1)
linewidth_scaling_edges = lambda n_edges: max(min(1 / (n_edges / 20 + 1), 1)**2 * 3, 0.025) # gotta play with this
# linewidth_scaling_edges = lambda n_edges: 0.06 # gotta play with this

cell_type_pair_include = lambda x,y: \
    ((x,y) in differential_cell_pairs_effect_size_map.keys()) \
    or ((y,x) in differential_cell_pairs_effect_size_map.keys())

# fig, ax = plt.subplots(1,1,figsize=(13,13), dpi=300); ax.axis("off")
# plot_difference(cell_type_pair_include, edge_style, ax=ax)

# # show all labels
# bbox = ax.get_window_extent()
# bbox_data = bbox.transformed(ax.transData.inverted())
# ax.update_datalim(bbox_data.corners())
# ax.autoscale_view()

# subgroup 1

fig, ax = plt.subplots(1,1,figsize=(13,13), dpi=300); ax.axis("off")
linewidths = plot_subgroup(
    "T3", cell_type_pair_include=cell_type_pair_include, edge_style=edge_style, ax=ax, linewidth_scaling=linewidth_scaling_edges)
# display(linewidths)
# linewidths = {a:0 for a,b in linewidths.items()}

bbox = ax.get_window_extent()
bbox_data = bbox.transformed(ax.transData.inverted())
ax.update_datalim(bbox_data.corners())
ax.autoscale_view()
#     ax.set_title(f"T3: {marker_category} - {marker_name}")

radius = 1
a1 = matplotlib.patches.Arc((0,0), radius*2, radius*2, 
                            theta1=28, theta2=135, 
#                             theta1=-8, theta2=135, 
                            linewidth=65, 
#                             color=contrast_colors[1], 
                            color=(0.1,0.1,0.1), 
                            alpha=0.7)
ax.add_patch(a1)
a1 = matplotlib.patches.Arc((0,0), radius*2, radius*2, 
                            theta1=153, theta2=350, 
#                             theta1=137, theta2=350, 
                            linewidth=65,
#                             color=contrast_colors[0], 
                            color=(0.9,0.9,0.9), 
                            alpha=0.7)
ax.add_patch(a1)

fig.savefig(path_figures / f"cell_correlations_t3.pdf", bbox_inches="tight", facecolor='none')
plt.show()
plt.close()

# subgroup 2

fig, ax = plt.subplots(1,1,figsize=(13,13), dpi=300); ax.axis("off")
plot_subgroup(
    "PP", cell_type_pair_include=cell_type_pair_include, edge_style=edge_style, ax=ax, linewidth_scaling=linewidths)

bbox = ax.get_window_extent()
bbox_data = bbox.transformed(ax.transData.inverted())
ax.update_datalim(bbox_data.corners())
ax.autoscale_view()
#     ax.set_title(f"PP: {marker_category} - {marker_name}")


radius = 1
a1 = matplotlib.patches.Arc((0,0), radius*2, radius*2, 
                            theta1=28, theta2=135, 
#                             theta1=-8, theta2=135, 
                            linewidth=65, 
#                             color=contrast_colors[1], 
                            color=(0.1,0.1,0.1), 
                            alpha=0.7)
ax.add_patch(a1)
a1 = matplotlib.patches.Arc((0,0), radius*2, radius*2, 
                            theta1=153, theta2=350, 
#                             theta1=137, theta2=350, 
                            linewidth=65,
#                             color=contrast_colors[0], 
                            color=(0.9,0.9,0.9), 
                            alpha=0.7)
ax.add_patch(a1)

fig.savefig(path_figures / f"cell_correlations_pp.pdf", bbox_inches="tight", facecolor='none')
plt.show()
plt.close()