# Gene Ontology Graph Preprocessing
Author: Cleverson Matiolli, Ph.D.

**Objective:** This notebook focuses on preprocessing the **Gene Ontology Directed Acyclic graph (GO DAG)** for use in the Aid2GO Heterogeneous Graph Attention Network (Aid2GO-HAN).

**Key Steps:**
1. Download and parse the Gene Ontology (GO) directed acyclic graph (DAG)
2. Calculate Information Content (IC)
3. Extract relationships between GO terms
4. Extract nodes (GO terms) and nodes' features (***not implemented***)
5. Embed GO term definitions
6. Evaluate the quality of GO terms' embeddings
7. Build the GO Graph Data

In [None]:
# Standard libraries
import random
import pickle
from collections import Counter, defaultdict

# Third-party libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm

# Bioinformatics
import networkx as nx
from obonet import read_obo

# Machine learning 
import torch
from sklearn.utils import resample
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder
from hdbscan import HDBSCAN

# Custom libraries
import aid2go.go as aidgo
import aid2go.ia as aidia
import aid2go.utils as aidutils
 
# Configuration
pd.options.mode.copy_on_write = True

## 1. Download and parse the Gene Ontology (GO) directed acyclic graph (DAG)

Permanent link: http://purl.obolibrary.org/obo/go/go-basic.obo.

In [None]:
go_dag, go_df = aidgo.get_godag("./data/go/", timeout=5)

## 2. Calculate Information Content (IC)

Information accretion (IA), introduced in [Clark and Radivojac, 2013], is a measure of how much information is added to an ontology annotation by node $v$ given that its parents $Pa(v)$ are already annotated. Specifically,

$$
ia(v) = \log_2 \frac{1}{\Pr(v | Pa(v))} = \log_2 \frac{\Pr(Pa(v))}{\Pr(Pa(v) | v) \Pr(v)}
$$

>*De Paolis Kaluza, C. (2024). Information Accretion (Version 1.0.0) [Computer software]. GitHub. https://github.com/claradepaolis/InformationAccretion

### Filter GO DAG Edges ("is_a", "part_of")

- Before propagating annotations, we filter out all edges that are not in the scope of the evaluation of predicted annotations (*"regulates", "negatively_regulates", "positively-regulates"*, **BPO subontology**). CAFA evaluators consider only "is_a" and "part_of" edges for the GO DAG complete/consistent subgraph corresponding to the annotated GO term

Filter-out BPO exclusive relations , keeping only *"is_a"* and *"part_of"* relations.

In [None]:
go_dag_cleaned = aidia.clean_ontology_edges(go_dag)
print(f"# edges original GO DAG: {go_dag.number_of_edges()}")
print(f"# edges cleaned GO DAG: {go_dag_cleaned.number_of_edges()}")

### Load GOA Annotation File

The GOA high-confidence (hc) annotations encompasses only protein-GO term associations from experimental evidence.

In [None]:
# Load annotations file (GOA high-confidence)
annot_df = pd.read_csv("./data/goa/goa_hc_annot.tsv", sep="\t")
annot_df

### Propagate Annotations

Genes annotated under a GO term inherit the annotation from all its ancestors, with the annotation of the term corresponding to node $v$ necessarily implying the annotation of terms $\mathcal{Pa}(v)$ to guarantee a *consistent subgraph* (valid annotations).

- The propagation of annotations follow the direction of edges: *(head --("is_a", "part_of")--> tail)*, i.e., *children --("is_a", "part_of") --> parents*

In [None]:
# Get the three subontologies
roots = {"BPO": "GO:0008150", "CCO": "GO:0005575", "MFO": "GO:0003674"}
subontologies = {
    aspect: aidia.fetch_aspect(go_dag_cleaned, roots[aspect]) for aspect in roots
}

print("BPO -->", len(subontologies["BPO"]))
print("MFO -->", len(subontologies["MFO"]))
print("CCO -->", len(subontologies["CCO"]))

# Propagate terms
annot_df_prop = aidia.propagate_terms(annot_df, subontologies)

# Save propagated terms
annot_df_prop.to_csv(
    "./data/goa/goa_hc_annot_prop.tsv",
    # header=False,
    index=False,
    sep="\t",
)
print(f"annotations df propagated terms shape: {annot_df_prop.shape}")
annot_df_prop.head()

### Get Aspects and Terms Counts

In [None]:
# Count term instances
print("Counting Terms")
aspect_counts = dict()
aspect_terms = dict()
term_idx = dict()
for aspect, subont in subontologies.items():
    aspect_terms[aspect] = sorted(subont.nodes)  # ensure same order
    term_idx[aspect] = {t: i for i, t in enumerate(aspect_terms[aspect])}
    aspect_counts[aspect] = aidia.term_counts(
        annot_df_prop[annot_df_prop.aspect == aspect], term_idx[aspect]
    )

    assert aspect_counts[aspect].sum() == len(
        annot_df_prop[annot_df_prop.aspect == aspect]
    ) + len(aspect_terms[aspect])
    
print(f"Length: Aspect Counts -> {len(aspect_counts)}")
print(f"Length: Aspect Terms -> {len(aspect_terms)}")

In [None]:
# Since we are indexing by column to compute IA, 
# let's convert to Compressed Sparse Column format
sp_matrix = {aspect:dok.tocsc() for aspect, dok in aspect_counts.items()}

# Compute IA
print('Computing Information Accretion')
aspect_ia = {aspect: {t:0 for t in aspect_terms[aspect]} for aspect in aspect_terms.keys()}
for aspect, subontology in subontologies.items():
    for term in aspect_ia[aspect].keys():
        aspect_ia[aspect][term] = aidia.calc_ia(term, sp_matrix[aspect], subontology, term_idx[aspect])

ia_df = pd.concat([pd.DataFrame.from_dict(
    {'term':aspect_ia[aspect].keys(), 
        'ia': aspect_ia[aspect].values(), 
        'aspect': aspect}) for aspect in subontologies.keys()])

# All counts should be non-negative
assert ia_df['ia'].min() >= 0

# Save to file
ia_df[['term','ia']].to_csv("./data/go/ia.txt",  header=None, sep='\t', index=False)
print(ia_df.shape)
ia_df.head()

### Map IC Values to GO Terms

In [None]:
# Load annotations file
go_df = pd.read_csv("./data/go/go-basic.csv")

# Merge annotations and IC values
go_df = pd.merge(go_df, ia_df, left_on="go_id", right_on="term", how="left")
go_df.drop(columns=["namespace"], inplace=True)  # Drop redundant subontology column

# Reorder columns for easy visualization
go_df = go_df[
    [
            "go_id",
            "name",
            "aspect",
            "definition",
            "def_word_count",
            "in_degree",
            "out_degree",
            "degree",
            "term",
            "ia",
    ]
]

# Save updated GO df
go_df.reset_index(drop=True, inplace=True)
go_df.to_csv("./data/go/go-basic.csv", index=False)
print(go_df.shape)
go_df.head()

## 3. Extract Relationships Between GO Terms

In the Gene Ontology (GO) graph, relations are encoded using five types: "is_a" for subclass relationships, "part_of" for part-whole relationships, "regulates," "positively_regulates," and "negatively_regulates" for various types of regulation between terms (https://wiki.geneontology.org/index.php/Relation_composition).

1. Main relationships (BPO, MFO, CCO): "is_a", "part_of" **(used to calculate IC**)
2. Other relationships (BPO only): "regulates", "positively_regulates", "negatively_regulates"

In [None]:
go_edges_df = aidgo.get_go_relations(go_dag, "./data/go")

In [None]:
go_edges_df_cleaned = aidgo.get_go_relations(go_dag_cleaned, "./data/go")

## 4. Extract Nodes/Edges Features

1. Node degree
2. Graph Centrality measures
3. Information Content (IC)
4. GO term namespace (BPO, MFO, CCO)
5. Local subgraph properties (clustering coefficients, graphlets)
6. Hierarchical position (ontological structure)


## 5. Embed GO Terms' Definitions

In [None]:
# Generate embeddings of GO definitions

# Create directory to save embeddings
save_path = Path("./data/go")
save_path.mkdir(parents=True, exist_ok=True)

go_embed_dict, id_text_mapping = aidgo.embed_texts(
    df=go_df,
    column_id="go_id",
    column_text="definition",
    batch_size=8,
    pre_trained_model="dmis-lab/biobert-v1.1",
    save_path=save_path,
)

print(f"Length of GO embeddings dict: {len(go_embed_dict)}")
print(f"Length of ID mapping dict: {len(id_text_mapping)}")

## 6. Evaluate the quality of GO terms' embeddings

#### Visualize Embeddings of Definitions

In [None]:
# Load GO embeddings
filepath =  "./data/go/go_emb_dict-definition.pkl"
with open(filepath, "rb") as file:
    go_embed_dict_def = pickle.load(file)

print(f"Length: {len(go_embed_dict_def)}")

##### Extract a small subset of GO embeddings

In [None]:
# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

# Extract GO terms and embeddings
go_terms = np.array(list(go_embed_dict_def.keys()))
embeddings = np.array(list(go_embed_dict_def.values()))

# Sample embeddings
ids_sample, embeddings_sample = resample(
    go_terms,
    embeddings,
    n_samples=5000,
    replace=False,
    random_state=42,
)

# Create term -> aspect mapping
for aspect, subont in subontologies.items():
    aspect_terms[aspect] = sorted(subont.nodes)  # ensure same order

aspect_terms = {aspect: sorted(subont.nodes) for aspect, subont in subontologies.items()}

print(f"sample ids shape: {ids_sample.shape}, embeddings shape: {embeddings_sample.shape},")
print(f"aspect terms dict keys: {aspect_terms.keys()}")

##### Reduce dimensionality using PCA

In [None]:
pca = PCA(n_components=8, random_state=42)
reduced_embeddings = pca.fit_transform(embeddings_sample)

print(f"Reduced embeddings shape: {reduced_embeddings.shape}")

##### Estimate Number of Clusters

In [None]:
# Sampled embeddings
num_of_neighbors = [5, 10, 25, 50, 75, 100]

# Path to save plots
save_path = Path("./eda/go/hdbscan")
save_path.mkdir(parents=True, exist_ok=True)

results = aidgo.analyze_clusters(
    reduced_embeddings,
    ids_sample,
    num_of_neighbors,
    save_path,
    0.5,
    show_plot=False,
)

### Plot Clusters and Silhouette Scores

##### Inspect Namespaces Frequencies in Clusters

In [None]:
def map_go_terms_to_namespace(go_terms, aspect_terms):
    """
    Map GO terms to their corresponding namespaces using the aspect_terms dictionary.

    :param go_terms: List of GO terms.
    :param aspect_terms: Dictionary with keys as namespaces ('BPO', 'MFO', 'CCO')
                         and values as sets of GO terms corresponding to each namespace.
    :return: A numpy array where each element is the namespace ('BP', 'MF', 'CC')
             corresponding to the GO term.
    """
    namespace_data = []
    for go_term in go_terms:
        if go_term in aspect_terms["BPO"]:
            namespace_data.append("BP")  # Biological Process
        elif go_term in aspect_terms["MFO"]:
            namespace_data.append("MF")  # Molecular Function
        elif go_term in aspect_terms["CCO"]:
            namespace_data.append("CC")  # Cellular Component
        else:
            namespace_data.append("Unknown")  # For terms not found

    return np.array(namespace_data)


def plot_combined_bar_charts(
    clustered_counts, noisy_counts, stacked_data, clusters, namespaces, color_map
):
    """
    Generate a plot with three subplots: stacked bar chart for namespaces in clusters,
    and two bar charts for clustered and noisy namespace frequencies. Annotate bars with counts.

    :param clustered_counts: Counter object with frequencies of namespaces in clustered data.
    :param noisy_counts: Counter object with frequencies of namespaces in noisy data.
    :param stacked_data: Data for the stacked bar chart.
    :param clusters: Array of unique clusters.
    :param namespaces: List of namespaces ('BP', 'MF', 'CC').
    :param color_map: Dictionary mapping namespaces to specific colors.
    """
    fig, ax = plt.subplots(1, 3, figsize=(18, 6))
    ax = ax.flatten()

    # Stacked bar chart for namespaces in each cluster
    bottom = np.zeros(len(clusters))
    for i, ns in enumerate(namespaces):
        bars = ax[0].bar(
            clusters, stacked_data[:, i], bottom=bottom, label=ns, color=color_map[ns]
        )
        bottom += stacked_data[:, i]

        # # Annotate with counts
        # for j, bar in enumerate(bars):
        #     height = bar.get_height()
        #     ax[0].annotate(
        #         f"{int(height)}",
        #         xy=(bar.get_x() + bar.get_width() / 2, bottom[j] - height / 2),
        #         xytext=(0, 3),  # 3 points vertical offset
        #         textcoords="offset points",
        #         ha="center",
        #         va="bottom",
        #     )

    ax[0].set_xlabel("Cluster")
    ax[0].set_ylabel("Counts")
    ax[0].set_title("Counts of GO Namespaces in Each Cluster")
    ax[0].legend(title="Namespace")

    # Bar chart for clustered namespace frequencies
    ax[1].bar(
        clustered_counts.keys(),
        clustered_counts.values(),
        color=[color_map[ns] for ns in clustered_counts.keys()],
    )
    ax[1].set_title("Clustered Namespaces Frequencies")
    ax[1].set_ylabel("Count")

    # Annotate clustered bars with counts
    for i, (label, count) in enumerate(clustered_counts.items()):
        ax[1].text(i, count, str(count), ha="center", va="bottom")

    # Bar chart for noisy namespace frequencies
    ax[2].bar(
        noisy_counts.keys(),
        noisy_counts.values(),
        color=[color_map[ns] for ns in noisy_counts.keys()],
    )
    ax[2].set_title("Noisy Namespaces Frequencies")
    ax[2].set_ylabel("Count")

    # Annotate noisy bars with counts
    for i, (label, count) in enumerate(noisy_counts.items()):
        ax[2].text(i, count, str(count), ha="center", va="bottom")

    plt.tight_layout()
    plt.savefig("./eda/go/hdbscan/combined_clustered_namespaces_counts.png")
    plt.show()


# Fit HDBSCAN clusterer
clusterer = HDBSCAN(min_cluster_size=50)
cluster_labels = clusterer.fit_predict(reduced_embeddings)

# Filter noisy data points
filtered_indices = cluster_labels != -1
filtered_cluster_labels = cluster_labels[filtered_indices]

print(
    f"Unique clusters ({len(np.unique(cluster_labels))}): {np.unique(cluster_labels)}"
)

# Assuming go_terms (or ids_sample) is the array of GO terms corresponding to the embeddings
namespace_data = map_go_terms_to_namespace(ids_sample, aspect_terms)

# Separate clustered and noisy data
filtered_namespace_data = namespace_data[filtered_indices]
noisy_namespace_data = namespace_data[~filtered_indices]

# Count namespace frequencies for clustered and noisy data
clustered_counts = Counter(filtered_namespace_data)
noisy_counts = Counter(noisy_namespace_data)

# Combine filtered cluster labels with namespaces
cluster_namespace_pairs = [
    (cluster, ns)
    for cluster, ns in zip(filtered_cluster_labels, filtered_namespace_data)
]

# Count occurrences for each cluster and namespace
cluster_namespace_dict = defaultdict(lambda: defaultdict(int))
for cluster, ns in cluster_namespace_pairs:
    cluster_namespace_dict[cluster][ns] += 1

# Get unique clusters and namespaces
clusters = np.unique(filtered_cluster_labels)
namespaces = [
    "BP",
    "MF",
    "CC",
]  # Biological Process, Molecular Function, Cellular Component

# Define consistent colors for each namespace
color_map = {"BP": "green", "MF": "blue", "CC": "orange"}

# Prepare the data for the stacked bar plot
stacked_data = np.array(
    [[cluster_namespace_dict[cluster][ns] for ns in namespaces] for cluster in clusters]
)

# Plot the combined chart with three subplots
plot_combined_bar_charts(
    clustered_counts, noisy_counts, stacked_data, clusters, namespaces, color_map
)


In [None]:
def inspect_misclassified_go_terms_df(
    cluster_labels, go_terms, namespaces, aspect_terms
):
    """
    Inspect misclassified GO terms by comparing their cluster assignments to their true namespaces.

    :param cluster_labels: Array of assigned cluster labels for each GO term.
    :param go_terms: List of GO terms corresponding to the embeddings.
    :param namespaces: List of possible namespaces ('BP', 'MF', 'CC').
    :param aspect_terms: Dictionary mapping namespaces to their GO terms.
    :return: A DataFrame of misclassified GO terms.
    """
    misclassified_data = []

    # Map each GO term to its true namespace
    true_namespaces = []
    for go_term in go_terms:
        if go_term in aspect_terms["BPO"]:
            true_namespaces.append("BP")
        elif go_term in aspect_terms["MFO"]:
            true_namespaces.append("MF")
        elif go_term in aspect_terms["CCO"]:
            true_namespaces.append("CC")
        else:
            true_namespaces.append("Unknown")  # For terms that are not found

    # Loop over each unique cluster to inspect terms
    for cluster in np.unique(cluster_labels):
        print(f"\nInspecting Cluster {cluster}...")

        # Extract GO terms in the current cluster
        cluster_indices = np.where(cluster_labels == cluster)[0]
        cluster_go_terms = [go_terms[i] for i in cluster_indices]
        cluster_namespaces = [true_namespaces[i] for i in cluster_indices]

        # Count occurrences of each namespace in the current cluster
        namespace_counts = Counter(cluster_namespaces)
        print(f"Namespace distribution in Cluster {cluster}: {namespace_counts}")

        # Determine the majority namespace in the cluster
        majority_namespace = max(namespace_counts, key=namespace_counts.get)

        # Find misclassified GO terms (terms that do not belong to the majority namespace)
        for i, term in enumerate(cluster_go_terms):
            if cluster_namespaces[i] != majority_namespace:
                misclassified_data.append(
                    {
                        "GO_term": term,
                        "True_Namespace": cluster_namespaces[i],
                        "Assigned_Cluster": cluster,
                        "Majority_Namespace": majority_namespace,
                    }
                )

    # Convert the misclassified data to a DataFrame
    misclassified_df = pd.DataFrame(misclassified_data)
    return misclassified_df


# Call the function to inspect misclassified GO terms and get the DataFrame
misclassified_go_terms_df = inspect_misclassified_go_terms_df(
    cluster_labels=cluster_labels,
    go_terms=ids_sample,  # Assuming ids_sample corresponds to GO terms
    namespaces=["BP", "MF", "CC"],
    aspect_terms=aspect_terms,  # Dictionary mapping namespaces to GO terms
)

# Print the DataFrame of misclassified GO terms
print("\nMisclassified GO terms DataFrame:")
misclassified_go_terms_df

In [None]:
go_df[go_df["go_id"].isin(misclassified_go_terms_df["GO_term"])].sort_values(by="definition")

##### TSNE

In [None]:
# Define a range of perplexities and lr (default is "auto")
perplexities = [5, 10, 25, 50, 75, 100, 125, 150, 200]
# learning_rate = 10

# store embeddings
tsne_dict = {}

# Perform t-SNE for each perplexity value
for perplexity in tqdm(perplexities, "Performing t-SNE", len(perplexities)):
    tsne = TSNE(
        n_components=2,
        perplexity=perplexity,
        random_state=42,
        max_iter=300,
        # learning_rate=learning_rate,
    )
    tsne_dict[perplexity] = tsne.fit_transform(reduced_embeddings)

# Save plots
save_path = Path("./eda/go/tsne")
save_path.mkdir(parents=True, exist_ok=True)

# Save T-SNE embeddings
with open(save_path / "tsne2D_dict.pkl", "wb") as file:
    pickle.dump(tsne_dict, file, protocol=pickle.HIGHEST_PROTOCOL)

print(tsne_dict.keys())

In [None]:
# Encode namespace labels
le = LabelEncoder()
encoded_labels = le.fit_transform(namespace_data)

# Create subplots for each perplexity value
nrows, ncols = aidutils.square_layout(len(perplexities))
fig, ax = plt.subplots(nrows, ncols, figsize=(24, 15))

for i, perplexity in enumerate(perplexities):
    tsne_results = tsne_dict[perplexity]

    scatter = ax[i // 3, i % 3].scatter(
        tsne_results[:, 0],
        tsne_results[:, 1],
        c=encoded_labels,
        cmap="viridis",
        alpha=0.7,
    )

    ax[i // 3, i % 3].set_title(f"t-SNE Perplexity: {perplexity}")

# Create a legend
handles, _ = scatter.legend_elements(prop="colors")
labels = le.classes_
fig.legend(
    handles,
    labels,
    title="GO Namespaces",
    loc="upper right",
    fontsize=14,
)
plt.savefig(save_path / "tsne_namespaces.png", dpi=300)
plt.show()

In [None]:
# Encode namespace labels
le = LabelEncoder()
encoded_labels = le.fit_transform(cluster_labels)

# Create subplots for each perplexity value
nrows, ncols = aidutils.square_layout(len(perplexities))
fig, ax = plt.subplots(nrows, ncols, figsize=(24, 15))

for i, perplexity in enumerate(perplexities):
    tsne_results = tsne_dict[perplexity]

    scatter = ax[i // 3, i % 3].scatter(
        tsne_results[:, 0],
        tsne_results[:, 1],
        c=clusterer.labels_,
        cmap="viridis",
        alpha=0.7,
    )

    ax[i // 3, i % 3].set_title(f"t-SNE Perplexity: {perplexity}")

# Create a legend
handles, _ = scatter.legend_elements(prop="colors")
labels = le.classes_
fig.legend(
    handles,
    labels,
    title="HDBSCAN Clusters",
    loc="upper right",
    fontsize=14,
)

plt.savefig(save_path / "tsne_cluster_labels.png", dpi=300)
plt.show()

## 7. Build the GO Graph Data

In [None]:
# Load GO node embeddings
save_path = Path("./data/go")
filepath = save_path / "go_emb_dict-definition.pkl"

with open(filepath, "rb") as file:
    go_embed_dict = pickle.load(file)

print(f"GO embeddings dict length: {len(go_embed_dict)}")
for key, value in go_embed_dict.items():
    print(
        f"GO id: {key}, embedding ({value.dtype}): {value[:10]}, shape: {value.shape}"
    )
    break

go_ids = list(go_embed_dict.keys())
print(f"Length GO ids list: {len(go_ids)}")

In [None]:
# Load edges

go_edges_df = pd.read_csv("./data/go/go_edges.tsv", sep="\t")
go_edges_df

In [None]:
save_path = Path("./data/go/")

# Heterogeneous
heterodata = aidgo.create_go_data(
    go_edges_df=go_edges_df,
    go_embed_dict=go_embed_dict,
    save_path=save_path,
    multi_edge=True,
)

# Homogeneous
data = aidgo.create_go_data(
    go_edges_df=go_edges_df,
    go_embed_dict=go_embed_dict,
    save_path=save_path,
    multi_edge=False
)