In [1]:
import pandas as pd
import numpy as np
import networkx as nx
import random
import re

from graphtoolbox import GraphHelper, OgbDataHelper, RandomWalker

from tqdm import tqdm

from ogb.graphproppred import GraphPropPredDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
d_name = 'ogbg-molfreesolv'
dataset = GraphPropPredDataset(name=d_name)
split_idx = dataset.get_idx_split()

X_raw = [data[0] for data in dataset]
y = [data[1] for data in dataset]

Downloading http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/freesolv.zip


Downloaded 0.00 GB: 100%|██████████| 2/2 [00:00<00:00, 12.20it/s]


Extracting dataset/freesolv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 642/642 [00:00<00:00, 95980.87it/s]


Saving...


In [3]:
graph = X_raw[0]

In [4]:
SENTENCE_END_SYMBOL = ' END '
WORD_END_SYMBOL = ' '

In [5]:
def walks_to_string(walks):
    the_string = SENTENCE_END_SYMBOL.join(
        [WORD_END_SYMBOL.join([f'_{str(num)}_' for num in walk]) for walk in walks])

    return the_string

def get_replace_dict(graph):
    replace_dict = dict()
                
    for node in range(graph.number_of_nodes()):
        replace_dict['_' + str(node) + '_'] = ','.join([str(num) for num in graph.nodes[node]['feature']])

    
    #for node in range(graph.number_of_nodes()):
    #    replace_dict['_' + str(node) + '_'] = str(graph.nodes[node]['feature'][0])
            
    return replace_dict

In [6]:
def get_paragraph(graph):
    random_walker = RandomWalker()
    walks = random_walker.random_walks(graph, num_walks=graph.number_of_nodes())
    the_string = walks_to_string(walks)
    replace_dict = get_replace_dict(graph)

    pattern = '|'.join(sorted(re.escape(k) for k in replace_dict))

    the_better_string = re.sub(pattern, lambda m: replace_dict.get(m.group(0).upper()), the_string, flags=re.IGNORECASE)

    return(the_better_string)

In [8]:
ogb_data_helper = OgbDataHelper()
text_file = open(f"{d_name}.txt", "w")

for i, x in enumerate(tqdm(X_raw)):
    graph = ogb_data_helper.get_nx_graph(x)
    paragraph = get_paragraph(graph)
    n = text_file.write(paragraph)
    if i != len(X_raw) - 1:
        text_file.write('\n')

text_file.close()

100%|██████████| 642/642 [00:00<00:00, 3766.71it/s]


In [9]:
random_walker = RandomWalker()
graph = ogb_data_helper.get_nx_graph(x)
walks = random_walker.random_walks(graph, num_walks=graph.number_of_nodes())

graph.nodes[0]

{'feature': array([5, 0, 4, 5, 2, 0, 2, 0, 1])}