In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import gseapy as gp

In [12]:
# plotting functions for heatmap

def scale_values(dataframe):
    # Scale the values in the dataframe
    
    dataframe['scaled_value'] = -np.log(dataframe['qval']) * np.sign(dataframe['coef'])
    
    dataframe = dataframe.sort_values(by='scaled_value')

    return dataframe


def create_heatmap(dataframe, title, filename):
    # Pivot the DataFrame for the heatmap
    # Assuming that the columns 'feature', 'value', and 'scaled_value' exist in the dataframe
    heatmap_data = dataframe.pivot(index='feature', columns='value', values='scaled_value')

    # Sort the pivot data by one of the columns for visual hierarchy in the heatmap
    # sorted_heatmap_data = heatmap_data.sort_values(by=target_column, ascending=False)

    # Create the heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(heatmap_data, annot=True, cmap='coolwarm', fmt='.2f', center = 0)

    # Customize the plot
    plt.title(title)
    plt.ylabel('')
    plt.xlabel('')

    # Save and show the plot
    plt.tight_layout() 
    plt.savefig(filename, dpi=600, bbox_inches='tight')
    plt.show()

def create_dotplot(dataframe, title, filename):
    plt.figure(figsize=(12, 10))

    # Create a dot plot using seaborn's scatterplot
    sns.scatterplot(data=dataframe, x='value', y='feature', size='scaled_value', hue='coef', 
                    sizes=(50, 500), palette='coolwarm_r', edgecolor="w", legend=None)

    # Customize the plot
    plt.title(title)
    plt.ylabel('DLE Clusters')
    plt.xlabel('')

    # Save and show the plot
    plt.tight_layout()
    # plt.savefig(filename, dpi=600, bbox_inches='tight')
    plt.show()

def create_clustermap(dataframe, title, filename):
    # Pivot the dataframe to create a matrix for clustering
    heatmap_data = dataframe.pivot(index='feature', columns='value', values='scaled_value')
    
    # Check for NaN or infinite values and handle them
    if heatmap_data.isnull().values.any() or np.isinf(heatmap_data.values).any():
        heatmap_data = heatmap_data.fillna(0)  # Replace NaNs with 0
        heatmap_data = heatmap_data.replace([np.inf, -np.inf], 0)  # Replace infinities with 0
    
    # Create a clustermap with hierarchical clustering
    sns.clustermap(heatmap_data, cmap='coolwarm_r', annot=True, center=0, figsize=(12, 10))

    # Add title and save the plot if desired
    plt.title(title)
    plt.tight_layout()
    # plt.savefig(filename, dpi=600, bbox_inches='tight')
    plt.show()

## Visualize PC correlation with inflammatory markers

In [None]:
all_results = pd.read_csv("/Volumes/PGH-Backup/ibd_data/shotgun_rnaseq_maaslin/rectum_pca/all_results.tsv", sep = '\t')
display(all_results)

In [4]:
pcs = ["PC6", "PC8", "PC9", "PC15", "PC25"]
genes = ["CCL2", "CXCL10", "IL17A", "NLRP3"]

In [None]:
all_results_filt = all_results[all_results['feature'].isin(pcs) & all_results['value'].isin(genes)]

display(all_results_filt)

In [None]:
RnaSeqShotgunPCA = pd.read_csv("/Volumes/PGH-Backup/ibd_data/shotgun_rnaseq_maaslin/rectum_pca/significant_results.tsv", sep = '\t')

display(RnaSeqShotgunPCA)

In [None]:
RnaSeqShotgunPCA = RnaSeqShotgunPCA[RnaSeqShotgunPCA['qval'] <= 0.05]

display(RnaSeqShotgunPCA)

In [None]:
RnaSeqShotgunPCA_scaled = scale_values(RnaSeqShotgunPCA)

all_results_scaled = scale_values(all_results_filt)

In [None]:
create_heatmap(all_results_filt, "", "PCAmgx_rnaseq_inflammatory_corrs.png")

In [None]:
create_heatmap(RnaSeqShotgunPCA_scaled, "Correlation of PCAd mgx and inflammatory markers in the rectum",
                "PCAmgx_rnaseq_inflammatory_corrs.png")

In [None]:
create_dotplot(RnaSeqShotgunPCA_scaled, "Correlation of PCAd mgx and inflammatory markers in the rectum",
                "PCAmgx_rnaseq_inflammatory_corrs.png")

In [None]:
create_clustermap(RnaSeqShotgunPCA_scaled, "Correlation of PCAd mgx and inflammatory markers in the rectum",
                "PCAmgx_rnaseq_inflammatory_corrs.png")

## Visualize individual PGH correlation with inflammatory markers

In [None]:
# Read in data

all_results_pgh = pd.read_csv('/Volumes/PGH-Backup/ibd_data/shotgun_rnaseq_maaslin/rectum/all_results.tsv', sep = '\t')

display(all_results_pgh)

In [None]:
# all_results_pgh = all_results_pgh[all_results_pgh['qval'] < 0.05]

# display(all_results_pgh)

In [None]:
all_results_pgh_scaled = scale_values(all_results_pgh)

display(all_results_pgh_scaled)

In [19]:
all_results_pgh_scaled = all_results_pgh_scaled[~all_results_pgh_scaled['feature'].isin(['Amidase.A0A1W7ABN8', "DL.endopeptidase.A0A6N3BHG0"])]

In [None]:
create_heatmap(all_results_pgh_scaled, "PGH cluster associations with inflammatory markers in the rectum", "mgx_rnaseq_inflammatory")


## GSEA analysis

In [41]:
def parse_gmt(file_path):
    """
    Parse a GMT file and return a dictionary of gene sets.
    Each key is a gene set name, and each value is a list of genes.
    """
    gene_sets = {}
    with open(file_path, 'r') as gmt_file:
        for line in gmt_file:
            # Split the line by tabs
            parts = line.strip().split('\t')
            
            # The first item is the gene set name, the second is a description, the rest are genes
            gene_set_name = parts[0]
            genes = parts[2:]  # Genes start from the third column onwards
            
            # Add to the dictionary
            gene_sets[gene_set_name] = genes
    
    return gene_sets

In [None]:
# Get nod2 gene list
nod2_path = '/Volumes/PGH-Backup/ibd_data/rnaseq/GSE22611_NOD2_VS_CTRL_TRANSDUCED_HEK293T_CELL_UP.v2024.1.Hs.gmt'

nod2_genes = parse_gmt(nod2_path)

# Hallmark inflammatory genes list
inflam_path = "/Volumes/PGH-Backup/ibd_data/rnaseq/HALLMARK_INFLAMMATORY_RESPONSE.v2024.1.Hs.gmt"

inflam_genes = parse_gmt(inflam_path)

# combine to dict
combined_gene_set = nod2_genes | inflam_genes

print(combined_gene_set)


In [None]:
# Check the structure of combined_gene_set
print(type(combined_gene_set))  # Should be <class 'dict'>

# Check the types of a key and a value
for key, value in combined_gene_set.items():
    print(f"Key: {key} (Type: {type(key)})")
    print(f"Value Sample: {value[:5]} (Type: {type(value)})")

In [None]:
ranked_genes = all_results_pgh[['value', 'coef']].sort_values(by='coef', ascending=False)

print(ranked_genes)

# Convert to the format required by GSEA
ranked_genes_list = ranked_genes.set_index('value')['coef'].to_dict()

In [47]:
# Convert all gene names in the ranked gene list and gene sets to uppercase
ranked_genes['value'] = ranked_genes['value'].str.upper()

combined_gene_set = {
    key: [gene.upper() for gene in genes] for key, genes in combined_gene_set.items()
}

In [None]:
gsea_results = gp.prerank(
    rnk=ranked_genes,  # Your ranked gene list
    gene_sets=inflam_path,  # Combined gene sets
    min_size=5,
    max_size=1000,
    outdir=None,  # Directory to save the results
    permutation_num=1000,  # Number of permutations (can adjust for significance)
    seed=42
)

# View the results
gsea_results.res2d.head()

In [None]:
print(type(ranked_genes_list))

print(type(combined_gene_set))