In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import scanpy as sc
import anndata

from ALLCools.clustering import tsne, balanced_pca, significant_pc_test, log_scale
from ALLCools.plot import *

import scanorama as scrm

from harmonypy import run_harmony

In [None]:
# cell metadata path
metadata_path = './CellMetadata.AfterQC.pdpkl'

# HVF mC Fraction AnnData Files
ch_adata_path = 'mCH.HVF.h5ad'
cg_adata_path = 'mCG.HVF.h5ad'

correct_batch_col = 'Donor'
correct_method = 'harmony'

# use feature type
# HVF: all highly variable features
# CEF: cluster enriched features
feature_type = 'CEF' 
pre_cluster_name = 'leiden'

# n_components
n_components = 'auto'  # if auto, will use Kolmogorov-Smirnov test to test the adjacent PCs and cut when P > p_cutoff
p_cutoff = 0.1  # ks test p value cutoff, only apply when n_components == 'auto'

# downsample large clusters
max_cell_prop = 0.05

interactive_downsample = 2000
interactive_plot = False

min_cluster_size = 25


In [None]:
metadata = pd.read_pickle(metadata_path)
ch_adata = anndata.read_h5ad(ch_adata_path)
cg_adata = anndata.read_h5ad(cg_adata_path)
if feature_type == 'CEF':
    print('Using Cluster Enriched Features')
    ch_adata = ch_adata[:, ch_adata.var[f'{pre_cluster_name}_enriched_features']].copy()
    cg_adata = cg_adata[:, cg_adata.var[f'{pre_cluster_name}_enriched_features']].copy()

In [None]:
log_scale(ch_adata)

log_scale(cg_adata)

 scanorama correction: not working in some cases

In [None]:
if correct_method != 'harmony':
    if correct_batch_col is not None:
        ch_adatas = [ ch_adata[ch_adata.obs_names.isin(x.index)].copy() for _,x in metadata.groupby(correct_batch_col) ]
        # corrected = scrm.correct([x.obsm['X_pca'] for x in ch_adatas], [np.arange(ch_n_components) for x in ch_adatas], dimred=-1)
        for x in ch_adatas:
            sc.pp.scale(x)
        corrected,_ = scrm.correct([x.X for x in ch_adatas], [np.arange(ch_adata.shape[1]) for x in ch_adatas])

        for x,y in zip(ch_adatas, corrected):
            x.X = y.todense()

        ch_adata = ch_adatas[0].concatenate(ch_adatas[1:], index_unique=None)[ch_adatas[0].obs_names]

In [None]:
if correct_method != 'harmony':
    n_comps = min(len(ch_adata)//min_cluster_size*min_cluster_size,
                  len(ch_adata)-2,
                  200)
    while True:
        try:
            balanced_pca(ch_adata, groups=pre_cluster_name, n_comps=n_comps)
            sc.pl.pca_variance_ratio(ch_adata)
            ch_n_components = significant_pc_test(ch_adata, p_cutoff=p_cutoff)
            break
        except:
            n_comps-=2
        

 harmonypy correction

In [None]:
if correct_method == 'harmony':
    if correct_batch_col is not None:
        ch_adata.obs[correct_batch_col] = metadata[correct_batch_col]
        try:
            ho = run_harmony(ch_adata.obsm['X_pca'],
                         meta_data=ch_adata.obs,
                         vars_use=correct_batch_col,
                         random_state=0,
                         nclust=15,
                         max_iter_harmony=50)
        except:
            ho = run_harmony(ch_adata.obsm['X_pca'],
                         meta_data=ch_adata.obs,
                         vars_use=correct_batch_col,
                         random_state=0,
                         nclust=10,
                         max_iter_harmony=50)

        ch_adata.obsm['X_pca'] = ho.Z_corr.T

In [None]:
if correct_method == 'harmony':
    ch_n_components = significant_pc_test(ch_adata, p_cutoff=p_cutoff)

In [None]:
hue = 'mCHFrac'
if hue in metadata.columns:
    ch_adata.obs[hue] = metadata[hue].reindex(ch_adata.obs_names)
    fig, axes = plot_decomp_scatters(ch_adata,
                                     n_components=ch_n_components,
                                     hue=hue,
                                     hue_quantile=(0.25, 0.75),
                                     nrows=5,
                                     ncols=5)

 scanorama correction: not working in some cases

In [None]:
if correct_method != 'harmony':
    if correct_batch_col is not None:
        cg_adatas = [ cg_adata[cg_adata.obs_names.isin(x.index)].copy() for _,x in metadata.groupby(correct_batch_col) ]
        # corrected = scrm.correct([x.obsm['X_pca'] for x in cg_adatas], [np.arange(cg_n_components) for x in cg_adatas], dimred=-1)
        for x in cg_adatas:
            sc.pp.scale(x)
        corrected,_ = scrm.correct([x.X for x in cg_adatas], [np.arange(cg_adata.shape[1]) for x in cg_adatas])

        for x,y in zip(cg_adatas, corrected):
            x.X = y.todense()

        cg_adata = cg_adatas[0].concatenate(cg_adatas[1:], index_unique=None)[cg_adatas[0].obs_names]

In [None]:
if correct_method != 'harmony':
    n_comps = min(len(cg_adata)//min_cluster_size*min_cluster_size,
                  len(cg_adata)-2,
                  200)
    while True:
        try:
            balanced_pca(cg_adata, groups=pre_cluster_name, n_comps=n_comps)
            sc.pl.pca_variance_ratio(cg_adata)
            cg_n_components = significant_pc_test(cg_adata, p_cutoff=p_cutoff)
            break
        except:
            n_comps-=2

 harmonypy correction

In [None]:
if correct_method == 'harmony':
    if correct_batch_col is not None:
        cg_adata.obs[correct_batch_col] = metadata[correct_batch_col]
        try:
            ho = run_harmony(cg_adata.obsm['X_pca'],
                         meta_data=cg_adata.obs,
                         vars_use=correct_batch_col,
                         random_state=0,
                         nclust=15,
                         max_iter_harmony=50)
        except:
            ho = run_harmony(cg_adata.obsm['X_pca'],
                         meta_data=cg_adata.obs,
                         vars_use=correct_batch_col,
                         random_state=0,
                         nclust=10,
                         max_iter_harmony=50)

        cg_adata.obsm['X_pca'] = ho.Z_corr.T

In [None]:
if correct_method == 'harmony':
    cg_n_components = significant_pc_test(cg_adata, p_cutoff=p_cutoff)

In [None]:
hue = 'mCGFrac'
if hue in metadata.columns:
    cg_adata.obs[hue] = metadata[hue].reindex(cg_adata.obs_names)
    fig, axes = plot_decomp_scatters(cg_adata,
                                     n_components=cg_n_components,
                                     hue=hue,
                                     hue_quantile=(0.25, 0.75),
                                     nrows=5,
                                     ncols=5)

In [None]:
ch_pcs = ch_adata.obsm['X_pca'][:, :ch_n_components]
cg_pcs = cg_adata.obsm['X_pca'][:, :cg_n_components]

# scale the PCs so CH and CG PCs has the same total var
cg_pcs = cg_pcs / cg_pcs.std()
ch_pcs = ch_pcs / ch_pcs.std()

# total_pcs
total_pcs = np.hstack([ch_pcs, cg_pcs])

In [None]:
# make a copy of adata, add new pcs
# this is suboptimal, will change this when adata can combine layer and X in the future
adata = ch_adata.copy()
adata.obsm['X_pca'] = total_pcs
del adata.uns['pca']
if correct_method != 'harmony':
    del adata.varm['PCs']

In [None]:
def dump_embedding(adata, name, n_dim=2):
    # put manifold coordinates into adata.obs
    for i in range(n_dim):
        adata.obs[f'{name}_{i}'] = adata.obsm[f'X_{name}'][:, i]
    return adata

In [None]:
tsne(adata,
     obsm='X_pca',
     metric='euclidean',
     exaggeration=-1,  # auto determined
     perplexity=30,
     n_jobs=-1)
dump_embedding(adata, 'tsne')

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=250)
_ = categorical_scatter(data=adata.obs, ax=ax, coord_base='tsne', hue=pre_cluster_name, show_legend=True)

In [None]:
if correct_batch_col is not None:
    adata.obs[correct_batch_col] = metadata[correct_batch_col]
    fig, ax = plt.subplots(figsize=(4, 4), dpi=250)
    _ = categorical_scatter(data=adata.obs, ax=ax, coord_base='tsne', hue=correct_batch_col, show_legend=True)

In [None]:
sc.pp.neighbors(adata)
try:
    sc.tl.paga(adata, groups=pre_cluster_name)
    sc.pl.paga(adata, plot=False)
    sc.tl.umap(adata, init_pos='paga')
except:
    sc.tl.umap(adata)
dump_embedding(adata, 'umap')

In [None]:
# fig, ax = plt.subplots(figsize=(4, 4), dpi=250)
# _ = categorical_scatter(data=adata.obs, ax=ax, coord_base='umap', hue=pre_cluster_name, show_legend=True)

In [None]:
if interactive_plot:
    if len(adata)>interactive_downsample:
        obs = adata.obs.sample(interactive_downsample)
    else:
        obs = adata.obs

    interactive_scatter(data=obs,
                        hue=pre_cluster_name,
                        coord_base='umap')

In [None]:
adata.write_h5ad(f'adata.with_coords.h5ad')
adata