In [None]:
from ogb.graphproppred import GraphPropPredDataset

import networkx as nx
import numpy as np

from tqdm import tqdm

In [None]:
dataset = GraphPropPredDataset(name = 'ogbg-molfreesolv')

split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]

In [None]:
feature_names = [
    'atomic_num',
    'chirality',
    'degree',
    'formal_charge',
    'num_h',
    'num_rad_e',
    'hybridization',
    'is_aromatic',
    'is_in_ring',
]

In [None]:
def graph_dict_to_nx_graph(graph_dict: dict):
    """
    Constructs a NetworkX graph object from the given graph dictionary.

    Args:
        graph_dict: A dictionary representing a graph, with the following keys:
            - 'num_nodes': The number of nodes in the graph.
            - 'node_feat': A list of node features, with one feature vector per node.
            - 'edge_index': An edge index array of shape (2, E), where E is the number of edges.
            - 'edge_feat': A list of edge features, with one feature vector per edge.

    Returns:
        A NetworkX graph object with nodes and edges corresponding to the input graph.
    """
    graph = nx.Graph()
    graph.add_nodes_from(range(graph_dict['num_nodes']))

    for node_number, features in enumerate(graph_dict['node_feat']):
        graph.add_node(node_number)
        for k, v in zip(feature_names, features):
            graph.nodes[node_number][k] = v

    # change shape from (2, E) to (E, 2)
    edge_list = np.transpose(graph_dict['edge_index'])
    for node_number, (from_node, to_node) in enumerate(edge_list):
        features = graph_dict['edge_feat'][node_number]
        graph.add_edge(from_node, to_node, feature=features)

    return graph

In [None]:
def get_shortest_paths(G):
    path = dict(nx.all_pairs_shortest_path(G))

    shortest_paths = [
        path[from_vertex][to_vertex] 
        for from_vertex in path 
        for to_vertex in path[from_vertex] 
        if from_vertex != to_vertex
    ]

    return shortest_paths

In [None]:
def get_sentences(graph_dict):
    G = graph_dict_to_nx_graph(graph_dict)

    shortest_paths = get_shortest_paths(G)

    sentences = [
        ' '.join(str(G.nodes[vertex]['atomic_num']) for vertex in shortest_path)
        for shortest_path in shortest_paths
    ]
    
    return sentences

In [None]:
graphs = [graph for graph, label in dataset]
labels = [label for graph, label in dataset]

train_graphs = [graphs[i] for i in train_idx]
train_labels = [labels[i] for i in train_idx]

test_graphs = [graphs[i] for i in test_idx]
test_labels = [labels[i] for i in test_idx]

valid_graphs = [graphs[i] for i in valid_idx]
valid_labels = [labels[i] for i in valid_idx]

In [None]:
my_sentences = [get_sentences(graph) for graph in tqdm(train_graphs)]

In [None]:
my_sentences[0][:10]

In [None]:
def stub(x):
    return x

: 

In [2]:
from multiprocessing import Pool

list_of_sentences = []

with Pool() as pool:
    for sentences in pool.map(stub, range(1000)):
        list_of_sentences.append(sentences)

list_of_sentences[:10]

In [None]:
len(list_of_sentences[0])