# scDoRI model training and downstream analysis

In [1]:
import logging
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from pathlib import Path
from sklearn.preprocessing import OneHotEncoder

from scdori import (
    trainConfig,
    load_scdori_inputs,
    save_model_weights,
    set_seed,
    scDoRI,
    train_scdori_phases,
    train_model_grn,
    initialize_scdori_parameters,
    load_best_model,
)

logger = logging.getLogger(__name__)

#### Loading and preparing data for training and model initialisation

In [None]:
logging.basicConfig(level=trainConfig.logging_level)
logger.info("Starting scDoRI pipeline with integrated GRN.")
set_seed(trainConfig.random_seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

## 1. load data

uses the path specified in config file to load processed RNA and ATAC anndata as well as precomputed insilico-chipseq matrix and peak-gene distances

In [None]:
rna_metacell, atac_metacell, gene_peak_dist, insilico_act, insilico_rep = (
    load_scdori_inputs(trainConfig)
)
gene_peak_fixed = gene_peak_dist.clone()
gene_peak_fixed[gene_peak_fixed > 0] = 1  # mask for peak-gene links based on distance

## 2. computing indices of genes which are TFs and setting number of cells per metacell ( set to 1 for single cell data)

In [4]:
# computing indices of genes which are TFs and setting number of cells per metacell ( set to 1 for single cell data)
rna_metacell.obs["num_cells"] = 1
rna_metacell.var["index_int"] = range(rna_metacell.shape[1])
tf_indices = rna_metacell.var[rna_metacell.var.gene_type == "TF"].index_int.values
num_cells = rna_metacell.obs.num_cells.values.reshape((-1, 1))

## 3. onehot encoding the batch column for entire dataset

In [None]:
batch_col = trainConfig.batch_col
rna_metacell.obs["batch"] = rna_metacell.obs[batch_col].values
atac_metacell.obs["batch"] = atac_metacell.obs[batch_col].values
# obtaining onehot encoding for technical batch,
from sklearn.preprocessing import OneHotEncoder

enc = OneHotEncoder(handle_unknown="ignore")
enc.fit(rna_metacell.obs["batch"].values.reshape(-1, 1))

onehot_batch = enc.transform(rna_metacell.obs["batch"].values.reshape(-1, 1)).toarray()
enc.categories_

## 4. making train and evaluation datasets

In [6]:
# 2) Make small train/test sets
n_cells = rna_metacell.n_obs
indices = np.arange(n_cells)
train_idx, eval_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_dataset = TensorDataset(torch.from_numpy(train_idx))
train_loader = DataLoader(
    train_dataset, batch_size=trainConfig.batch_size_cell, shuffle=True
)

eval_dataset = TensorDataset(torch.from_numpy(eval_idx))
eval_loader = DataLoader(
    eval_dataset, batch_size=trainConfig.batch_size_cell, shuffle=False
)

## 5. Build scDoRI model using parameters from config file

In [7]:
num_genes = rna_metacell.n_vars
num_peaks = atac_metacell.n_vars

num_tfs = insilico_act.shape[1]

num_batches = onehot_batch.shape[1]
model = scDoRI(
    device=device,
    num_genes=num_genes,
    num_peaks=num_peaks,
    num_tfs=num_tfs,
    num_topics=trainConfig.num_topics,
    num_batches=num_batches,
    dim_encoder1=trainConfig.dim_encoder1,
    dim_encoder2=trainConfig.dim_encoder2,
).to(device)

## 6. initialising scDoRI model with precomputed matrices and setting gradients 

initialising with precomputed insilico-chipseq matrices and distance dependent peak-gene links

also setting corresponding gradients for TF-gene links to False as they are not updated in Phase 1 of training

In [None]:
initialize_scdori_parameters(
    model,
    gene_peak_dist.to(device),
    gene_peak_fixed.to(device),
    insilico_act=insilico_act.to(device),
    insilico_rep=insilico_rep.to(device),
    phase="warmup",
)

#### Train Phase 1 of scDoRI model 

here topics are inferred using reconstruction of ATAC peaks (module 1), reconstruction of RNA from predicted ATAC (module 2) and reconstruction of TF expression (module 3)

Warmup start is used where only module 1 and module 3 are trained for some initial epochs before adding module 2 

In [None]:
model = train_scdori_phases(
    model,
    device,
    train_loader,
    eval_loader,
    rna_metacell,
    atac_metacell,
    num_cells,
    tf_indices,
    onehot_batch,
    config,
)

In [None]:
# saving the model weight correspoinding to final epoch where model stopped training
save_model_weights(model, Path(traintrainConfig.weights_folder_scdori), "scdori_final")

#### Train Phase 2 of scDoRI model 
here activator and repressor TF-gene links per topic are inferred (module 4)

optionally the encoder and other model parameters from module 1,2,3 are frozen for stability

## 7. Load best checkpoint from Phase 1

In [None]:
# Phase 2

# loading the best checkpoint from Phase 1
model = load_best_model(
    model, Path(trainConfig.weights_folder_scdori) / "best_scdori_best_eval.pth", device
)

## 8. Set gradients for Phase 2 training

TF-gene links are learnt in this step

In [None]:
initialize_scdori_parameters(
    model,
    gene_peak_dist,
    gene_peak_fixed,
    insilico_act=insilico_act,
    insilico_rep=insilico_rep,
    phase="grn",
)

## 9. Phase 2 training and saving model weights


In [None]:
# train Phase 2 of scDoRI model, TF-gene links are learnt in this phase and used to reconstruct gene-expression profiles
model = train_model_grn(
    model,
    device,
    train_loader,
    eval_loader,
    rna_metacell,
    atac_metacell,
    num_cells,
    tf_indices,
    onehot_batch,
    trainConfig,
)

In [None]:
# saving the model weight correspoinding to final epoch where model stopped training
save_model_weights(model, Path(trainConfig.weights_folder_grn), "scdori_final")

#### Downstream analysis

scDoRI supports a comprehensive suite of downstream analyses for single-cell multiome RNA-ATAC data. These include dimensionality reduction using latent topics, identification of gene and peak programs associated with each topic, inference of enhancer–gene interactions, and construction of topic-specific transcription factor–gene regulatory networks (GRNs).

We demonstrate these capabilities using the mouse gastrulation dataset from https://www.biorxiv.org/content/10.1101/2022.06.15.496239v1

In [None]:
from scdori.downstream import (
    load_best_model,
    compute_neighbors_umap,
    compute_topic_peak_umap,
    compute_topic_gene_matrix,
    compute_atac_grn_activator_with_significance,
    compute_atac_grn_repressor_with_significance,
    compute_significant_grn,
    visualize_downstream_targets,
    plot_topic_activation_heatmap,
    get_top_activators_per_topic,
    get_top_repressor_per_topic,
    compute_activator_tf_activity_per_cell,
    compute_repressor_tf_activity_per_cell,
    save_regulons,
)
from scdori.evaluation import get_latent_topics
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import scipy
from scdori.train_grn import get_tf_expression

## 10. Load best checkpoint model

In [None]:
model = load_best_model(
    model, Path(trainConfig.weights_folder_grn) / "best_scdori_best_eval.pth", device
)

## 11. Computing and visualising latent topic activity per cell

In [None]:
# creating dataloader for all cells
n_cells = rna_metacell.n_obs
indices = np.arange(n_cells)

all_dataset = TensorDataset(torch.from_numpy(indices))
all_dataset_loader = DataLoader(
    all_dataset, batch_size=trainConfig.batch_size_cell_prediction, shuffle=False
)

In [None]:
# get scDoRI latent embedding (topics)
scdori_latent = get_latent_topics(
    model,
    device,
    all_dataset_loader,
    rna_metacell,
    atac_metacell,
    num_cells,
    tf_indices,
    onehot_batch,
)

In [None]:
# adding scDoRI embedding to the anndata object
rna_metacell.obsm["X_scdori"] = scdori_latent

In [None]:
## adding color palette
colPalette_celltypes = [
    "#532C8A",
    "#c19f70",
    "#f9decf",
    "#c9a997",
    "#B51D8D",
    "#3F84AA",
    "#9e6762",
    "#354E23",
    "#F397C0",
    "#ff891c",
    "#635547",
    "#C72228",
    "#f79083",
    "#EF4E22",
    "#989898",
    "#7F6874",
    "#8870ad",
    "#647a4f",
    "#EF5A9D",
    "#FBBE92",
    "#139992",
    "#cc7818",
    "#DFCDE4",
    "#8EC792",
    "#C594BF",
    "#C3C388",
    "#0F4A9C",
    "#FACB12",
    "#8DB5CE",
    "#1A1A1A",
    "#C9EBFB",
    "#DABE99",
    "#65A83E",
    "#005579",
    "#CDE088",
    "#f7f79e",
    "#F6BFCB",
]

cell_type_sorted = sorted(list(set(rna_metacell.obs["celltype"].values)))

color_dict = dict(zip(cell_type_sorted, colPalette_celltypes))

col_sorted = []
for i in sorted(cell_type_sorted):
    col_sorted.append(color_dict[i])

rna_metacell.uns["celltype_colors"] = col_sorted
# rna_metacell.uns['celltype_plot_colors']= col_sorted
# rna_ad_tl.uns['celltype_colors']= col_sorted

In [None]:
# computing neighbourhood graph and UMAP based on scDoRI embedding, UMAP parameters can be set in config file
compute_neighbors_umap(rna_metacell, rep_key="X_scdori")

In [None]:
# visualing cell-types on scDoRI computed UMAP
sns.set(font_scale=0.3)
sns.set_style("whitegrid")
with plt.rc_context({"figure.figsize": (7, 7), "figure.dpi": (200)}):
    sc.pl.umap(
        rna_metacell,
        color=["celltype"],
        add_outline=True,
        outline_color=("white", "black"),
        size=10,
    )

## 12. Computing average topic activation in different celltypes/groups


In [None]:
df_topic_celltype = plot_topic_activation_heatmap(
    rna_metacell, groupby_key=["celltype"], aggregation="mean"
)

In [None]:
sns.histplot(df_topic_celltype.max(axis=1))

In [None]:
# removing topics not active highly in any of the celltypes
select_topics = [
    "Topic_" + str(k) for k in np.where(df_topic_celltype.max(axis=1) > 0.06)[0]
]

In [None]:
len(select_topics)

In [None]:
# masking low activations and removing topics not active
%matplotlib inline
sns.set(font_scale=0.4)
sns.clustermap(
    df_topic_celltype.loc[select_topics],
    cmap="Greens",
    vmin=0.05,
    vmax=0.4,
    figsize=(10, 10),
)

## 13. single cell level visualisation of scDoRI Topics

In Erythroid Trajectory from Blood Progenitors to Mature Erythroids: We can see that Topic 26 is active in progenitors whereas topic 4 and 9 capture a continuum of erythroid and progenitor programs along the trajectory

In transition of bipotent progenitors NMP (Neuromesodermal progenitors) to spinal cord or mesodermal lineage: We can see that Topic 34 and 7 capture transition to mesoderm and spinal cord respectively


In [None]:
for k in range(trainConfig.num_topics):
    rna_metacell.obs["Topic_" + str(k)] = scdori_latent[:, k]

In [None]:
sns.set(font_scale=0.8)
sns.set_style("whitegrid")


with plt.rc_context({"figure.figsize": (5, 5), "figure.dpi": (150)}):
    sc.pl.umap(
        rna_metacell,
        color=["celltype", "Topic_26", "Topic_4", "Topic_9"],
        add_outline=True,
        outline_color=("white", "black"),
        size=10,
        legend_loc="none",
    )

In [None]:
# visualising NMP transitions to Mesoderm [Topic 34] and Spinal Cord [Topic 7] Trajectory
sns.set(font_scale=0.8)
sns.set_style("whitegrid")

celltype_plot_list = [
    "Mixed_mesoderm",
    "Somitic_mesoderm",
    "Nascent_mesoderm",
    "Intermediate_mesoderm",
    "Forebrain_Midbrain_Hindbrain",
    "Spinal_cord",
    "Caudal_Mesoderm",
    "NMP",
    "ExE_mesoderm",
    "Allantois",
]
adata_plot = rna_metacell[rna_metacell.obs.celltype.isin(celltype_plot_list), :]

with plt.rc_context({"figure.figsize": (5, 5), "figure.dpi": (100)}):
    sc.pl.umap(
        adata_plot,
        color=["celltype", "Topic_34", "Topic_7"],
        add_outline=True,
        outline_color=("white", "black"),
        size=10,
    )

## 14. Computing Top genes per topic 

this can be used for further analysis such as gene-set enrichment


In [None]:
topic_gene_embedding = compute_topic_gene_matrix(model, device)

In [None]:
sns.clustermap(topic_gene_embedding, vmax=0.01)

## 15. Performing gene-set enrichment analysis on each topic

adapted from https://decoupler-py.readthedocs.io/en/latest/notebooks/msigdb.html#MSigDB-gene-sets

users can play around with other gene sets of their choice


In [None]:
#!pip install decoupler
### caution! this can create dependency issues in the conda environment

In [None]:
#!pip install liana

In [None]:
import decoupler as dc

In [None]:
# anndata with topic gene values
adata_gene = sc.AnnData(topic_gene_embedding)
adata_gene.obs.index = ["Topic_" + str(i) for i in range(model.num_topics)]
adata_gene.var.index = [
    s.upper() for s in rna_metacell.var.index
]  # dirty and incorrect hack to convert mouse gene names to human
adata_gene.raw = adata_gene

In [None]:
# using msigdb
msigdb = dc.get_resource("MSigDB")

In [None]:
msigdb["collection"].unique()

In [None]:
# using msigdb celltype signature set
msigdb = dc.get_resource("MSigDB")
# Filter by hallmark
msigdb = msigdb[msigdb["collection"] == "cell_type_signatures"]

# Remove duplicated entries
msigdb = msigdb[~msigdb.duplicated(["geneset", "genesymbol"])]
msigdb

In [None]:
dc.run_ora(
    mat=adata_gene, net=msigdb, source="geneset", target="genesymbol", verbose=True
)

In [None]:
acts = dc.get_acts(adata_gene, obsm_key="ora_estimate")

# We need to remove inf and set them to the maximum value observed
acts_v = acts.X.ravel()
max_e = np.nanmax(acts_v[np.isfinite(acts_v)])
acts.X[~np.isfinite(acts.X)] = max_e

acts

In [None]:
# plot top celltype signatures per topic
df_acts = acts.to_df()
top_programs_per_topic = df_acts.idxmax(axis=1)
unique_top_programs = top_programs_per_topic.unique()
df_topic_program = df_acts.loc[:, unique_top_programs]

In [None]:
sns.set(font_scale=1)
sns.clustermap(df_topic_program.loc[select_topics], figsize=(30, 30))

###### 
we can see enrichments of respective celltype programs in different topics such Cardiomycoyte in topic 25, Erythroblast in topic 9, Liver hepatoblasts topic 19 

## 16. Computing and visualising peaks associated with each topic

peaks associated with a topic should capture co-accesibility patterns

we visualise average accesibility of peaks (on a UMAP) in different celltypes and their association to a topic, to see if topics have captured co-accesibility patterns. Each point on the UMAP is a peak.

In [None]:
umap_embedding_peaks, topic_peak_embedding = compute_topic_peak_umap(model, device)

In [None]:
## creating anndata with observations as peaks and values as topic association of each peak
adata_peak = sc.AnnData(topic_peak_embedding)
adata_peak.var.index = ["Topic_" + str(i) for i in range(model.num_topics)]
adata_peak.obs.index = atac_metacell.var.index
adata_peak.obsm["X_umap"] = umap_embedding_peaks

In [None]:
atac_metacell.obs["celltype"] = rna_metacell.obs["celltype"].copy()

In [None]:
# computing average accesiblity of peaks in each celltype
atac_metacell.layers["counts"] = atac_metacell.X
sc.pp.normalize_total(atac_metacell)
aggregated_atac = sc.get.aggregate(atac_metacell, by="celltype", func=["mean"])
aggregated_atac.X = aggregated_atac.layers["mean"]
sc.pp.normalize_total(aggregated_atac)
sc.pp.scale(aggregated_atac)

# adding average accesibility of each peak in a celltype to peak anndata
peak_celltype_df = aggregated_atac.to_df().T
peak_celltype_df = peak_celltype_df.loc[adata_peak.obs.index.values]
adata_peak.obs = pd.concat([adata_peak.obs, peak_celltype_df], axis=1)

In [None]:
celltype_name = "Erythroid1"
sns.set(font_scale=1.5)
sc.pl.umap(
    adata_peak, color=[celltype_name, "Topic_9", "Topic_4"], cmap="RdBu_r"
)  # topic 9 and 4 are active in erythroids from heatmap above

In [None]:
celltype_name = "NMP"
sns.set(font_scale=1.5)
sc.pl.umap(adata_peak, color=[celltype_name, "Topic_34", "Topic_7"], cmap="RdBu_r")

In [None]:
celltype_name = "Spinal_cord"
sns.set(font_scale=1.5)
sc.pl.umap(adata_peak, color=[celltype_name, "Topic_7"], cmap="RdBu_r")

In [None]:
celltype_name = "Cardiomyocytes"
sns.set(font_scale=1.5)
sc.pl.umap(adata_peak, color=[celltype_name, "Topic_25"], cmap="RdBu_r")

## 17. Visualising TF binding scores on peaks

In [None]:
## adding insilico chipseq embeddings to peak anndata


tf_names = rna_metacell.var[rna_metacell.var.gene_type == "TF"].index.values

# Create a dictionary for activator and repressor insilico-chipseq binding score
tf_binding_data = {
    tf_name + "_activator_binding": insilico_act[:, i].numpy()
    for i, tf_name in enumerate(tf_names)
}
tf_binding_data.update(
    {
        tf_name + "_repressor_binding": np.abs(insilico_rep[:, i].numpy())
        for i, tf_name in enumerate(tf_names)
    }
)

# Convert the dictionary to a DataFrame
tf_binding_data_df = pd.DataFrame(tf_binding_data, index=adata_peak.obs.index)

# Concatenate new columns with existing obs

In [None]:
adata_peak.obs = pd.concat([adata_peak.obs, tf_binding_data_df], axis=1)
adata_peak.obs = tf_binding_data_df

In [None]:
# visualsing TF binding scores on peak umap
sns.set(font_scale=1.5)

with plt.rc_context({"figure.figsize": (6, 6), "figure.dpi": (200)}):
    sc.pl.umap(
        adata_peak,
        color=[
            "Tbx5_activator_binding",
            "Gata1_activator_binding",
            "Ets1_activator_binding",
            "Sox2_activator_binding",
            "Hnf4a_activator_binding",
        ],
        cmap="Greens",
        vmin=0.1,
        vmax=0.8,
        sort_order=True,
    )

#### Computing eGRNs

## 18. Computing ATAC based GRNs with emprirical significance 
these GRNs do not use evidence of TF-gene co-expression 

activator GRNs here indicate if within a topic, peaks linked to a gene have accesible binding sites for a TF (from activator insilico-chipseq scores)

repressor GRNs here indicate if within a topic, peaks linked to a gene have non-accesible repressor binding sites for a TF (from repressor insilico-chipseq scores)

additionally we compute a background set of GRN values by shuffling insilico-chipseq scores, which are used to compute empirical significance 

In [None]:
grn_act_atac = compute_atac_grn_activator_with_significance(
    model, device, cutoff_val=0.05, outdir="grn_act_atac"
)

In [None]:
# ATAC based GRN for repressors
grn_rep_atac = compute_atac_grn_repressor_with_significance(
    model, device, cutoff_val=0.05, outdir="grn_act_atac"
)

## 19. Computing final GRNs

to compute these, we use the significant ATAC based GRNs derived previously and do element wise product with GRNs learnt by scDoRI incorproating TF - gene co-expression

In [None]:
# calculating TF-expression per topic
# either from scdori model weights or from true data
# using true expression here
tf_normalised = get_tf_expression(
    "True",
    model,
    device,
    all_dataset_loader,
    rna_metacell,
    atac_metacell,
    num_cells,
    tf_indices,
    onehot_batch,
    trainConfig,
)

In [None]:
tf_normalised.shape

In [None]:
# compute final GRNs which use the significant ATAC based GRNs derived above
grn_act, grn_rep = compute_significant_grn(
    model,
    device,
    cutoff_val_activator=0.05,
    cutoff_val_repressor=0.05,
    tf_normalised=tf_normalised.detach().cpu().numpy(),
    outdir="grn_act_atac",
)

In [None]:
# save regulons per TF
save_regulons(
    grn_act,
    tf_names=tf_names,
    gene_names=rna_metacell.var.index.values,
    num_topics=model.num_topics,
    output_dir="grn_act_atac",
    mode="activator",
)

In [None]:
# save regulons per TF
save_regulons(
    grn_rep,
    tf_names=tf_names,
    gene_names=rna_metacell.var.index.values,
    num_topics=model.num_topics,
    output_dir="grn_act_atac",
    mode="repressor",
)

In [None]:
# loading saved GRN
grn_act = np.load("grn_act_atac/grn_activator__0.05.npy")
grn_rep = np.load("grn_act_atac/grn_repressor__0.05.npy")

In [None]:
grn_act.shape  # num_topics x num_tfs x num_genes

## 20. Computing and plotting top activator TFs per topic

In [None]:
# plotting TF activity across topics
tf_names = rna_metacell.var[rna_metacell.var.gene_type == "TF"].index.values

# plot top k activators per topic
df_topic_activator, top_regulators = get_top_activators_per_topic(
    grn_act,
    tf_names,
    scdori_latent,
    selected_topics=None,
    top_k=2,
    clamp_value=1e-8,
    zscore=True,
    figsize=(25, 10),
    out_fig=None,
)

In [None]:
df_topic_activator  # matrix of Topic TF activities

## 21. Computing and plotting TF activity per cell

In [None]:
# computing TF activity per cell
cell_tf_act = compute_activator_tf_activity_per_cell(
    grn_act,
    tf_names,
    scdori_latent,
    selected_topics=None,
    clamp_value=1e-8,
    zscore=True,
)

In [None]:
# aggregating activity per celltype
df_celltype_tf = pd.DataFrame(cell_tf_act, columns=tf_names)
df_celltype_tf["celltype"] = rna_metacell.obs["celltype"].values
df_celltype_tf = df_celltype_tf.groupby("celltype").mean()
df_celltype_tf = df_celltype_tf.fillna(0)
df_celltype_tf = df_celltype_tf.loc[
    :, (df_celltype_tf != 0).any(axis=0)
]  # removing TF with 0/Nan activity

In [None]:
# top TFs per celltype
for k in df_celltype_tf.index:
    print(df_celltype_tf.loc[k].sort_values(ascending=False)[:5])

In [None]:
sns.clustermap(df_celltype_tf, cmap="RdBu_r", vmin=-4, vmax=4, figsize=(20, 15))

In [None]:
# visualing TF activity on UMAP
df_cell_tf = pd.DataFrame(cell_tf_act, columns=[s + "_activity" for s in tf_names])
df_cell_tf.index = rna_metacell.obs.index.values
obs_df = pd.concat([rna_metacell.obs, df_cell_tf], axis=1)

In [None]:
for k in df_cell_tf.columns:
    rna_metacell.obs[k] = df_cell_tf[k].values

In [None]:
sc.pl.umap(
    rna_metacell,
    color=[
        "celltype",
        "Gata1_activity",
        "Sox10_activity",
        "Gata6_activity",
        "Ets1_activity",
        "Hnf4a_activity",
        "T_activity",
        "Prrx1_activity",
        "Tbx5_activity",
        "Pou5f1_activity",
    ],
    vmin=-2,
    vmax=4,
    cmap="RdBu_r",
)

## 22. visualising downstream target genes of a TF

In [None]:
tf_plot = "Gata4"

tf_index = list(tf_names).index(tf_plot)
tf_index

In [None]:
df_topic_activator[tf_plot].sort_values(ascending=False)[:5]

##### 
since Gata4 is active in multiple topics ( and associated cellltype), we obtain different targets for it in the respective context

In [None]:
topic_num = [3, 19]  # endodermal topics
target_gene_idx = np.where(grn_act[topic_num, tf_index, :].sum(axis=0) > 0.0)[
    0
]  # adjust this value to get more stringent/ strongly regulated downstream taregts
genes_endoderm = rna_metacell.var_names[target_gene_idx]

In [None]:
genes_endoderm

In [None]:
topic_num = [25]  # cardiomyocyte specific topic
target_gene_idx = np.where(grn_act[topic_num, tf_index, :].sum(axis=0) > 0.0)[0]
genes_cardiomyocytes = rna_metacell.var_names[target_gene_idx]

In [None]:
genes_cardiomyocytes

In [None]:
# plotting downstream expression of target genes
rna_metacell.X = rna_metacell.layers["counts"]
# normalsing raw counts
sc.pp.normalize_total(rna_metacell)
sc.pp.log1p(rna_metacell)

sc.tl.score_genes(rna_metacell, genes_endoderm, score_name="Gata4_endoderm_target")
sc.tl.score_genes(
    rna_metacell, genes_cardiomyocytes, score_name="Gata4_cardiomyocyte_target"
)

In [None]:
# visualing cell-types on scDoRI computed UMAP
sns.set(font_scale=1)
sns.set_style("whitegrid")
with plt.rc_context({"figure.figsize": (7, 7), "figure.dpi": (200)}):
    sc.pl.umap(
        rna_metacell,
        color=[
            "celltype",
            "Gata4",
            "Gata4_endoderm_target",
            "Gata4_cardiomyocyte_target",
        ],
        add_outline=True,
        outline_color=("white", "black"),
        size=10,
        cmap="RdBu_r",
        legend_loc=None,
    )

##### 
we can clearly see that scDoRI finds differential downstream targets for the same TF in different contexts

we can visualise the TF binding profiles to confirm that these differences are coming from chromatin differences between states

In [None]:
# visualsing TF binding scores and topic values where TF is active, on peak umap
sns.set(font_scale=1)
# topic 3 and 19 are endoderm related and topic 25 is cardiomyocyte specific
with plt.rc_context({"figure.figsize": (3, 3), "figure.dpi": (150)}):
    sc.pl.umap(
        adata_peak, color=["Gata4_activator_binding"], cmap="Greens", sort_order=True
    )

In [None]:
# visualsing TF binding scores and topic values where TF is active, on peak umap
sns.set(font_scale=1)
# topic 3 and 19 are endoderm related and topic 25 is cardiomyocyte specific
with plt.rc_context({"figure.figsize": (3, 3), "figure.dpi": (200)}):
    sc.pl.umap(
        adata_peak,
        color=["Topic_3", "Topic_19", "Topic_25"],
        cmap="Greens",
        vmin=0.8,
        vmax=1,
        sort_order=True,
    )

##### 
we can see that Gata4 binds to regulatory regions associated with different topics and can regulate different set of genes in those topics respectively 

## 23. Repressor analysis

In [None]:
df_topic_repressor, top_regulators_repressor = get_top_repressor_per_topic(
    grn_rep,
    tf_names,
    scdori_latent,
    selected_topics=None,
    top_k=20,
    clamp_value=1e-8,
    zscore=True,
    figsize=(25, 10),
    out_fig=None,
)

In [None]:
cell_tf_rep = compute_repressor_tf_activity_per_cell(
    grn_rep,
    tf_names,
    scdori_latent,
    selected_topics=None,
    clamp_value=1e-8,
    zscore=True,
)

In [None]:
df_celltype_tf_rep = pd.DataFrame(cell_tf_rep, columns=tf_names)
df_celltype_tf_rep["celltype"] = rna_metacell.obs["celltype"].values
df_celltype_tf_rep = df_celltype_tf_rep.groupby("celltype").mean()

## 24. visualising enhancer gene links

In [None]:
# peaks gene links used by scdori
gene_peak = (model.gene_peak_factor_learnt.detach().cpu().numpy()) * (
    model.gene_peak_factor_fixed.detach().cpu().numpy()
)

In [None]:
gene_peak.shape

In [None]:
gene_name = "Tal1"
gene_index = list(rna_metacell.var_names).index(gene_name)

enhancers = np.where(gene_peak[gene_index, :] > 0.99)[
    0
]  # change this threshold to obtain more links
enhancers = atac_metacell.var_names[enhancers]

##### 
plotting accesibility of Tal1 enhancers across celltypes

In [None]:
peak_gene_celltype = peak_celltype_df.loc[enhancers]
peak_gene_celltype

In [None]:
sns.clustermap(peak_celltype_df.loc[enhancers], cmap="RdBu_r", vmin=-5, vmax=5)

#### 
plotting Tal1 expression and net accesibility of its predcited enhancers across celltypes

In [None]:
atac_metacell.X = atac_metacell.layers["counts"]
# normalsing raw counts
sc.pp.normalize_total(atac_metacell)
sc.tl.score_genes(atac_metacell, enhancers, score_name="Tal1_enhancer_accesibility")

In [None]:
rna_metacell.obs["Tal1_enhancer_accesibility"] = atac_metacell.obs[
    "Tal1_enhancer_accesibility"
].values

In [None]:
sns.set(font_scale=1)
sns.set_style("whitegrid")
with plt.rc_context({"figure.figsize": (7, 7), "figure.dpi": (200)}):
    sc.pl.umap(
        rna_metacell,
        color=["celltype", "Tal1", "Tal1_enhancer_accesibility"],
        add_outline=True,
        outline_color=("white", "black"),
        size=10,
        cmap="YlGnBu",
        vmin=0,
        vmax=1,
        legend_loc=None,
    )