In [1]:
import argparse

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.nn import functional as F
from torchvision.utils import make_grid
from torchvision.utils import save_image

import os

import pickle
from collections import Counter
import numpy as np

In [2]:
# ds_name = '../../data/collab.graph'   # Without node features
ds_name = '../../data/mutag.graph'   # With node features

In [3]:
def load_data(ds_name):
    f = open(ds_name, "rb")
    print("Found dataset:", ds_name)
    data = pickle.load(f, encoding="latin1")
    graph_data = data["graph"]
    labels = data["labels"]
    labels  = np.array(labels, dtype = np.float)
    return graph_data, labels

In [4]:
graphs, labels = load_data(ds_name)

Found dataset: ../../data/mutag.graph


In [5]:
print ("Dataset: %s \nNumber of Graphs: %s \nLabel distribution: %s"%(ds_name, len(graphs), Counter(labels)))
avg_edges = []
avg_nodes = []
n_features = 0
avg_features = []
features = []
for gidxs, nodes in graphs.items():
    for n in nodes:
        avg_edges.append(len(nodes[n]['neighbors']))
        if nodes[n]['label'] != '':
            n_features += 1
            avg_features.append(len(nodes[n]['label']))
            features.append(nodes[n]['label'])
        else:
            avg_features.append(0)
            features.append(None)
    avg_nodes.append(len(nodes))
print("\nMean #nodes: %s \nMedian #nodes: %s \nMax #nodes: %s \nMin #nodes: %s \nTotal #nodes: %s"%(np.mean(avg_nodes), np.median(avg_nodes), max(avg_nodes), min(avg_nodes), sum(avg_nodes)))
print("\nMean #edges: %s \nMedian #edges: %s \nMax #edges: %s \nMin #edges: %s \nTotal #edges: %s"%(np.mean(avg_edges), np.median(avg_edges), max(avg_edges), min(avg_edges), sum(avg_edges)))
print("\nMean #features_len: %s \nMedian #features_len: %s \nMax #features_len: %s \nMin #features_len: %s \nTotal #features_len: %s"%(np.mean(avg_features), np.median(avg_features), max(avg_features), min(avg_features), sum(avg_features)))
print("\nNumber of nodes with features: %s"%(n_features))
print("Features distribution: %s"%(Counter(features)))

Dataset: ../../data/mutag.graph 
Number of Graphs: 188 
Label distribution: Counter({1.0: 125, -1.0: 63})

Mean #nodes: 17.930851063829788 
Median #nodes: 17.5 
Max #nodes: 28 
Min #nodes: 10 
Total #nodes: 3371

Mean #edges: 2.2076535152773658 
Median #edges: 2.0 
Max #edges: 4 
Min #edges: 1 
Total #edges: 7442

Mean #features_len: 1.0 
Median #features_len: 1.0 
Max #features_len: 1 
Min #features_len: 1 
Total #features_len: 3371

Number of nodes with features: 3371
Features distribution: Counter({(3,): 2395, (7,): 593, (6,): 345, (2,): 23, (4,): 12, (1,): 2, (5,): 1})


In [6]:
graph = graphs[5]
nodes = graph.keys()
print("Example of Graph keys (e.g. node_ids):\n", nodes)

Example of Graph keys (e.g. node_ids):
 dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])


In [7]:
node = graph[0]
print("Example of Node:\n",node)

Example of Node:
 {'neighbors': array([ 1, 13], dtype=uint8), 'label': (3,)}
