Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,65 @@
Recent advances in single-cell RNA sequencing (scRNA-seq) techniques have provided unprecedented insights into tissue heterogeneity. However, gene expression data alone often fails to capture changes in cellular pathways and complexes, which are more discernible at the protein level. Additionally, analyzing scRNA-seq data presents challenges due to high noise levels and zero inflation. In this study, we propose a novel approach to address these limitations by integrating scRNA-seq datasets with a protein-protein interaction (PPI) network. Our method employs a unique bi-graph architecture based on graph neural networks (GNNs), enabling the joint representation of gene expression and PPI network data. This approach models gene-to-gene relationships under specific biological contexts and refines cell-cell relations using an attention mechanism, resulting in new gene and cell embeddings. We provide comprehensive evaluations to demonstrate the effectiveness of our method.

![Overview of the scNET Method](images/scNET.jpg)
## Download via Git
## Download via PIP
`pip install scnet`

## Download via git
To clone the repository, use the following command:
git clone https://github.com/madilabcode/scNET
`git clone https://github.com/madilabcode/scNET`

We recommend using the provided Conda environment located at ./Data/scNET-env.yaml.
cd scNET
conda env create -f ./Data/scNET-env.yaml

## Tutorial
## import scNET
`import scNET`

## API
To train scNET on scRNA-seq data, first load an AnnData object using Scanpy, then initialize training with the following command:

`scNET.run_scNET(obj, pre_processing_flag=False, human_flag=False, number_of_batches=3, split_cells= True, max_epoch=250, model_name = project_name)`

with the following args:

obj (AnnData, optional): AnnData obj.

pre_processing_flag (bool, optional): If True, perform pre-processing steps.

human_flag (bool, optional): Controls gene name casing in the network.

number_of_batches (int, optional): Number of mini-batches for the training.

split_cells (bool, optional): If True, split by cells instead of edges during training.

n_neighbors (int, optional): Number of neighbors for building the adjacency graph.

### We recommend using the Google Colab framework for running scNET. Our method works with a Scanpy AnnData object and provides the following outputs:
max_epoch (int, optional): Max number of epochs for model training.

New cell embedding,
New gene embedding,
The trained model,
Pruned KNN network
model_name (str, optional): Identifier for saving the model outputs.

save_model_flag (bool, optional): If True, save the trained model.


Retrieve embeddings and model outputs with:

`embedded_genes, embedded_cells, node_features , out_features = scNET.load_embeddings(project_name)`

where:
- embedded_genes (np.ndarray): Learned gene embeddings.
- embedded_cells (np.ndarray): Learned cell embeddings.
- node_features (pd.DataFrame): Original gene expression matrix.
- out_features (np.ndarray): Reconstructed gene expression matrix


Create a new AnnData object using model outputs:

`recon_obj = scNET.create_reconstructed_obj(node_features, out_features, obj)`

Construct a co-embedded network using the gene embeddings:
`scNET.build_co_embeded_network(embedded_genes, node_features)`
## Tutorial

### For a basic usage example of our framework, please refer to the following notebook:
[scNET Example Notebook](https://colab.research.google.com/github/madilabcode/scNET/blob/main/scNET.ipynb)
The provided tutorial includes instructions on how to clone the Git repository to your Google Drive, run the model, and load the outputs.

1,071 changes: 381 additions & 690 deletions scNET.ipynb

Large diffs are not rendered by default.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
8 changes: 3 additions & 5 deletions MultyGraphModel.py → scNET/MultyGraphModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from torch_geometric.utils import negative_sampling
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax
import matplotlib.pyplot as plt
import seaborn as sns
import math
import numpy as np
import pandas as pd
Expand All @@ -34,9 +32,9 @@ def forward(self, z):
out = self.decoder(z)
return out

class MutaelEncoder(torch.nn.Module):
class MutualEncoder(torch.nn.Module):
def __init__(self,col_dim, row_dim,num_layers=4, drop_p = 0.25):
super(MutaelEncoder, self).__init__()
super(MutualEncoder, self).__init__()
self.col_dim = col_dim
self.row_dim = row_dim
self.num_layers = num_layers
Expand Down Expand Up @@ -168,7 +166,7 @@ def __init__(self,col_dim, row_dim,inter_row_dim, embd_row_dim, inter_col_dim,em
self.lambda_cols = lambda_cols


self.encoder = MutaelEncoder(col_dim, row_dim,num_layers, drop_p)
self.encoder = MutualEncoder(col_dim, row_dim,num_layers, drop_p)
self.rows_encoder = DimEncoder(row_dim, inter_row_dim, embd_row_dim,drop_p = drop_p, scale_param=None, reducer=False)

self.cols_encoder = DimEncoder(col_dim, inter_col_dim, embd_col_dim,drop_p=drop_p, reducer=True)
Expand Down
67 changes: 59 additions & 8 deletions Utils.py → scNET/Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import average_precision_score
import pickle
import pkg_resources

alpha = 0.9
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epsilon = 0.0001


def save_obj(obj, name):
with open(name + '.pkl', 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
Expand Down Expand Up @@ -86,7 +88,6 @@ def wilcoxon_enrcment_test(up_sig, down_sig, exp):
backround_exp = exp.loc[exp.index.isin(down_sig)]

rank = ranksums(backround_exp,gene_exp,alternative="less")[1] # rank expression of up sig higher than backround
rank = 1 if rank > 0.05 else rank
return -1 * np.log(rank)


Expand All @@ -110,10 +111,37 @@ def signature_values(exp, up_sig, down_sig=None):
return exp.apply(lambda cell: wilcoxon_enrcment_test(up_sig, down_sig, cell), axis=0)

def run_signature(obj, up_sig, down_sig=None, umap_flag = True, alpha = 0.9,prop_exp = None):
exp = obj.raw.to_adata().to_df().T
"""
Calculate and visualize a propagated signature score for cells in the given object.
Parameters
----------
obj : AnnData
The annotated data object containing gene expression matrix and graph data.
up_sig : list or set
A collection of genes used to calculate the up-regulated signature score.
down_sig : list or set, optional
A collection of genes used to calculate the down-regulated signature score.
If None, only the up-regulated signature is used. Default is None.
umap_flag : bool, optional
If True, generates a UMAP plot colored by the calculated signature score.
If False, generates a t-SNE plot. Default is True.
alpha : float, optional
A parameter controlling the smoothing or propagation factor during signature
score calculation. Default is 0.9.
prop_exp : None or other, optional
An unused parameter placeholder, reserved for future use or extended
signature propagation functionality.
Returns
-------
np.ndarray
An array of propagated signature scores, with one score per cell. The
scores are also stored in obj.obs["SigScore"].
"""

exp = obj.to_df().T
graph = obj.obsp["connectivities"].toarray()
prop_exp = propagate_all_genes(graph, exp)
sigs_scores = signature_values(prop_exp, up_sig, down_sig)
sigs_scores = signature_values(exp, up_sig, down_sig)
sigs_scores = propagation(sigs_scores, graph)
obj.obs["SigScore"] = sigs_scores
# color_map = "jet"
if umap_flag:
Expand All @@ -129,6 +157,13 @@ def calculate_roc_auc(idents, predict):
def calculate_aupr(idents, predict):
return average_precision_score(idents, predict)

def calculate_roc_auc(idents, predict):
fpr, tpr, _ = roc_curve(idents, predict, pos_label=1)
return auc(fpr, tpr)

def calculate_aupr(idents, predict):
return average_precision_score(idents, predict)

# ---------------------------
# Y - scores vector of cells
# W - Adjacency matrix
Expand Down Expand Up @@ -187,7 +222,23 @@ def crate_anndata(path, pcs = 15,neighbors = 30):
def save_model(path, model):
torch.save(model.state_dict(), path)

def load_model(path,node_feature,net):
model = VGAE(GAEncoder(node_feature.shape[1], 350, 100), GAEDncoder(100,350)).to(device)
state = torch.load(path)
model.load_state_dict(state)

def load_embeddings(proj_name):
'''
Loads the embeddings and gene expression data for a given project.

Args:
proj_name (str): The name of the project.

Returns:
tuple: A tuple containing:
- embedded_genes (np.ndarray): Learned gene embeddings.
- embedded_cells (np.ndarray): Learned cell embeddings.
- node_features (pd.DataFrame): Original gene expression matrix.
- out_features (np.ndarray): Reconstructed gene expression matrix.
'''
embeded_genes = load_obj(pkg_resources.resource_filename(__name__,r"./Embedding/row_embedding_" + proj_name))
embeded_cells = load_obj(pkg_resources.resource_filename(__name__,r"./Embedding/col_embedding_" + proj_name))
node_features = pd.read_csv(pkg_resources.resource_filename(__name__,r"./Embedding/node_features_" + proj_name),index_col=0)
out_features = load_obj(pkg_resources.resource_filename(__name__,r"./Embedding/out_features_" + proj_name))
return embeded_genes, embeded_cells, node_features, out_features
6 changes: 6 additions & 0 deletions scNET/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .main import run_scNET
from .Utils import load_embeddings, propagation, run_signature
from .coEmbeddedNetwork import build_co_embeded_network, create_reconstructed_obj, pathway_enricment, test_KEGG_prediction, plot_de_pathways
from scNET.MultyGraphModel import scNET

__all__ = ['run_scNET', 'load_embeddings', 'build_co_embeded_network', 'scNET', 'create_reconstructed_obj', "test_KEGG_prediction", "pathway_enricment", "plot_de_pathways", "propagation", "run_signature"]
62 changes: 22 additions & 40 deletions coEmbeddedNetwork.py → scNET/coEmbeddedNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import torch
from sklearn.cluster import KMeans
#import umap.plot
import umap.umap_
import networkx as nx
from networkx.algorithms import community
import networkx.algorithms.community as nx_comm
from sklearn.metrics import precision_recall_curve, auc
import Utils as ut
import scNET.Utils as ut
import gseapy as gp
import os

Expand All @@ -20,6 +19,7 @@

import gseapy as gp
import warnings
import pkg_resources
warnings.filterwarnings('ignore')


Expand Down Expand Up @@ -57,27 +57,6 @@
'29': '#17becf'
}

def load_embeddings(proj_name):
'''
Loads the embeddings and gene expression data for a given project.

Args:
proj_name (str): The name of the project.

Returns:
tuple: A tuple containing:
- embedded_genes (np.ndarray): Learned gene embeddings.
- embedded_cells (np.ndarray): Learned cell embeddings.
- node_features (pd.DataFrame): Original gene expression matrix.
- out_features (np.ndarray): Reconstructed gene expression matrix.
'''
embeded_genes = ut.load_obj(r"./Embedding/row_embedding_" + proj_name)
embeded_cells = ut.load_obj(r"./Embedding/col_embedding_" + proj_name)
node_features = pd.read_csv(r"./Embedding/node_features_" + proj_name,index_col=0)
out_features = ut.load_obj(r"./Embedding/out_features_" + proj_name)
return embeded_genes, embeded_cells, node_features, out_features


def create_reconstructed_obj(node_features, out_features, orignal_obj=None):
'''
Creates an AnnData object from reconstructed gene expression data, normalizes it, and computes PCA, neighbors, clustering, and UMAP.
Expand Down Expand Up @@ -105,8 +84,8 @@ def create_reconstructed_obj(node_features, out_features, orignal_obj=None):
return adata


def cal_marker_gene_aupr(adata, marker_genes=['Cd4', 'Cd8a', 'Cd14',"P2ry12","Ncr1"]\
, cell_types=[['CD4 Tcells'], ["CD8 Tcells","NK"], ['Macrophages'], ['Microglia'],["NK"]]):
def calculate_marker_gene_aupr(adata, marker_genes=['Cd4','Cd14',"P2ry12","Ncr1"]\
, cell_types=[['CD4 Tcells'], ['Macrophages'], ['Microglia'],["NK"]]):
'''
Calculates the Area Under the Precision-Recall curve (AUPR) for specified marker genes in identifying specific cell types.

Expand All @@ -129,14 +108,15 @@ def cal_marker_gene_aupr(adata, marker_genes=['Cd4', 'Cd8a', 'Cd14',"P2ry12","Nc
print(f"AUPR for {marker_gene} in identifying {cell_type[0]}: {aupr:.4f}")


def pathway_enricment(adata, groupby="seurat_clusters", groups=None):
def pathway_enricment(adata, groupby="seurat_clusters", groups=None, gene_sets=None):
'''
Performs pathway enrichment analysis using KEGG pathways for differentially expressed genes in specific groups.

Args:
adata (AnnData): The annotated data matrix (AnnData object) containing gene expression data and cell clustering/grouping information.
groupby (str, optional): The key in `adata.obs` to group cells by for differential expression analysis. Defaults to "seurat_clusters".
groups (list, optional): A list of specific groups (clusters or cell types) to analyze. If None, all unique groups in `adata.obs[groupby]` are used. Defaults to None.
gene_sets (dict, optional): A dictionary of gene sets to use for pathway enrichment analysis. If None, the KEGG 2021 Human gene sets are used. Defaults to None.

Returns:
tuple: A tuple containing:
Expand All @@ -152,12 +132,13 @@ def pathway_enricment(adata, groupby="seurat_clusters", groups=None):
- Pathways with adjusted p-values below 0.05 are considered significant.
'''
adata.var.index = adata.var.index.str.upper()
kegg_gene_sets = gp.get_library('KEGG_2021_Human')
if gene_sets is None:
gene_sets = gp.get_library('KEGG_2021_Human')

filtered_kegg = {pathway: [gene for gene in genes if gene in adata.var.index]
for pathway, genes in kegg_gene_sets.items()}
filtered_gene_set = {pathway: [gene for gene in genes if gene in adata.var.index]
for pathway, genes in gene_sets.items()}

filtered_kegg = {pathway: genes + ["t1"] for pathway, genes in filtered_kegg.items() if len(genes) > 0}
filtered_gene_set = {pathway: genes for pathway, genes in filtered_gene_set.items() if len(genes) > 0}


if groups is None:
Expand All @@ -180,9 +161,9 @@ def pathway_enricment(adata, groupby="seurat_clusters", groups=None):

try:
genes = genes['names'].values
enr = gp.enrichr(gene_list=(genes.tolist() + ["t1"]),
gene_sets=filtered_kegg,
background=list(adata.var.index) + ["t1"],
enr = gp.enrichr(gene_list=(genes.tolist()),
gene_sets=filtered_gene_set,
background=list(adata.var.index),
organism='Human',
outdir=None)
except:
Expand All @@ -194,16 +175,17 @@ def pathway_enricment(adata, groupby="seurat_clusters", groups=None):
significant_pathways[group] = significant[['Term', 'Adjusted P-value']]


return de_genes_per_group, significant_pathways, filtered_kegg , enrichment_results
return de_genes_per_group, significant_pathways, filtered_gene_set , enrichment_results


def plot_de_pathways(significant_pathways,enrichment_results):
def plot_de_pathways(significant_pathways,enrichment_results, head=20):
'''
Plots a heatmap of the -log10(Adjusted P-value) for significant pathways across multiple datasets.

Args:
significant_pathways (dict): A dictionary where keys are dataset names (or groups), and values are DataFrames containing significant pathways and their adjusted p-values.
enrichment_results (dict): A dictionary where keys are dataset names (or groups), and values are DataFrames containing full pathway enrichment results, including adjusted p-values for each pathway.
head (int, optional): The number of top pathways to display in the heatmap. Defaults to 20.

Returns:
None: The function generates and displays a heatmap showing the significance (-log10(Adjusted P-value)) of the top 20 pathways across different datasets.
Expand All @@ -214,7 +196,7 @@ def plot_de_pathways(significant_pathways,enrichment_results):
combined_df = pd.DataFrame()

for _, df in enrichment_results.items():
top5_df = df.sort_values(by='Adjusted P-value').head(20)
top5_df = df.sort_values(by='Adjusted P-value').head(head)
for dataset_name, df2 in enrichment_results.items():
df2 = df2.loc[df2.Term.isin(top5_df.Term)]
df2['Dataset'] = dataset_name
Expand Down Expand Up @@ -249,13 +231,13 @@ def plot_gene_umap_clustring(embedded_rows):
return means_embedd.labels_


def build_co_embeded_network(embedded_rows,node_fetures, threshold=99):
def build_co_embeded_network(embedded_rows ,node_features,threshold=99):
'''
Builds a co-embedded network from the given embedded rows using a correlation-based thresholding approach and detects communities using the Louvain algorithm.

Args:
embedded_rows (np.ndarray): A matrix of embeddings (e.g., gene embeddings) where each row corresponds to an entity (e.g., gene or cell).
node_fetures (pd.DataFrame): A DataFrame containing features or identifiers for the nodes, where the index corresponds to the entities in `embedded_rows`.
node_features (pd.DataFrame): A DataFrame containing features or identifiers for the nodes, where the index corresponds to the entities in `embedded_rows`.
threshold (int, optional): The percentile threshold to use when binarizing the correlation matrix. Defaults to 99.

Returns:
Expand All @@ -278,7 +260,7 @@ def build_co_embeded_network(embedded_rows,node_fetures, threshold=99):
graph = nx.from_numpy_array(mat)
comm = nx_comm.louvain_communities(graph,resolution=1, seed=42)
mod = nx_comm.modularity(graph, comm)
map_nodes = {list(graph.nodes)[i]:node_fetures.index[i] for i in range(len(node_fetures.index))}
map_nodes = {list(graph.nodes)[i]:node_features.index[i] for i in range(len(node_features.index))}
graph = nx.relabel_nodes(graph,map_nodes)
return graph, mod

Expand Down Expand Up @@ -433,7 +415,7 @@ def make_term_predication(graphs, term_vec):
result_aupr.append([calculate_aupr(pred , term_vec, test_vec)])
return result_aupr

def predict_kegg(gene_embedding, ref):
def test_KEGG_prediction(gene_embedding, ref):
'''
Predicts KEGG pathway memberships using gene embeddings and reference data, and evaluates the performance using AUPR.

Expand Down
Loading