Notebook for the analysis  and visualistion of sc-GRIP predictions

In [None]:
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.cluster import KMeans
from sklearn.cluster import SpectralClustering
from sklearn.cluster import DBSCAN
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import adjusted_rand_score
import numpy as np
import pandas as pd
from sklearn.linear_model import Lasso
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import ElasticNet
from sklearn.linear_model import ElasticNetCV
from sklearn.metrics import normalized_mutual_info_score
import re
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
    adjusted_rand_score,
    normalized_mutual_info_score,
    roc_curve,
    auc
)

# Creating necessary files from published datasets #

## read in adjacency matrix ##

In [None]:
species = "mouse"
celltype = "muscle"

df = pd.read_csv("data/trrust_rawdata."+species+".tsv",sep="\t",header=None,names=["gene_ids","target_genes","polarity","val"])

filtered_subset_df = df[df["polarity"] != "Unknown"]
filtered_subset_df = filtered_subset_df.drop_duplicates(subset=["gene_ids","target_genes"],keep=False)

subset_df =filtered_subset_df[["gene_ids","target_genes","polarity"]]

gene_counts = df['target_genes'].value_counts()
tf_counts = df['gene_ids'].value_counts()
adj_matrix = pd.crosstab(subset_df["gene_ids"], subset_df["target_genes"])

adj_matrix = (adj_matrix > 0).astype(int)

## read in single-cell matrix ##

In [None]:
adata = sc.read_h5ad("data/h5ad_files/"+species+"_"+celltype+".h5ad")

#depending on the mumber of genes you might have to filter for highly variable genes
sc.pp.highly_variable_genes(adata, flavor='seurat')
adata = adata[:, adata.var['highly_variable']].copy()


#depending on the adata.X datatype, chose the corresponding line to comment out
data = adata.X.toarray()
# data =adata.X

expr_df = pd.DataFrame(data=data, columns=adata.var_names,index=adata.obs_names.to_list())

In [None]:
edge_df = adj_matrix

tf_list = edge_df.index.to_list()
target_gene_list = edge_df.columns.to_list()
all_interacting_genes = tf_list + target_gene_list
common_genes = list(set(expr_df.columns.to_list()) & set(all_interacting_genes))
common_target_genes = list(set(expr_df.columns.to_list()) & set(target_gene_list))
common_tf_genes = list(set(expr_df.columns.to_list()) & set(tf_list))
all_common_genes = common_target_genes + common_tf_genes
new_expr_df = expr_df[common_genes]
new_edge_df = edge_df[common_target_genes]
new_edge_df = new_edge_df.loc[common_tf_genes]

expr_df = new_expr_df
edge_df = new_edge_df

## save files for later analyses ##

In [None]:
target_cluster =str(celltype + "_")

new_edge_df.to_csv("data/"+species+"_"+target_cluster+"tf_interaction_common.csv")
new_expr_df.to_csv("data/"+species+"_"+target_cluster+"gex_common.csv")

# Comparison analyses #

## Pearson Correlation ##

In [None]:
def compute_gene_correlations(df_expr, df_adj):
    data = []

    for gene1 in df_adj.index:
        for gene2 in df_adj.columns:
            if df_adj.loc[gene1, gene2] == 1:
                corr = np.corrcoef(df_expr[gene1], df_expr[gene2])[0, 1]
                data.append({
                    'TF': gene1,
                    'target': gene2,
                    'Correlation': corr
                })

    df_correlations = pd.DataFrame(data)
    return df_correlations

df_gene_correlations = compute_gene_correlations(new_expr_df,new_edge_df)
df_gene_correlations['edge'] = df_gene_correlations['TF'] + "->" + df_gene_correlations['target']
df_clean = df_gene_correlations.dropna()
df_gene_correlations.to_csv(species+"_"+celltype+"_correlations.csv")

## LASSO edge weights ##

In [None]:
from sklearn.linear_model import Lasso
from sklearn.preprocessing import StandardScaler

def compute_lasso_scores(df_expr, df_adj, alpha=0.01):
    data = []
    scaler = StandardScaler()
    df_expr_scaled = pd.DataFrame(scaler.fit_transform(df_expr), columns=df_expr.columns, index=df_expr.index)

    for target_gene in df_adj.columns:  # target
        tf_genes = df_adj.index[df_adj[target_gene] == 1].tolist()
        if not tf_genes:
            continue  # No TFs regulating this target gene

        X = df_expr_scaled[tf_genes]
        y = df_expr_scaled[target_gene]

        lasso = Lasso(alpha=alpha, max_iter=10000)
        lasso.fit(X, y)
        coefs = lasso.coef_

        for tf, coef in zip(tf_genes, coefs):
            if coef != 0:
                data.append({
                    'TF': tf,
                    'target': target_gene,
                    'Lasso_Coefficient': coef
                })

    df_lasso = pd.DataFrame(data)
    return df_lasso


# Visualising sc-GRIP predictions #

## Quantitative analysis ##

In [None]:
species = "mouse"
celltype = "lung"

true_df = subset_df
predicted_df = pd.read_csv("predictions/"+species+"_"+celltype+"_corr.csv")
true_df['edge'] = true_df['gene_ids'] + "->" + true_df['target_genes']
df = pd.merge(true_df, predicted_df, on='edge', how='inner')
col_label="activation_score_mean"

In [None]:
labels = ["Activation", "Repression"]

label_map = {"Activation": 1, "Repression": 0,"Unknown":0.5}
y_true_binary = df['polarity'].map(label_map)
y_scores = df[col_label]

fpr, tpr, thresholds = roc_curve(y_true_binary, y_scores)
roc_auc = auc(fpr, tpr)
print(roc_auc)

plt.figure()
plt.plot(fpr, tpr, color='navy', lw=2,linestyle="solid",label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color='red', lw=2, linestyle='--',label="Random chance (AUC=0.5)")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC curve ('+species+' dataset ('+celltype+")")
plt.legend(loc="lower right")
plt.tight_layout()
plt.show()

## Qualitative analysis ##

In [None]:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors,cm


def load_data(cluster_of_interest):
    df = pd.read_csv("/species_predictions.csv")
    print("number of edges (raw)", df.shape[0])
    print("number of edges (filtered)", df.shape[0])
    return df

def plot_in_out_degree_dist(G):
    in_degrees = [G.in_degree(n) for n in G.nodes()]
    out_degrees = [G.out_degree(n) for n in G.nodes()]

    # Plot histogram with transparency and separate colors
    plt.hist(in_degrees, bins=np.arange(0, max(in_degrees + out_degrees) + 2) - 0.5,
             alpha=0.6, color='blue', label='In-Degree')
    plt.hist(out_degrees, bins=np.arange(0, max(in_degrees + out_degrees) + 2) - 0.5,
             alpha=0.6, color='orange', label='Out-Degree')

    # Set x and y ticks (optional: customize based on actual data)
    plt.xticks(np.arange(0, max(in_degrees + out_degrees) + 1, 5))
    plt.yticks(np.arange(0, plt.gca().get_ylim()[1]+1, 5))

    plt.xlabel("Degree")
    plt.ylabel("Number of Nodes")
    plt.title("In-Degree and Out-Degree Distribution")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"plots/node_degree_distr_{cluster_of_interest}_{target_gene}.png", dpi=300, bbox_inches='tight')
    plt.close()

def plot_activation_score_distribution(df, target_gene, cluster_of_interest):
    print(df.head(10))
    col_name = "Correlation"
    subset_df = df[df['edge'].str.contains(target_gene)]
    plt.figure(figsize=(8,6))
    subset_df[col_name].hist(bins=20, color='skyblue', edgecolor='black')
    plt.xlabel('Activation Score')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.title(f"Activation score distribution cluster {cluster_of_interest}: {target_gene}")
    plt.savefig(f"plots/act_score_{cluster_of_interest}_{target_gene}.png", dpi=300, bbox_inches='tight')
    plt.close()
    print("Activation score distribution fig saved.")

def build_graph(df, threshold, focus_genes=None):
    G = nx.DiGraph()
    col_name = "Correlation"

    if focus_genes:
        for _, row in df.iterrows():
            target, tf = row['edge'].split("->")  # Swap order here
            if tf in focus_genes and target in focus_genes:
                if target == tf:
                    continue
                score = row[col_name]
                G.add_edge(tf, target, weight=score, label=f"{score:.2f}")
    else:
        for _, row in df.iterrows():
            target, tf = row['edge'].split("->")  # Swap order here
            score = row[col_name]
            G.add_edge(tf, target, weight=score, label=f"{score:.2f}")

    hub_nodes = [n for n in G.nodes if G.out_degree(n) >= threshold]
    filtered_edges = [(u, v, G[u][v]) for u, v in G.edges if u in hub_nodes]

    final_G = nx.DiGraph()
    final_G.add_edges_from([(u, v, d) for u, v, d in filtered_edges])
    return final_G


def map_genes(G):
    #if you're species of interest doesnt have a full annotation, you can you an excel sheet with gene_id, predicted_gene_sybmol to annotate you plot
    gene_names = {node: node.split(".")[-1] for node in G.nodes}
    subgraph_gene_names = {node: gene_names[node] for node in G.nodes if node in gene_names}

    # Load ortholog translation table
    translation_df = pd.read_excel("gene_symbol_annotation.xlsx")
    translation_dict = dict(zip(translation_df['Gene ID'], translation_df['Gene Symbol']))

    # Handle missing values
    for k, v in translation_dict.items():
        if pd.isna(v):
            translation_dict[k] = k
    node_ortholog_dict = {}
    for node, gene_name in subgraph_gene_names.items():
        node_ortholog_dict[node] = translation_dict.get(node, gene_name)
    return node_ortholog_dict
def plot_graph(G, node_ortholog_dict, cluster_of_interest, target_gene=None, save_fig=None):
    if len(G.nodes) == 0:
        print("No nodes to plot.")
        return

    plt.figure(figsize=(8, 6))
    pos = nx.spring_layout(G)

    source_nodes = {u for u, v in G.out_edges()}
    sink_nodes = set(G.nodes) - source_nodes
    edges = G.edges()

    weights = [G[u][v]['weight'] for u, v in edges]
    norm = colors.Normalize(vmin=0, vmax=1)
    cmap = plt.cm.coolwarm

    nx.draw_networkx_nodes(
        G, pos,
        nodelist=source_nodes,
        node_size=600,
        node_color="#4c6ef5"
    )
    nx.draw_networkx_nodes(
        G, pos,
        nodelist=sink_nodes,
        node_size=200,
        node_color="#f6a800"
    )
    nx.draw_networkx_labels(G, pos, labels=node_ortholog_dict, font_size=10)

    nx.draw_networkx_edges(
        G, pos, edgelist=edges,
        edge_color=weights,
        edge_cmap=cmap,
        edge_vmin=0, edge_vmax=1,
        width=1,
        arrows=True,
        arrowsize=10,
        arrowstyle='-|>'
    )
    # Add colorbar (legend for edge weights)
    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])  # Workaround to satisfy colorbar requirement
    cbar = plt.colorbar(sm, ax=plt.gca(), shrink=0.6)
    cbar.set_label("Predicted polarity score", fontsize=10)
    cbar.ax.tick_params(labelsize=10)

    # Remove edge labels
    nx.draw_networkx_edge_labels(G, pos, edge_labels={}, font_size=10)

    if target_gene:
        if save_fig:
            filename = f"plots/interaction_graph_{cluster_of_interest}_{save_fig}.png"
        else:
            filename = f"plots/interaction_graph_{cluster_of_interest}_{target_gene}.png"
    else:
        filename = f"plots/interaction_graph_{cluster_of_interest}_full.png"

    plt.axis("off")
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved graph to {filename}")


# === RUNNING SECTION ===

cluster_of_interest = "all"
threshold = 1

if target_genes:
    # Keep only edges where BOTH source and target are in target_genes
    df[['source', 'target']] = df['edge'].str.split('->', expand=True)
    df_filtered = df[df['source'].isin(target_genes) & df['target'].isin(target_genes)]
    df_filtered = df_filtered.head(n_rows)
    print("number of edges (between target genes only):", df_filtered.shape[0])
    G = build_graph(df_filtered, threshold=threshold, focus_genes=target_genes)
    target_gene = "_".join(target_genes)  # For plot label
else:
    df = df.head(n_rows)
    G = build_graph(df, threshold=threshold)


#if you're species of interest doesnt have a full annotation, you can you an 
# excel sheet with gene_id, predicted_gene_sybmol to annotate you plot
# otherwise, the map_genes funciton is not necessary
node_ortholog_dict = map_genes(G)
title = "sc-grip_plot"


plot_graph(G, node_ortholog_dict, cluster_of_interest, target_gene=target_gene,save_fig = title)