In [None]:
%load_ext autoreload
%autoreload 2
    
import scprinter as scp
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import time
import pandas as pd
import numpy as np
import os
import pickle
import torch
import random
from tqdm.auto import *

import matplotlib as mpl

mpl.rcParams['pdf.fonttype'] = 42

In [None]:
def strided_axis0(a, L):
    # Store the shape and strides info
    shp = a.shape
    s  = a.strides

    # Compute length of output array along the first axis
    nd0 = shp[0]-L+1

    # Setup shape and strides for use with np.lib.stride_tricks.as_strided
    # and get (n+1) dim output array
    shp_in = (nd0,L)+shp[1:]
    strd_in = (s[0],) + s
    return np.lib.stride_tricks.as_strided(a, shape=shp_in, strides=strd_in)


In [None]:
# shared info to load
barcodegroups = pd.read_table('/data/rzhang/PRINT_rev/mouse_HSC/barcodeGrouping.txt')
groupinfo = pd.read_table("/data/rzhang/PRINT_rev/mouse_HSC/pbulkClusters.txt", header=None)
map1 = {
  "Old_1":"Old_Mk-biased", 
  "Old_2":"Old_intermediate",
  "Old_3":"Old_Mk-biased",
  "Old_4":"Old_multi-lineage",
  "Young_1":"Young_multi-lineage",
  "Young_2":"Young_multi-lineage",
  "Young_3":"Young_Mk-biased"}

groups = barcodegroups['group'].unique()
ct2bc = {k:[] for k in map1.values()}
for i, group in enumerate(groups):
    bc = barcodegroups[barcodegroups['group'] == group]['barcode']
    bc = list(bc)
    ct = map1[groupinfo.iloc[i, 0]]
    ct2bc[ct] += bc


for ct in ct2bc:
    ct2bc[ct] = list(set(ct2bc[ct]))


barcodes = [ct2bc[ct] for ct in ct2bc]
cell_types = [ct for ct in ct2bc]

map1 = {
  "Old_1":"Old", 
  "Old_2":"Old",
  "Old_3":"Old",
  "Old_4":"Old",
  "Young_1":"Young",
  "Young_2":"Young",
  "Young_3":"Young"}

groups = barcodegroups['group'].unique()
ct2bc = {k:[] for k in map1.values()}
for i, group in enumerate(groups):
    bc = barcodegroups[barcodegroups['group'] == group]['barcode']
    bc = list(bc)
    ct = map1[groupinfo.iloc[i, 0]]
    ct2bc[ct] += bc


for ct in ct2bc:
    ct2bc[ct] = list(set(ct2bc[ct]))


barcodes = [ct2bc[ct] for ct in ct2bc]
cell_types = [ct for ct in ct2bc]

In [None]:
printer = scp.pp.import_fragments(
                        path_to_frags= '/data/rzhang/PRINT_rev/mouse_HSC/all.frags.filt.tsv.gz',
                        barcodes=np.unique(barcodegroups['barcode']),
                        savename='/data/rzhang/PRINT_rev/mouse_HSC/scprinter.h5ad',
                        genome=scp.genome.mm10,
                        plus_shift=4,
                        minus_shift=-5,
                        min_num_fragments=0, min_tsse=0,
                        sorted_by_barcode=False, 
                        low_memory=False,
                        )

In [None]:
allCREs = pd.read_csv("/data/rzhang/PRINT_rev/mouse_HSC/regionRanges.bed", sep='\t', header=None)
allCREs[1] = allCREs[1].astype('int')

In [None]:
peak_adata = scp.pp.make_peak_matrix(printer, allCREs)
max_cov = np.sum(peak_adata.X, axis=0)
np.sum(max_cov < 10)
peak_adata = peak_adata[:, max_cov >= 10].copy()

In [None]:
CREs = allCREs.loc[np.array(max_cov).reshape((-1)) >= 10]

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import scanpy as sc
import anndata
import cupy as cp
import cupyx as cpx
import time
# import rapids_singlecell as rsc
# from rapids_singlecell.cunnData import cunnData

import warnings
warnings.filterwarnings("ignore")
import rmm
from rmm.allocators.cupy import rmm_cupy_allocator
rmm.reinitialize(
    managed_memory=True, # Allows oversubscription
    pool_allocator=True, # default is False
    devices=0, # GPU device IDs to register. By default registers only GPU 0.
)

cp.cuda.set_allocator(rmm_cupy_allocator)

In [None]:
scp.chromvar.sample_bg_peaks(peak_adata, 
                             genome=scp.genome.hg38, 
                             method='chromvar', 
                             niterations=250)

## chromvar for cisBP motifs

In [None]:
motifs = scp.motifs.FigR_Mouse_Motifs(scp.genome.mm10, 
                           list(peak_adata.uns['bg_freq']),
                           n_jobs=16,pvalue=5e-5, 
                           mode='motifmatchr' # when turn on mode='motifmatchr', it replicates the behavior and results of motifmatchr
                          )
motifs.prep_scanner(tf_genes=None,pvalue=5e-5)
motifs.chromvar_scan(peak_adata)

In [None]:
chromvar_existing = scp.chromvar.compute_deviations(peak_adata, 
                                        chunk_size=50000, device='cuda')

In [None]:
chromvar_existing

In [None]:
chromvar_existing.obs['age'] = [xx.split("-")[1] for xx in chromvar_existing.obs.index]
from scipy.stats import *
mask = chromvar_existing.obs['age'] == 'Old'
res = []
for i in range(chromvar_existing.X.shape[1]):
    x = chromvar_existing.X[:, i]
    test = ttest_ind(x[mask], x[~mask])
    res.append([test[0], test[1]])
res = np.array(res)
# name = [xx.split("_")[0] for xx in chromvar_existing.var.index]
# qval = [float(xx.split("_")[1]) for xx in chromvar_existing.var.index]
results = pd.DataFrame({'pval':(-np.log10(res[:, 1] + 1e-200)),
                        'stats': res[:, 0],
                        # 'TF':name,
                        # 'qval':qval
                        'cluster': chromvar_existing.var.index,
                        'nhits': peak_adata.varm['motif_match'].sum(axis=0)
                       })
chromvar_existing.write(f'/data/rzhang/PRINT_rev/mouse_HSC/final_model/chromvar_exo.h5ad')
chromvar_existing.var_names_make_unique()
results.to_csv(f'/data/rzhang/PRINT_rev/mouse_HSC/final_model/diff_test_exo.tsv', sep='\t', index=False)

## chromvar for de-novo motifs

In [None]:
# 
strs1 = [
    'finemo.results.young.count',
'finemo.results.old.count',
'finemo.results.young.just_sum',
'finemo.results.old.just_sum',
]
strs2 = [
    'report_count_y',
    'report_count_o',
    'report_just_sum_y',
    'report_just_sum_o',
]

In [None]:
metamotif2tomtom = {}
for str1, str2 in zip(strs1, strs2):
    
    hits_path = f'/data/rzhang/PRINT_rev/mouse_HSC/final_model/{str1}/hits.tsv'
    hits = pd.read_csv(hits_path, sep='\t')
    motif_uniq = np.sort(hits['motif_name'].unique())
    new_names = []
    motif2tomtom = {}
    for motif in motif_uniq:
        res = pd.read_csv(f'/data/rzhang/PRINT_rev/mouse_HSC/final_model/{str2}/tomtom/{motif}.tomtom.tsv', sep='\t')
        motif2tomtom[motif] = res.copy()
    metamotif2tomtom[str2] = motif2tomtom
    motif2id = {m:i for i,m in enumerate(motif_uniq)}
    ids = [motif2id[m] for m in hits['motif_name']]
    match_mm = np.zeros((len(CREs), len(motif_uniq)))
    match_mm[hits['peak_id'], ids] += 1
    match_mm
    peak_adata.varm['motif_match'] = match_mm[peak_mask]
    motif_uniq = [f'{xx}_{str2}' for xx in motif_uniq]
    peak_adata.uns['motif_name'] = motif_uniq
    start = time.time()
    chromvar_denovo = compute_deviations_gpu(peak_adata, 
                                        chunk_size=50000, device='cuda')
    chromvar_denovo.obs['age'] = [xx.split("-")[1] for xx in chromvar_denovo.obs.index]
    from scipy.stats import *
    mask = chromvar_denovo.obs['age'] == 'Old'
    res = []
    for i in range(chromvar_denovo.X.shape[1]):
        x = chromvar_denovo.X[:, i]
        test = ttest_ind(x[mask], x[~mask])
        res.append([test[0], test[1]])
    res = np.array(res)
    # name = [xx.split("_")[0] for xx in chromvar_denovo.var.index]
    # qval = [float(xx.split("_")[1]) for xx in chromvar_denovo.var.index]
    results = pd.DataFrame({'pval':(-np.log10(res[:, 1] + 1e-200)),
                            'stats': res[:, 0],
                            # 'TF':name,
                            # 'qval':qval
                            'cluster': chromvar_denovo.var.index,
                            'nhits': peak_adata.varm['motif_match'].sum(axis=0)
                           })
    chromvar_denovo.write(f'/data/rzhang/PRINT_rev/mouse_HSC/final_model/chromvar_{str2}.h5ad')
    chromvar_denovo.var_names_make_unique()
    results.to_csv(f'/data/rzhang/PRINT_rev/mouse_HSC/final_model/diff_test_{str2}.tsv', sep='\t', index=False)

In [None]:

strs2 = [
    'report_count_y',
    'report_count_o',
    'report_just_sum_y',
    'report_just_sum_o'
]

In [None]:
metamotif2tomtom = {}
for str1, str2 in zip(strs1, strs2):
    
    hits_path = f'/data/rzhang/PRINT_rev/mouse_HSC/final_model/finemo/{str1}/hits.tsv'
    hits = pd.read_csv(hits_path, sep='\t')
    motif_uniq = np.sort(hits['motif_name'].unique())
    new_names = []
    motif2tomtom = {}
    for motif in motif_uniq:
        res = pd.read_csv(f'/data/rzhang/PRINT_rev/mouse_HSC/final_model/{str2}/tomtom/{motif}.tomtom.tsv', sep='\t')
        motif2tomtom[motif] = res.copy()
    metamotif2tomtom[str2] = motif2tomtom

In [None]:
res_all = []
for str2 in strs2:
    res = pd.read_csv(f'/data/rzhang/PRINT_rev/mouse_HSC/final_model/{str2}.tsv', sep='\t')
    motif2tomtom = metamotif2tomtom[str2]
    tf_list = []
    for cluster in tqdm(res['cluster']):
        cluster = "_".join(cluster.split("_")[:3])
        tomtom = motif2tomtom[cluster]
        survival_tf = tomtom[tomtom['q-value'] < 0.05]['Target_ID']
        if len(survival_tf) > 0:
            survival_tf = list(survival_tf)
        else:
            tf = tomtom['Target_ID'][0]
            qval = tomtom['q-value'][0]
            survival_tf = [f'unknown_maybe{tf}_{qval:.2f}']
            
        tf_list.append(survival_tf)
    res['TF'] = tf_list
    res_all.append(res)
res = pd.concat(res_all, axis=0).reset_index(drop=True)
res

res = res.sort_values('stats', ascending=False)
res

cluster2tf = {c:tf for (c,tf) in zip(res['cluster'],res['TF'])}

In [None]:
res = res.loc[kept]

In [None]:
chromvar_denovo = []
for str2 in strs2:
    chromvar_denovo.append(anndata.read_h5ad(f'/data/rzhang/PRINT_rev/mouse_HSC/final_model/{str2}.h5ad'))
chromvar_denovo = anndata.concat(chromvar_denovo, axis=1)
chromvar_denovo

In [None]:
chromvar_denovo = chromvar_denovo[:, res['cluster']].copy()
chromvar_denovo

In [None]:
std_denovo = np.std(chromvar_denovo.X, axis=0)
std_denovo = kept_denovos['stats']

In [None]:
import h5py
mms = []
for cluster in kept_denovos['cluster']:
    info = cluster.split(".")
    cluster = info[0]
    info = info[1].split("_")
    pattern = "_".join(info[:2])
    class_ = "_".join(info[2:])
    class_ = class_.replace('report', 'modisco_results')
    with h5py.File(f'/data/rzhang/PRINT_rev/mouse_HSC/final_model/{class_}.h5', 'r') as f:
        m = f[cluster][pattern]['sequence'][:].T
        m = np.concatenate([np.ones((4, 5)), m, np.ones((4, 5))], axis=-1)
        mms.append(m)


In [None]:
motifs = scp.motifs.Motifs("./mouse_pfms_v4.txt", 
                           scp.genome.mm10.fetch_fa(), scp.genome.mm10.bg)
motif2matrix = {motif.name.split("_")[2]: np.array([motif.counts['A'], motif.counts['C'], motif.counts['G'], motif.counts['T']]) for motif in motifs.all_motifs}

In [None]:
def normalize_pfm(pfm):
    """Convert PFM to PPM by normalizing each column."""
    # pfm += 0.08
    return pfm / pfm.sum(axis=0)

def pearson_correlation(col1, col2):
    col1 = col1.reshape((-1))
    col2 = col2.reshape((-1))
    """Compute Pearson correlation coefficient using NumPy."""
    mean_col1 = np.mean(col1)
    mean_col2 = np.mean(col2)
    numerator = np.sum((col1 - mean_col1) * (col2 - mean_col2))
    denominator = np.sqrt(np.sum((col1 - mean_col1) ** 2) * np.sum((col2 - mean_col2) ** 2))
    if denominator == 0:
        return 0  # Avoid division by zero
    return numerator / denominator
    
def cross_correlation(pfm1, pfm2):
    """Compute cross-correlation between two PFMs with different shapes."""
    # Normalize the PFMs
    ppm1 = normalize_pfm(np.copy(pfm1))
    ppm2 = normalize_pfm(np.copy(pfm2))
    ppm1 -= np.ones_like(ppm1) * 0.25
    ppm2 -= np.ones_like(ppm2) * 0.25
    # print (ppm1)
    # print (ppm1.shape, ppm2.shape, ppm1, ppm2)
    len1, len2 = ppm1.shape[1], ppm2.shape[1]
    # print (ppm1, ppm2)
    larger_ppm, smaller_ppm = (ppm1, ppm2) if len1 >= len2 else (ppm2, ppm1)
    larger_len, smaller_len = max(len1, len2), min(len1, len2)
    
    scores = []
    
    for offset in range(0, larger_len - smaller_len):
        # score = np.sum(larger_ppm[:, offset:offset + smaller_len] * smaller_ppm)
        # score /= np.sum(smaller_ppm * np.ones_like(smaller_ppm) * 0.25)
        score = pearson_correlation(larger_ppm[:, offset:offset + smaller_len],  smaller_ppm)
        scores.append(score)
    larger_ppm = larger_ppm[::-1][:, ::-1]
    for offset in range(0, larger_len - smaller_len):
        # score = np.sum(larger_ppm[:, offset:offset + smaller_len] * smaller_ppm)
        # score /= np.sum(smaller_ppm * np.ones_like(smaller_ppm) * 0.25)
        score = pearson_correlation(larger_ppm[:, offset:offset + smaller_len],  smaller_ppm)
        scores.append(score)
    # print (scores, np.max(scores))
    return scores

In [None]:
sims = []
shorted_name = []
kind = []
for dnv_m, tf in zip(mms, kept_denovos['TF']):
    tf = tf[1:-1]
    tfs = tf.replace('\'', '').split(', ')
    if "unknown" not in tfs[0]:
        # tf = tfs[0]
        s = []
        for tf in tfs:
            m = motif2matrix[tf]
            s.append(np.max(cross_correlation(dnv_m, m)))
        sims.append(np.max(s))
        kind.append('known')
    else:
        tf = tfs[0].split("_")[1].replace('maybe', '')
        m = motif2matrix[tf]
        sims.append(np.max(cross_correlation(dnv_m, m)))
        tfs = [tf]
        kind.append('unknown')
    
    shorted_name.append(", ".join(tfs))

In [None]:
data = {'chromvar_std': std_denovo, 'max_correlation': sims, 'cluster':kept_denovos['cluster'] , 'names': shorted_name, 'kind': kind}
data = pd.DataFrame(data)

In [None]:
import pandas as pd
from plotnine import *

colors = ['blue', 'red']
label_data = data[np.abs(data['chromvar_std']) > 25]
# Create a ggplot2 style scatter plot
fig = (ggplot(data, aes(x='max_correlation', y='chromvar_std', label='names', color='kind',)) + 
geom_point(size=3) +  # Adjust point size as needed +
    geom_text(aes(x='max_correlation', y='chromvar_std', label='names'), 
              data=label_data, format_string='', 
              ha='center', va='center', color='black', adjust_text={
            'arrowprops': {
                'arrowstyle': '->',
                'color': 'black',
            }}) + \
    # geom_abline(intercept = 0, slope = 1) + \
    # xlim(0, 40) + \
    # ylim(0, 40) + \
    # theme_minimal() +
    theme(legend_position='right'))  # Positioning the legend
fig.save('/data/rzhang/PRINT_rev/mouse_HSC/chromvar_stats_denovo_corr_cisbp.pdf', height=5, width=6)
data.to_csv('/data/rzhang/PRINT_rev/mouse_HSC/chromvar_stats_denovo_corr_cisbp.tsv', sep='\t', index=False)
fig.draw()