In [1]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import warnings
from contextlib import contextmanager
@contextmanager
def ignore_warnings():
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        yield

from numba.core.errors import NumbaDeprecationWarning
warnings.filterwarnings('ignore', category=NumbaDeprecationWarning)
import scanpy as sc
import anndata as ad
import mudata as md
import muon as mu
import pyranges as pr

%config InlineBackend.figure_format = 'retina'

2024-12-08 16:28:58.631036: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-08 16:28:58.646811: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-08 16:28:58.665659: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-08 16:28:58.671400: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-08 16:28:58.685628: I tensorflow/core/platform/cpu_feature_guar

In [2]:
from stwg_grn_params import *

RAW_MDATA_PATH=../data/stwg-v2.h5mu
MDATA_PATH=../data/stwg-v2-filtered.h5mu
OUTPUT_DIR=../analysis/
JASPAR_PATH=../data/JASPAR2024_CORE_vertebrates_non-redundant_pfms_jaspar.txt
REFTSS_PATH=../data/reftss.pkl
HPA_PATH=../data/hpa_tfs.pkl
CHIPATLAS_PATH=../data/chipatlas_kidney_promoters.pkl
PLATFORM=batch
SAMPLE=sample
CELLTYPE=celltype
GEX=rna
ACC=atac
MIN_CELLS=30
PROXIMAL_BP=5000
RANDOM_STATE=0
NORMALIZE_TOTAL=False
NUM_CELLS=None
READS_PER_CELL=None
NUM_TOP_GENES=1000
MIN_SAMPLES=-1
NUM_TREES=20
LEARNING_RATE=0.5
MAX_DEPTH=None
EARLY_STOPPING=3
FEATURE_FRACTION=1
BAGGING_FRACTION=1
LEAVE_P_OUT=2
IMPORTANCE_THRESHOLD=0.95
CORRELATION_THRESHOLD=0.2


In [3]:
mdata = mu.read('../data/stwg-v2-filtered.h5mu')

  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


In [4]:
# map genes to proximal regulatory loci

from collections import defaultdict
pr_acc = mdata.var.reset_index()[['index','atac:Chromosome','atac:Start','atac:End']].dropna()
pr_acc.columns = [ACC,'Chromosome','Start','End']
pr_acc = pr.PyRanges(pr_acc.drop_duplicates())

pr_gex = mdata.var.reset_index()[['index','rna:seqname','rna:start','rna:end']].dropna()
pr_gex.columns = [GEX,'Chromosome','Start','End']
pr_gex = pr_gex.rename(columns={'Gene_symbol':GEX})
pr_gex = pr.PyRanges(pr_gex.drop_duplicates())

with ignore_warnings():
    gex2acc = pr_gex.join(pr_acc, slack=PROXIMAL_BP).df.groupby(GEX,observed=True)[ACC].agg(set)

gex2acc = defaultdict(list, gex2acc.apply(list))

len(gex2acc)

20187

In [5]:
tfs = !grep -oP '^>MA\d+\.\d+\t(.*)$' {JASPAR_PATH} | cut -f2
tfs = set('::'.join(tfs).upper().split('::'))
len(tfs)

755

In [6]:
def preprocess_data(gex, acc, obs):
    if len(gex) > 0 and len(acc) > 0:
        cells = list(set(gex.obs_names)&set(acc.obs_names))
        gex = gex[cells]
        acc = acc[cells]
        obs = obs.loc[cells]

    gex = gex.copy()
    acc = acc.copy()
    obs = obs.copy()
    
    if READS_PER_CELL:
        sc.pp.downsample_counts(gex, total_counts=READS_PER_CELL*len(gex))

    obs['total_counts'] = gex.X.sum(1)

    sc.pp.filter_genes(gex, min_cells=MIN_CELLS)
    sc.pp.filter_genes(acc, min_cells=MIN_CELLS)

    if gex.shape[1] > 0:
        if NORMALIZE_TOTAL:
            sc.pp.normalize_total(gex)
        sc.pp.log1p(gex)
        with ignore_warnings():
            sc.pp.highly_variable_genes(gex, batch_key=SAMPLE, n_top_genes=NUM_TOP_GENES)

    num_samples = obs[SAMPLE].nunique()
    hvgs = gex.var_names[gex.var.highly_variable_nbatches == num_samples].tolist()

    if acc.shape[1] > 0:
        acc.X[:] = acc.X > 0

    return gex, acc, obs, hvgs

def prepare_gbm_data(gene, gex, acc, *, tfs, gex2acc):
    X_tfs = gex[:,gex.var_names.isin(tfs) & (gex.var_names != gene)].to_df() # all TFs

    X_acc = acc[:,acc.var_names.isin(gex2acc[gene])].to_df() # proximal loci not annotated as this gene's promoter

    X = pd.concat([X_tfs, X_acc], axis=1)
    X = X.loc[:,X.std(0) > 0]
    y = gex[:,gene].to_df().squeeze()

    return X, y

def fit_gbm(X_train, y_train, X_val, y_val):
    model = LGBMRegressor(
        random_state=RANDOM_STATE, 
        # force_col_wise=True, 
        verbose=-1, 
        importance_type='gain', # split
        max_depth=MAX_DEPTH,
        n_estimators=NUM_TREES,
        learning_rate=LEARNING_RATE,
        feature_fraction=FEATURE_FRACTION,
        bagging_freq=int(BAGGING_FRACTION < 1),
        bagging_fraction=BAGGING_FRACTION,
        callbacks=[lgb.early_stopping(EARLY_STOPPING)] if EARLY_STOPPING else None
    )

    model.fit(X_train.values, y_train.values, eval_set=(X_val.values, y_val.values))

    importances = pd.Series(model.feature_importances_, index=X_train.columns)

    return model, importances

def cv_fit_predict_gbm(X, y, obs):
    groups = obs[SAMPLE]

    leave_p_out = LEAVE_P_OUT if LEAVE_P_OUT > 0 else groups.nunique() + LEAVE_P_OUT
    cv = LeavePGroupsOut(leave_p_out)

    # y_pred = pd.Series(index=y.index, dtype=float)

    y_pred = {}
    importances = {}

    for train_index, test_index in cv.split(X, y, groups=groups):
        split = (tuple(sorted(set(groups.iloc[train_index]))), 
                 tuple(sorted(set(groups.iloc[test_index]))))

        X_train, X_test = X.iloc[train_index], X.iloc[test_index]
        y_train, y_test = y.iloc[train_index], y.iloc[test_index]

        model, importances[split] = fit_gbm(X_train, y_train, X_test, y_test)

        y_pred[split] = pd.Series(model.predict(X), index=y.index)

    # importances = aggregate_feature_importances(importances)
    y_pred = pd.DataFrame(y_pred)
    y_pred.columns.names = ['train','test']
    importances = pd.DataFrame(importances)
    importances.columns.names = ['train','test']

    return y_pred, importances

def get_scores_per_celltype(y, y_pred, obs):
    scores = {}
    for train, test in y_pred.columns:
        for c, _obs in obs.groupby(CELLTYPE,observed=True):
            for _test in test:
                __obs = _obs[_obs[SAMPLE] == _test]
                _y = y.loc[__obs.index]
                _y_pred = y_pred.loc[__obs.index, (train,test)].squeeze()
                if _y.std() == 0 or _y_pred.std() == 0:
                    continue
                scores[(train,_test,c)] = _y.corr(_y_pred)
    return pd.Series(scores)

def fit_gene_model(gene, gex, acc, obs, *, tfs, gex2acc):
    X, y = prepare_gbm_data(gene, gex, acc, tfs=tfs, gex2acc=gex2acc)
    y_pred, importances = cv_fit_predict_gbm(X, y, obs)
    scores = get_scores_per_celltype(y, y_pred, obs)
    return scores, importances

def xcorr(X, Y):
    # Standardize X and Y - note gex is already logged here
    X_std = (X - X.mean(0)) / X.std(0)
    Y_std = (Y - Y.mean(0)) / Y.std(0)
    
    # Compute the number of observations
    n = X.shape[0]
    
    # Compute the cross-correlation matrix
    cross_corr_matrix = np.dot(X_std.T, Y_std) / n
    
    return cross_corr_matrix

def compute_celltype_correlations(gex, obs, *, tfs, hvgs):
    C = {}
    genes = set(gex.var_names)
    tfs = list(tfs&genes)
    hvgs = list(hvgs)
    for c, _obs in obs[[CELLTYPE]].groupby(CELLTYPE,observed=True):
        _X = gex[_obs.index,tfs].to_df()
        _Y = gex[_obs.index,hvgs].to_df()
        _X = _X.loc[:,_X.std(0)>0]
        _Y = _Y.loc[:,_Y.std(0)>0]
        _C = pd.DataFrame(xcorr(_X,_Y),index=_X.columns,columns=_Y.columns)
        _C = _C.melt(ignore_index=False).reset_index()
        _C.columns = ['tf','target','rho']
        C[c] = _C.set_index(['tf','target'])['rho']
    C = pd.concat(C,axis=1).fillna(0)
    return C

In [10]:
from sklearn.model_selection import LeavePGroupsOut
import lightgbm as lgb
from lightgbm import LGBMRegressor
from tqdm import tqdm

grns = {}
celltype_scores = {}

for platform, obs in mdata.obs.groupby(PLATFORM,observed=True):
    print(platform)
    
    _mdata = mdata[obs.index]

    gex, acc, obs, hvgs = preprocess_data(_mdata[GEX], _mdata[ACC], _mdata.obs)

    _celltype_scores = {}
    importances = {}

    for gene in tqdm(hvgs):
        _celltype_scores[gene], _importances = fit_gene_model(
            gene,
            gex, acc, obs,
            tfs=tfs, gex2acc=gex2acc
        )
        _importances = _importances[_importances.index.isin(tfs)]
        importances[gene] = _importances
        
    celltype_scores[platform] = pd.DataFrame(_celltype_scores)
    importances = pd.concat(importances).rank(axis=0,pct=True,method='dense')
    importances.index.names = ['target','tf']
        
    celltype_corrs = []
    for train, test in importances.columns:
        _obs = obs[obs[SAMPLE].isin(train)]
        _celltype_corrs = compute_celltype_correlations(
            gex[_obs.index], _obs, tfs=tfs, hvgs=hvgs
        )
        _celltype_corrs['train'] = [train] * len(_celltype_corrs)
        celltype_corrs.append(_celltype_corrs)
    celltype_corrs = pd.concat(celltype_corrs)
    celltype_corrs = celltype_corrs.reset_index().set_index(['train','target','tf'])
    
    _grns = (
        importances.melt(ignore_index=False)
        .reset_index()
        .set_index(['train','tf','target'])
        .drop(columns='test')
    )
    _grns = _grns[_grns['value']>=IMPORTANCE_THRESHOLD][[]]
    _grns = _grns.join(celltype_corrs,how='inner')
    grns[platform] = _grns

celltype_scores = pd.concat(celltype_scores)
celltype_scores.index.names = [PLATFORM, 'train', 'test', CELLTYPE]
celltype_scores.to_pickle(f'{OUTPUT_DIR}/scores.pickle')
    
grns = pd.concat(grns)
grns.index.names = [PLATFORM, *grns.index.names[1:]]
grns.to_pickle(f'{OUTPUT_DIR}/grns.pickle')