In [None]:
! pip install pyro-ppl scanpy gseapy

Collecting gseapy
  Downloading gseapy-1.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (552 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m552.9/552.9 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: gseapy
Successfully installed gseapy-1.1.3


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pyro
import pyro.distributions as dist
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import LabelEncoder
import scanpy as sc
from collections import defaultdict
from itertools import combinations
import numpy as np

import os
import requests
import os
from pyro.infer import SVI,MCMC, NUTS,TraceMeanField_ELBO
import math
from tqdm import trange
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster.hierarchy import linkage
import gseapy as gp

In [None]:
class Encoder(nn.Module):
    # Base class for the encoder net, used in the guide
    def __init__(self, vocab_size, num_topics, hidden, dropout):
        super().__init__()
        self.drop = nn.Dropout(dropout)  # to avoid component collapse
        self.fc1 = nn.Linear(vocab_size, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fcmu = nn.Linear(hidden, num_topics)
        self.fclv = nn.Linear(hidden, num_topics)
        # NB: here we set `affine=False` to reduce the number of learning parameters
        # See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
        # for the effect of this flag in BatchNorm1d
        self.bnmu = nn.BatchNorm1d(num_topics, affine=False)  # to avoid component collapse
        self.bnlv = nn.BatchNorm1d(num_topics, affine=False)  # to avoid component collapse

    def forward(self, inputs):
        h = F.softplus(self.fc1(inputs))
        h = F.softplus(self.fc2(h))
        h = self.drop(h)
        # μ and Σ are the outputs
        logtheta_loc = self.bnmu(self.fcmu(h))
        logtheta_logvar = self.bnlv(self.fclv(h))
        logtheta_scale = (0.5 * logtheta_logvar).exp()  # Enforces positivity
        return logtheta_loc, logtheta_scale


class Decoder(nn.Module):
    # Base class for the decoder net, used in the model
    def __init__(self, vocab_size, num_topics, dropout):
        super().__init__()
        self.beta = nn.Linear(num_topics, vocab_size, bias=False)
        self.bn = nn.BatchNorm1d(vocab_size, affine=False)
        self.drop = nn.Dropout(dropout)

    def forward(self, inputs):
        inputs = self.drop(inputs)
        # the output is σ(βθ)
        return F.softmax(self.bn(self.beta(inputs)), dim=1)


class ProdLDA(nn.Module):
    def __init__(self, vocab_size, num_topics, hidden, dropout):
        super().__init__()
        self.vocab_size = vocab_size
        self.num_topics = num_topics
        self.encoder = Encoder(vocab_size, num_topics, hidden, dropout)
        self.decoder = Decoder(vocab_size, num_topics, dropout)

    def model(self, docs):
        pyro.module("decoder", self.decoder)
        with pyro.plate("documents", docs.shape[0]):
            # Dirichlet prior 𝑝(𝜃|𝛼) is replaced by a logistic-normal distribution
            logtheta_loc = docs.new_zeros((docs.shape[0], self.num_topics))
            logtheta_scale = docs.new_ones((docs.shape[0], self.num_topics))
            logtheta = pyro.sample(
                "logtheta", dist.Normal(logtheta_loc, logtheta_scale).to_event(1))
            theta = F.softmax(logtheta, -1)

            # conditional distribution of 𝑤𝑛 is defined as
            # 𝑤𝑛|𝛽,𝜃 ~ Categorical(𝜎(𝛽𝜃))
            count_param = self.decoder(theta)
            # Currently, PyTorch Multinomial requires `total_count` to be homogeneous.
            # Because the numbers of words across documents can vary,
            # we will use the maximum count accross documents here.
            # This does not affect the result because Multinomial.log_prob does
            # not require `total_count` to evaluate the log probability.
            total_count = int(docs.sum(-1).max())
            pyro.sample(
                'obs',
                dist.Multinomial(total_count, count_param),
                obs=docs
            )

    def guide(self, docs):
        pyro.module("encoder", self.encoder)
        with pyro.plate("documents", docs.shape[0]):
            # Dirichlet prior 𝑝(𝜃|𝛼) is replaced by a logistic-normal distribution,
            # where μ and Σ are the encoder network outputs
            logtheta_loc, logtheta_scale = self.encoder(docs)
            logtheta = pyro.sample(
                "logtheta", dist.Normal(logtheta_loc, logtheta_scale).to_event(1))

    def beta(self):
        # beta matrix elements are the weights of the FC layer on the decoder
        return self.decoder.beta.weight.cpu().detach().T

In [None]:
TSP = sc.read_h5ad('/content/drive/MyDrive/CS273B/TSP_SS2.h5ad')
sc.pp.highly_variable_genes(TSP)
TSP = TSP[:, TSP.var.highly_variable]

In [None]:
# So we will go ahead and set up our adata.X to work in this.
# basically we want our input matrix as a torch of each cell as a row and each gene as a column
# this should be exactly how our TSP.X comes out.

# Also we want integer counts so we need to use raw...

In [None]:
#import torch
from scipy.sparse import csr_matrix

In [None]:
# First things first let's cast the raw counts as a tensor.

dense_array =TSP.raw.X.toarray()
dense_tensor = torch.tensor(dense_array)

In [None]:
# setting global variables
seed = 0
torch.manual_seed(seed)
pyro.set_rng_seed(seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_topics = 20 #if not smoke_test else 3
docs = dense_tensor.float().to(device)
batch_size = 32
learning_rate = 1e-3
num_epochs = 50# if not smoke_test else 1

In [None]:
# training
pyro.clear_param_store()

prodLDA = ProdLDA(
    vocab_size=docs.shape[1],
    num_topics=num_topics,
    hidden=100,# if not smoke_test else 10,
    dropout=0.2
)
prodLDA.to(device)

optimizer = pyro.optim.Adam({"lr": learning_rate})
svi = SVI(prodLDA.model, prodLDA.guide, optimizer, loss=TraceMeanField_ELBO())
num_batches = int(math.ceil(docs.shape[0] / batch_size))# if not smoke_test else 1

bar = trange(num_epochs)
for epoch in bar:
    running_loss = 0.0
    for i in range(num_batches):
        batch_docs = docs[i * batch_size:(i + 1) * batch_size, :]
        loss = svi.step(batch_docs)
        running_loss += loss / batch_docs.size(0)

    bar.set_postfix(epoch_loss='{:.2e}'.format(running_loss))

100%|██████████| 50/50 [07:06<00:00,  8.53s/it, epoch_loss=2.00e+09]


In [None]:
def calculate_perplexity(model, test_docs, num_topics):
    model.eval()
    log_likelihood = 0.0
    num_words = test_docs.sum().item()

    with torch.no_grad():
        for i in range(test_docs.shape[0]):
            doc = test_docs[i:i+1]
            logtheta_loc, logtheta_scale = model.encoder(doc)
            logtheta_scale = logtheta_scale + 1e-6  # Ensure strictly positive values
            logtheta = pyro.sample("logtheta", dist.Normal(logtheta_loc, logtheta_scale).to_event(1))
            theta = F.softmax(logtheta, -1)

            count_param = model.decoder(theta)
            total_count = int(doc.sum(-1).max())
            log_prob = dist.Multinomial(total_count, count_param).log_prob(doc)
            log_likelihood += log_prob.item()

    per_word_perplexity = np.exp(-log_likelihood / num_words)
    return per_word_perplexity

# Assuming test_docs is your test dataset
test_docs = docs  # Your test dataset
perplexity = calculate_perplexity(prodLDA, test_docs, num_topics)
print("Perplexity:", perplexity)

Perplexity: 35.228237258803574


In [None]:
# Import data from MsigDB
url = "https://data.broadinstitute.org/gsea-msigdb/msigdb/release/7.5.1/c5.all.v7.5.1.symbols.gmt"

response = requests.get(url)
gmt_content = response.text

# Save the file
gmt_file_path = "h.all.v7.5.1.symbols.gmt"
with open(gmt_file_path, "w") as gmt_file:
    gmt_file.write(gmt_content)

# Function to parse the .gmt file
def parse_gmt(file_path):
    gene_sets = []
    with open(file_path, "r") as file:
        for line in file:
            parts = line.strip().split("\t")
            gene_set_name = parts[0]
            genes = parts[2:]
            gene_sets.append(genes)
    return gene_sets

# Parse the downloaded .gmt file
gene_sets = parse_gmt(gmt_file_path)

In [None]:
# Build a cooccurence network of the genes from the MsigDB gene sets
from collections import defaultdict

def build_cooccurrence_network(gene_sets):
    cooccurrence = defaultdict(set)
    for gene_set in gene_sets:
        for gene1 in gene_set:
            for gene2 in gene_set:
                if gene1 != gene2:
                    cooccurrence[gene1].add(gene2)
    return cooccurrence

cooccurrence_network = build_cooccurrence_network(gene_sets)

In [None]:
# now I need to extract our topics and top genes
def get_top_genes(beta_matrix, feature_names, top_n=10):
    top_genes = []
    for topic_idx, topic in enumerate(beta_matrix):
        top_gene_indices = topic.argsort()[-top_n:][::-1]
        top_genes.append([feature_names[i] for i in top_gene_indices])
    return top_genes

In [None]:
# generate a beta matrix which is topicsxgenes and load in the gene names from our anndata
beta_matrix = prodLDA.beta().cpu().detach().numpy()
feature_names = TSP.var_names.tolist()
top_genes = get_top_genes(beta_matrix, feature_names)

In [None]:
from itertools import combinations

In [None]:
def calculate_gsea_coherence(top_genes, cooccurrence_network):
    coherence_scores = []
    for topic in top_genes:
        score = 0
        pairs_count = 0
        for gene1, gene2 in combinations(topic, 2):
            if gene2 in cooccurrence_network[gene1]:
                score += 1
            pairs_count += 1
        coherence_scores.append(score / pairs_count if pairs_count > 0 else 0)
    return np.mean(coherence_scores)

# Calculate coherence score based on MsigDB gene sets
gsea_coherence_score = calculate_gsea_coherence(top_genes, cooccurrence_network)
print("GSEA Coherence Score:", gsea_coherence_score)

GSEA Coherence Score: 0.5055555555555555


In [None]:
b = prodLDA.beta()[0]

In [None]:
sorted_, indices = torch.sort(b, descending=True)

In [None]:
df = pd.DataFrame(indices[:50].numpy(), columns=['index'])

In [None]:
df['gene_name'] = df['index'].apply(lambda x: TSP.var.iloc[x].gene_symbol)

In [None]:
gene_list = df['gene_name'].tolist()

In [None]:
enrichr_libraries = ['GO_Biological_Process_2021']#, 'GO_Cellular_Component_2021', 'GO_Molecular_Function_2021']

# Run Enrichr
enrich_results = gp.enrichr(gene_list=gene_list,
                            gene_sets=enrichr_libraries,
                            organism='Human'  # Change to your organism)  # Set to True if you want to generate plots
                           )

In [None]:
enrich_results.results[enrich_results.results['Adjusted P-value'] < 0.001].sort_values(by=['Combined Score'],ascending = False)

Unnamed: 0,Gene_set,Term,Overlap,P-value,Adjusted P-value,Old P-value,Old Adjusted P-value,Odds Ratio,Combined Score,Genes
0,GO_Biological_Process_2021,muscle contraction (GO:0006936),7/129,3.126094e-08,1.7e-05,0,0,26.457491,457.209168,ACTA2;TPM2;LMOD1;MYH11;CRYAB;ACTG2;MYLK


In [None]:
betas = prodLDA.beta()

# Initialize DataFrame to store results
results_df = pd.DataFrame(columns=['Module', 'Top_Term', 'Top_Genes'])

for i in range(betas.shape[0]):  # Loop over each list/module
    b = betas[i]
    sorted_, indices = torch.sort(b, descending=True)
    df = pd.DataFrame(indices[:50].numpy(), columns=['index'])
    df['gene_name'] = df['index'].apply(lambda x: TSP.var.iloc[x].gene_symbol)
    gene_list = df['gene_name'].tolist()

    # Define the gene sets and organism
    enrichr_libraries = ['GO_Biological_Process_2021']

    # Run Enrichr
    enrich_results = gp.enrichr(gene_list=gene_list,
                                gene_sets=enrichr_libraries,
                                organism='Human')

    # Filter and sort the results for significant and top ranked entries
    filtered_sorted = enrich_results.results[(enrich_results.results['Adjusted P-value'] < 0.001)]
    filtered_sorted = filtered_sorted.sort_values(by=['Combined Score'], ascending=False)

    # Check if there are any significant results
    if not filtered_sorted.empty:
        top_term = filtered_sorted.iloc[0]['Term']
        top_genes = filtered_sorted.iloc[0]['Genes']
        # Create a temporary DataFrame for the current iteration
        temp_df = pd.DataFrame({'Module': [i], 'Top_Term': [top_term], 'Top_Genes': [top_genes]})
        # Concatenate the temporary DataFrame to the main results DataFrame
        results_df = pd.concat([results_df, temp_df], ignore_index=True)

# Display or save the results dataframe
print(results_df)

   Module                                           Top_Term  \
0       0                    muscle contraction (GO:0006936)   
1       1                 leukocyte aggregation (GO:0070486)   
2       3         actin-myosin filament sliding (GO:0033275)   
3       4                    muscle contraction (GO:0006936)   
4       5                    retina homeostasis (GO:0001895)   
5       6                    retina homeostasis (GO:0001895)   
6       8                    muscle contraction (GO:0006936)   
7       9                    myofibril assembly (GO:0030239)   
8      10          keratinocyte differentiation (GO:0030216)   
9      12  chemokine-mediated signaling pathway (GO:0070098)   
10     17   negative regulation of fibrinolysis (GO:0051918)   
11     18  positive regulation of natural killer cell che...   

                                         Top_Genes  
0          ACTA2;TPM2;LMOD1;MYH11;CRYAB;ACTG2;MYLK  
1                               IL1B;S100A9;S100A8  
2       

In [None]:
unique_terms_df = results_df.drop_duplicates(subset=['Top_Term'], keep='first')

In [None]:
unique_terms_df

Unnamed: 0,Module,Top_Term,Top_Genes
0,0,muscle contraction (GO:0006936),ACTA2;TPM2;LMOD1;MYH11;CRYAB;ACTG2;MYLK
1,1,leukocyte aggregation (GO:0070486),IL1B;S100A9;S100A8
2,3,actin-myosin filament sliding (GO:0033275),ACTA1;DES;MYL1;MYL2;TNNC2;TNNT3;TTN
4,5,retina homeostasis (GO:0001895),PRR4;OPRPN;ZG16B;LCN1;LYZ;LTF
7,9,myofibril assembly (GO:0030239),TMOD4;TCAP;CASQ1;MYOZ1
8,10,keratinocyte differentiation (GO:0030216),DSP;SPRR3;CSTA;SPRR1A;SPRR1B;S100A7
9,12,chemokine-mediated signaling pathway (GO:0070098),CCL22;CXCL8;CCL3L1;CCL4;CCL3;CXCL1;CXCL3;CXCL2
10,17,negative regulation of fibrinolysis (GO:0051918),THBD;SERPINE1;THBS1
11,18,positive regulation of natural killer cell che...,CCL5;CCL4;CCL3


In [None]:
TSP.obs.columns

Index(['donor', 'tissue', 'anatomical_position', 'method', 'cdna_plate',
       'library_plate', 'notes', 'cdna_well', 'old_index', 'assay',
       'sample_id', 'sample', 'replicate', '10X_run', '10X_barcode',
       'ambient_removal', 'donor_method', 'donor_assay', 'donor_tissue',
       'donor_tissue_assay', 'cell_ontology_class', 'cell_ontology_id',
       'compartment', 'broad_cell_class', 'free_annotation',
       'manually_annotated', 'published_2022', 'n_genes_by_counts',
       'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ercc',
       'pct_counts_ercc', '_scvi_batch', '_scvi_labels',
       'scvi_leiden_donorassay_full', 'age', 'sex', 'ethnicity'],
      dtype='object')

In [None]:
def score_cell_type_by_genes(adata, genes, cell_type):
    sub_adata = adata[adata.obs['cell_ontology_class'] == cell_type].copy()
    if len(genes) > 0 and sub_adata.n_obs > 0:
        sc.tl.score_genes(sub_adata, gene_list=genes, score_name='score')
        return sub_adata.obs['score'].mean()  # Average score of all cells of the type
    else:
        return np.nan

heatmap_data = pd.DataFrame(index=unique_terms_df['Top_Term'], columns=TSP.obs['cell_ontology_class'].unique())

for idx, row in unique_terms_df.iterrows():
    genes = row['Top_Genes'].split(';')
    genes = [gene for gene in genes if gene in TSP.var_names]
    for cell_type in heatmap_data.columns:
        heatmap_data.at[row['Top_Term'], cell_type] = score_cell_type_by_genes(TSP, genes, cell_type)

# Replace NaNs and convert to float for clustering
heatmap_data.fillna(0, inplace=True)
heatmap_data = heatmap_data.astype(float)

# Create a clustermap
g = sns.clustermap(heatmap_data, method='ward', metric='euclidean', cmap='viridis',
                   linewidths=.5, figsize=(12, 8), annot=True, fmt=".2f")
plt.setp(g.ax_heatmap.get_xticklabels(), rotation=45, horizontalalignment='right')  # Rotate x labels for better visibility
plt.title('Clustered Model Score Heatmap by Cell Type and GO Term')
plt.show()

In [None]:
# Create a clustermap
g = sns.clustermap(heatmap_data, method='ward', metric='euclidean', cmap='viridis', figsize=(25, 15), annot=False, fmt=".2f")
plt.setp(g.ax_heatmap.get_xticklabels(), rotation=45, horizontalalignment='right')  # Rotate x labels for better visibility
plt.title('Clustered Model Score Heatmap by Cell Type and GO Term')
plt.show()