In [None]:
import pathlib
import pandas as pd
import seaborn as sns
from ALLCools.mcds import MCDS
import dask
import ALLCools
from ALLCools.clustering import  cluster_enriched_features

In [None]:
# change this to the path to your filtered metadata
metadata_path = './CellMetadata.AfterQC.pdpkl'

# change this to the paths to your MCDS files
mcds_path_list = [
]

# Dimension name used to do clustering
obs_dim = 'cell'  # observation
var_dim = 'chrom100k'  # feature

min_cov = 250
max_cov = 3000

# change this to the path to ENCODE blacklist.
# The ENCODE blacklist can be download from https://github.com/Boyle-Lab/Blacklist/
black_list_path = '~/refs/human/hg38/blacklist/hg38-blacklist.v2.bed.gz'
black_list_f = 0.2

exclude_chromosome = ['chrM', 'chrY']

In [None]:
load = True
feature_path = 'FeatureList.BasicFilter.txt'

# HVF method:
# SVR: regression based
# Bins: normalize dispersion per bin
hvf_method = 'SVR'
mch_pattern = 'CHN'
mcg_pattern = 'CGN'
n_top_feature = 5000

# Downsample cells
downsample = 20000

In [None]:
metadata = pd.read_pickle(metadata_path)
total_cells = metadata.shape[0]
print(f'Metadata of {total_cells} cells')

In [None]:
mcds = MCDS.open(mcds_path_list, obs_dim='cell', use_obs=metadata.index)
total_feature = mcds.get_index(var_dim).size

In [None]:
mcds.add_feature_cov_mean(var_dim=var_dim)

In [None]:
mcds = mcds.filter_feature_by_cov_mean(
    var_dim=var_dim,
    min_cov=min_cov,  # minimum coverage
    max_cov=max_cov  # Maximum coverage
)

In [None]:
mcds = mcds.remove_black_list_region(
    var_dim,
    black_list_path,
    f=black_list_f  # Features having overlap > f with any black list region will be removed.
)

mcds = mcds.remove_chromosome(var_dim, exclude_chromosome)

In [None]:
print(
    f'{mcds.get_index(var_dim).size} ({mcds.get_index(var_dim).size * 100 / total_feature:.1f}%) '
    f'{var_dim} remained after all the basic filter.')

In [None]:
with open(feature_path, 'w') as f:
    for var in mcds.get_index(var_dim).astype(str):
        f.write(var + '\n')

In [None]:
use_features = pd.read_csv(feature_path, header=None, index_col=0).index
use_features.name = var_dim

In [None]:
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    # still use all the cells to load MCDS
    total_mcds = MCDS.open(mcds_path_list,
                           obs_dim=obs_dim,
                           use_obs=metadata.index).sel({var_dim: use_features})

In [None]:
total_mcds.add_mc_rate(var_dim=var_dim,
                       normalize_per_cell=True,
                       clip_norm_value=10)

total_mcds


In [None]:
if downsample and total_cells > downsample:
    # make a downsampled mcds
    print(f'Downsample cells to {downsample} to calculate HVF.')
    downsample_cell_ids = metadata.sample(downsample, random_state=0).index
    mcds = total_mcds.sel(
        {obs_dim: total_mcds.get_index(obs_dim).isin(downsample_cell_ids)})
else:
    mcds = total_mcds

In [None]:
if load and (mcds.get_index('cell').size <= downsample):
    # load the relevant data so the computation can be fater, watch out memory!
    mcds[f'{var_dim}_da_frac'].load()

In [None]:
if hvf_method == 'SVR':
    # use SVR based method
    mch_hvf = mcds.calculate_hvf_svr(var_dim=var_dim,
                                     mc_type=mch_pattern,
                                     n_top_feature=n_top_feature,
                                     plot=True)
else:
    # use bin based method
    mch_hvf = mcds.calculate_hvf(var_dim=var_dim,
                                 mc_type=mch_pattern,
                                 min_mean=0,
                                 max_mean=5,
                                 n_top_feature=n_top_feature,
                                 bin_min_features=5,
                                 mean_binsize=0.05,
                                 cov_binsize=100)

In [None]:
total_mcds.coords[f'{var_dim}_{mch_pattern}_feature_select'] = mcds.coords[
    f'{var_dim}_{mch_pattern}_feature_select']
mch_adata = total_mcds.get_adata(mc_type=mch_pattern,
                           var_dim=var_dim,
                           select_hvf=True)

mch_adata.write_h5ad(f'mCH.HVF.h5ad')

mch_adata

In [None]:
if hvf_method == 'SVR':
    # use SVR based method
    mcg_hvf = mcds.calculate_hvf_svr(var_dim=var_dim,
                                     mc_type=mcg_pattern,
                                     n_top_feature=n_top_feature,
                                     plot=True)
else:
    # use bin based method
    mcg_hvf = mcds.calculate_hvf(var_dim=var_dim,
                                 mc_type=mcg_pattern,
                                 min_mean=0,
                                 max_mean=5,
                                 n_top_feature=n_top_feature,
                                 bin_min_features=5,
                                 mean_binsize=0.02,
                                 cov_binsize=20)

In [None]:
total_mcds.coords[f'{var_dim}_{mch_pattern}_feature_select'] = mcds.coords[
    f'{var_dim}_{mch_pattern}_feature_select']
mcg_adata = total_mcds.get_adata(mc_type=mcg_pattern,
                                 var_dim=var_dim,
                                 select_hvf=True)

mcg_adata.write_h5ad(f'mCG.HVF.h5ad')

mcg_adata