In [None]:
"""
Graph: 
The ogbn-proteins dataset is an undirected, weighted, and typed 
(according to species) graph. Nodes represent proteins, and edges 
indicate different types of biologically meaningful associations 
between proteins, e.g., physical interactions, co-expression or 
homology [1,2]. All edges come with 8-dimensional features, where 
each dimension represents the approximate confidence of a single 
association type and takes values between 0 and 1 (the larger the 
value is, the more confident we are about the association). The 
proteins come from 8 species.

Prediction task: 
The task is to predict the presence of protein functions in a 
multi-label binary classification setup, where there are 112 kinds 
of labels to predict in total. The performance is measured by the 
average of ROC-AUC scores across the 112 tasks.

Dataset splitting: 
We split the protein nodes into training/validation/test sets 
according to the species which the proteins come from. This enables 
the evaluation of the generalization performance of the model across 
different species.
"""

"""
edges:
undirected
weighted
typed 
8-dimensional features
"""

In [None]:
import torch

from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data import Data

In [None]:
proteins_dataset = PygNodePropPredDataset(name="ogbn-proteins")

split_idx = proteins_dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
graph = proteins_dataset[0] # pyg graph object

graph
# Data(num_nodes=132534, edge_index=[2, 79122504], edge_attr=[79122504, 8], node_species=[132534, 1], y=[132534, 112])

graph.x
# None

graph.edge_attr[0]
# tensor([0.5010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010])

graph.y[0]
# tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1,
#         1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1,
#         1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0,
#         0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

train_graph = graph.subgraph(train_idx)
train_graph
# Data(num_nodes=86619, edge_index=[2, 40846716], edge_attr=[40846716, 8], node_species=[86619, 1], y=[86619, 112])

In [None]:
node_embedding = torch.load("embedding.pt")
node_embedding.shape
# torch.Size([132534, 32])

train_node_embedding = node_embedding[train_idx]
train_node_embedding.shape
# torch.Size([86619, 32])

In [None]:
enriched_dataset = Data(
    num_nodes = train_graph.num_nodes,
    edge_index = train_graph.edge_index,
    edge_attr = train_graph.edge_attr,
    node_species = train_graph.node_species,
    x = train_node_embedding,
    y = train_graph.y
)
enriched_dataset
# Data(x=[86619, 32], edge_index=[2, 40846716], edge_attr=[40846716, 8], y=[86619, 112], num_nodes=86619, node_species=[86619, 1])