In [None]:
import os
from collections import defaultdict

import anndata
import faiss
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.stats
import seaborn as sns
import statsmodels.api as sm
import yaml
from matplotlib import rcParams
from sklearn.linear_model import LinearRegression
from statsmodels.nonparametric.smoothers_lowess import lowess
from tqdm.notebook import tqdm

import scglue

In [None]:
scglue.plot.set_publication_params()
rcParams["figure.figsize"] = (4, 4)

PATH = "s04_reg_contrib"
os.makedirs(PATH, exist_ok=True)

# Read data

In [None]:
rna = anndata.read_h5ad("s02_glue/rna.h5ad")
met = anndata.read_h5ad("s02_glue/met.h5ad")
atac = anndata.read_h5ad("s02_glue/atac2rna.h5ad")
combined = anndata.read_h5ad("s02_glue/combined.h5ad")

In [None]:
rna.obs["common_cell_type"] = anndata.read_h5ad("s03_markers/rna_filtered.h5ad", backed="r").obs["common_cell_type"]
met.obs["common_cell_type"] = anndata.read_h5ad("s03_markers/met_filtered.h5ad", backed="r").obs["common_cell_type"]
atac.obs["common_cell_type"] = anndata.read_h5ad("s03_markers/atac_filtered.h5ad", backed="r").obs["common_cell_type"]

In [None]:
with open("manual_colors.yaml", "r") as f:
    MANUAL_COLORS = yaml.load(f, Loader=yaml.FullLoader)

# Clustering pseudocells

In [None]:
kmeans = faiss.Kmeans(combined.obsm["X_glue"].shape[1], 200, gpu=False, seed=0)
kmeans.train(combined.obsm["X_glue"])
_, combined.obs["pseudocell"] = kmeans.index.search(combined.obsm["X_glue"], 1)

In [None]:
combined.obs["pseudocell"] = pd.Categorical(combined.obs["pseudocell"])
combined.obs["pseudocell"].cat.rename_categories(lambda x: f"pseudocell-{x}", inplace=True)

In [None]:
rna.obs["pseudocell"] = combined.obs["pseudocell"].loc[rna.obs_names].cat.remove_unused_categories()
met.obs["pseudocell"] = combined.obs["pseudocell"].loc[met.obs_names].cat.remove_unused_categories()
atac.obs["pseudocell"] = combined.obs["pseudocell"].loc[atac.obs_names].cat.remove_unused_categories()
rna.obs["n_cells"] = 1
met.obs["n_cells"] = 1
atac.obs["n_cells"] = 1

In [None]:
mCH = met[:, [item.endswith("_mCH") for item in met.var_names]]
mCG = met[:, [item.endswith("_mCG") for item in met.var_names]]
mCH.var_names = mCH.var_names.str.replace("_mCH", "")
mCG.var_names = mCG.var_names.str.replace("_mCG", "")

In [None]:
rna_agg = scglue.data.aggregate_obs(
    rna, by="pseudocell", X_agg="sum",
    obs_agg={"common_cell_type": "majority", "domain": "majority", "n_cells": "sum"},
    obsm_agg={"X_glue": "mean", "X_glue_umap": "mean"}
)
rna_agg

In [None]:
mCH_agg = scglue.data.aggregate_obs(
    mCH, by="pseudocell", X_agg="mean",
    obs_agg={"common_cell_type": "majority", "domain": "majority", "n_cells": "sum"},
    obsm_agg={"X_glue": "mean", "X_glue_umap": "mean"}
)
mCH_agg

In [None]:
mCG_agg = scglue.data.aggregate_obs(
    mCG, by="pseudocell", X_agg="mean",
    obs_agg={"common_cell_type": "majority", "domain": "majority", "n_cells": "sum"},
    obsm_agg={"X_glue": "mean", "X_glue_umap": "mean"}
)
mCG_agg

In [None]:
atac_agg = scglue.data.aggregate_obs(
    atac, by="pseudocell", X_agg="sum",
    obs_agg={"common_cell_type": "majority", "domain": "majority", "n_cells": "sum"},
    obsm_agg={"X_glue": "mean", "X_glue_umap": "mean"}
)
atac_agg

In [None]:
common_pseudocells = set(rna_agg.obs_names).intersection(
    mCH_agg.obs_names
).intersection(
    mCG_agg.obs_names
).intersection(
    atac_agg.obs_names
)
rna_agg = rna_agg[[item in common_pseudocells for item in rna_agg.obs_names], :].copy()
mCH_agg = mCH_agg[[item in common_pseudocells for item in mCH_agg.obs_names], :].copy()
mCG_agg = mCG_agg[[item in common_pseudocells for item in mCG_agg.obs_names], :].copy()
atac_agg = atac_agg[[item in common_pseudocells for item in atac_agg.obs_names], :].copy()

# Normalization and filtering

In [None]:
sc.pp.normalize_total(rna_agg)
sc.pp.log1p(rna_agg)

# sc.pp.normalize_total(mCH_agg)
sc.pp.log1p(mCH_agg)

# sc.pp.normalize_total(mCG_agg)
sc.pp.log1p(mCG_agg)

sc.pp.normalize_total(atac_agg)
sc.pp.log1p(atac_agg)

In [None]:
rna_agg_stat = pd.DataFrame({
    "mean": rna_agg.X.mean(axis=0).A1,
    "std": np.sqrt(scglue.num.col_var(rna_agg.X))
}, index=rna_agg.var_names)
mCH_agg_stat = pd.DataFrame({
    "mean": mCH_agg.X.mean(axis=0),
    "std": mCH_agg.X.std(axis=0)
}, index=mCH_agg.var_names)
mCG_agg_stat = pd.DataFrame({
    "mean": mCG_agg.X.mean(axis=0),
    "std": mCG_agg.X.std(axis=0)
}, index=mCG_agg.var_names)
atac_agg_stat = pd.DataFrame({
    "mean": atac_agg.X.mean(axis=0).A1,
    "std": np.sqrt(scglue.num.col_var(atac_agg.X))
}, index=atac_agg.var_names)

In [None]:
rna_agg_stat["std_lowess"] = lowess(rna_agg_stat["std"], rna_agg_stat["mean"], frac=0.3, return_sorted=False)
rna_agg_stat["std_remain"] = rna_agg_stat["std"] - rna_agg_stat["std_lowess"]

mCH_agg_stat["std_lowess"] = lowess(mCH_agg_stat["std"], mCH_agg_stat["mean"], frac=0.3, return_sorted=False)
mCH_agg_stat["std_remain"] = mCH_agg_stat["std"] - mCH_agg_stat["std_lowess"]

mCG_agg_stat["std_lowess"] = lowess(mCG_agg_stat["std"], mCG_agg_stat["mean"], frac=0.3, return_sorted=False)
mCG_agg_stat["std_remain"] = mCG_agg_stat["std"] - mCG_agg_stat["std_lowess"]

atac_agg_stat["std_lowess"] = lowess(atac_agg_stat["std"], atac_agg_stat["mean"], frac=0.3, return_sorted=False)
atac_agg_stat["std_remain"] = atac_agg_stat["std"] - atac_agg_stat["std_lowess"]

stats = [rna_agg_stat, mCH_agg_stat, mCG_agg_stat, atac_agg_stat]

In [None]:
fig, axes = plt.subplots(figsize=(16, 3), ncols=4, gridspec_kw=dict(wspace=0.5))
for ax, stat in zip(axes, stats):
    ax = sns.scatterplot(x="mean", y="std", data=stat, edgecolor=None, s=3, ax=ax)
    ax = sns.scatterplot(x="mean", y="std_lowess", data=stat, edgecolor=None, s=3, ax=ax)

In [None]:
mean_cutoffs = [0.5, 0.1, 0.1, 0.5]
std_cutoffs = [-0.02, -0.02, -0.02, -0.02]

In [None]:
fig, axes = plt.subplots(figsize=(16, 3), ncols=4, gridspec_kw=dict(wspace=0.5))
for ax, stat, mean_cutoff, std_cutoff in zip(axes, stats, mean_cutoffs, std_cutoffs):
    ax = sns.scatterplot(x="mean", y="std_remain", data=stat, edgecolor=None, s=3, ax=ax)
    ax.axvline(x=mean_cutoff, c="darkred", ls="--")
    ax.axhline(y=std_cutoff, c="darkred", ls="--")

In [None]:
rna_agg_use = rna_agg[:, np.logical_and(
    rna_agg_stat["mean"] >= mean_cutoffs[0],
    rna_agg_stat["std_remain"] >= std_cutoffs[0]
)]
mCH_agg_use = mCH_agg[:, np.logical_and(
    mCH_agg_stat["mean"] >= mean_cutoffs[1],
    mCH_agg_stat["std_remain"] >= std_cutoffs[1]
)]
mCG_agg_use = mCG_agg[:, np.logical_and(
    mCG_agg_stat["mean"] >= mean_cutoffs[2],
    mCG_agg_stat["std_remain"] >= std_cutoffs[2]
)]
atac_agg_use = atac_agg[:, np.logical_and(
    atac_agg_stat["mean"] >= mean_cutoffs[3],
    atac_agg_stat["std_remain"] >= std_cutoffs[3]
)]

In [None]:
common_genes = list(set(
    rna_agg_use.var_names
).intersection(
    mCH_agg_use.var_names
).intersection(
    mCG_agg_use.var_names
).intersection(
    atac_agg_use.var_names
))
len(common_genes)

In [None]:
rna_agg_use = rna_agg_use[:, common_genes].copy()
mCH_agg_use = mCH_agg_use[:, common_genes].copy()
mCG_agg_use = mCG_agg_use[:, common_genes].copy()
atac_agg_use = atac_agg_use[:, common_genes].copy()

In [None]:
rna_X = rna_agg_use.X.toarray()
mCH_X = mCH_agg_use.X
mCG_X = mCG_agg_use.X
atac_X = atac_agg_use.X.toarray()

# Global

## Correlation

In [None]:
corr = []
for i in tqdm(range(rna_X.shape[1])):
    corr.append([
        scipy.stats.spearmanr(rna_X[:, i], mCH_X[:, i]).correlation,
        scipy.stats.spearmanr(rna_X[:, i], mCG_X[:, i]).correlation,
        scipy.stats.spearmanr(rna_X[:, i], atac_X[:, i]).correlation,
    ])
corr = pd.DataFrame(corr, index=common_genes, columns=["mCH", "mCG", "ATAC"])
corr.head()

In [None]:
def offdiag_func(x, y, ax=None, **kwargs):
    ax = ax or plt.gca()
    ax.axvline(x=0, c="darkred", ls="--")
    ax.axhline(y=0, c="darkred", ls="--")

g = sns.pairplot(
    corr, diag_kind="kde", height=2,
    plot_kws=dict(s=3, edgecolor=None, alpha=0.5, rasterized=True)
).map_offdiag(offdiag_func)
g.savefig(f"{PATH}/corr_cmp.pdf")

In [None]:
gene_stat = rna_agg_stat.loc[common_genes, :].assign(
    gene_length=np.log10(
        rna.var.loc[common_genes, "chromEnd"] -
        rna.var.loc[common_genes, "chromStart"]
    )
)
gene_stat.head()

In [None]:
gene_stat_corr = pd.DataFrame(
    scglue.num.spr_mat(
        gene_stat.loc[:, ["gene_length", "mean", "std_remain"]], corr
    ), index=["Length", "Expr mean", "Expr variability"], columns=corr.columns
).abs()
gene_stat_corr.index.name = "Gene stat"
gene_stat_corr = gene_stat_corr.reset_index().melt(
    id_vars=["Gene stat"], var_name="Omics layer", value_name="Association"
)
gene_stat_corr

In [None]:
ax = sns.lineplot(
    x="Omics layer", y="Association", hue="Gene stat",
    data=gene_stat_corr, lw=2, legend=False
)
ax = sns.scatterplot(
    x="Omics layer", y="Association", hue="Gene stat",
    data=gene_stat_corr, edgecolor=None, ax=ax
)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.legend(title="Gene stat", loc="center left", bbox_to_anchor=(1.05, 0.5), frameon=False)
ax.get_figure().savefig(f"{PATH}/gene_stat_corr.pdf")

## R square

In [None]:
rsquare = defaultdict(list)

for i in range(rna_X.shape[1]):
    y_ = rna_X[:, i]

    X_ = np.stack([mCH_X[:, i], mCG_X[:, i], atac_X[:, i]], axis=1)
    lm = LinearRegression().fit(X_, y_)
    rsquare["Combined"].append(lm.score(X_, y_))

    X_ = np.expand_dims(mCH_X[:, i], axis=1)
    lm = LinearRegression().fit(X_, y_)
    rsquare["mCH"].append(lm.score(X_, y_))

    X_ = np.expand_dims(mCG_X[:, i], axis=1)
    lm = LinearRegression().fit(X_, y_)
    rsquare["mCG"].append(lm.score(X_, y_))

    X_ = np.expand_dims(atac_X[:, i], axis=1)
    lm = LinearRegression().fit(X_, y_)
    rsquare["ATAC"].append(lm.score(X_, y_))

rsquare = pd.DataFrame(rsquare, index=common_genes)
rsquare_melt = rsquare.melt(var_name="Omics layer", value_name="rsquare")

In [None]:
ax = sns.boxplot(
    x="Omics layer", y="rsquare", data=rsquare_melt,
    saturation=1.0, width=0.6, showmeans=True,
    meanprops=dict(marker="^", markerfacecolor="white", markeredgecolor="black"),
    boxprops=dict(edgecolor="black"), medianprops=dict(color="black"),
    whiskerprops=dict(color="black"), capprops=dict(color="black"),
    flierprops=dict(marker=".", markerfacecolor="black", markeredgecolor="none", markersize=3),
)
ax.set_ylabel("Gene expression $R^2$")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.get_figure().savefig(f"{PATH}/rsquare.pdf")

In [None]:
rsquare.mean(axis=0)

# Per cell type

In [None]:
common_cell_types = pd.DataFrame({
    "rna": rna_agg_use.obs["common_cell_type"],
    "mCH": mCH_agg_use.obs["common_cell_type"],
    "mCG": mCG_agg_use.obs["common_cell_type"],
    "atac": atac_agg_use.obs["common_cell_type"]
})
common_cell_types["n_ct"] = common_cell_types.apply(lambda x: len(set(x)), axis=1)
common_cell_types.head()

In [None]:
consistent_pseudocells = common_cell_types.query("n_ct == 1").index
rna_agg_use = rna_agg_use[consistent_pseudocells, :]
mCH_agg_use = mCH_agg_use[consistent_pseudocells, :]
mCG_agg_use = mCG_agg_use[consistent_pseudocells, :]
atac_agg_use = atac_agg_use[consistent_pseudocells, :]

In [None]:
rna_X = rna_agg_use.X.toarray()
mCH_X = mCH_agg_use.X
mCG_X = mCG_agg_use.X
atac_X = atac_agg_use.X.toarray()

In [None]:
ct_sizes = rna_agg_use.obs["common_cell_type"].value_counts()
ct_sizes

In [None]:
min_size = 10
used_cts = ct_sizes.index[ct_sizes >= min_size].to_numpy()
used_cts

## R square

In [None]:
rsquare = {}
n_subsample = 20
rs = np.random.RandomState(0)

for ct in used_cts:
    ct_idx_all = np.where(rna_agg_use.obs["common_cell_type"].to_numpy() == ct)[0]
    rsquare_ct_list = []

    for _ in tqdm(range(n_subsample), desc=ct):
        ct_idx = rs.choice(ct_idx_all, min_size, replace=False)
        rsquare_ct = defaultdict(list)

        for i in range(rna_X.shape[1]):
            y_ = rna_X[ct_idx, i]

            X_ = np.expand_dims(mCH_X[ct_idx, i], axis=1)
            lm = LinearRegression().fit(X_, y_)
            rsquare_ct["mCH"].append(lm.score(X_, y_))

            X_ = np.expand_dims(mCG_X[ct_idx, i], axis=1)
            lm = LinearRegression().fit(X_, y_)
            rsquare_ct["mCG"].append(lm.score(X_, y_))

            X_ = np.expand_dims(atac_X[ct_idx, i], axis=1)
            lm = LinearRegression().fit(X_, y_)
            rsquare_ct["ATAC"].append(lm.score(X_, y_))

        rsquare_ct_list.append(rsquare_ct)
    
    rsquare[ct] = {
        k: np.stack([rsquare_ct_list[i][k] for i in range(n_subsample)]).mean(axis=0)
        for k in ("mCH", "mCG", "ATAC")
    }

rsquare = pd.concat({
    ct: pd.DataFrame(d, index=common_genes)
    for ct, d in rsquare.items()
})
rsquare.index.names = ["Cell type", "Gene"]
rsquare.reset_index(inplace=True)
rsquare["Cell type"] = pd.Categorical(
    rsquare["Cell type"],
    categories=["mL2/3", "mL4", "mL5-1", "mDL-2", "mL6-2"],
    ordered=True
)
rsquare_melt = rsquare.melt(
    id_vars=["Cell type", "Gene"],
    var_name="Omics layer", value_name="rsquare"
)

In [None]:
coefs = {}
pvals = {}
for k in ("mCH", "mCG", "ATAC"):
    regress_data = rsquare_melt.query(f"`Omics layer` == '{k}'")
    X = regress_data["Cell type"].cat.codes
    X = sm.add_constant(X)
    y = regress_data["rsquare"]
    model = sm.OLS(y, X)
    results = model.fit()
    coefs[k] = results.params.loc[0]
    pvals[k] = results.pvalues.loc[0]

In [None]:
fig, ax = plt.subplots(figsize=(6, 5))
ax = sns.boxplot(
    x="Omics layer", y="rsquare", hue="Cell type", data=rsquare_melt,
    saturation=1.0, width=0.8, showmeans=True,
    meanprops=dict(marker="^", markerfacecolor="white", markeredgecolor="black"),
    boxprops=dict(edgecolor="black"), medianprops=dict(color="black"),
    whiskerprops=dict(color="black"), capprops=dict(color="black"),
    flierprops=dict(marker=".", markerfacecolor="black", markeredgecolor="none", markersize=3),
    palette=MANUAL_COLORS, ax=ax
)
text_kws = dict(
    size=12, ha="center", va="center",
    bbox=dict(facecolor="white", alpha=0.8, edgecolor="lightgrey")
)
ax.text(0.0, 0.95, f"$\\beta$ = {coefs['mCH']:.2e}\n$P$ = {pvals['mCH']:.2e}", **text_kws)
ax.text(1.0, 0.95, f"$\\beta$ = {coefs['mCG']:.2e}\n$P$ = {pvals['mCG']:.2e}", **text_kws)
ax.text(2.0, 0.95, f"$\\beta$ = {coefs['ATAC']:.2e}\n$P$ = {pvals['ATAC']:.2e}", **text_kws)
ax.set_ylabel("Gene expression $R^2$")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.legend(frameon=False, loc="center left", bbox_to_anchor=(1.05, 0.5), title="Cell type")
fig.savefig(f"{PATH}/rsquare_ct.pdf")