# The Graphons of Different Class of Graphs in one Dataset are distinctly different

In [1]:
import torch
torch.__version__

'1.12.1+cu116'

In [20]:
from torch_geometric.datasets import TUDataset
import os.path as osp
from src.gmixup import prepare_dataset_onehot_y
from src.utils import split_class_graphs
from src.graphon_estimator import universal_svd
from src.utils import align_graphs, stat_graph
import random
import matplotlib.pyplot as plt
import matplotlib as mpl

In [31]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

dataset_names = ['IMDB-BINARY', 'REDDIT-BINARY', 'IMDB-MULTI']
graphon_sizes = [17, 15, 12]
data_path = './'
align_max_size = 500
for dataset_name, graphon_size in zip(dataset_names, graphon_sizes):
    path = osp.join(data_path, dataset_name)
    dataset = TUDataset(path, name=dataset_name)
    dataset = list(dataset)

    for graph in dataset:
        graph.y = graph.y.view(-1)

    dataset = prepare_dataset_onehot_y(dataset)
    random.seed(1314)
    random.shuffle(dataset)
    avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(
        dataset)
    print('Median num nodes: ', int(median_num_nodes))
    class_graphs = split_class_graphs(dataset)
    print('Finished splitting class graphs')
    graphons = []
    for label, graphs in class_graphs:
        align_graphs_list, normalized_node_degrees, max_num, min_num = align_graphs(
            graphs[:align_max_size], padding=True, N=int(graphon_size))
        print('Finished aligning graphs of label ', label)
        graphon = universal_svd(align_graphs_list, threshold=0.2)
        graphons.append((label, graphon))

    fig, ax = plt.subplots(1, len(class_graphs), facecolor='w')

    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.3, 0.05, 0.4])
    for (label, graphon), axis in zip(graphons, ax):
        print(f"graphon info: label:{label}; mean: {graphon.mean()}, shape, {graphon.shape}")
        im = axis.imshow(graphon, vmin=0, vmax=1, cmap=plt.cm.plasma)
        axis.set_title(label)
    fig.colorbar(im, cax=cbar_ax, orientation='vertical', )
    fig.suptitle(dataset_name, y=0.2)
    plt.savefig('test.png', facecolor='white', bbox_inches='tight')

Median num nodes:  17
Finished splitting class graphs
Finished aligning graphs of label  [0. 1.]
Doing SVD of matrix of size:  17
Finished SVD!


KeyboardInterrupt: 