In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch_geometric.loader import DataLoader

from loader.random_graphs import GraphonDataset
from models.mpnn import MPNN
from models.iwn import IWN2
# from models.simple_higher_order_wnn import SimpleGNN3

In [None]:
# Custom style options.
sns.set_style("white")
sns.set_context("paper", font_scale=1, rc={
        "lines.linewidth": 1.2,
        "xtick.major.size": 0,
        "xtick.minor.size": 0,
        "ytick.major.size": 0,
        "ytick.minor.size": 0
    })

plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.family'] = 'STIXGeneral'

plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['figure.autolayout'] = True
plt.rcParams['axes.grid'] = True

plt.rc('font', size=13)
plt.rc('axes', titlesize=13)
plt.rc('axes', labelsize=14)
plt.rc('xtick', labelsize=13)
plt.rc('ytick', labelsize=13)
plt.rc('legend', fontsize=13)
plt.rc('figure', titlesize=16)

### Datasets

In [None]:
node_sizes = [10, 200, 400, 600, 800, 1000]
simple_graphs_per_size = 100

In [None]:
datasets = {
    'ER': [
        GraphonDataset(
            node_sizes=node_sizes,
            graphs_per_size=simple_graphs_per_size,
            signal='constant',
            graph_type='sbm',
            params={'num_blocks': 1, 'intra_prob': 0.5, 'inter_prob': 0.5},
            weighted=False,
            name='er',
        ),
        GraphonDataset(
            node_sizes=node_sizes,
            graphs_per_size=1,
            signal='constant',
            graph_type='sbm',
            params={'num_blocks': 1, 'intra_prob': 0.5, 'inter_prob': 0.5},
            weighted=True,
            name='er',
        ),
    ],
    'SBM': [
        GraphonDataset(
            node_sizes=node_sizes,
            graphs_per_size=simple_graphs_per_size,
            signal='constant',
            graph_type='sbm',
            params={'num_blocks': 5, 'intra_prob': 0.8, 'inter_prob': 0.3},
            weighted=False,
        ),
        GraphonDataset(
            node_sizes=node_sizes,
            graphs_per_size=1,
            signal='constant',
            graph_type='sbm',
            params={'num_blocks': 5, 'intra_prob': 0.8, 'inter_prob': 0.3},
            weighted=True,
        ),
    ],
    'Triangular': [
        GraphonDataset(
            node_sizes=node_sizes,
            graphs_per_size=simple_graphs_per_size,
            signal='constant',
            graph_type='triangular',
            weighted=False,
        ),
        GraphonDataset(
            node_sizes=node_sizes,
            graphs_per_size=1,
            signal='constant',
            graph_type='triangular',
            weighted=True,
        ),
    ],
    'Narrow': [
        GraphonDataset(
            node_sizes=node_sizes,
            graphs_per_size=simple_graphs_per_size,
            signal='constant',
            graph_type='smooth_narrow',
            weighted=False,
        ),
        GraphonDataset(
            node_sizes=node_sizes,
            graphs_per_size=1,
            signal='constant',
            graph_type='smooth_narrow',
            weighted=True,
        ),
    ],
}

### Graphon examples plots

In [None]:
markers_mpl = [
    "o", "^", "s", "*"
]

markers = [
    "●", "▲", "■", "★"
]

In [None]:
examples = {name: data[0][100] for name, data in datasets.items()}

In [None]:
def plot_multiple_graphons(examples):
    _, axes = plt.subplots(2, 2, figsize=(2.5, 2.5), dpi=300)
    axes = axes.ravel()
    
    for (i, ax), (name, graphon) in zip(enumerate(axes), examples.items()):
        adj_matrix = torch.zeros(graphon.num_nodes,
                                 graphon.num_nodes,
                                 dtype=torch.float32)
        edge_weights = (graphon.edge_weight if 'edge_weight' in graphon
                        else graphon.weight)
        adj_matrix[graphon.edge_index[0], graphon.edge_index[1]] = edge_weights
        ax.imshow(1-adj_matrix, cmap='gray', interpolation='none',
               vmin=0, vmax=1)
        # plt.rcParams['font.family'] = 'DejaVu'
        ax.set_title(f'{name} {markers[i]}', fontsize=11, pad=3, color=f'C{i}',
                     fontweight='semibold')
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        for spine in ax.spines.values():
            spine.set_linewidth(0.5)
        # ax.axis('off')
    
    plt.tight_layout(pad=0.1)
    plt.savefig('outputs/graphons.pdf')
    plt.show()

In [None]:
plot_multiple_graphons(examples)

### Continuity and Transferability

In [None]:
models = {
    'MPNN': MPNN(in_channels=1, hidden_channels=16, num_layers=2),
    '2-IWN': IWN2(in_channels=1, hidden_channels=16, num_layers=2),
}

In [None]:
def get_outputs_by_size(model, dataset):
    size_outputs = {size: [] for size in node_sizes}
    with torch.no_grad():
        dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
        for batch in dataloader:
            size = int(batch.x.size(0))
            out = model(batch.x, batch.edge_index,
                        batch.weight, batch.batch)
            size_outputs[size].append(out.item())
    return size_outputs

In [None]:
plt.rcParams['font.family'] = 'STIXGeneral'

for model_name, model in models.items():

    plt.figure(figsize=(2.5, 2.5), dpi=200)

    for i, (dataset_name, (dataset_01, dataset_weighted)) in \
        enumerate(datasets.items()):

        outputs_weighted = get_outputs_by_size(model, dataset_weighted)
        outputs_01 = get_outputs_by_size(model, dataset_01)
        limit = outputs_weighted[max(node_sizes)][0]

        means_01, conf_intervals_01 = [], []
        for size in node_sizes[1:]:
            abs_errors = np.abs(np.array(outputs_01[size]) - limit)
            mean = np.mean(abs_errors)
            std_error = np.std(abs_errors, ddof=1) / np.sqrt(len(abs_errors))
            conf_interval = 1.96 * std_error
            means_01.append(mean)
            conf_intervals_01.append(conf_interval)
        start_value = means_01[0]

        plt.plot(node_sizes[1:], means_01 / start_value, color=f'C{i}',
                 marker=markers_mpl[i])
        # plt.errorbar(node_sizes,
        #              means_01 / start_value,
        #              yerr=conf_intervals_01 / start_value,
        #              fmt='o-',
        #              capsize=5,
        #              capthick=2,
        #              color=f'C{i}')
        plt.gca().set_box_aspect(1)
        plt.xticks([min(node_sizes[1:]), max(node_sizes[1:])])
        plt.xlabel("# nodes", labelpad=-5)
        plt.ylabel("Abs. error", labelpad=0)
        plt.grid(True)
    
    plt.tight_layout(pad=0.0)
    plt.savefig(f'outputs/abs_error_{model_name}.pdf')
    plt.show()

In [None]:
for model_name, model in models.items():

    plt.figure(figsize=(2.5, 2.5), dpi=200)

    for i, (dataset_name, (dataset_01, dataset_weighted)) in \
        enumerate(datasets.items()):

        outputs_weighted = get_outputs_by_size(model, dataset_weighted)
        outputs_01 = get_outputs_by_size(model, dataset_01)

        conf_intervals_01 = []
        for size in node_sizes[1:]:
            outputs_size = np.array(outputs_01[size])
            outputs_range = (
                np.quantile(outputs_size, 0.9) - np.quantile(outputs_size, 0.1))
            conf_intervals_01.append(outputs_range)
        start_value = conf_intervals_01[0]

        plt.plot(node_sizes[1:], conf_intervals_01 / start_value,
                 color=f'C{i}', marker=markers_mpl[i])
        # plt.errorbar(node_sizes,
        #              means_01 / start_value,
        #              yerr=conf_intervals_01 / start_value,
        #              fmt='o-',
        #              capsize=5,
        #              capthick=2,
        #              color=f'C{i}')
        plt.gca().set_box_aspect(1)
        plt.xticks([min(node_sizes[1:]), max(node_sizes[1:])])
        plt.xlabel("# nodes", labelpad=-5)
        plt.ylabel(r"$q_{.95} - q_{.05}$", labelpad=0)
        plt.grid(True)
    
    plt.tight_layout(pad=0.0)
    plt.savefig(f'outputs/transferability_{model_name}.pdf')
    plt.show()