# General parameters

In [None]:
prefix = "test"

cell_type_selection = "filter"

n_samples = 10
sampling_scheme = "double-replacement"
n_sampled_cells_per_celltype = 1000  # for top-k calculation

n_samples_deque = 3  # how many raw samples are retained (data in addition to stats)

topk_target = "combined"  # phenotype, function, or combined
topk_ratio = 0.001

n_topk_stats_bins = 201

n_max_cells_emb = 100  # for visualization 

n_threads = 64
n_jobs_topk = 64
n_jobs_fcs = 64


In [None]:
# # Parameters

# prefix = "v1"

# cell_type_selection = "filter"

# n_samples = 100
# sampling_scheme = "double-replacement"
# n_sampled_cells_per_celltype = 10000  # for top-k calculation
# topk_ratio = 0.0001  # 0.01%

# n_max_cells_emb = 1000  # for visualization 

In [None]:
# # Parameters

# prefix = "function_v3"

# cell_type_selection = "filter"

# n_samples = 100
# sampling_scheme = "double-replacement"
# n_sampled_cells_per_celltype = 10000  # for top-k calculation
# topk_target = "function"  # 0.01%
# topk_ratio = 0.0001  # 0.01%

# n_max_cells_emb = 1000  # for visualization 

In [None]:
# Parameters

prefix = "function1000_v3"

cell_type_selection = "filter"

n_samples = 1000
sampling_scheme = "double-replacement"
n_sampled_cells_per_celltype = 10000  # for top-k calculation
topk_target = "function"  # 0.01%
topk_ratio = 0.0001  # 0.01%

n_max_cells_emb = 1000  # for visualization 

In [None]:
# # Parameters
# prefix = "rescue_v3"
# n_samples = 100
# sampling_scheme = "double-replacement"
# n_sampled_cells_per_celltype = 10000
# topk_target = "function"
# topk_ratio = 0.001
# cell_type_selection = "filter"
# n_max_cells_emb = 1000


In [None]:
verbose = 1

# Preamble

In [None]:
notebook_name = f"application___singlecell"\
                f"___parameters"\
                f"___prefix__{prefix}"\
                f"___cell_types__{cell_type_selection}"\
                f"___n_samples__{n_samples}"\
                f"___sampling_scheme__{sampling_scheme}"\
                f"___n_sampled_cells_per_celltype__{n_sampled_cells_per_celltype}"\
                f"___topk_target__{topk_target}"\
                f"___topk_ratio__{topk_ratio}"\
                f"___n_max_cells_emb__{n_max_cells_emb}"
print(notebook_name)

# Imports

In [None]:
%load_ext autoreload
%autoreload 2

# disable parallelization for BLAS and co.
from corals.threads import set_threads_for_external_libraries
set_threads_for_external_libraries(n_threads=n_threads)

# general
import re
import collections
import pickle
import warnings 
import joblib
import pathlib

# data
import numpy as np
import pandas as pd
import h5py

# ml / stats
import sklearn
import scipy.stats

# plotting
import matplotlib.pyplot as plt
import seaborn as sns

# init matplotlib defaults
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'


In [None]:
from matplotlib.collections import LineCollection
import sklearn.manifold
from corals.correlation.topk._deprecated.original import topk_balltree_combined_tree_parallel_optimized as cor_topk  # TODO: eventually replace with newer implementations

In [None]:
import corals.correlation.utils
import sklearn.impute
from coralsarticle.visualization import CurvedText

In [None]:
from coralsarticle.data.process.singlecell import load_cytof, prepare_cell_sampling, sample_cell_subgroups

# Load and prepare Cytof data

## Load data

In [None]:
cell_file="../data/processed/immuneclock_singlecell_unstim.h5"
marker_file="../data/raw/singlecell/markers.xlsx"
output_dir="../data/processed"

In [None]:
# load all cells
cytof = load_cytof(
    cell_file=cell_file, 
    marker_file=marker_file, 
    verbose=verbose)

## Prepare sampling

In [None]:
cytof_preprocessed_phenotype, cytof_preprocessed_function, subgroups, subgroups_with_cell_types, sample_masking, cell_types, cell_type_order = prepare_cell_sampling(
    cytof,
    cell_type_selection=cell_type_selection,
    marker_file=marker_file,
    verbose=verbose
)

In [None]:
cytof_preprocessed_function.shape

In [None]:
cytof[cytof["cell_type"].isin(cell_type_order)].groupby(["patient_id", "timepoint"]).size()

## Cell statistics (can be skipped)

In [None]:
# set cells of interest
excluded_cell_types = [    
    'CD235-CD61-',            # leukocytes
    'CD45+CD66-',             # mononuclear cells
    'CD66+CD45-',             # granulocytes
]

### Some cell counts

In [None]:
cytof\
    .groupby(["timepoint", "cell_type"]).size()\
    .groupby(["cell_type"]).agg(["min", "max", "mean", "std"]).sort_index()\
    .drop(excluded_cell_types).sort_values("min")

In [None]:
cytof\
    .groupby(["timepoint", "cell_type"]).size()\
    .groupby(["cell_type"]).agg(["min", "max", "mean", "std"]).sort_index()\
    .loc[cell_type_order,:]\
    .sort_values("min")

In [None]:
%matplotlib inline
subset = cytof[cytof.cell_type.isin(cell_type_order)]
subset = subset.groupby([subset.cell_type.cat.remove_unused_categories(), "timepoint"]).agg("size").reset_index().rename({0: "size"}, axis=1)
display(subset[subset.timepoint == "T3"].reset_index())
sns.histplot(subset, x="size", bins=10, hue="timepoint")

In [None]:
# check cells per timepoint

# max_n = n_sampled_cells_per_celltype
max_n = 1000
r = cytof\
    .groupby(["timepoint", "cell_type"]).size()\
    .apply(lambda x: max_n if x > max_n else x)
r = r.loc[r.reset_index()["cell_type"].isin(cell_types).values]
r = r.groupby(["timepoint"]).sum()
print(r.min(), r.max())

### Check cell type overlaps

In [None]:
# # check overlap (can take long)

# patient_ids = cytof.patient_id.unique()
# timepoints = cytof.timepoint.unique()

# cell_type_overlap_stats = {} 
# cell_type_subset_stats = {} 
# for p in patient_ids:
#     for t in timepoints:
#         print("*", p,t)
#         overlap_stats = np.zeros((len(cell_type_order), len(cell_type_order)))
#         subset_stats = np.zeros((len(cell_type_order), len(cell_type_order)))
        
#         for i_c1, c1 in enumerate(cell_type_order):
            
#             cells1_id = cytof[(cytof.patient_id == p) & (cytof.timepoint == t) & (cytof.cell_type == c1)].loc[:,"Time":].sum(axis=1)
#             set1 = set(cells1_id)
            
#             duplicate = cells1_id.shape[0] - len(set1)
#             overlap_stats[i_c1, i_c1] = duplicate
            
#             assert duplicate == 0
            
#             for i_c2, c2 in enumerate(cell_type_order):
#                 if i_c1 < i_c2:
                    
#                     cells2_ids = cytof[(cytof.patient_id == p) & (cytof.timepoint == t) & (cytof.cell_type == c2)].loc[:,"Time":].sum(axis=1)
#                     set2 = set(cells2_ids)
#                     overlap = len(set1.intersection(set2))
                    
#                     assert overlap == 0

#                     overlap_stats[i_c1, i_c2] = overlap
#                     overlap_stats[i_c2, i_c1] = overlap
                    
#                     if overlap > 0:
#                         subset_stats[i_c2, i_c1] = 2
#                         subset_stats[i_c1, i_c2] = 2
#                         if len(set1 - set2) == 0:
#                             subset_stats[i_c2, i_c1] = 1
#                         if len(set2 - set1) == 0:
#                             subset_stats[i_c1, i_c2] = 1
                            
#         cell_type_overlap_stats[(p, t)] = overlap_stats              
#         cell_type_subset_stats[(p, t)] = subset_stats
        
#         fig, axes = plt.subplots(1, 2, figsize=(15 * 2,13))
#         ax = axes[0]
#         sns.heatmap(overlap_stats, mask=(overlap_stats==0), linewidths=1, linecolor="grey", ax=ax)
#         ax.set_xticks(np.arange(len(cell_type_order)) + 0.5)
#         ax.set_yticks(np.arange(len(cell_type_order)) + 0.5)
#         ax.set_xticklabels(cell_type_order, rotation=270)
#         ax.set_yticklabels(cell_type_order, rotation=0)
        
#         ax = axes[1]
#         sns.heatmap(subset_stats, linewidths=1, linecolor="grey", ax=ax)
#         ax.set_xticks(np.arange(len(cell_type_order)) + 0.5)
#         ax.set_yticks(np.arange(len(cell_type_order)) + 0.5)
#         ax.set_xticklabels(cell_type_order, rotation=270)
#         ax.set_yticklabels(cell_type_order, rotation=0)
#         ax.set(xlabel="subset", ylabel="superset")
#         fig.suptitle(f"{p}, {t}")
        
#         plt.show()
#         plt.close()
        
# #         break
# #     break
                

### Explore individual cell type overlaps

In [None]:
p = 12
t = "PP"

In [None]:
select = (cytof.patient_id == p) & (cytof.timepoint == t)
a = cytof[select & (cytof.cell_type == "CD25+CD4+Tcells_naive_noTregs")].Time
b = cytof[select & (cytof.cell_type == "CD45RA-Tregs")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

In [None]:
select = (cytof.patient_id == p) & (cytof.timepoint == t)
a = cytof[select & (cytof.cell_type == "intMCs")].Time
b = cytof[select & (cytof.cell_type == "M-MDSC")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

In [None]:
select = (cytof.patient_id == p) & (cytof.timepoint == t)
a = cytof[select & (cytof.cell_type == "cMCs")].Time
b = cytof[select & (cytof.cell_type == "pDCs")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

In [None]:
a = cytof[select & (cytof.cell_type == "CD25+CD4+Tcells_naive")].Time
b = cytof[select & (cytof.cell_type == "CD45RA+Tregs")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

In [None]:
a = cytof[select & (cytof.cell_type == "CD4+Tcells_naive")].Time
b = cytof[select & (cytof.cell_type == "Tbet+CD4+Tcells_mem")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

In [None]:
a = cytof[select & (cytof.cell_type == "CD25+CD4+Tcells_naive")].Time
b = cytof[select & (cytof.cell_type == "Tbet+CD4+Tcells_mem")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

In [None]:
a = cytof[select & (cytof.cell_type == "CD8+Tcells")].loc[:,"Time":].sum(axis=1)
b = cytof[select & (cytof.cell_type == "CD25+CD8+Tcells_naive")].loc[:,"Time":].sum(axis=1)
print(a.size, b.size, len(set(a).intersection(set(b))))

In [None]:
select = (cytof.patient_id == 1) & (cytof.timepoint == "T3")

a = cytof[select & (cytof.cell_type == "CD4+Tcells_mem")].Time
b = cytof[select & (cytof.cell_type == "CD45RA-Tregs")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

a = cytof[select & (cytof.cell_type == "CD4+Tcells_naive")].Time
b = cytof[select & (cytof.cell_type == "CD45RA-Tregs")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))
print()

a = cytof[select & (cytof.cell_type == "CD4+Tcells_mem")].Time
b = cytof[select & (cytof.cell_type == "CD45RA+Tregs")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

a = cytof[select & (cytof.cell_type == "CD4+Tcells_naive")].Time
b = cytof[select & (cytof.cell_type == "CD45RA+Tregs")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))
print()

a = cytof[select & (cytof.cell_type == "CD4+Tcells_mem")].Time
b = cytof[select & (cytof.cell_type == "CD25+CD4+Tcells_mem")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

a = cytof[select & (cytof.cell_type == "CD4+Tcells_naive")].Time
b = cytof[select & (cytof.cell_type == "CD25+CD4+Tcells_naive")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))
print()


a = cytof[select & (cytof.cell_type == "CD4+Tcells_mem")].Time
b = cytof[select & (cytof.cell_type == "CD25+CD4+Tcells_naive")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

a = cytof[select & (cytof.cell_type == "CD4+Tcells_naive")].Time
b = cytof[select & (cytof.cell_type == "CD25+CD4+Tcells_mem")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))
print()


a = cytof[select & (cytof.cell_type == "CD45RA-Tregs")].Time
b = cytof[select & (cytof.cell_type == "CD25+CD4+Tcells_mem")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

a = cytof[select & (cytof.cell_type == "CD45RA-Tregs")].Time
b = cytof[select & (cytof.cell_type == "CD25+CD4+Tcells_naive")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

a = cytof[select & (cytof.cell_type == "CD45RA+Tregs")].Time
b = cytof[select & (cytof.cell_type == "CD25+CD4+Tcells_mem")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

a = cytof[select & (cytof.cell_type == "CD45RA+Tregs")].Time
b = cytof[select & (cytof.cell_type == "CD25+CD4+Tcells_naive")].Time
print(a.size, b.size, len(set(a).intersection(set(b))))

# Prepare cells

In [None]:
def sample_cells():

    # sample cells
    idx_sample = sample_cell_subgroups(
        subgroups=subgroups_with_cell_types,
        subgroups_masking=sample_masking,
        n_sampled_cells_per_celltype=n_sampled_cells_per_celltype,
        sampling_scheme=sampling_scheme,
        verbose=verbose
    )

    cytof_preprocessed_phenotype_sample = collections.OrderedDict()
    for (timepoint, cell_type), idx in idx_sample.items():
        cytof_preprocessed_phenotype_sample.setdefault(timepoint, dict())[cell_type] = cytof_preprocessed_phenotype[idx,:] 

    cytof_preprocessed_function_sample = collections.OrderedDict()
    for (timepoint, cell_type), idx in idx_sample.items():
        cytof_preprocessed_function_sample.setdefault(timepoint, dict())[cell_type] = cytof_preprocessed_function[idx,:] 
    
    return cytof_preprocessed_phenotype_sample, cytof_preprocessed_function_sample

In [None]:
def calculate_topk(cells):

    topk = dict()
    for subgroup_id in subgroups:
                   
        print(subgroup_id, end=": ")
        cell_matrix = np.concatenate([cells[subgroup_id][c] for c in cell_types])
        print(cell_matrix.shape)

        topk_cor, (topk_idx_dst, topk_idx_src) = cor_topk(
            cell_matrix.transpose(),
            k=cell_matrix.shape[0] **2 * topk_ratio, 
            correlation_type="spearman",
            n_jobs=n_jobs_topk)

        n_src = cell_matrix.shape[0]
        n_dst = cell_matrix.shape[0]

        topk[subgroup_id] = (topk_cor, (topk_idx_src, topk_idx_dst)), (n_src, n_dst)
        
    topk_matrices = {
        s:scipy.sparse.csr_matrix(topk, shape=shape) 
        for s,(topk,shape) in topk.items()}
        
    return topk, topk_matrices

In [None]:
bins = np.concatenate([[-2], np.linspace(-1,1,n_topk_stats_bins), [2]])
    
def calculate_topk_stats(cells, topk_matrices):
    
    import scipy.sparse
    
    topk_stats = dict()
    for subgroup in subgroups:
        print(subgroup)
        counts = np.zeros((cell_types.size, cell_types.size))
        histograms = np.zeros((cell_types.size, cell_types.size, len(bins) - 1))
        means = np.zeros((cell_types.size, cell_types.size))
        medians = np.zeros((cell_types.size, cell_types.size))
        stds = np.zeros((cell_types.size, cell_types.size))
        for i, cell_type_i in enumerate(cell_types):
#             print(cell_type_i)
            for j, cell_type_j in enumerate(cell_types):

                if i <= j:

#                     print(cell_type_i, cell_type_j)
                    
                    # find coordinates of area in correlation matrix where the 
                    # corresponding correlation between the two cell types reside
                    
                    i_start = sum([cells[subgroup][c].shape[0] for c in cell_types[:i]])
                    i_end   = sum([cells[subgroup][c].shape[0] for c in cell_types[:(i + 1)]])

                    j_start = sum([cells[subgroup][c].shape[0] for c in cell_types[:j]])
                    j_end   = sum([cells[subgroup][c].shape[0] for c in cell_types[:(j + 1)]])

                    # extract sub matrix according to the coordinates calculated above
                    
                    mm = topk_matrices[subgroup]
#                     mm += mm.transpose()
#                     print(mm.nnz)
                    m = mm[i_start:i_end, j_start:j_end]
    
                    # add transpose to m to make sure we are symmetric
                    mt = mm[j_start:j_end, i_start:i_end].transpose()
                    msk = m.multiply(mt)
                    msk.data = np.ones_like(msk.data)
                    mt = mt - msk.multiply(mt)
                    mt.eliminate_zeros()
        
                    m = m + mt
                    mt.eliminate_zeros()

                    # count top-k correlations
                    counts[i,j] = m.nnz
                    counts[j,i] = m.nnz
                    histograms[i, j, :] = np.histogram(m.data, bins=bins)[0]
                    histograms[j, i, :] = np.histogram(m.data, bins=bins)[0]
        
                    means[i,j] = np.mean(m.data)
                    means[j,i] = np.mean(m.data)
                    stds[i,j] = np.std(m.data)
                    stds[j,i] = np.std(m.data)
                    medians[i,j] = np.median(m.data)
                    medians[j,i] = np.median(m.data)
        
        # calculate stats
        topk_stats[subgroup] = {}
        topk_stats[subgroup]["counts"] = counts
        topk_stats[subgroup]["frequency"] = counts / counts.sum()
        topk_stats[subgroup]["histograms"] = histograms
        topk_stats[subgroup]["means"] = means
        topk_stats[subgroup]["stds"] = stds
        topk_stats[subgroup]["medians"] = medians
        
    return topk_stats

In [None]:
# we look at multiple samples
# the last sample will be used for embeddings and drawing individual edges
cells_phenotype = None
cells_function = None

# sample and accumulate top-k statistics
topk_stats_samples = []

cells_phenotype_deque = collections.deque([], n_samples_deque)
cells_function_deque = collections.deque([], n_samples_deque)
cells_deque = collections.deque([], n_samples_deque)
topk_deque = collections.deque([], n_samples_deque)
topk_matrices_deque = collections.deque([], n_samples_deque)

for i in range(n_samples):
    print(f"##################################################################")
    print(f"### Sample {i} #####################################################")
    print(f"### sampling cells ###############################################")
    cells_phenotype, cells_function = sample_cells()
    
    cells_phenotype_deque.append(cells_phenotype)
    cells_function_deque.append(cells_function)
    
    if topk_target == "phenotype":
        cells = cells_phenotype
    elif topk_target == "function":
        cells = cells_function
    elif topk_target == "combined":
        cells = collections.OrderedDict([
            (
                k, 
                {
                    cell_type: np.concatenate(
                        [cells_phenotype[k][cell_type], cells_function[k][cell_type]], 
                        axis=1)
                    for cell_type in cells_phenotype[k].keys()
                }
            )
            for k in cells_phenotype])
    else:
        raise ValueError(f"Unknown top-k target: {topk_target}")
    cells_deque.append(cells)
    print(f"### top-k ########################################################")
    topk, topk_matrices = calculate_topk(cells)
    topk_deque.append(topk)
    topk_matrices_deque.append(topk_matrices)
    print(f"### stats ########################################################")
    topk_stats = calculate_topk_stats(cells, topk_matrices)
    topk_stats_samples.append(topk_stats)

# Visualization preparation

In [None]:
%%time

cells_phenotype_emb_deque = collections.deque([], n_samples_deque)
cells_phenotype_emb_idx_deque = collections.deque([], n_samples_deque)

for i, cells_phenotype in enumerate(cells_phenotype_deque):

    print(f"Queue position: {i}")
    
    # calculate embeddings
    cells_phenotype_emb = dict()
    cells_phenotype_emb_idx = dict()

    for cell_type in cell_types:

        # collect / sample cells
        cells = []
        for s in subgroups:
            # we are using phenotype features for embedding cells
            pheno = cells_phenotype[s][cell_type]
            idx = np.random.choice(np.arange(pheno.shape[0]), min(n_max_cells_emb, pheno.shape[0]), replace=False)
            cells_phenotype_emb_idx.setdefault(s, dict())[cell_type] = idx
            cells.append(pheno[idx,:])

        subgroup_sizes = [c.shape[0] for c in cells]
        cells = np.concatenate(cells)

        print(" *", cell_type)
        print("  ", cells.shape)
        tsne = sklearn.manifold.TSNE(n_components=2, random_state=42)
        emb = tsne.fit_transform(cells)

        # split into subgroups
        offset = 0
        for subgroup_id, subgroup_size in zip(subgroups, subgroup_sizes):     
            cells_phenotype_emb.setdefault(subgroup_id, dict())[cell_type] = emb[offset:(offset + subgroup_size),:]
            offset += subgroup_size
            
    cells_phenotype_emb_deque.append(cells_phenotype_emb)
    cells_phenotype_emb_idx_deque.append(cells_phenotype_emb_idx)
            
    print()

# Save essential results

In [None]:
path = pathlib.Path("../_out/" + notebook_name)
path.mkdir(parents=True, exist_ok=True)
path

In [None]:
# save essentials
pickle.dump(cell_types, open(path / "cell_types.pickle", "wb"))
pickle.dump(subgroups, open(path / "subgroups.pickle", "wb"))

pickle.dump(cells_phenotype_deque, open(path / "cells_phenotype.pickle", "wb"))
pickle.dump(cells_function_deque, open(path / "cells_function.pickle", "wb"))

pickle.dump(bins, open(path / "topk_stats_bins.pickle", "wb"))

pickle.dump(topk_deque, open(path / "topk.pickle", "wb"))
pickle.dump(topk_matrices_deque, open(path / "topk_matrices.pickle", "wb"))

pickle.dump(topk_stats_samples, open(path / "topk_stats_samples.pickle", "wb"))

pickle.dump(cells_phenotype_emb_deque, open(path / "cells_phenotype_emb.pickle", "wb"))
pickle.dump(cells_phenotype_emb_idx_deque, open(path / "cells_phenotype_emb_idx.pickle", "wb"))