### Programming for Biomedical Informatics
#### Week 7 - Network Construction Techniques

Constructing networks that are useful representations of the underlying biological data is a complex task. In this notebook we will explore some key concepts that are used to incoporate data into networks and then refine those using a selection of methodologies.
Quantifying the impact of the assumptions and decisions made in the network construction and refinement process is a key part of the experimental analysis of networks. This is often confounded by the lack of ground-truth data upon which to make decisions.

Thanks to Sebestyen Kamp who developed parts of these scripts for a workshop on networks presented at ISMB2024 in Montreal, Canada.

Files used in this analysis

- ISMB_TCGA_GE.pkl - contains gene expression data for TCGA samples
- correlation_matrices.pkl - contains correlation matrices for TCGA sample gene expression data
- correlation_matrices_figure.png - figure showing the correlation heatmaps
- gene_coexpression_network_pearson.gml - base network for the gene coexpression network
- full_gene_coexpression_network.png - figure showing the full gene coexpression network

These files can be downloaded from [here](https://datasync.ed.ac.uk/index.php/s/0DDNSGC4YHv0NMi) with password: 'pbi2024'

In [1]:
'''Biomedical Networks'''
# standard libraries
import os
import pickle

# scientific and data manipulation libraries
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.feature_selection import mutual_info_regression
# astropy is a library for astronomy and astrophysics! but has some very nice statistical tools
import astropy
from astropy.stats import median_absolute_deviation
# mygene is a library for querying gene information (though you could use eUtils etc.)
import mygene

# graph and network libraries
import networkx as nx

# visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.io as pio
from IPython.display import Image
from IPython.display import display

# some deprecation warnings that are safe to ignore can be silenced using the warnings library
import warnings
warnings.filterwarnings('ignore')

We're going to be looking at some gene expression data from the cancer genome atlas.

Note that it has multi-modal data - we will look at this in later lectures. We're going to concentrate on gene expression in this notebook

- **Title**: The Cancer Genome Atlas Lung Adenocarcinoma (TCGA-LUAD)
- **Main Focus**: Study of lung adenocarcinoma (a common type of lung cancer)
- **Data Collected**: Genomic, epigenomic, transcriptomic, and proteomic data from lung adenocarcinoma samples
- **Disease Types**:
  - Acinar Cell Neoplasms
  - Adenomas and Adenocarcinomas
  - Cystic, Mucinous, and Serous Neoplasms
- **Number of Cases**: 585 (498 with transcriptomic data)
- **Data Accessibility**: Available on the NIH-GDC Data Portal

- **Link**: [TCGA-LUAD Project Page](https://portal.gdc.cancer.gov/projects/TCGA-LUAD)

In [2]:
# I will use these data locations for the session but you should download the files
# using the link above and change the paths below to the correct location on your machine
raw_data_dir = './data/data/raw'
intermediate_data_dir = './data/data/intermediate'

In [None]:
# load the gene expression data from a pickle file
'''a pickle file is a serialized python object that can be saved to disk and loaded back into memory
these can be very useful for sharing python objects. In this case we have a dictionary with the gene 
expression data'''

with open(os.path.join(raw_data_dir,"ISMB_TCGA_GE.pkl"), 'rb') as file:
    data = pickle.load(file)

# print the keys of the dictionary
print(data.keys())

In order to construct a biological network, we are going to first:
- examine the TCGA metadata 
- come up with useful strategies to tackle the large data size 
- create the basis of a biological network

We're going to first familiarise ourselves with the data by looking at the meta-data that \
comes with the gene expression data

In [None]:
# show the first few rows of the gene expression meta-data
data["datMeta"]

In [None]:
# Count the number of unique patient identifiers in the 'patient' column of the dataFrame
data["datMeta"]["patient"].unique().size

In [None]:
# Count the occurrences of each unique value in the 'sample_type' column of the 'datMeta' DataFrame
data["datMeta"]['sample_type'].value_counts()

We are going to visualise various metadata attributes such as race, gender, sample type, cigarettes per day, and smoking status by gender.

In [None]:
# Set up the figure and axes for a 2-column layout
fig, axes = plt.subplots(3, 2, figsize=(18, 18))
fig.suptitle('Metadata Distributions', fontsize=20, y=0)

# Plot 1: Distribution of Race
sns.countplot(ax=axes[0, 0], x='race', data=data['datMeta'], palette='viridis')
axes[0, 0].set_title('Distribution of Race')
axes[0, 0].set_xlabel('Race')
axes[0, 0].set_ylabel('Count')
axes[0, 0].tick_params(axis='x', rotation=45)

# Plot 2: Gender Distribution
sns.countplot(ax=axes[0, 1], x='gender', data=data['datMeta'], palette='magma')
axes[0, 1].set_title('Gender Distribution')
axes[0, 1].set_xlabel('Gender')
axes[0, 1].set_ylabel('Count')

# Plot 3: Sample Type Distribution
sns.countplot(ax=axes[1, 0], x='sample_type', data=data['datMeta'], palette='plasma')
axes[1, 0].set_title('Sample Type Distribution')
axes[1, 0].set_xlabel('Sample Type')
axes[1, 0].set_ylabel('Count')
axes[1, 0].tick_params(axis='x', rotation=45)

# Plot 4: Distribution of Cigarettes Per Day
sns.histplot(ax=axes[1, 1], data=data['datMeta']['cigarettes_per_day'], kde=True, color='blue')
axes[1, 1].set_title('Distribution of Cigarettes Per Day')
axes[1, 1].set_xlabel('Cigarettes Per Day')
axes[1, 1].set_ylabel('Frequency')

# Plot 5: Smoking Status by Gender
sns.countplot(ax=axes[2, 0], x='Smoked', hue='gender', data=data['datMeta'], palette='coolwarm')
axes[2, 0].set_title('Smoking Status by Gender')
axes[2, 0].set_xlabel('Smoking Status')
axes[2, 0].set_ylabel('Count')
axes[2, 0].legend(title='Gender')

axes[2, 1].axis('off')
plt.tight_layout()
plt.show()

This dataset contains gene expression levels for various samples, identified by their TCGA (The Cancer Genome Atlas) codes.  
Each row represents a different sample, while each column represents a different gene, identified by its Ensembl gene ID.  
The values in the table are the expression levels of the genes for each sample.

In [None]:
# let's create a new variable for the expression data alone, just for ease of use, and then inspect it
expression_data = data["datExpr"]
expression_data

498 rows × 22637 columns is a quite large matrix, thus we have to consider reducing the size. This could be by removing columns (genes) or sub-setting the patient samples

We could:
- Filter by Mean Expression: Select genes with high mean expression levels
- Filter by Variance: Select genes with high variance across samples, as low variance genes might not contribute significantly to the analysis.
- Analyse differentially expressed genes and keep ones that meet some criteria for fold-change and significance. What to compare?

We are doing this to make sure our computations are  
- computationally efficient,  
- the network complexity is manageable,  
- the biological signal is enhanced thus we make sure our analysis is biologically relevant.


In [9]:
# A few preliminary steps that might be useful for data cleaning and preprocessing
# Ensure all columns are numeric
expression_data = expression_data.apply(pd.to_numeric, errors='coerce')

# Drop columns that could not be converted to numeric (if any)
expression_data = expression_data.dropna(axis=1, how='all')

In [None]:
# In case we want to check the shape of the data further down the line
expression_data.shape

# note nothing has changed in the data, but we have ensured that all columns are numeric

In [None]:
# Checking for duplicate rows and columns
print(f"Number of duplicate indices: {expression_data.index.duplicated().sum()}")  
print(f"Number of duplicate columns: {expression_data.columns.duplicated().sum()}") 

In [None]:
# Plot the distribution of gene expression levels 

# Calculate the mean expression level for each gene
gene_means = expression_data.mean(axis=0)

# Plot the distribution of gene expression levels
plt.figure(figsize=(10, 6))
sns.histplot(gene_means, bins=50, kde=True)
plt.xlabel('Mean Gene Expression')
plt.ylabel('Frequency')
plt.title('Distribution of Gene Expression Levels')

# mean of the gene means
threshold = gene_means.mean()
plt.axvline(threshold, color='red', linestyle='--', label=f'Mean = {threshold:.2f}')
plt.legend()

plt.show()

The histogram shows the frequency of genes at different mean expression levels.  
Dashed red line shows the mean of the mean gene expression.  (8.11)  
The histogram show a bimodal distribution - does this mean two groups of genes?  
Some genes are consistently expressed at lower levels (housekeeping genes?), while others are expressed at higher levels. 
  
We could use this for thresholding, however it is not informative about the variability of the genes.

Still, let's inspect how many genes would retain in our dataset at different expression thresholds.

In [13]:
# Filter out low-expressed genes from the dataset
# As we are going to explore effects of different thresholds, we will create a function for this
def filter_low_expression_genes(data, threshold=1.0):
    """
    Filter out low-expressed genes from the dataset.

    Calculates the mean expression level for each gene and filters out
    genes whose mean expression level is below the specified threshold.

    Parameters:
    data (DataFrame): Expression data with genes as columns.
    threshold (float): Minimum mean expression level to retain a gene.
                       Default is 1.0.

    Returns:
    DataFrame: Filtered data with genes above the threshold.
    """
    # Calculate the mean expression for each gene
    gene_means = data.mean(axis=0)
    # Filter out genes with mean expression below the threshold
    mask = gene_means >= threshold
    filtered_data = data.loc[:, mask]
    return filtered_data

In [None]:
# Gene Retention at Various Thresholds
# Define a range of thresholds
thresholds = np.arange(0, 15, 0.5)

# List to store the number of genes retained at each threshold
num_genes = []

# Assuming df_renamed is your DataFrame with gene expression data
for threshold in thresholds:
    df_filtered = filter_low_expression_genes(expression_data, threshold)
    num_genes.append(df_filtered.shape[1])

# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(thresholds, num_genes, marker='o')
plt.xlabel('Threshold')
plt.ylabel('Number of Genes Retained')
plt.title('Number of Genes Retained at Different Thresholds')
plt.grid(True)
plt.show()

In [15]:
# Filter out genes based on their variance
# As we are going to explore effects of variance, we will create a function for this
def filter_high_variance_genes(data, threshold):
    """
    Filter out genes with variance below the specified threshold.

    Calculates the variance for each gene and filters out genes whose 
    variance is below the specified threshold.

    Parameters:
    data (DataFrame): Gene expression data with genes as columns and samples as rows.
    threshold (float): Minimum variance level to retain a gene.

    Returns:
    DataFrame: Filtered data with genes having variance above the threshold.
    """

    # Calculate the variance for each gene (column)
    gene_variances = data.var(axis=0)
    # Create a boolean mask to filter out genes with variance below the threshold
    mask = gene_variances >= threshold
    # Apply the mask to filter the DataFrame
    filtered_data = data.loc[:, mask]
    return filtered_data

Visualise the gene retention at different variance thresholds.  
 
Calculate the 75th percentile of the variance distribution, and use it as a threshold.  
It is a good balance between retaining enough data for meaningful analysis and removing low-variance noise.

In [None]:
# Define a range of variance thresholds
variance_thresholds = np.arange(0, 10, 0.5)

# List to store the number of genes retained at each threshold
num_genes = []

# Assuming df_renamed is your DataFrame with gene expression data
for threshold in variance_thresholds:
    df_filtered = filter_high_variance_genes(expression_data, threshold)
    num_genes.append(df_filtered.shape[1])

# Calculate the variance for each gene
gene_variances = expression_data.var(axis=0)
# Calculate the 75th percentile of the variance distribution
variance_threshold = np.percentile(gene_variances, 75)
print(f"Chosen Variance Threshold: {variance_threshold }")

# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(variance_thresholds, num_genes, marker='o')
plt.axvline(variance_threshold, color='red', linestyle='--', label=f'75th Percentile Threshold = {variance_threshold:.2f}')
plt.xlabel('Variance')
plt.ylabel('Number of Genes Retained')
plt.legend()
plt.title('Number of Genes Retained at Different Variance Thresholds')
plt.grid(True)
plt.show()

- Let's assign this value as threshold to focus on genes with higher variance.
- Part of the rationale here is that if genes aren't varying much between samples they are not liklely \
to be differentially expressed based on any of the factors that vary between patients.
- Note that this is not the same as differential expression analysis, as it would retain more genes including those that are not well estimated i.e. due to technical rather than biological variation.
- This may be interesting to us because one argument we may have for making a network is that such genes would not consistently correlate with any given other set of genes (biologically) and so would not connect strongly on the network.
- overall this is a 'softer' assumption that GXD and allows more opportunity for "discovery"

In [None]:
# Assign the filtered data to a new variable
# Focusing on the top quartile (75th percentile) of genes with the highest variance
df_filtered_variance = filter_high_variance_genes(expression_data, threshold = 1.2)
print(f"Filtered data shape: {df_filtered_variance.shape}")  # Check the new shape to confirm filtering

You may have noticed that the column header are not gene names so we're going to fix that by mapping (as we have done before in the course). You could use eUtils to do this or even bulk download the meta-data and use table merging, but we're going to use a nice package called "mygene" - (https://docs.mygene.info/projects/mygene-py/en/latest/)

Converting Ensembl gene IDs (ENSG) to HGNC (HUGO Gene Nomenclature Committee) gene symbols is often a good practice as HGCN is an international standard.

In [28]:
# define a function to do the gene mapping
def rename_ensembl_to_gene_names(df, chunk_size=1000):
    """
    Renames Ensembl gene IDs to gene names using mygene.

    NB we chunk the requests to avoid hitting the rate limit.
    
    Parameters:
    df (pd.DataFrame): DataFrame with Ensembl gene IDs as columns.
    chunk_size (int): Number of Ensembl IDs to query at a time.
    
    Returns:
    pd.DataFrame: DataFrame with gene names as columns, excluding genes that couldn't be mapped.
    """
    
    # Make a copy of the DataFrame to avoid modifying the original
    df_copy = df.copy()

    # Remove the `.number` suffix from ENSG IDs
    df_copy.columns = df_copy.columns.str.split('.').str[0]

    # Initialize mygene client
    mg = mygene.MyGeneInfo()

    # Split ENSG IDs into smaller chunks
    def chunks(lst, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(lst), n):
            yield lst[i:i + n]

    ensg_ids = df_copy.columns.tolist()
    gene_mappings = []

    unmapped_genes = []

    # send requests in chunks
    for chunk in chunks(ensg_ids, chunk_size):
        result = mg.querymany(chunk, scopes='ensembl.gene', fields='symbol', species='human')
        gene_mappings.extend(result)

    # Create a mapping from ENSG to gene symbol, handle missing mappings
    ensg_to_gene = {item['query']: item.get('symbol', None) for item in gene_mappings}
    
    # Log the unmapped genes
    batch_unmapped_genes = [gene for gene in ensg_ids if ensg_to_gene.get(gene) is None]
    if batch_unmapped_genes:
        # Add unmapped genes to the list
        unmapped_genes.extend(batch_unmapped_genes)

    # Filter the DataFrame to only include columns that have been mapped
    df_filtered = df_copy.loc[:, df_copy.columns.isin(ensg_to_gene.keys())]

    # Further filter to ensure we have the same number of columns as mapped gene names
    df_filtered = df_filtered.loc[:, [ensg for ensg in df_filtered.columns if ensg_to_gene[ensg] is not None]]

    # Assign new column names
    df_filtered.columns = [ensg_to_gene[ensg] for ensg in df_filtered.columns]

    # Handle duplicate gene names by aggregating them (e.g., by taking the mean)
    df_final = df_filtered.T.groupby(df_filtered.columns).mean().T

    return df_final, set(unmapped_genes)

In [None]:
# convert the ensembl gene IDs to gene names
df_renamed,unmapped_genes = rename_ensembl_to_gene_names(df_filtered_variance)
print(f'{len(unmapped_genes)} were not mapped to gene names.')

print(f'The first unmapped gene is',unmapped_genes.pop())

In [None]:
# Let's check the new shape of the data after renaming
print("Shape of df_filtered_variance:", df_filtered_variance.shape)
print("Shape of df_renamed:", df_renamed.shape)

In [None]:
# let's inspect the first few rows of the renamed DataFrame
df_renamed

In [None]:
# Plot the distribution of gene expression levels 

# Calculate the mean expression level for each gene
gene_means = df_renamed.mean(axis=0)

# Plot the distribution of gene expression levels
plt.figure(figsize=(10, 6))
sns.histplot(gene_means, bins=50, kde=True)
plt.xlabel('Mean Gene Expression')
plt.ylabel('Frequency')
plt.title('Distribution of Gene Expression Levels')

# mean of the gene means
threshold = gene_means.mean()
plt.axvline(threshold, color='red', linestyle='--', label=f'Mean = {threshold:.2f}')
plt.legend()

plt.show()

In [None]:
# Sort and get the top 10 and bottom 10 columns based on their mean values

top_10_columns = gene_means.sort_values(ascending=False).head(10)
bottom_10_columns = gene_means.sort_values(ascending=True).head(10)

# Combine top 10 and bottom 10 columns into one DataFrame
combined = pd.concat([top_10_columns, bottom_10_columns])

# Plot the results
plt.figure(figsize=(14, 6))
combined.plot(kind='bar', color=['skyblue' if i < 10 else 'lightcoral' for i in range(20)])
plt.title('Top 10 and Bottom 10 Genes Based on Mean Expression Values')
plt.ylabel('Mean Value')
plt.xlabel('Gene')
plt.xticks(rotation=45, ha='right')
plt.grid(True)
plt.show()

In [None]:
# Select numerical columns
numerical_columns = df_renamed.select_dtypes(include=['float64', 'int64'])

# Calculate the mean of each column for the numerical columns only
means = numerical_columns.mean()

# Get the top 8 and bottom 8 columns based on mean values
top_8_columns = means.nlargest(8).index.tolist()
bottom_8_columns = means.nsmallest(8).index.tolist()

# Combine top 8 and bottom 8 columns for plotting
columns_to_plot = top_8_columns + bottom_8_columns

# Filter the numerical columns to only include those to plot
filtered_numerical_columns = numerical_columns[columns_to_plot]

# Calculate the number of rows needed for subplots based on the number of selected columns
num_plots = len(filtered_numerical_columns.columns)
num_rows = (num_plots // 4) + (num_plots % 4 > 0)  # Ensure there is an extra row if there are leftovers

# Plot histograms for each selected numerical column
fig, axes = plt.subplots(num_rows, 4, figsize=(20, 5 * num_rows))  # Adjust width and height as needed
fig.suptitle('Distribution of Top 8 & Bottom 8 Genes Mean Expression', fontsize=16)

for i, col in enumerate(filtered_numerical_columns.columns):
    ax = axes.flatten()[i]
    filtered_numerical_columns[col].hist(bins=15, ax=ax, color='skyblue' if i < 8 else 'lightcoral')
    ax.set_title(col)
    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency')
    
    # Adding mean and median lines
    mean_val = filtered_numerical_columns[col].mean()
    median_val = filtered_numerical_columns[col].median()
    ax.axvline(mean_val, color='blue', linestyle='dashed', linewidth=1)
    ax.axvline(median_val, color='red', linestyle='dashed', linewidth=1)
    ax.legend({'Mean': mean_val, 'Median': median_val})

# Hide any unused axes if the number of plots isn't a perfect multiple of 4
if num_plots % 4:
    for ax in axes.flatten()[num_plots:]:
        ax.set_visible(False)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout to make room for the suptitle
plt.show()

We now want to establish whether there are any strong relationships between the expression levels of the genes in the different samples. We can do this by calculating the correlation between their expression values across samples. We can use this correlation matrix as an adjacency matrix to build a network.

- nodes: genes  
- edges: highly correlated genes (above a given threshold)
- edge-weights: correlation values

There are a few correlation metrics one could consider:
- [Pearson](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)  
  - O(n^2) complexity, fast for large datasets
- [Spearman](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient)
  -  O(n^2 log n) complexity, relatively fast but can be slower than Pearson
- [Absolute biweight midcorrelation](https://en.wikipedia.org/wiki/Biweight_midcorrelation)
  - Robust but slower than Pearson and Spearman, suitable for datasets with outliers


In [37]:
# function to calculate absolute biweight midcorrelation 
def calc_abs_bicorr(data):
    """
    Calculate the absolute biweight midcorrelation matrix for numeric data.

    Parameters:
    data (pd.DataFrame): Input DataFrame with numeric data.

    Returns:
    pd.DataFrame: DataFrame containing the absolute biweight midcorrelation matrix.
    """

    # Select only numeric data
    data = data._get_numeric_data()
    cols = data.columns
    idx = cols.copy()
    mat = data.to_numpy(dtype=float, na_value=np.nan, copy=False)
    mat = mat.T

    K = len(cols)
    correl = np.empty((K, K), dtype=np.float32)

    # Calculate biweight midcovariance
    bicorr = astropy.stats.biweight_midcovariance(mat, modify_sample_size=True)

    for i in range(K):
        for j in range(K):
            if i == j:
                correl[i, j] = 1.0
            else:
                denominator = np.sqrt(bicorr[i, i] * bicorr[j, j])
                if denominator != 0:
                    correl[i, j] = bicorr[i, j] / denominator
                else:
                    correl[i, j] = 0  # Or handle it in another appropriate way

    return pd.DataFrame(data=np.abs(correl), index=idx, columns=cols, dtype=np.float32)

We're going to use pre-computed correlation matrices as it takes a while to calculate. We've left the code to do this commented out below in case you would like to try it later.

In [None]:
# # Dictionary to store different correlation matrices
# correlation_matrices = {}

# # Pearson correlation - O(n^2) complexity, fast for large datasets
# correlation_matrices['pearson'] = df_renamed.corr(method='pearson')

# # Spearman rank correlation -  O(n^2 log n) complexity, relatively fast but can be slower than Pearson
# correlation_matrices['spearman'] = df_renamed.corr(method='spearman')

# # Biweight midcorrelation -  Robust but slower than Pearson and Spearman, suitable for datasets with outliers
# correlation_matrices['biweight_midcorrelation'] = abs_bicorr(df_renamed)

# # Print the keys of the correlation matrices to verify
# print("Correlation matrices calculated:")
# print(correlation_matrices.keys())


# # Save the entire dictionary of correlation matrices as a pickle file
# with open(os.path.join(intermediate_data_dir,"correlation_matrices.pkl"), 'wb') as f:
#     pickle.dump(correlation_matrices, f)


In [None]:
# Load the entire dictionary of correlation matrices from a pickle file
with open(os.path.join(intermediate_data_dir,"correlation_matrices.pkl"), 'rb') as f:
    correlation_matrices = pickle.load(f)

# Verify the loaded data where we have each of the correlation matrices stored
print(correlation_matrices.keys())

In [40]:
# Plot the correlation matrices as heatmaps
from scipy.cluster.hierarchy import linkage, leaves_list

# here we have a function to plot the correlation matrices as heatmaps
# this is time consuming so we will not run it here but will load the pre-saved images
def plot_correlation_matrices(correlation_matrices):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))  # Adjust to 1 row and 3 columns
    axes = axes.flatten()
    
    for i, (key, matrix) in enumerate(correlation_matrices.items()):
        # Perform hierarchical clustering
        Z = linkage(matrix, method='average')  # You can use other methods like 'single', 'complete', etc.
        idx = leaves_list(Z)
        
        # Reorder matrix
        ordered_matrix = matrix.iloc[idx, :].iloc[:, idx]
        
        # Plot heatmap
        sns.heatmap(ordered_matrix, ax=axes[i], cmap='coolwarm', cbar=True, xticklabels=False, yticklabels=False)
        axes[i].set_title(f'{key.capitalize()} Correlation Matrix')
        
        # Set square aspect ratio
        axes[i].set_aspect('equal', adjustable='box')
    
    # Hide any unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].set_visible(False)
    
    plt.tight_layout()
    plt.show()


In [None]:
# use the function to plot the correlation matrices
# plot_correlation_matrices(correlation_matrices)

In [None]:
# Display the image of the correlation matrices
display(Image(filename=os.path.join(intermediate_data_dir,
                                    "figures",
                                    "correlation_matrices_figure.png")))

We are going to define a function `create_graph_from_correlation` to make networks from correlation matrices.

The function starts by creating an empty graph G. Then iterates through the columns of the correlation matrix and adds each column name as a node in the graph. This means each gene (or feature) in your dataset becomes a node in the graph.

The function iterates over the upper triangle of the correlation matrix (excluding the diagonal) to avoid redundancy and self-loops. Remembering that this is an undirected graph so is symmetric.

For each pair of nodes (i, j), it checks if the absolute value of the correlation coefficient between them is greater than or equal to the specified threshold.

If the condition is met, an edge is added between the nodes i and j with the correlation coefficient as the weight of the edge. This signifies a strong correlation (positive or negative) between the two nodes.

In [30]:
# Create a graph from the correlation matrix using a specified threshold
def create_graph_from_correlation(correlation_matrix, threshold=0.8):
    """
    Creates a graph from a correlation matrix using a specified threshold.

    Parameters:
    correlation_matrix (pd.DataFrame): DataFrame containing the correlation matrix.
    threshold (float): Threshold for including edges based on correlation value.

    Returns:
    G (nx.Graph): Graph created from the correlation matrix.
    """
    G = nx.Graph()

    # Add nodes
    for node in correlation_matrix.columns:
        G.add_node(node)

    # Add edges with weights above the threshold
    for i in range(correlation_matrix.shape[0]):
        for j in range(i + 1, correlation_matrix.shape[1]):
            if i != j:  # Ignore the diagonal elements
                weight = correlation_matrix.iloc[i, j]
                if abs(weight) >= threshold:
                    G.add_edge(correlation_matrix.index[i], correlation_matrix.columns[j], weight=weight)

    return G

To save some time we will load the graph from a `.gml` file. A GML file is one of several commonly used  data exchange file formats for graph data. Networkx can export graphs in this and several other formats (see networkx documentation for further detail)

In [53]:
## Create a graph from the Pearson correlation matrix with a threshold of 0.8
# pearson_graph = create_graph_from_correlation(correlation_matrices['pearson'], threshold=0.8)
# nx.write_gml(pearson_graph, os.path.join(data_dir,'gene_coexpression_network_pearson.gml'))

In [43]:
# load the graph from the gml file
pearson_graph = nx.read_gml(os.path.join(intermediate_data_dir,'gene_coexpression_network_pearson.gml'))


Now let's go through a few useful NetworkX functions and create a `print_graph_info()` function.

In [44]:
# Print basic information about the graph
def print_graph_info(G):
    """
    Print basic information about a NetworkX graph.

    
    Parameters:
    G (nx.Graph): The NetworkX graph.
    """
    print(f"Number of nodes: {G.number_of_nodes()}")
    print(f"Number of edges: {G.number_of_edges()}")
    print("Sample nodes:", list(G.nodes)[:10])  # Print first 10 nodes as a sample
    print("Sample edges:", list(G.edges(data=True))[:10])  # Print first 10 edges as a sample
    
    info_str = "Graph type: "
    is_directed = G.is_directed()
    if is_directed:
        info_str += "directed"
    else:
        info_str += "undirected"
    print(info_str)

    # Check for self-loops
    self_loops = list(nx.selfloop_edges(G))
    if self_loops:
        print(f"Number of self-loops: {len(self_loops)}")
        print("Self-loops:", self_loops)
    else:
        print("No self-loops in the graph.")

    # density of the graph
    density = nx.density(G)
    print(f"Graph density: {density}")

    # Find and print the number of connected components
    num_connected_components = nx.number_connected_components(G)
    print(f"Number of connected components: {num_connected_components}")

    # Calculate and print the clustering coefficient of the graph
    clustering_coeff = nx.average_clustering(G)
    print(f"Average clustering coefficient: {clustering_coeff}")

In [None]:
print_graph_info(pearson_graph)

In [46]:
# Function to visualize the graph
def visualise_graph(G, title='Gene Co-expression Network'):
    """
    Visualizes the graph using Matplotlib and NetworkX.

    Parameters:
    G (nx.Graph): Graph to visualize.
    title (str): Title of the plot.
    """
    plt.figure(figsize=(10, 10))
    pos = nx.spring_layout(G, k=0.1)  # k controls the distance between nodes
    nx.draw_networkx_nodes(G, pos, node_size=50, node_color='blue', alpha=0.7)
    nx.draw_networkx_edges(G, pos, width=0.2, alpha=0.5)
    plt.title(title)
    plt.show()

As this is time consuming to generate we will load a pre-generated image of the graph

In [None]:
# # Visualize the graph
# visualise_graph(pearson_graph, title='Pearson Correlation Network (Threshold = 0.8)')

In [None]:
# Display the image of the full graph
display(Image(filename=os.path.join(intermediate_data_dir,
                                    "figures",
                                    "full_gene_coexpression_network.png")))

We now have the base gene correlation network but we can see that there are a lot of orphans (due to the threshold filterinf and so need to clean the network up. We can use functions from NetworkX for this.

In [48]:
# Function to clean the graph
def clean_graph(G, degree_threshold=1, keep_largest_component=True):
    """
    Cleans the graph by performing several cleaning steps:
    - Removes unconnected nodes (isolates)
    - Removes self-loops
    - Removes nodes with a degree below a specified threshold
    - Keeps only the largest connected component (optional)

    Parameters:
    G (nx.Graph): The NetworkX graph to clean.
    degree_threshold (int): Minimum degree for nodes to keep.
    keep_largest_component (bool): Whether to keep only the largest connected component.

    Returns:
    G (nx.Graph): Cleaned graph.
    """
    G = G.copy()  # Work on a copy of the graph to avoid modifying the original graph

    # Remove self-loops
    G.remove_edges_from(nx.selfloop_edges(G))

    # Remove nodes with no edges (isolates)
    G.remove_nodes_from(list(nx.isolates(G)))

    # Remove nodes with degree below the threshold
    low_degree_nodes = [node for node, degree in dict(G.degree()).items() if degree < degree_threshold]
    G.remove_nodes_from(low_degree_nodes)

    # Keep only the largest connected component
    if keep_largest_component:
        largest_cc = max(nx.connected_components(G), key=len)
        G = G.subgraph(largest_cc).copy()

    return G

In [58]:
# Clean the graph by removing unconnected nodes
pearson_graph_cleaned = clean_graph(pearson_graph,
                                    degree_threshold=1,
                                    keep_largest_component=False)

In [None]:
# view the cleaned graph
# NB this is now tractable quickly as the graph is much smaller
visualise_graph(pearson_graph_cleaned, title='Pearson Correlation Network - Cleaned')

In [None]:
# we can re-use the function to print the graph information
print_graph_info(pearson_graph_cleaned)

In [60]:
# Clean the graph by keeping only the largest connected component
pearson_graph_pruned = clean_graph(pearson_graph,
                                    degree_threshold=1,
                                    keep_largest_component=True)

In [None]:
visualise_graph(pearson_graph_pruned, title='Pearson Correlation Network - Pruned')

In [None]:
# we can re-use the function to print the graph information
print_graph_info(pearson_graph_pruned)

In [79]:
# Function to plot the degree distribution of a graph
def plot_power_law_distribution(G):

    # Compute the degree of each node
    G_degrees = [degree for _, degree in G.degree()]

    # Calculate the degree frequency distribution
    G_degree_counts = np.bincount(G_degrees)
    G_degree_values = np.nonzero(G_degree_counts)[0]  # degrees with at least one node
    G_degree_probabilities = G_degree_counts[G_degree_values] / sum(G_degree_counts)
    
    # Plot the power law distribution
    plt.figure(figsize=(10, 10))

    # plot the scatter for G
    plt.scatter(G_degree_values, G_degree_probabilities, color="blue", edgecolor="black", s=50, alpha=0.7)

    # add a fit line for G
    fit = stats.linregress(np.log(G_degree_values), np.log(G_degree_probabilities))
    
    # add the fit lines for G and G_er
    plt.plot(G_degree_values, np.exp(fit.intercept) * G_degree_values ** fit.slope, label=f'G: {fit.slope:.2f}', color="blue")
    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("Degree (log scale)")
    plt.ylabel("Probability (log scale)")
    plt.title("Degree Distribution in Log-Log Scale (Power Law)")
    plt.show()    

In [None]:
# Plot the degree distribution of the pruned graph
plot_power_law_distribution(pearson_graph_pruned)

In [82]:
# function to calculate the edge weight distribution
def visualise_edge_weight_distribution(G):
    """
    Visualizes the distribution of edge weights.

    Parameters:
    edge_weights (list): List of edge weights.
    """
    plt.figure(figsize=(10, 6))
    edge_weights = [G[u][v]['weight'] for u, v in G.edges()]
    # Histogram
    sns.histplot(edge_weights, bins=30, kde=False)
    
    plt.title('Distribution of Edge Weights')
    plt.xlabel('Edge Weight')
    plt.ylabel('Frequency')
    plt.show()

In [None]:
# Visualize the distribution of edge weights
visualise_edge_weight_distribution(pearson_graph_pruned)

With sparsification we aim to reduce the number of edges in a network while preserving important structural properties.

- Edge Sampling: Randomly removes a fraction of edges.
- Thresholding: Removes edges with weights below a certain threshold.
- Degree-based Sparsification

In [86]:
# simply remove the edges below a certain edge-weight threshold
def threshold_sparsification(graph, threshold):
    """
    Sparsifies the graph by removing edges below the specified weight threshold.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.
    threshold (float): The weight threshold.

    Returns:
    nx.Graph: The sparsified graph.
    """
    graph_copy = graph.copy()
    sparsified_graph = nx.Graph()
    sparsified_graph.add_nodes_from(graph_copy.nodes(data=True))
    sparsified_graph.add_edges_from((u, v, d) for u, v, d in graph_copy.edges(data=True) if d.get('weight', 0) >= threshold)
    return sparsified_graph

# keep the specified top quantile of edges by edge-weight
def top_percentage_sparsification(graph, top_percentage):
    """
    Sparsifies the graph by keeping the top percentage of edges by weight.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.
    top_percentage (float): The percentage of top-weight edges to keep.

    Returns:
    nx.Graph: The sparsified graph.
    """
    graph_copy = graph.copy()
    sorted_edges = sorted(graph_copy.edges(data=True), key=lambda x: x[2].get('weight', 0), reverse=True)
    top_edges_count = max(1, int(len(sorted_edges) * (top_percentage / 100)))
    sparsified_graph = nx.Graph()
    sparsified_graph.add_nodes_from(graph_copy.nodes(data=True))
    sparsified_graph.add_edges_from(sorted_edges[:top_edges_count])
    return sparsified_graph


# remove nodes with degree below a certain threshold
def remove_by_degree(graph, min_degree):
    """
    Sparsifies the graph by removing nodes with degree below the specified threshold.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.
    min_degree (int): The minimum degree threshold.

    Returns:
    nx.Graph: The sparsified graph.
    """
    graph_copy = graph.copy()
    nodes_to_remove = [node for node, degree in dict(graph_copy.degree()).items() if degree < min_degree]
    
    graph_copy.remove_nodes_from(nodes_to_remove)
    return graph_copy

# use KNN sparsification to keep up to only the top N edges for a node
def knn_sparsification(graph, k):
    """
    Sparsifies the graph by keeping only the top-k edges with the highest weights for each node.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.
    k (int): The number of nearest neighbors to keep for each node.

    Returns:
    nx.Graph: The sparsified graph.
    """
    graph_copy = graph.copy()
    sparsified_graph = nx.Graph()
    sparsified_graph.add_nodes_from(graph_copy.nodes(data=True))
    
    for node in graph_copy.nodes():
        edges = sorted(graph_copy.edges(node, data=True), key=lambda x: x[2].get('weight', 0), reverse=True)
        sparsified_graph.add_edges_from(edges[:k])
    
    return sparsified_graph

# create a minimum spanning tree
def spanning_tree_sparsification(graph):
    """
    Sparsifies the graph by creating a minimum spanning tree.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.

    Returns:
    nx.Graph: The sparsified graph.
    """
    graph_copy = graph.copy()
    return nx.minimum_spanning_tree(graph_copy, weight='weight')



In [88]:
# a function to plot the density of the graph at different thresholds
def analyse_and_plot_density(graph):
    """
    Calculates and plots the density of the graph for a predefined series of thresholds.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.

    Returns:
    densities (list of float): Densities of the graph at each threshold.
    """
    thresholds = [0.7 + i * 0.01 for i in range(31)]
    densities = []

    for threshold in thresholds:
        filtered_edges = [(u, v) for u, v, d in graph.edges(data=True) if d['weight'] > threshold]
        temp_graph = nx.Graph()
        temp_graph.add_edges_from(filtered_edges)
        densities.append(nx.density(temp_graph))

    # Plot the densities
    plt.figure(figsize=(10, 6))
    plt.plot(thresholds, densities, marker='o')
    plt.xlabel('Threshold')
    plt.ylabel('Density')
    plt.title('Density vs. Threshold')
    plt.grid(True)
    plt.show()

    return densities

In [None]:
# calculate and plot the density of the graph at different thresholds
densities = analyse_and_plot_density(pearson_graph_pruned)

In [None]:
# Initialise a dictionary to store graphs
graphs = {}
# Store the original graph for comparison
graphs['original'] = pearson_graph_pruned.copy()

# Apply sparsification methods to the original graph
graphs['threshold'] = threshold_sparsification(graphs['original'], threshold=0.82)
graphs['top_10_percent'] = top_percentage_sparsification(graphs['original'], top_percentage=10)
graphs['degree_below_3'] = remove_by_degree(graphs['original'], min_degree=3)
graphs['knn_5'] = knn_sparsification(graphs['original'], k=5)
graphs['spanning_tree'] = spanning_tree_sparsification(graphs['original'])


# Visualise the graphs after sparsification
visualise_graph(graphs['original'], 'Original Graph')
visualise_graph(graphs['threshold'], 'Thresholded Graph (weight > 0.82)')
visualise_graph(graphs['top_10_percent'], 'Top 10% Edges by Weight')
visualise_graph(graphs['degree_below_3'], 'Degree Below 3')
visualise_graph(graphs['knn_5'], 'K-Nearest Neighbors (k=5)')
visualise_graph(graphs['spanning_tree'], 'Minimum Spanning Tree')


In [None]:
# Let's inspect the information of the KNN sparsified graph
print_graph_info(graphs['knn_5'])


In [93]:
# function to analyse the effect of different k values on the network properties
def analyse_knn_effect(graph, k_values):
    """
    Analyses the effect of different k values on the network properties.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.
    k_values (list): List of k values to use for sparsification.

    Returns:
    pd.DataFrame: DataFrame containing the analysis results.
    """
    results = {
        'k': [],
        'num_edges': [],
        'avg_degree': [],
        'avg_clustering': [],
        'num_connected_components': [],
    }
    
    for k in k_values:
        sparsified_graph = knn_sparsification(graph, k)
        num_edges = sparsified_graph.number_of_edges()
        avg_degree = sum(dict(sparsified_graph.degree()).values()) / sparsified_graph.number_of_nodes()
        avg_clustering = nx.average_clustering(sparsified_graph)
        num_connected_components = nx.number_connected_components(sparsified_graph)
        
        results['k'].append(k)
        results['num_edges'].append(num_edges)
        results['avg_degree'].append(avg_degree)
        results['avg_clustering'].append(avg_clustering)
        results['num_connected_components'].append(num_connected_components)
    
    return pd.DataFrame(results)

# plot the analysis of the effect of different k values on network properties
def plot_knn_analysis(df):
    """
    Plots the analysis of the effect of different k values on network properties.

    Parameters:
    df (pd.DataFrame): DataFrame containing the analysis results.
    """
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    axes[0, 0].plot(df['k'], df['num_edges'], marker='o')
    axes[0, 0].set_title('Number of Edges vs k')
    axes[0, 0].set_xlabel('k')
    axes[0, 0].set_ylabel('Number of Edges')
    
    axes[0, 1].plot(df['k'], df['avg_degree'], marker='o')
    axes[0, 1].set_title('Average Degree vs k')
    axes[0, 1].set_xlabel('k')
    axes[0, 1].set_ylabel('Average Degree')
    
    axes[1, 0].plot(df['k'], df['avg_clustering'], marker='o')
    axes[1, 0].set_title('Average Clustering Coefficient vs k')
    axes[1, 0].set_xlabel('k')
    axes[1, 0].set_ylabel('Average Clustering Coefficient')
    
    axes[1, 1].plot(df['k'], df['num_connected_components'], marker='o')
    axes[1, 1].set_title('Number of Connected Components vs k')
    axes[1, 1].set_xlabel('k')
    axes[1, 1].set_ylabel('Number of Connected Components')
    
    plt.tight_layout()
    plt.show()

In [None]:
k_values = list(range(1, 11))  # Different k values to analyze
analysis_results = analyse_knn_effect(graphs['original'], k_values)

# Plot the analysis results
plot_knn_analysis(analysis_results)

In [95]:
# function to look at the top nodes based on degree
def get_highest_degree_nodes(graph, top_n=10):
    """
    Returns the nodes with the highest degree in the graph.

    Parameters:
    graph (nx.Graph): The NetworkX graph.
    top_n (int): The number of top nodes to return.

    Returns:
    List of tuples: Each tuple contains a node and its degree.
    """
    degrees = dict(graph.degree())
    sorted_degrees = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
    return sorted_degrees[:top_n]

# gather some information about nodes using mygene
def fetch_gene_info(gene_list):
    """
    Fetches gene information from MyGene.info.

    Parameters:
    gene_list (list): List of gene symbols or Ensembl IDs.

    Returns:
    list: List of dictionaries containing gene information.
    """
    mg = mygene.MyGeneInfo()
    gene_info = mg.querymany(gene_list, scopes='symbol,ensembl.gene', 
                             fields='name,symbol,entrezgene,summary,disease,pathway', 
                             species='human')
    return gene_info

# combined function to report node information alongside gene metadata
def print_gene_info_with_degree(top_genes_with_degrees, gene_info):
    """
    Prints gene information including the degree.

    Parameters:
    top_genes_with_degrees (list): List of tuples containing gene symbols and their degrees.
    gene_info (list): List of dictionaries containing gene information.
    """
    for gene, degree in top_genes_with_degrees:
        info = next((item for item in gene_info if item['query'] == gene), None)
        if info:
            print(f"Gene Symbol: {info.get('symbol', 'N/A')}")
            print(f"Degree: {degree}")
            print(f"Gene Name: {info.get('name', 'N/A')}")
            print(f"Entrez ID: {info.get('entrezgene', 'N/A')}")
            print(f"Summary: {info.get('summary', 'N/A')}")
            if 'disease' in info:
                diseases = ', '.join([d['term'] for d in info['disease']])
                print(f"Diseases: {diseases}")
            else:
                print("Diseases: N/A")
            if 'pathway' in info:
                pathways = []
                if isinstance(info['pathway'], dict):
                    for key in info['pathway']:
                        pathway_data = info['pathway'][key]
                        if isinstance(pathway_data, list):
                            pathways.extend([p['name'] for p in pathway_data if 'name' in p])
                        elif isinstance(pathway_data, dict) and 'name' in pathway_data:
                            pathways.append(pathway_data['name'])
                        elif isinstance(pathway_data, str):
                            pathways.append(pathway_data)
                print(f"Pathways: {', '.join(pathways) if pathways else 'N/A'}")
            else:
                print("Pathways: N/A")
            print("-" * 40)
        else:
            print(f"Gene not found: {gene}")
            print(f"Degree: {degree}")
            print("-" * 40)



In [None]:
# get the top 10 genes with the highest degree in the pruned graph using get_highest_degree_nodes
top_genes_with_degrees = get_highest_degree_nodes(pearson_graph_pruned, top_n=10)
gene_symbols = [gene for gene, degree in top_genes_with_degrees]

# get gene information with fetch_gene_info
gene_info = fetch_gene_info(gene_symbols)

# print gene information including degree
print_gene_info_with_degree(top_genes_with_degrees, gene_info)


In [None]:
# Save the pruned graph as a GML file
#nx.write_gml(pearson_graph_pruned, 
#             os.path.join(intermediate_data_dir,'gene_coexpression_network_workshop.gml'))