# Aid2GO Heterogeneous Graph Attention Network (HAN) Training
Author: Cleverson Matiolli, Ph.D.

**Implement**
1. F-measure with weighted precision and recall ($weight = ic(term)$)
2. Differentiate between different GO edges (relationships)
3. Skip-connections to mitigate loss of hierarchical information
4. Embeddings of relation definition as edge features to model relations of GO definitions accounting for how the definitions of two connected nodes are related (maybe...)

**Debug**
1. Model predict the same GO term many times for a given protein.

In [None]:
# Standard library imports
import os
import warnings
import json
from pathlib import Path

# Third-party imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm.auto import tqdm
import wandb

# Machine Learning and Deep Learning
import torch
import torch.nn.functional as F
from torch.nn import Linear, Parameter
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import degree
from torch_geometric.data import Data, HeteroData
from torch_geometric.nn import GATConv
from torch_geometric.utils import (
    to_undirected,
    structured_negative_sampling_feasible,
    remove_self_loops,
    add_self_loops,
    softmax,
    degree,
)
from torch_geometric.data import HeteroData
from torch_geometric.nn.inits import glorot, zeros
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import roc_auc_score, classification_report, ConfusionMatrixDisplay

# Disable warnings
warnings.filterwarnings("ignore")

# Set plot params
plt.rcParams["axes.grid"] = False

# Set wandb notebook name environment variable
os.environ["WANDB_NOTEBOOK_NAME"] = "model.ipynb"

# Set CUDA and PyTorch environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["TORCH"] = torch.__version__

# Check PyTorch and CUDA
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Device: {device}")

# Define folder paths
base_dir = Path(Path.cwd())

output_dir = base_dir / "outputs"  # Model outputs
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Base directory: {base_dir}")

## Custom Classes

### Protein-GO Dataloader Class

In [2]:
class ProteinGoDataloader:
    def __init__(self, protein_ids_map, go_ids_map, ppi_folder):
        """
        Initialize the ProteinGODataset class.

        Parameters:
        protein_ids_map (dict): A dictionary mapping protein IDs to their indices.
        go_ids_map (dict): A dictionary mapping GO IDs to their indices.
        ppi_folder (str): The path to the folder containing the species-specific data.

        Attributes:
        protein_ids_map (dict): A dictionary mapping protein IDs to their indices.
        go_ids_map (dict): A dictionary mapping GO IDs to their indices.
        ppi_folder (Path): The path to the folder containing the species-specific data.
        edge_index_protgo_positive (torch.Tensor): A tensor representing the positive edges between proteins and GO terms.
        edge_index_protgo_negative (torch.Tensor): A tensor representing the negative edges between proteins and GO terms.
        associations_positive_df (pandas.DataFrame): A DataFrame containing the positive associations between proteins and GO terms.
        associations_negative_df (pandas.DataFrame): A DataFrame containing the negative associations between proteins and GO terms.
        associations_combined_df (pandas.DataFrame): A DataFrame containing both positive and negative associations between proteins and GO terms.
        hetero_data (torch_geometric.data.HeteroData): A heterogeneous graph data object.
        train_mask (torch.Tensor): A tensor representing the mask for the training set.
        val_mask (torch.Tensor): A tensor representing the mask for the validation set.
        test_mask (torch.Tensor): A tensor representing the mask for the test set.
        train_protein_ids (list): A list of protein IDs for the training set.
        val_protein_ids (list): A list of protein IDs for the validation set.
        test_protein_ids (list): A list of protein IDs for the test set.
        train_go_ids (list): A list of GO IDs for the training set.
        val_go_ids (list): A list of GO IDs for the validation set.
        test_go_ids (list): A list of GO IDs for the test set.
        """
        self.protein_ids_map = protein_ids_map
        self.go_ids_map = go_ids_map
        self.ppi_folder = Path(ppi_folder)
        self.edge_index_protgo_positive = None
        self.edge_index_protgo_negative = None
        self.associations_positive_df = None
        self.associations_negative_df = None
        self.associations_combined_df = None
        self.hetero_data = None
        self.train_mask = None
        self.val_mask = None
        self.test_mask = None
        self.train_protein_ids = None
        self.val_protein_ids = None
        self.test_protein_ids = None
        self.train_go_ids = None
        self.val_go_ids = None
        self.test_go_ids = None

    def convert_associations_to_edge_index(self, associations_df):
        """
        Convert the associations DataFrame into edge indices.

        Parameters:
        associations_df (pandas.DataFrame): A DataFrame containing the associations between proteins and GO terms.

        Returns:
        torch.Tensor: A tensor representing the edge indices for the associations.
        """
        # Map go ids
        associations_df["go_idx"] = associations_df["go_id"].map(self.go_ids_map)

        # Map proteins ids
        associations_df["uniprot_idx"] = associations_df["uniprot_id"].map(
            self.protein_ids_map
        )

        # Save the positive associations
        self.associations_positive_df = associations_df

        # Create edge index
        protein_index = associations_df["uniprot_idx"].values
        go_index = associations_df["go_idx"].values
        indices = np.stack((protein_index, go_index))
        edge_index = torch.tensor(indices, dtype=torch.long)

        self.edge_index_protgo_positive = edge_index
        return edge_index

    def generate_negative_edges(
        self,
        edge_index,
        num_nodes_protein,
        num_nodes_go,
        num_neg_samples,
        weighted=False,
        random_state=42,
        max_attempts_factor=10,
    ):
        """
        Generate negative edges for protein-GO associations.

        Parameters:
        - edge_index (torch.Tensor): The edge index tensor representing existing protein-GO associations.
        - num_nodes_protein (int): The number of protein nodes.
        - num_nodes_go (int): The number of GO nodes.
        - num_neg_samples (int): The number of negative samples to generate.
        - weighted (bool, optional): Whether to use weighted sampling for protein nodes. Default is False.
        - random_state (int, optional): The random seed for reproducibility. Default is 42.
        - max_attempts_factor (int, optional): The maximum number of attempts to generate a valid negative sample. Default is 10.

        Returns:
        torch.Tensor: The tensor representing the generated negative protein-GO associations.
        pandas.DataFrame: The DataFrame containing the generated negative protein-GO associations.
        """
        rng = np.random.default_rng(random_state)
        existing_edges = set(map(tuple, edge_index.t().tolist()))

        if weighted:
            protein_degrees = np.zeros(num_nodes_protein, dtype=int)
            for protein_node, go_node in edge_index.t().tolist():
                protein_degrees[protein_node] += 1

        negative_samples = []
        attempts = 0
        max_attempts = num_neg_samples * max_attempts_factor
        neg_degree_count = np.zeros(num_nodes_protein, dtype=int) if weighted else None

        while len(negative_samples) < num_neg_samples and attempts < max_attempts:
            if weighted:
                protein_node = rng.choice(
                    num_nodes_protein, p=protein_degrees / protein_degrees.sum()
                )
            else:
                protein_node = rng.integers(0, num_nodes_protein)

            go_node = rng.integers(0, num_nodes_go)

            if (protein_node, go_node) not in existing_edges:
                negative_samples.append([protein_node, go_node])
                if weighted:
                    neg_degree_count[protein_node] += 1

            attempts += 1

        neg_sample_tensor = (
            torch.tensor(negative_samples, dtype=torch.long).t()
            if negative_samples
            else torch.empty((2, 0), dtype=torch.long)
        )

        protein_mapping = {v: k for k, v in self.protein_ids_map.items()}
        go_mapping = {v: k for k, v in self.go_ids_map.items()}

        mapped_negative_samples = [
            {
                "uniprot_id": protein_mapping.get(u_idx, None),
                "go_id": go_mapping.get(g_idx, None),
                "uniprot_idx": u_idx,
                "go_idx": g_idx,
            }
            for u_idx, g_idx in negative_samples
        ]

        self.associations_negative_df = pd.DataFrame(mapped_negative_samples)
        self.edge_index_protgo_negative = neg_sample_tensor

        # Ensure no overlap between positive and negative edges
        if not self.check_edge_overlap():
            raise ValueError("Overlap found between positive and negative edges.")

    def combine_associations(self):
        """
        Combine the positive and negative associations DataFrames.

        Returns:
        pandas.DataFrame: A DataFrame containing both positive and negative associations.
        """
        if (
            self.associations_positive_df is None
            or self.associations_negative_df is None
        ):
            raise ValueError(
                "Both positive and negative associations must be generated before combining."
            )

        # Add a label column to differentiate positive and negative associations
        self.associations_positive_df["label"] = 1
        self.associations_negative_df["label"] = 0

        # Combine the DataFrames
        self.associations_combined_df = pd.concat(
            [self.associations_positive_df, self.associations_negative_df],
            ignore_index=True,
        )

        return self.associations_combined_df

    def check_balance(
        self, positive_edge_index, negative_edge_index, num_nodes_protein, num_nodes_go
    ):
        """
        This function checks the degree distribution of proteins and GO terms in the positive and negative edge sets.
        It plots the distributions using KDE plots and saves the plot as a PNG file.

        Parameters:
        - positive_edge_index (torch.Tensor): A tensor representing the edge indices of positive edges.
        - negative_edge_index (torch.Tensor): A tensor representing the edge indices of negative edges.
        - num_nodes_protein (int): The number of protein nodes in the graph.
        - num_nodes_go (int): The number of GO terms in the graph.

        Returns:
        - None. The function only plots and saves the degree distribution plot.
        """
        protein_degrees_pos = np.zeros(num_nodes_protein, dtype=int)
        go_degrees_pos = np.zeros(num_nodes_go, dtype=int)
        for protein_node, go_node in positive_edge_index.t().tolist():
            protein_degrees_pos[protein_node] += 1
            go_degrees_pos[go_node] += 1

        protein_degrees_neg = np.zeros(num_nodes_protein, dtype=int)
        go_degrees_neg = np.zeros(num_nodes_go, dtype=int)
        for protein_node, go_node in negative_edge_index.t().tolist():
            protein_degrees_neg[protein_node] += 1
            go_degrees_neg[go_node] += 1

        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        sns.kdeplot(
            protein_degrees_pos, color="blue", label="Positive", fill=True, alpha=0.5
        )
        sns.kdeplot(
            protein_degrees_neg, color="red", label="Negative", fill=True, alpha=0.5
        )
        plt.title("Protein Degree Distribution")
        plt.xlabel("Degree")
        plt.ylabel("Density")
        plt.legend()

        plt.subplot(1, 2, 2)
        sns.kdeplot(
            go_degrees_pos, color="blue", label="Positive", fill=True, alpha=0.5
        )
        sns.kdeplot(go_degrees_neg, color="red", label="Negative", fill=True, alpha=0.5)
        plt.title("GO Terms Degree Distribution")
        plt.xlabel("Degree")
        plt.ylabel("Density")
        plt.legend()

        plt.tight_layout()
        plt.savefig(self.ppi_folder / "balance_check.png", dpi=300)
        plt.show()

    ########################################################################
    def create_heterodata(
        self,
        protein_feats,
        go_feats,
        protein_ids,
        go_ids,
        edge_index_ppi,
        go_edges_df,
        edge_index_protgo_positive,
    ):
        """
        Creates a HeteroData object by combining protein and GO features,
        protein and GO IDs, and protein-protein interaction (PPI) and GO relation edge indices.

        Parameters:
        protein_feats (torch.Tensor): A tensor containing protein features.
        go_feats (torch.Tensor): A tensor containing GO term features.
        protein_ids (list): A list of protein IDs.
        go_ids (list): A list of GO term IDs.
        edge_index_ppi (torch.Tensor): A tensor containing edge indices for PPI.
        go_edges_df (pd.DataFrame): DataFrame containing GO term relationships and edge types.
        edge_index_protgo_positive (torch.Tensor): A tensor containing positive edge indices for protein-GO associations.

        Returns:
        None
        """
        # Make PPI graph undirected
        edge_index_ppi = to_undirected(edge_index_ppi)

        # Concatenate positive and negative edge indices
        edge_index_protgo = torch.cat(
            (edge_index_protgo_positive, self.edge_index_protgo_negative), dim=1
        )

        # Create positive and negative labels
        positive_labels = torch.ones(
            edge_index_protgo_positive.shape[1], dtype=torch.float32
        )
        negative_labels = torch.zeros(
            self.edge_index_protgo_negative.shape[1], dtype=torch.float32
        )
        edge_labels = torch.cat((positive_labels, negative_labels), dim=0)

        # Create a HeteroData object
        hetero_data = HeteroData()
        hetero_data["protein"].x = protein_feats
        hetero_data["go"].x = go_feats
        hetero_data["protein"].id = protein_ids
        hetero_data["go"].id = go_ids
        hetero_data["protein", "ppi", "protein"].edge_index = edge_index_ppi

        # Handle GO edges
        go_ids_map = {go_id: idx for idx, go_id in enumerate(go_ids)}
        relationship_types = go_edges_df["relationship"].unique()

        for relationship in relationship_types:
            edges = go_edges_df[go_edges_df["relationship"] == relationship]
            source_nodes = [go_ids_map[go_id] for go_id in edges["source_go_id"]]
            target_nodes = [go_ids_map[go_id] for go_id in edges["target_go_id"]]
            edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)
            hetero_data["go", relationship, "go"].edge_index = edge_index

        # Add protein-GO associations
        hetero_data["protein", "associates_with", "go"].edge_index = edge_index_protgo
        hetero_data["protein", "associates_with", "go"].edge_label = edge_labels

        # Assign the HeteroData object to the class attribute
        self.hetero_data = hetero_data

        # Output information about the GO graph structure
        print("Heterogeneous Graph Information:")
        print(f"Number of GO nodes: {hetero_data['go'].x.size(0)}")
        for relationship in relationship_types:
            num_edges = hetero_data["go", relationship, "go"].edge_index.size(1)
            print(f"GO Relationship: {relationship}, Number of edges: {num_edges}")

    def split_dataset(self, train_ratio=0.6):
        """
        Splits the dataset into training, validation, and testing sets based on the given train ratio.

        Parameters:
        train_ratio (float, optional): The ratio of the dataset to be used for training. Default is 0.6.

        Returns:
        None
        """
        val_ratio = (1 - train_ratio) / 2
        test_ratio = val_ratio
        assert isinstance(train_ratio, float), "Train ratio must be a float."
        assert 0 < train_ratio < 1, "Train ratio must be between 0 and 1."

        unique_proteins, edge_positions = torch.unique(
            self.hetero_data["protein", "associates_with", "go"].edge_index[0],
            return_inverse=True,
        )
        permuted_indices = torch.randperm(unique_proteins.size(0))
        num_proteins = unique_proteins.size(0)
        train_end = int(num_proteins * train_ratio)
        val_end = int(num_proteins * (train_ratio + val_ratio))

        train_proteins = permuted_indices[:train_end]
        val_proteins = permuted_indices[train_end:val_end]
        test_proteins = permuted_indices[val_end:]

        train_mask = edge_positions.unsqueeze(0).eq(train_proteins.unsqueeze(1)).any(0)
        val_mask = edge_positions.unsqueeze(0).eq(val_proteins.unsqueeze(1)).any(0)
        test_mask = edge_positions.unsqueeze(0).eq(test_proteins.unsqueeze(1)).any(0)

        self.hetero_data["protein", "associates_with", "go"]["train_mask"] = train_mask
        self.hetero_data["protein", "associates_with", "go"]["val_mask"] = val_mask
        self.hetero_data["protein", "associates_with", "go"]["test_mask"] = test_mask

        self.train_mask = train_mask
        self.val_mask = val_mask
        self.test_mask = test_mask

        protein_mapping = {v: k for k, v in self.protein_ids_map.items()}
        go_mapping = {v: k for k, v in self.go_ids_map.items()}

        self.train_protein_ids = [protein_mapping[idx.item()] for idx in train_proteins]
        self.val_protein_ids = [protein_mapping[idx.item()] for idx in val_proteins]
        self.test_protein_ids = [protein_mapping[idx.item()] for idx in test_proteins]

        train_go_indices = (
            self.hetero_data["protein", "associates_with", "go"]
            .edge_index[1, train_mask]
            .tolist()
        )
        val_go_indices = (
            self.hetero_data["protein", "associates_with", "go"]
            .edge_index[1, val_mask]
            .tolist()
        )
        test_go_indices = (
            self.hetero_data["protein", "associates_with", "go"]
            .edge_index[1, test_mask]
            .tolist()
        )

        self.train_go_ids = [go_mapping[idx] for idx in train_go_indices]
        self.val_go_ids = [go_mapping[idx] for idx in val_go_indices]
        self.test_go_ids = [go_mapping[idx] for idx in test_go_indices]

    def plot_go_edge_type_ratios(self):
        """
        Plots the ratios of different edge types in the GO graph part of the heterodata.
        """
        go_edge_types = []
        edge_counts = []

        for edge_type in self.hetero_data.edge_types:
            if edge_type[0] == "go" and edge_type[2] == "go":
                go_edge_types.append(edge_type[1])
                edge_store = self.hetero_data[edge_type]
                edge_counts.append(edge_store.edge_index.size(1))

        if not go_edge_types:
            print("No GO-GO edges found in the heterodata.")
            return

        total_edges = sum(edge_counts)
        edge_ratios = [count / total_edges for count in edge_counts]

        plt.figure(figsize=(10, 6))
        bars = plt.bar(go_edge_types, edge_ratios)
        plt.title("GO Edge Type Ratios")
        plt.xlabel("Edge Types")
        plt.ylabel("Ratio")
        plt.xticks(rotation=45, ha="right")

        # Add value labels on top of each bar
        for bar in bars:
            height = bar.get_height()
            plt.text(
                bar.get_x() + bar.get_width() / 2.0,
                height,
                f"{height:.2f}",
                ha="center",
                va="bottom",
            )

        # Add total edge count as text
        plt.text(
            0.05, 0.95, f"Total edges: {total_edges}", transform=plt.gca().transAxes
        )

        plt.tight_layout()
        plt.savefig(
            self.ppi_folder / "go_edge_type_ratios.png", dpi=300, bbox_inches="tight"
        )
        plt.close()

        print("GO Edge Type Ratios:")
        for edge_type, ratio, count in zip(go_edge_types, edge_ratios, edge_counts):
            print(f"{edge_type}: {ratio:.2f} ({count} edges)")

    def to_device(self, device):
        """
        Moves the HeteroData object and the masks to the specified device.

        Parameters:
        device (torch.device): The device to move the data to.

        Returns:
        None
        """
        self.hetero_data = self.hetero_data.to(device)
        self.train_mask = self.train_mask.to(device)
        self.val_mask = self.val_mask.to(device)
        self.test_mask = self.test_mask.to(device)

    def create_data_loaders(self, batch_size=32):
        """
        Creates data loaders for training, validation, and testing sets.

        Parameters:
        batch_size (int, optional): The batch size for the data loaders. Default is 32.

        Returns:
        None
        """
        datasets = {}
        masks = {"train": self.train_mask, "val": self.val_mask, "test": self.test_mask}

        for split in masks:
            edge_index = self.hetero_data[
                "protein", "associates_with", "go"
            ].edge_index[:, masks[split]]
            edge_label = self.hetero_data[
                "protein", "associates_with", "go"
            ].edge_label[masks[split]]
            datasets[split] = TensorDataset(edge_index.t(), edge_label)

        self.train_loader = DataLoader(
            datasets["train"], batch_size=batch_size, shuffle=True
        )
        self.val_loader = DataLoader(
            datasets["val"], batch_size=batch_size, shuffle=False
        )
        self.test_loader = DataLoader(
            datasets["test"], batch_size=batch_size, shuffle=False
        )

    def plot_label_ratios(self):
        """
        Plots the label ratios in the training, validation, and testing sets.

        Parameters:
        None

        Returns:
        None
        """
        edge_index = self.hetero_data["protein", "associates_with", "go"].edge_index
        edge_labels = self.hetero_data["protein", "associates_with", "go"].edge_label

        train_labels = edge_labels[self.train_mask]
        val_labels = edge_labels[self.val_mask]
        test_labels = edge_labels[self.test_mask]

        train_label_ratio = torch.sum(train_labels).item() / train_labels.size(0)
        val_label_ratio = torch.sum(val_labels).item() / val_labels.size(0)
        test_label_ratio = torch.sum(test_labels).item() / test_labels.size(0)

        ratios = [train_label_ratio, val_label_ratio, test_label_ratio]
        labels = ["Train", "Val", "Test"]

        plt.figure(figsize=(8, 6))
        bars = plt.bar(labels, ratios, color=["#1f77b4", "#ff7f0e", "#2ca02c"])
        plt.ylabel("Positive Label Ratio")
        plt.title("Positive Label Ratios in Train, Validation, and Test Sets")

        # Add value labels on top of each bar
        for bar in bars:
            height = bar.get_height()
            plt.text(
                bar.get_x() + bar.get_width() / 2.0,
                height,
                f"{height:.2f}",
                ha="center",
                va="bottom",
            )

        # Add total counts as text
        plt.text(
            0.05, 0.95, f"Train: {train_labels.size(0)}", transform=plt.gca().transAxes
        )
        plt.text(
            0.05, 0.90, f"Val: {val_labels.size(0)}", transform=plt.gca().transAxes
        )
        plt.text(
            0.05, 0.85, f"Test: {test_labels.size(0)}", transform=plt.gca().transAxes
        )

        plt.ylim(0, 1)  # Set y-axis limit from 0 to 1
        plt.savefig(self.ppi_folder / "label_ratios.png", dpi=300, bbox_inches="tight")
        plt.close()  # Close the plot to free up memory

        print(f"Train positive ratio: {train_label_ratio:.2f}")
        print(f"Validation positive ratio: {val_label_ratio:.2f}")
        print(f"Test positive ratio: {test_label_ratio:.2f}")

    def save_heterodata(self):
        """
        Saves the HeteroData object to a specified file path.

        Returns:
        None
        """
        save_file_path = self.ppi_folder / "heterodata.pt"
        try:
            torch.save(self.hetero_data, save_file_path)
            print(f"HeteroData object saved to {save_file_path}.")
        except Exception as e:
            print(f"An error occurred while saving the HeteroData object: {e}")

    def check_edge_overlap(self):
        """
        Ensure that there is no overlap between positive and negative edges.

        Returns:
        bool: True if no overlap is found, False otherwise.
        """
        positive_edges_set = set(
            map(tuple, self.edge_index_protgo_positive.t().tolist())
        )
        negative_edges_set = set(
            map(tuple, self.edge_index_protgo_negative.t().tolist())
        )

        # Check for overlap
        overlap = positive_edges_set & negative_edges_set
        if overlap:
            print(f"Found overlap between positive and negative edges: {overlap}")
            return False
        else:
            print("No overlap between positive and negative edges.")
            return True

    def check_data_overlap(self):
        """
        Ensure that there is no leakage of proteins between train, val, and test sets.

        Returns:
        bool: True if no leakage is found, False otherwise.
        """
        train_set = set(self.train_protein_ids)
        val_set = set(self.val_protein_ids)
        test_set = set(self.test_protein_ids)

        # Check for overlaps
        train_val_overlap = train_set & val_set
        train_test_overlap = train_set & test_set
        val_test_overlap = val_set & test_set

        if train_val_overlap or train_test_overlap or val_test_overlap:
            if train_val_overlap:
                print(f"Found overlap between train and val sets: {train_val_overlap}")
            if train_test_overlap:
                print(
                    f"Found overlap between train and test sets: {train_test_overlap}"
                )
            if val_test_overlap:
                print(f"Found overlap between val and test sets: {val_test_overlap}")
            return False
        else:
            print("No leakage of proteins between train, val, and test sets.")
            return True

    def process(
        self,
        associations_df,
        protein_feats,
        go_feats,
        protein_ids,
        go_ids,
        edge_index_ppi,
        go_edges_df,
        train_ratio=0.6,
        ratio=1,
    ):
        """
        Processes the dataset by generating negative edges, creating a HeteroData object,
        splitting the dataset into training, validation, and testing sets, and plotting label ratios.

        Parameters:
        associations_df (pandas.DataFrame): DataFrame containing the associations between proteins and GO terms.
        protein_feats (torch.Tensor): Protein features.
        go_feats (torch.Tensor): GO term features.
        protein_ids (list): Protein IDs.
        go_ids (list): GO term IDs.
        edge_index_ppi (torch.Tensor): Edge indices for protein-protein interactions.
        go_edges_df (pandas.DataFrame): DataFrame containing GO term relationships and edge types.
        train_ratio (float, optional): The ratio of the dataset to be used for training. Default is 0.6.
        ratio (int, optional): The ratio of negative edges to positive edges. Default is 1.

        Returns:
        None
        """
        edge_index_protgo_positive = self.convert_associations_to_edge_index(
            associations_df
        )
        number_pos_edges = edge_index_protgo_positive.shape[1]

        if self.structured_negative_sampling_feasible(edge_index_protgo_positive):
            print(f"Generating {number_pos_edges * ratio} negative edges...")

            self.generate_negative_edges(
                edge_index=edge_index_protgo_positive,
                num_nodes_protein=len(protein_ids),
                num_nodes_go=len(go_ids),
                num_neg_samples=number_pos_edges * ratio,
                weighted=True,
                random_state=42,
            )

            print(
                f"Protein negative associates with GO edge index shape: {self.edge_index_protgo_negative.shape}"
            )
            self.check_balance(
                edge_index_protgo_positive,
                self.edge_index_protgo_negative,
                len(protein_ids),
                len(go_ids),
            )

            self.combine_associations()

            self.create_heterodata(
                protein_feats,
                go_feats,
                protein_ids,
                go_ids,
                edge_index_ppi,
                go_edges_df,
                edge_index_protgo_positive,
            )

            self.split_dataset(train_ratio)

            if not self.check_data_overlap():
                raise ValueError(
                    "Protein leakage detected between train, val, and test sets."
                )

            self.plot_label_ratios()
            self.plot_go_edge_type_ratios()
            self.save_heterodata()

        else:
            print("Negative edges generation isn't feasible.")

    def structured_negative_sampling_feasible(self, edge_index_protgo_positive):
        """
        Checks if structured negative sampling is feasible for the given positive edge indices.

        Parameters:
        edge_index_protgo_positive (torch.Tensor): Edge indices for positive protein-GO associations.

        Returns:
        bool: True if structured negative sampling is feasible, False otherwise.
        """
        return structured_negative_sampling_feasible(edge_index_protgo_positive)

### HAN Model Class

#### Baseline

In [3]:
class HGNNGAT(torch.nn.Module):
    def __init__(
        self,
        protein_features,
        go_features,
        out_features,
        dropout=0.6,
        heads=8,
        concat=True,
    ):
        super(HGNNGAT, self).__init__()
        per_head_features = out_features // heads if concat else out_features

        self.conv1_protein = GATConv(
            protein_features,
            per_head_features,
            heads=heads,
            concat=concat,
            dropout=dropout,
        )
        self.conv1_go = GATConv(
            go_features, per_head_features, heads=heads, concat=concat, dropout=dropout
        )
        self.conv2_protein = GATConv(
            out_features, per_head_features, heads=heads, concat=concat, dropout=dropout
        )
        self.conv2_go = GATConv(
            out_features, per_head_features, heads=heads, concat=concat, dropout=dropout
        )

        final_out_features = out_features if concat else per_head_features
        self.protein_to_go = Linear(final_out_features * 2, 1)

    def forward(self, x_dict, edge_index_dict, mask=None):
        x_protein = F.elu(
            self.conv1_protein(
                x_dict["protein"], edge_index_dict[("protein", "ppi", "protein")]
            )
        )
        x_go = F.elu(
            self.conv1_go(x_dict["go"], edge_index_dict[("go", "relation", "go")])
        )

        x_protein = F.elu(
            self.conv2_protein(
                x_protein, edge_index_dict[("protein", "ppi", "protein")]
            )
        )
        x_go = F.elu(self.conv2_go(x_go, edge_index_dict[("go", "relation", "go")]))

        edge_index_protein_go = edge_index_dict[("protein", "associates_with", "go")]
        if mask is not None:
            edge_index_protein_go = edge_index_protein_go[:, mask]
        protein_features = x_protein[edge_index_protein_go[0]]
        go_features = x_go[edge_index_protein_go[1]]

        association_features = torch.cat([protein_features, go_features], dim=1)
        association_scores = self.protein_to_go(association_features).squeeze()
        return association_scores

    def get_attention_weights(self, x_dict, edge_index_dict):
        self.eval()
        with torch.no_grad():

            def safe_unpack(result):
                if isinstance(result, tuple) and len(result) == 2:
                    return result
                else:
                    return result, None

            x_protein, attention_protein_1 = safe_unpack(
                self.conv1_protein(
                    x_dict["protein"],
                    edge_index_dict[("protein", "ppi", "protein")],
                    return_attention_weights=True,
                )
            )
            x_protein, attention_protein_2 = safe_unpack(
                self.conv2_protein(
                    x_protein,
                    edge_index_dict[("protein", "ppi", "protein")],
                    return_attention_weights=True,
                )
            )

            x_go, attention_go_1 = safe_unpack(
                self.conv1_go(
                    x_dict["go"],
                    edge_index_dict[("go", "relation", "go")],
                    return_attention_weights=True,
                )
            )
            x_go, attention_go_2 = safe_unpack(
                self.conv2_go(
                    x_go,
                    edge_index_dict[("go", "relation", "go")],
                    return_attention_weights=True,
                )
            )

            return (
                attention_protein_1,
                attention_go_1,
                attention_protein_2,
                attention_go_2,
            )

#### FocusedGATConv

In [4]:
# Custom GAT layer definition

class FocusedGATConv(GATConv):
    def __init__(
        self, *args, sample_size=8, high_degree_threshold=10, directed=False, **kwargs
    ):
        super(FocusedGATConv, self).__init__(*args, **kwargs)
        self.sample_size = sample_size
        self.high_degree_threshold = high_degree_threshold
        self.directed = directed

    def forward(self, x, edge_index, return_attention_weights=False):
        row, col = edge_index
        deg = degree(row, x.size(0), dtype=x.dtype)

        # Create mask for high-degree nodes
        high_deg_mask = deg > self.high_degree_threshold

        # Keep all edges for low-degree nodes
        low_deg_edges = edge_index[:, ~high_deg_mask[row]]

        # Sample edges for high-degree nodes
        high_deg_edges = []
        high_deg_nodes = high_deg_mask.nonzero(as_tuple=True)[0]
        for node in high_deg_nodes:
            if self.directed:
                # For directed graphs, consider only outgoing edges
                neighbors = col[row == node]
            else:
                # For undirected graphs, consider all connected nodes
                neighbors = torch.cat([col[row == node], row[col == node]])

            if len(neighbors) > self.sample_size:
                sampled_neighbors = neighbors[
                    torch.randperm(len(neighbors))[: self.sample_size]
                ]
            else:
                sampled_neighbors = neighbors

            if self.directed:
                sampled_edges = torch.stack(
                    [torch.full_like(sampled_neighbors, node), sampled_neighbors], dim=0
                )
            else:
                sampled_edges = torch.stack(
                    [
                        torch.cat(
                            [
                                torch.full_like(sampled_neighbors, node),
                                sampled_neighbors,
                            ]
                        ),
                        torch.cat(
                            [
                                sampled_neighbors,
                                torch.full_like(sampled_neighbors, node),
                            ]
                        ),
                    ],
                    dim=0,
                )

            high_deg_edges.append(sampled_edges)

        if high_deg_edges:
            high_deg_edges = torch.cat(high_deg_edges, dim=1)
            sampled_edges = torch.cat([low_deg_edges, high_deg_edges], dim=1)
        else:
            sampled_edges = low_deg_edges

        # Call the original forward method to compute attention
        out = super().forward(x, sampled_edges, return_attention_weights=True)
        x, (sampled_edges, alpha) = out

        # Apply softmax normalization to the attention coefficients
        alpha = F.softmax(alpha, dim=1)

        if return_attention_weights:
            return x, (sampled_edges, alpha)
        else:
            return x


class FocusedHGNNGATdirect(torch.nn.Module):
    def __init__(
        self,
        protein_features,
        go_features,
        out_features,
        dropout=0.6,
        heads=8,
        concat=True,
        protein_sample_size=8,
        protein_high_degree_threshold=10,
        go_sample_size=8,
        go_high_degree_threshold=20,
    ):
        super(FocusedHGNNGATdirect, self).__init__()
        per_head_features = out_features // heads if concat else out_features

        self.conv1_protein = FocusedGATConv(
            protein_features,
            per_head_features,
            heads=heads,
            concat=concat,
            dropout=dropout,
            sample_size=protein_sample_size,
            high_degree_threshold=protein_high_degree_threshold,
            directed=False,
        )
        self.conv1_go = FocusedGATConv(
            go_features,
            per_head_features,
            heads=heads,
            concat=concat,
            dropout=dropout,
            sample_size=go_sample_size,
            high_degree_threshold=go_high_degree_threshold,
            directed=True,
        )
        self.conv2_protein = FocusedGATConv(
            out_features,
            per_head_features,
            heads=heads,
            concat=concat,
            dropout=dropout,
            sample_size=protein_sample_size,
            high_degree_threshold=protein_high_degree_threshold,
            directed=False,
        )
        self.conv2_go = FocusedGATConv(
            out_features,
            per_head_features,
            heads=heads,
            concat=concat,
            dropout=dropout,
            sample_size=go_sample_size,
            high_degree_threshold=go_high_degree_threshold,
            directed=True,
        )

        final_out_features = out_features if concat else per_head_features
        self.protein_to_go = Linear(final_out_features * 2, 1)

    def forward(self, x_dict, edge_index_dict, mask=None):
        x_protein = F.elu(
            self.conv1_protein(
                x_dict["protein"], edge_index_dict[("protein", "ppi", "protein")]
            )
        )
        x_go = F.elu(
            self.conv1_go(x_dict["go"], edge_index_dict[("go", "relation", "go")])
        )

        x_protein = F.elu(
            self.conv2_protein(
                x_protein, edge_index_dict[("protein", "ppi", "protein")]
            )
        )
        x_go = F.elu(self.conv2_go(x_go, edge_index_dict[("go", "relation", "go")]))

        edge_index_protein_go = edge_index_dict[("protein", "associates_with", "go")]
        if mask is not None:
            edge_index_protein_go = edge_index_protein_go[:, mask]
        protein_features = x_protein[edge_index_protein_go[0]]
        go_features = x_go[edge_index_protein_go[1]]

        association_features = torch.cat([protein_features, go_features], dim=1)
        association_scores = self.protein_to_go(association_features).squeeze()
        return association_scores

    def get_attention_weights(self, x_dict, edge_index_dict):
        self.eval()
        with torch.no_grad():
            _, attention_protein_1 = self.conv1_protein(
                x_dict["protein"],
                edge_index_dict[("protein", "ppi", "protein")],
                return_attention_weights=True,
            )
            _, attention_protein_2 = self.conv2_protein(
                F.elu(
                    self.conv1_protein(
                        x_dict["protein"],
                        edge_index_dict[("protein", "ppi", "protein")],
                    )
                ),
                edge_index_dict[("protein", "ppi", "protein")],
                return_attention_weights=True,
            )

            _, attention_go_1 = self.conv1_go(
                x_dict["go"],
                edge_index_dict[("go", "relation", "go")],
                return_attention_weights=True,
            )
            _, attention_go_2 = self.conv2_go(
                F.elu(
                    self.conv1_go(
                        x_dict["go"], edge_index_dict[("go", "relation", "go")]
                    )
                ),
                edge_index_dict[("go", "relation", "go")],
                return_attention_weights=True,
            )

            return [
                attention_protein_1,
                attention_go_1,
                attention_protein_2,
                attention_go_2,
            ]

### Evaluation Functions

In [5]:
def evaluate(model, x_dict, edge_index_dict, edge_labels, val_mask, device="cuda"):
    """Evaluate train and validation during training."""
    model.to(device)
    for key in x_dict:
        x_dict[key] = x_dict[key].to(device)
    for key in edge_index_dict:
        edge_index_dict[key] = edge_index_dict[key].to(device)
    edge_labels = edge_labels.to(device)
    val_mask = val_mask.to(device)

    model.eval()
    with torch.no_grad():
        out = model(x_dict, edge_index_dict, val_mask)
        val_loss = F.binary_cross_entropy_with_logits(out, edge_labels[val_mask])
        predictions = out.sigmoid()
        val_auroc = roc_auc_score(
            edge_labels[val_mask].cpu().numpy(),
            predictions.cpu().numpy(),
        )
    return val_loss.item(), val_auroc


def predict(model, x_dict, edge_index_dict, edge_labels, mask, device="cuda"):
    """Make predictions (test)"""
    model.to(device)
    for key in x_dict:
        x_dict[key] = x_dict[key].to(device)
    for key in edge_index_dict:
        edge_index_dict[key] = edge_index_dict[key].to(device)
    edge_labels = edge_labels.to(device)
    mask = mask.to(device)

    model.eval()
    with torch.no_grad():
        out = model(x_dict, edge_index_dict, mask)
        probabilities = out.sigmoid()

        true_labels = edge_labels[mask]

        auroc = roc_auc_score(true_labels.cpu().numpy(), probabilities.cpu().numpy())

        predicted_labels = (probabilities > 0.5).float()

        return predicted_labels, true_labels, auroc, probabilities


def evaluate_model_performance(
    model,
    hetero_data,
    associations_df,
    go_id_map,
    save_path,
    device="cuda" if torch.cuda.is_available() else "cpu",
):
    """
    Evaluate the performance of a trained model on a test set.

    Parameters:
    - model: The trained model to evaluate.
    - hetero_data: The heterogeneous graph data containing the protein, GO, and associations.
    - associations_df: The DataFrame containing the associations between proteins and GO terms.
    - go_id_map: The mapping of GO term indices to their IDs.
    - save_path: The path to save the test results and plots.
    - device: The device to run the model and data on (default is 'cuda' if available, otherwise 'cpu').

    Returns:
    - results_df: The DataFrame containing the test results, including protein IDs, GO IDs, ground truth labels, predicted labels, and probabilities.
    """

    # Ensure Path object of save path
    save_path = Path(save_path)

    # Move model and data to device
    model.to(device)
    for key in hetero_data.x_dict:
        hetero_data.x_dict[key] = hetero_data.x_dict[key].to(device)
    for key in hetero_data.edge_index_dict:
        hetero_data.edge_index_dict[key] = hetero_data.edge_index_dict[key].to(device)
    hetero_data["protein", "associates_with", "go"].edge_label = hetero_data[
        "protein", "associates_with", "go"
    ].edge_label.to(device)
    hetero_data["protein", "associates_with", "go"].test_mask = hetero_data[
        "protein", "associates_with", "go"
    ].test_mask.to(device)

    # Extract test mask
    test_mask = hetero_data["protein", "associates_with", "go"].test_mask

    # Predict
    predicted_labels, test_labels, test_auroc, probabilities = predict(
        model,
        hetero_data.x_dict,
        hetero_data.edge_index_dict,
        hetero_data["protein", "associates_with", "go"].edge_label,
        test_mask,
        device=device,
    )

    # Retrieve the indices of the test edges
    test_edge_indices = hetero_data["protein", "associates_with", "go"].edge_index[
        :, test_mask
    ]

    # Get the protein and GO indices
    protein_indices = test_edge_indices[0].cpu().numpy()
    go_indices = test_edge_indices[1].cpu().numpy()

    # Map indices to IDs
    protein_id_map = dict(
        zip(
            associations_df["uniprot_idx"].astype(int),
            associations_df["uniprot_id"],
        )
    )
    go_id_map = {v: k for k, v in go_id_map.items()}  # swap mapping

    # Prepare data for the DataFrame
    data = {
        "Protein ID": [protein_id_map.get(idx, "Unknown") for idx in protein_indices],
        "GO ID": [go_id_map.get(idx, "Unknown") for idx in go_indices],
        "Ground Truth Label": [label.item() for label in test_labels],
        "Predicted Label": [label.item() for label in predicted_labels],
        "Probability": [prob.item() for prob in probabilities],
    }

    # Create and save df
    results_df = pd.DataFrame(data)
    results_df.to_csv(save_path / "test_results.csv", index=False)

    # Plot metrics
    print(f"AUROC test: {test_auroc:.2f}")

    # Save classification report and print
    class_report = classification_report(
        test_labels.cpu().numpy(),
        predicted_labels.cpu().numpy(),
        target_names=["0", "1"],
    )

    with open(save_path / "classification_report.txt", "w") as file:
        file.write(class_report)

    print("\nClassification Report:")
    print(class_report)

    disp = ConfusionMatrixDisplay.from_predictions(
        test_labels.cpu().numpy(),
        predicted_labels.cpu().numpy(),
        normalize="true",
        cmap="Blues",
        values_format=".1%",
    )
    plt.savefig(save_path / "confusion_matrix.png", dpi=300)
    plt.show()

    return results_df

### Training Function

In [6]:
def train(
    model,
    dataset,
    config,
    run,
    save_path=None,
    device="cuda" if torch.cuda.is_available() else "cpu",
    use_scheduler=True,
    accumulation_steps=1,
):

    model.to(device)
    dataset.to_device(device)
    dataset.create_data_loaders(batch_size=config["batch_size"])

    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
    scheduler = (
        torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            "min",
            patience=config["scheduler_patience"],
            factor=0.1,
            verbose=True,
        )
        if use_scheduler
        else None
    )

    best_val_loss = float("inf")
    early_stopping_counter = 0

    for epoch in tqdm(
        range(1, config["epochs"] + 1), desc="Training...", total=config["epochs"]
    ):
        model.train()
        train_loss = 0
        train_predictions, train_labels = [], []

        for i, (edge_index, edge_label) in enumerate(dataset.train_loader):
            edge_index, edge_label = edge_index.to(device), edge_label.to(device)
            out = model(
                {
                    "protein": dataset.hetero_data["protein"].x,
                    "go": dataset.hetero_data["go"].x,
                },
                {
                    ("protein", "ppi", "protein"): dataset.hetero_data[
                        "protein", "ppi", "protein"
                    ].edge_index,
                    ("go", "relation", "go"): dataset.hetero_data[
                        "go", "relation", "go"
                    ].edge_index,
                    ("protein", "associates_with", "go"): edge_index.t(),
                },
            )
            loss = (
                F.binary_cross_entropy_with_logits(out, edge_label) / accumulation_steps
            )
            loss.backward()

            if (i + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            train_loss += loss.item() * accumulation_steps
            train_predictions.append(out.sigmoid().detach().cpu().numpy())
            train_labels.append(edge_label.cpu().numpy())

        train_loss /= len(dataset.train_loader)
        train_auroc = roc_auc_score(
            np.concatenate(train_labels), np.concatenate(train_predictions)
        )
        wandb.log({"train_loss": train_loss, "train_auroc": train_auroc})

        model.eval()
        with torch.no_grad():
            val_loss = 0
            val_predictions, val_labels = [], []

            for edge_index, edge_label in dataset.val_loader:
                edge_index, edge_label = edge_index.to(device), edge_label.to(device)
                out = model(
                    {
                        "protein": dataset.hetero_data["protein"].x,
                        "go": dataset.hetero_data["go"].x,
                    },
                    {
                        ("protein", "ppi", "protein"): dataset.hetero_data[
                            "protein", "ppi", "protein"
                        ].edge_index,
                        ("go", "relation", "go"): dataset.hetero_data[
                            "go", "relation", "go"
                        ].edge_index,
                        ("protein", "associates_with", "go"): edge_index.t(),
                    },
                )
                loss = F.binary_cross_entropy_with_logits(out, edge_label)
                val_loss += loss.item()
                val_predictions.append(out.sigmoid().cpu().numpy())
                val_labels.append(edge_label.cpu().numpy())

            val_loss /= len(dataset.val_loader)
            val_auroc = roc_auc_score(
                np.concatenate(val_labels), np.concatenate(val_predictions)
            )
            wandb.log({"val_loss": val_loss, "val_auroc": val_auroc})

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), f"{save_path}/best_model_weights.pt")
                early_stopping_counter = 0
            else:
                early_stopping_counter += 1

            if early_stopping_counter >= config["early_stopping_patience"]:
                print(f"Early stopping triggered after {epoch} epochs.")
                break

        if scheduler:
            scheduler.step(val_loss)

        if epoch % 10 == 0:
            print(
                f"Epoch {epoch}: Train Loss: {train_loss:.3f}, Train AUROC: {train_auroc:.3f}, "
                f"Val Loss: {val_loss:.3f}, Val AUROC: {val_auroc:.3f}"
            )
    return model, best_val_loss

### Explainability Class

In [7]:
class Explainability:
    def __init__(self, model, hetero_data, save_path):
        """
        Initialize the ExplainabilityModule with the given model, heterogeneous data, and save path.

        Parameters:
        - model (torch.nn.Module): The graph neural network model to be explained.
        - hetero_data (torch_geometric.data.HeteroData): The heterogeneous graph data containing nodes and edges.
        - save_path (str): The path where the generated visualizations will be saved.

        Returns:
        None
        """
        self.model = model
        self.hetero_data = hetero_data
        self.save_path = Path(save_path)
        self.save_path.mkdir(parents=True, exist_ok=True)
        self.id_mappings = {
            "protein": {idx: id for idx, id in enumerate(hetero_data["protein"].id)},
            "go": {idx: id for idx, id in enumerate(hetero_data["go"].id)},
        }

    def aggregate_attention_weights(self, attn_weights, edge_index, num_nodes):
        node_attention = torch.zeros(
            (num_nodes, attn_weights.size(1)), device=attn_weights.device
        )
        src, dst = edge_index
        node_attention.index_add_(0, src, attn_weights)
        node_attention.index_add_(0, dst, attn_weights)
        return node_attention / 2

    def visualize_node_importance(
        self, attn_weights, node_type, id_mapping, conv_layer, top_n=10
    ):
        if attn_weights is None or not attn_weights.numel():
            print("Attention weights are empty or invalid.")
            return [], None

        attn_weights = attn_weights.mean(dim=1).cpu().detach().numpy()
        valid_weights = np.where(np.isfinite(attn_weights), attn_weights, 0)

        if not np.any(valid_weights):
            print("All attention weights are zero or invalid.")
            return [], None

        top_indices = np.argsort(valid_weights)[-top_n:][::-1]
        top_ids = [id_mapping[idx] for idx in top_indices]

        return top_ids, valid_weights[top_indices]

    def plot_node_importance(self, top_n=25):
        attention_weights = self.model.get_attention_weights(
            self.hetero_data.x_dict, self.hetero_data.edge_index_dict
        )
        results = {}
        plots = []

        num_layers = len(attention_weights) // 2
        fig = make_subplots(
            rows=2,
            cols=num_layers,
            subplot_titles=[
                f"{node_type.capitalize()} Layer {i+1}"
                for node_type in ["Protein", "GO"]
                for i in range(num_layers)
            ],
        )

        for node_type, attn_weights_list, edge_indices_list in [
            (
                "protein",
                [aw[1] for aw in attention_weights[::2]],
                [aw[0] for aw in attention_weights[::2]],
            ),
            (
                "go",
                [aw[1] for aw in attention_weights[1::2]],
                [aw[0] for aw in attention_weights[1::2]],
            ),
        ]:
            for i, (attn, edges) in enumerate(
                zip(attn_weights_list, edge_indices_list), 1
            ):
                attention = self.aggregate_attention_weights(
                    attn, edges, self.hetero_data[node_type].x.size(0)
                )
                top_ids, top_weights = self.visualize_node_importance(
                    attention,
                    node_type.capitalize(),
                    self.id_mappings[node_type],
                    conv_layer=i,
                    top_n=top_n,
                )

                if top_ids and top_weights is not None:
                    row = 1 if node_type == "protein" else 2
                    fig.add_trace(go.Bar(x=top_ids, y=top_weights), row=row, col=i)
                    fig.update_xaxes(
                        title_text=f"Top {top_n} {node_type.capitalize()} Node IDs",
                        row=row,
                        col=i,
                    )
                    fig.update_yaxes(title_text="Importance", row=row, col=i)

                results[f"{node_type}_{i}"] = top_ids

        fig.update_layout(
            height=800,
            width=300 * num_layers,
            title_text="Node Importance Across Layers",
        )
        plots.append(fig)

        fig.write_image(str(self.save_path / "node_importance_all_layers.png"))

        return results, plots

    def compare_degrees(self, important_nodes, node_type, edge_index):
        index_mapping = {v: k for k, v in self.id_mappings[node_type].items()}
        important_indices = [index_mapping[node_id] for node_id in important_nodes]
        row, col = edge_index
        node_degrees = degree(row, row.max().item() + 1).cpu().numpy()
        important_degrees = [node_degrees[idx] for idx in important_indices]

        comparison_df = pd.DataFrame(
            {"Node ID": important_nodes, "Degree": important_degrees}
        )

        fig, ax = plt.subplots(figsize=(12, 6))
        ax.bar(comparison_df["Node ID"], comparison_df["Degree"])
        ax.set_xlabel(f"{node_type} Node IDs")
        ax.set_ylabel("Degree")
        ax.set_title(f"Degree Comparison of Important {node_type} Nodes")
        plt.xticks(rotation=90)

        plt.tight_layout()
        plt.savefig(self.save_path / f"degree_comparison_attention_{node_type}.png")
        plt.show()
        plt.close()

        return comparison_df, fig

    @staticmethod
    def plot_tfidf(df, text_col, label, top_n=25):
        vectorizer = TfidfVectorizer(max_df=0.85, min_df=2, ngram_range=(1, 2))
        tf_idf_matrix = vectorizer.fit_transform(df[text_col])
        max_scores = np.max(tf_idf_matrix, axis=0).toarray().flatten()
        sorted_indices = np.argsort(max_scores)[::-1][:top_n]
        sorted_features = vectorizer.get_feature_names_out()[sorted_indices]
        sorted_scores = max_scores[sorted_indices]

        plt.figure(figsize=(12, 8))
        plt.barh(sorted_features, sorted_scores, color="green")
        plt.xlabel("Max TF-IDF Score")
        plt.title(f"Top {top_n} Terms with Highest TF-IDF Scores in {label}")
        plt.gca().invert_yaxis()
        plt.close()

        return sorted_features.tolist()

    @staticmethod
    def plot_kde_with_overlay(df, top_go_df, information_accretion_col, label):
        plt.figure(figsize=(12, 6))
        sns.kdeplot(
            df[information_accretion_col], fill=True, alpha=0.5, label="All GO Terms"
        )
        sns.rugplot(
            top_go_df[information_accretion_col],
            height=0.1,
            color="red",
            label="Top GO Terms",
        )
        plt.title(
            f"Distribution of 'ia' across all GO terms with overlay of top GO terms - {label}"
        )
        plt.xlabel("ia")
        plt.ylabel("Density")
        plt.legend()
        plt.show()
        plt.close()

    @staticmethod
    def plot_heatmap(
        df,
        top_go_ids_conv1,
        top_go_ids_conv2,
        information_accretion_col,
        top_n=25,
        random_state=None,
    ):
        top_go_df_conv1 = df[df["go_id"].isin(top_go_ids_conv1)]
        top_go_df_conv2 = df[df["go_id"].isin(top_go_ids_conv2)]
        num_top_terms = len(top_go_df_conv1)

        random_go_df = df[
            ~df["go_id"].isin(top_go_ids_conv1 + top_go_ids_conv2)
        ].sample(n=num_top_terms, random_state=random_state)
        combined_df = pd.concat([random_go_df, top_go_df_conv1, top_go_df_conv2])
        combined_df["Category"] = (
            ["Random GO Terms"] * num_top_terms
            + ["Top GO Terms Conv1"] * num_top_terms
            + ["Top GO Terms Conv2"] * num_top_terms
        )

        plt.figure(figsize=(12, 0.8 * top_n))
        ia_matrix = combined_df.pivot_table(
            index="go_id", columns="Category", values=information_accretion_col
        )
        sns.heatmap(
            ia_matrix, annot=True, cmap="viridis", cbar_kws={"label": "IA Value"}
        )
        plt.title(
            "Heatmap of IA Values for Random GO Terms, Conv1, and Conv2 Top GO Terms"
        )
        plt.xlabel("Category")
        plt.ylabel("GO Term ID")
        plt.close()

    def analyze_top_go(
        self,
        df,
        results,
        information_accretion_col="ia",
        text_col="definition",
        top_n=25,
        random_state=None,
    ):
        top_go_df_conv1 = df[df["go_id"].isin(results["go_1"])]
        self.plot_kde_with_overlay(
            df, top_go_df_conv1, information_accretion_col, label="Conv 1"
        )
        top_tokens_conv1 = self.plot_tfidf(
            top_go_df_conv1, text_col, label="Top GO Terms Conv1", top_n=top_n
        )

        top_go_df_conv2 = df[df["go_id"].isin(results["go_2"])]
        self.plot_kde_with_overlay(
            df, top_go_df_conv2, information_accretion_col, label="Conv 2"
        )
        top_tokens_conv2 = self.plot_tfidf(
            top_go_df_conv2, text_col, label="Top GO Terms Conv2", top_n=top_n
        )

        self.plot_heatmap(
            df,
            results["go_1"],
            results["go_2"],
            information_accretion_col,
            top_n=top_n,
            random_state=random_state,
        )

        return top_tokens_conv1, top_tokens_conv2, top_go_df_conv1, top_go_df_conv2

    def analyze_feature_importance(self):
        # Implement SHAP analysis here
        pass

    def visualize_attention_flow(self):
        # Implement attention flow visualization
        pass

    def explain_prediction(self, protein_id):
        # Implement individual prediction explanation
        pass

    def analyze_errors(self):
        # Implement error analysis
        pass

    def test_robustness(self):
        # Implement robustness analysis
        pass

    def layer_wise_relevance_propagation(self):
        # Implement LRP
        pass

    def analyze_graph_structure(self):
        # Implement graph structure analysis
        pass

    def ablation_study(self):
        # Implement ablation study functionality
        pass

    def concept_activation_vectors(self):
        # Implement TCAV
        pass

    def generate_counterfactuals(self):
        # Implement counterfactual generation
        pass

## Load Aid2GO Training Data

In [8]:
# Load Prot2GO Dataset
filepath = base_dir / "data/ppi/protgo_dataset.pt"
protgo_dataset = torch.load(filepath)
protgo_dataset.to_device(device)
protgo_dataset

# Define output dim for models
out_features = 128

## Baseline

### Train Model

In [None]:
# Hyperparameters
config = {
    "lr": 0.001,
    "epochs": 200,
    "batch_size": 100000,
    "dropout_rate": 0.5,
    "heads": 8,
    "out_features": out_features,
    "scheduler_patience": 10,
    "early_stopping_patience": 20,
    "balance": "1:2",
}

# Model instance
model = HGNNGAT(
    protein_features=protgo_dataset.hetero_data["protein"].x.size(1),
    go_features=protgo_dataset.hetero_data["go"].x.size(1),
    out_features=config["out_features"],
    dropout=config["dropout_rate"],
    heads=config["heads"],
)
model

In [None]:
# Initialize wandb logging
run = wandb.init(
    project="prot2go",
    notes="Baseline GAT Model",
    tags=[
        "PPI data: human",
        "balance 1:2",
    ],
    config=config,
    # mode="disabled"
)
wandb.watch(model, log="all", log_freq=10)

# Train and save best model
save_path =  output_dir / f"models/baseline_{config["out_features"]}_{config["balance"]}"
save_path.mkdir(parents=True, exist_ok=True)

train(model=model, dataset=protgo_dataset, config=config, run=run, save_path=save_path, use_scheduler=True,)
wandb.finish()

### Evaluate Model Performance

In [None]:
# Load model

# Instantiate the model
model = HGNNGAT(
    protein_features=protgo_dataset.hetero_data["protein"].x.size(1),
    go_features=protgo_dataset.hetero_data["go"].x.size(1),
    out_features=config["out_features"],
    dropout=config["dropout_rate"],
    heads=config["heads"],
)

# Load best model's state_dict
save_path =  output_dir / f"models/baseline_{config["out_features"]}_{config["balance"]}"
filepath = save_path / "best_model_weights.pt"
state_dict = torch.load(filepath)
model.load_state_dict(state_dict)
model

In [None]:
results_df = evaluate_model_performance(
    model,
    protgo_dataset.hetero_data,
    protgo_dataset.associations_combined_df,
    protgo_dataset.go_ids_map,
    save_path,
)

In [None]:
# Get predictions in CAFA format for evaluation

# Create folder to save predictions
save_path = output_dir / "predictions"
save_path.mkdir(parents=True, exist_ok=True)
filename = f"predictions_baseline_{out_features}.tsv"

# Filter columns and round scores (cafaeval accepts only 3 decimals)
predictions = results_df[["Protein ID", "GO ID", "Probability"]]
predictions["Probability"] = predictions["Probability"].round(3)

# Sort by protein identiiers (optional), reset index and save
predictions = predictions.sort_values(by="Protein ID")
predictions.reset_index(drop=True, inplace=True)
predictions.to_csv(
    save_path / filename,
    sep="\t",
    header=False,
    index=False,
    encoding="utf-8",
)
predictions

### Explain Model Predictions

In [13]:
explainer = Explainability(
    model=model, hetero_data=protgo_dataset.hetero_data, save_path=save_path
)

results, plots = explainer.plot_node_importance()


In [None]:
comparison_df, plot = explainer.compare_degrees(
    results["protein_1"],
    "protein",
    protgo_dataset.hetero_data["protein", "ppi", "protein"].edge_index,
)

comparison_df, plot = explainer.compare_degrees(
    results["protein_2"],
    "protein",
    protgo_dataset.hetero_data["protein", "ppi", "protein"].edge_index,
)

In [None]:
comparison_df, plot = explainer.compare_degrees(
    results["go_1"],
    "go",
    protgo_dataset.hetero_data["go", "relation", "go"].edge_index,
)

comparison_df, plot = explainer.compare_degrees(
    results["go_2"],
    "go",
    protgo_dataset.hetero_data["go", "relation", "go"].edge_index,
)

In [None]:
# Load GO dataset
go_df = pd.read_csv(base_dir / "data/go/go-basic.csv")

top_tokens_conv1, top_tokens_conv2, top_go_df_conv1, top_go_df_conv2 = (
    explainer.analyze_top_go(
        go_df,
        results,
        information_accretion_col="ia",
        text_col="definition",
        top_n=25,
        random_state=None,
    )
)

## Focused Attention with Static Node Sampling

### Train Model

In [17]:
# Hyperparameters
config = {
    "lr": 0.001,
    "epochs": 200,
    "batch_size": 100000,
    "dropout_rate": 0.5,
    "heads": 8,
    "out_features": out_features,
    "scheduler_patience": 10,
    "early_stopping_patience": 25,
    "protein_sample_size":10,
    "protein_high_degree_threshold":10,
    "go_sample_size":10,
    "go_high_degree_threshold":20,
    "balance": "1:2",
}

# Model instance
model = FocusedHGNNGATdirect(
    protein_features=protgo_dataset.hetero_data["protein"].x.size(1),
    go_features=protgo_dataset.hetero_data["go"].x.size(1),
    out_features=config["out_features"],
    dropout=config["dropout_rate"],
    heads=config["heads"],
    concat=True,
    protein_sample_size=config["protein_sample_size"],
    protein_high_degree_threshold=config["protein_high_degree_threshold"],
    go_sample_size=config["go_sample_size"],
    go_high_degree_threshold=config["go_high_degree_threshold"]
)

In [None]:
# Initialize wandb logging
run = wandb.init(
    project="prot2go",
    notes="Focused GAT Model",
    tags=[
        "human PPI",
        "balance 1:2",
        "go def embeddings",
        "biobert model",
    ],
    config=config,
    # mode="disabled"
)
wandb.watch(model, log="all", log_freq=10)

save_path =  output_dir / f"models/focused_{config["out_features"]}_{config["balance"]}"
save_path.mkdir(parents=True, exist_ok=True)

train(model=model, dataset=protgo_dataset, config=config, run=run, save_path=save_path, use_scheduler=True,)
wandb.finish()

### Evaluate Model Performance

In [None]:
# Load model

# Model instance
model = FocusedHGNNGATdirect(
    protein_features=protgo_dataset.hetero_data["protein"].x.size(1),
    go_features=protgo_dataset.hetero_data["go"].x.size(1),
    out_features=config["out_features"],
    dropout=config["dropout_rate"],
    heads=config["heads"],
    concat=True,
    protein_sample_size=config["protein_sample_size"],
    protein_high_degree_threshold=config["protein_high_degree_threshold"],
    go_sample_size=config["go_sample_size"],
    go_high_degree_threshold=config["go_high_degree_threshold"]
)

# Load best model's state_dict
save_path =  output_dir / f"models/focused_{config["out_features"]}_{config["balance"]}"
filepath = save_path / "best_model_weights.pt"
state_dict = torch.load(filepath)
model.load_state_dict(state_dict)
model

In [None]:
results_df = evaluate_model_performance(
    model,
    protgo_dataset.hetero_data,
    protgo_dataset.associations_combined_df,
    protgo_dataset.go_ids_map,
    save_path,
)

In [None]:
# Get predictions in CAFA format for evaluation

# Create folder to save predictions
save_path = output_dir / "predictions"
save_path.mkdir(parents=True, exist_ok=True)
filename = f"predictions_focused-attn_{out_features}.tsv"

# Filter columns and round scores (cafaeval accepts only 3 decimals)
predictions = results_df[["Protein ID", "GO ID", "Probability"]]
predictions["Probability"] = predictions["Probability"].round(3)

# Sort by protein identiiers (optional), reset index and save
predictions = predictions.sort_values(by="Protein ID")
predictions.reset_index(drop=True, inplace=True)
predictions.to_csv(
    save_path / filename,
    sep="\t",
    header=False,
    index=False,
    encoding="utf-8",
)
predictions

### Explain Model Predictions

In [None]:
explainer = Explainability(
    model=model, hetero_data=protgo_dataset.hetero_data, save_path=save_path
)
explainer

In [None]:
results, plots = explainer.plot_node_importance(top_n=25)

for plot in plots:
    plot.show()

In [None]:
comparison_df, plot = explainer.compare_degrees(
    results["protein_1"],
    "protein",
    protgo_dataset.hetero_data["protein", "ppi", "protein"].edge_index,
)

comparison_df, plot = explainer.compare_degrees(
    results["protein_2"],
    "protein",
    protgo_dataset.hetero_data["protein", "ppi", "protein"].edge_index,
)

In [None]:
comparison_df_1, plot = explainer.compare_degrees(
    results["go_1"],
    "go",
    protgo_dataset.hetero_data["go", "relation", "go"].edge_index,
)

comparison_df_2, plot = explainer.compare_degrees(
    results["go_2"],
    "go",
    protgo_dataset.hetero_data["go", "relation", "go"].edge_index,
)

In [None]:
import pandas as pd

top_tokens_conv1, top_tokens_conv2, top_go_df_conv1, top_go_df_conv2 = (
    explainer.analyze_top_go(
        go_df,
        results,
        information_accretion_col="ia",
        text_col="definition",
        top_n=25,
        random_state=None,
    )
)