In [None]:
import datasets as datasets_utils
from itertools import chain
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
from torch_geometric.utils import to_undirected, to_dense_adj, remove_self_loops

In [None]:
# Define consistent colormaps and label flags

strategy_colormap = {
    "mu_f" : "firebrick",
    "uniform" : "cornflowerblue",
    "none" : "#DAA520"
}

strategy_namemap = {
    "mu_f" : "Degree Based",
    "uniform" : "Uniform",
    "none": "None"
}

dataset_density_flags = {
    "cora": "S",
    "citeseer": "S",
    "pubmed": "S",
    "actor": "S",
    "texas": "S",
    "cornell": "S",
    "computers": "D",
    "photo": "D",
    "chameleon": "D",
    "squirrel": "D"
}

### Dense vs. Sparse Spectra Comparison

In [None]:
spectra_dict = {}

for dataset in ["cora", "citeseer", "photo", "chameleon"]:
    data = datasets_utils.load_dataset(dataset, 'dense')
    edge_index = to_undirected(data.edge_index)
    A =  to_dense_adj(edge_index).squeeze()
    n = A.shape[0]
    data.edge_index, data.edge_attr = remove_self_loops(data.edge_index)
    # print(contains_self_loops(data.edge_index))
    I = torch.eye(A.shape[0]) 
    degree = torch.diag(A.sum(-1)**(-0.5))
    degree[torch.isinf(degree)] = 0.
    L_sym = I - degree.mm(A.mm(degree))
    p = A.sum() / (n * n)
    print([n, p, np.log(n)])
    L_hat_var = torch.sqrt((1 - p) / (n * p))
    # e,_=torch.symeig(L_sym,eigenvectors=True)
    e, _ = torch.linalg.eigh(L_sym, UPLO='U')
    spectra_dict[dataset] = e


In [None]:
x = np.arange(0, 2, 0.02)

fig, axs= plt.subplots(figsize=(10, 5),
                       ncols=2,
                       sharex=True,
                       sharey=True,
                       layout="constrained")

color_vector = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
color_idx = 0

for k, v in spectra_dict.items():
    dri = torch.histc(v, 100, -0.001, 2.001) / L_sym.shape[0]
    if k in ('cora', 'citeseer'):
        axs[0].plot(x, dri.cpu().numpy(), 
                    label= "%s" % k.capitalize(),
                    color = color_vector[color_idx])
    else:
        axs[1].plot(x, dri.cpu().numpy(), 
                    label= "%s" % k.capitalize(),
                    color = color_vector[color_idx])

    color_idx += 1

for ax in axs:
    ax.grid()
    ax.tick_params(labelsize=12)
    ax.set_xlabel('λ',fontsize=12)
    ax.legend(loc='upper left')

axs[0].set_ylabel('Density',fontsize=14)
axs[0].set_title(r'Sparse Graphs, $np<\log(n)$')
axs[1].set_title(r'Dense Graphs, $np\geq\log(n)$')

plt.savefig('sparse_dense_comparison.pdf', bbox_inches='tight')

### Correction Methodology Demonstration on Eigenvalue Distributions

In [None]:
def multiplicity_check(e,
                       rtol: float = 1e-5,
                       atol: float = 1e-8):
    e = e.sort()[0]
    e_diffs = torch.diff(e)

    approx_thresh = atol + rtol * e[:-1].abs()
    diff_mask = e_diffs < approx_thresh

    e_multiplicities = {}
    mult_counter = 0
    fix_e_i = e[0]

    for e_i, i in zip(e[1:], range(len(e[1:]))):
        mult_counter += 1
        if not(diff_mask[i]):
            e_multiplicities[fix_e_i] = mult_counter
            mult_counter = 0      
            fix_e_i = e_i
        else:
            continue   

    return e_multiplicities


beta = 0.5
beta_2 = 0.5


for dataset in ['cora']:

    fig, ax = plt.subplots(figsize=(5, 5),
                       ncols=1,
                       sharex=True,
                       sharey=True,
                       layout="constrained")


    ax.grid()
    ax.tick_params(labelsize=12)
    ax.set_xlabel('λ',fontsize=12)
    ax.legend(loc='upper left')
    ax.set_ylabel('Density',fontsize=14)

    data = datasets_utils.load_dataset(dataset, 'dense')
    edge_index = to_undirected(data.edge_index)
    A =  to_dense_adj(edge_index).squeeze()
    n = A.shape[0]
    p = A.sum() / (n * n)
    I = torch.eye(A.shape[0])
    degree = torch.diag(A.sum(-1)**(-0.5))
    degree[torch.isinf(degree)] = 0.
    L = I - degree.mm(A.mm(degree))

    A_e, _ = torch.linalg.eigh((A - p) / np.sqrt(n * p), UPLO='U')
    A_x = np.arange(-5, 5, 0.02)
    A_e_density = torch.histc(A_e, 500, -5.001, 5.001) / n

    bound_value = 2/torch.sqrt(torch.Tensor([(n * p)]))

    L_mod_var = [1 - bound_value,
                    1 + bound_value]
    label_text = "Degree Based Mass Bound"
    L_mod = L
    e, _ = torch.linalg.eigh(L_mod, UPLO='U')

    x = np.arange(0, 2, 0.02)

    print("Homog. Correction Range %s" % L_mod_var)

    uncorrected_e = e.clone().detach()
    uniform_corrected_e = torch.FloatTensor(np.linspace(0, 2, n))
    uniform_corrected_e = beta * e + (1 - beta) * uniform_corrected_e
    e_mask = (e >= L_mod_var[0]) & (e <= L_mod_var[1])
    bounded_e = e[e_mask]
    n_e_mask = len(bounded_e)
    corrected_bounded_e = torch.FloatTensor(np.linspace(L_mod_var[0], 
                                                        L_mod_var[1], 
                                                        n_e_mask).reshape(-1))
    corrected_bounded_e = beta_2 * bounded_e + (1 - beta_2) * corrected_bounded_e
    degree_corrected_e = e.clone().detach()
    degree_corrected_e[e_mask] = corrected_bounded_e

    uncorrected_multiplicities = list(multiplicity_check(uncorrected_e).values())
    uniform_EC_multiplicities = list(multiplicity_check(uniform_corrected_e).values())
    degree_corrected_EC_multiplicities = list(multiplicity_check(degree_corrected_e).values())

    print("Uncorrected Max Mult: %s" % max(uncorrected_multiplicities))
    print("Uniform EC Max Mult: %s" % max(uniform_EC_multiplicities))
    print("Degree EC Max Mult: %s" % max(degree_corrected_EC_multiplicities))

    uncorrected_dri = torch.histc(uncorrected_e, 100, -0.001, 2.001) / n
    uniform_corrected_dri = torch.histc(uniform_corrected_e, 100, -0.001, 2.001) / n
    degree_corrected_dri = torch.histc(degree_corrected_e, 100, -0.001, 2.001) / n

    ax.plot(x, uncorrected_dri.cpu().numpy(),
                label="Uncorrected (%s)" % max(uncorrected_multiplicities),
                color="#DAA520")
    ax.plot(x, uniform_corrected_dri.cpu().numpy(),
                linestyle='--',
                label="Uniform EC (%s)" % max(uniform_EC_multiplicities),
                color="cornflowerblue")
    ax.plot(x, degree_corrected_dri.cpu().numpy(),
                linestyle=':',
                label="Degree Based EC (%s)" % max(degree_corrected_EC_multiplicities),
                color="firebrick")

    xticks = np.linspace(0, 2, 11)
    ax.tick_params(labelsize=12)
    ax.set_xticks(xticks)
    ax.axvline(L_mod_var[0],
                color='red',
                linewidth=0.4)
    ax.axvline(L_mod_var[1],
                color='red',
                linewidth=0.4,
                label=label_text)
    ax.legend(loc='upper center',
        bbox_to_anchor=(0.5, 1.15),
        ncol=2)
    ax.set_xlabel('λ',fontsize=12)

    fig.savefig('%s_correction_comp_orig.pdf' % dataset)

### $\beta$ Tuning Results

In [7]:
jacobi_path = "jacobi_beta_tuning_results"
jacobi_result_file_names = os.listdir(jacobi_path)

jacobi_result_files = [pd.read_csv("%s/%s" % (jacobi_path, x)) for x in jacobi_result_file_names]
jacobi_results = pd.concat(jacobi_result_files)
jacobi_results.drop(jacobi_results.columns[0], axis=1, inplace=True)
max_acc_grouped_results = jacobi_results.groupby(['dataset', 'strategy'])['test_acc'].transform('max')
max_acc_results = jacobi_results[max_acc_grouped_results == jacobi_results['test_acc']]
deduped_max_acc_results = max_acc_results.groupby(['dataset', 'strategy', "test_acc"]).max().reset_index()

In [None]:
bernnet_file_name = "DEC_BernNet_default_config_tuning_results.csv"
bernnet_tuning_results = pd.read_csv(bernnet_file_name)
bernnet_tuning_results.columns = ['beta','dataset', 'test_acc', 'ci', 'val_acc']
filtered_bernnet_results = bernnet_tuning_results.loc[bernnet_tuning_results['beta'] != 1]

gprgnn_file_name = "DEC_GPRGNN_default_config_tuning_results.csv"
grpgnn_tuning_results = pd.read_csv(gprgnn_file_name)
grpgnn_tuning_results.columns = ['beta','dataset', 'test_acc', 'ci', 'val_acc']
filtered_gprgnn_results = grpgnn_tuning_results.loc[grpgnn_tuning_results['beta'] != 1]

muf_jacobi_results = jacobi_results.loc[jacobi_results['strategy']=='mu_f']
filtered_jacobi_results = muf_jacobi_results.loc[muf_jacobi_results['beta'] != 1]
filtered_jacobi_results['test_acc'] = filtered_jacobi_results['test_acc'] * 100

In [None]:
comp_ds_list = ['cora', 'citeseer', 'chameleon' ,'squirrel']

fig, axs = plt.subplots(figsize=(10, 10),
                        nrows=2,
                        ncols=2,
                        layout='constrained')

is_legend = False

for ds, ax in zip(comp_ds_list, list(chain.from_iterable(axs))):
    jacobi_ds_results = filtered_jacobi_results.loc[filtered_jacobi_results['dataset'] == ds]
    bernnet_ds_results = filtered_bernnet_results.loc[filtered_bernnet_results['dataset'] == ds]
    gprgnn_ds_results = filtered_gprgnn_results.loc[filtered_gprgnn_results['dataset'] == ds]
    result_df_list = [jacobi_ds_results, bernnet_ds_results, gprgnn_ds_results]
    result_labels = ['DEC-Jacobi', 'DEC-BernNet', 'DEC-GPRGNN']
    result_colors = ['green' , 'red', 'blue']


    for result_df, result_label, result_color in zip(result_df_list, result_labels, result_colors):
        if not is_legend:
             ax.plot(result_df['beta'],
                        result_df['test_acc'],
                        marker='x',
                        linestyle='--',
                        label=result_label,
                        color=result_color)
        else:
             ax.plot(result_df['beta'],
                        result_df['test_acc'],
                        marker='x',
                        linestyle='--',
                        color=result_color)
        max_acc_beta = result_df.loc[result_df['test_acc']==result_df['test_acc'].max(),
                                     'beta'].to_list()[0]
    ax.set_xlabel(r"$\beta$", fontsize=12)
    ax.set_ylabel("Mean Accuracy (%)", fontsize=12)
    ax.grid()
    ax.tick_params(labelsize=12)
    ax.set_ylim(0, 100)
    ax.set_title(ds.capitalize(),
                 fontsize=16)
    
    is_legend = True

fig.legend(loc='upper center', ncol=3,
           bbox_to_anchor=(0.5, 1.04))
fig.savefig("comparison_quad_plot.pdf", bbox_inches='tight')

In [None]:
filtered_bernnet_results.loc[:, 'filter'] = 'bernnet'
filtered_gprgnn_results.loc[:, 'filter'] = 'gprgnn'
filtered_jacobi_results.loc[:, 'filter'] = 'jacobi'

joined_filter_results = pd.concat([filtered_bernnet_results, filtered_gprgnn_results, filtered_jacobi_results])
joined_filter_results.reset_index(inplace=True)
joined_filter_results.drop(['strategy', 'index'], axis=1, inplace=True)

all_muf_filter_maxes = joined_filter_results.loc[joined_filter_results.groupby(['dataset', 'filter'])['test_acc'].idxmax()]

### JacobiConv Results

#### Optimized $\beta$

In [None]:
fig, ax = plt.subplots(figsize=(15, 7), layout="constrained")
strategies = ['none', 'uniform', 'mu_f']

datasets = all_muf_filter_maxes['dataset'].unique()
dataset_labels = ["%s (%s)" % (x.capitalize(), dataset_density_flags[x]) for x in datasets]
dataset_labels.sort()

x = np.arange(len(datasets))  
width = 0.23
width_fix = 0.25
multiplier = 0

for strat in strategies:
    offset = width_fix * multiplier
    name_mapped_strat = strategy_namemap[strat]
    if strat in ('none', 'uniform'):
        strat_subdf = deduped_max_acc_results.loc[deduped_max_acc_results['strategy'] == strat].sort_values('dataset')
        rects = ax.bar(x + offset, 
                   strat_subdf['test_acc'] * 100, 
                   width, 
                   label=name_mapped_strat,
                   color=strategy_colormap[strat])
    else:
        strat_subdf = all_muf_filter_maxes.loc[all_muf_filter_maxes['filter'] == 'jacobi'].sort_values('dataset')
        rects = ax.bar(x + offset, 
                   strat_subdf['test_acc'], 
                   width, 
                   label=name_mapped_strat,
                   color=strategy_colormap[strat])
    multiplier += 1

ax.grid()
ax.set_xlabel("Dataset", fontsize=12)
ax.set_ylabel("Mean Accuracy (%)", fontsize=12)
ax.set_xticks(x + width, dataset_labels)
ax.tick_params(labelsize=12)
ax.legend(loc='upper left', ncols=3)
fig.suptitle(r'Mean Accuracy of JacobiConv Correction Strategies on Real-World Datasets (Optimized $\beta$)')

fig.savefig("final_accuracy_jacobiconv_comparison.pdf", bbox_inches='tight')

### GPRGNN Results

In [150]:
results_file_fmt = "%s_default_config_training_results.csv"

gprgnn_file_prefixes = ["DEC-GPRGNN", "EC_GPRGNN", "GPRGNN"]
gprgnn_file_names = [results_file_fmt % x for x in gprgnn_file_prefixes]
gprgnn_files = [pd.read_csv(x) for x in gprgnn_file_names]
gprgnn_strategy_name_padder = ["mu_f"] * 10 + ["uniform"] * 10 + ["none"] * 10
gprgnn_results = pd.concat(gprgnn_files)
gprgnn_results['strategy'] = gprgnn_strategy_name_padder

In [None]:
fig, ax = plt.subplots(figsize=(15, 7), layout="constrained")

datasets = gprgnn_results['dataset'].unique()
dataset_labels = ["%s (%s)" % (x.capitalize(), dataset_density_flags[x]) for x in datasets]
dataset_labels.sort()

x = np.arange(len(datasets))  
width = 0.23
width_fix = 0.25
multiplier = 0

for strat in strategies:
    offset = width_fix * multiplier
    name_mapped_strat = strategy_namemap[strat]
    if strat in ('none', 'uniform'):
        strat_subdf = gprgnn_results.loc[gprgnn_results['strategy'] == strat].sort_values('dataset')
        rects = ax.bar(x + offset, 
                   strat_subdf['test_accuracy'], 
                   width, 
                   label=name_mapped_strat,
                   color=strategy_colormap[strat])
    else:
        strat_subdf = all_muf_filter_maxes.loc[all_muf_filter_maxes['filter'] == 'gprgnn'].sort_values('dataset')
        rects = ax.bar(x + offset, 
                   strat_subdf['test_acc'], 
                   width, 
                   label=name_mapped_strat,
                   color=strategy_colormap[strat])
    multiplier += 1

ax.grid()
ax.set_xlabel("Dataset", fontsize=12)
ax.set_ylabel("Mean Accuracy (%)", fontsize=12)
ax.set_xticks(x + width, dataset_labels)
ax.tick_params(labelsize=12)
ax.legend(loc='upper left', ncols=3)
fig.suptitle(r'Mean Accuracy of GPRGNN Correction Strategies on Real-World Datasets')

fig.savefig("final_accuracy_gprgnn_comparison.pdf", bbox_inches='tight')

### BernNet Results

In [164]:
bernnet_file_prefixes = ["DEC_BernNet", "EC_BernNet", "BernNet"]
bernnet_file_names = [results_file_fmt % x for x in bernnet_file_prefixes]
bernnet_files = [pd.read_csv(x) for x in bernnet_file_names]
bernnet_strategy_name_padder = ["mu_f"] * 10 + ["uniform"] * 10 + ["none"] * 10
bernnet_results = pd.concat(bernnet_files)
bernnet_results['strategy'] = bernnet_strategy_name_padder

In [None]:
fig, ax = plt.subplots(figsize=(15, 7), layout="constrained")

datasets = bernnet_results['dataset'].unique()
dataset_labels = ["%s (%s)" % (x.capitalize(), dataset_density_flags[x]) for x in datasets]
dataset_labels.sort()

x = np.arange(len(datasets))  
width = 0.23
width_fix = 0.25
multiplier = 0

for strat in strategies:
    offset = width_fix * multiplier
    name_mapped_strat = strategy_namemap[strat]
    if strat in ('none', 'uniform'):
        strat_subdf = bernnet_results.loc[bernnet_results['strategy'] == strat].sort_values('dataset')
        rects = ax.bar(x + offset, 
                   strat_subdf['test_accuracy'], 
                   width, 
                   label=name_mapped_strat,
                   color=strategy_colormap[strat])
    else:
        strat_subdf = all_muf_filter_maxes.loc[all_muf_filter_maxes['filter'] == 'bernnet'].sort_values('dataset')
        rects = ax.bar(x + offset, 
                   strat_subdf['test_acc'], 
                   width, 
                   label=name_mapped_strat,
                   color=strategy_colormap[strat])
    multiplier += 1

ax.grid()
ax.set_xlabel("Dataset", fontsize=12)
ax.set_ylabel("Mean Accuracy (%)", fontsize=12)
ax.set_xticks(x + width, dataset_labels)
ax.tick_params(labelsize=12)
ax.legend(loc='upper left', ncols=3)
fig.suptitle(r'Mean Accuracy of BernNet Correction Strategies on Real-World Datasets')

fig.savefig("final_accuracy_bernnet_comparison.pdf", bbox_inches='tight')

### Final Accuracy Plots

In [None]:
all_muf_filter_maxes_dupe = all_muf_filter_maxes.copy()
all_muf_filter_maxes_dupe.loc[:, 'strategy'] = 'mu_f'
bernnet_results_dupe = bernnet_results.copy().loc[bernnet_results['strategy'] != 'mu_f']
bernnet_results_dupe.columns = ['dataset', 'test_acc', 'ci', 'val_acc', 'strategy']
bernnet_results_dupe.loc[:, 'filter'] = 'bernnet'
gprgnn_results_dupe = gprgnn_results.copy().loc[gprgnn_results['strategy'] != 'mu_f']
gprgnn_results_dupe.columns = ['dataset', 'test_acc', 'ci', 'val_acc', 'strategy']
gprgnn_results_dupe.loc[:, 'filter'] = 'gprgnn'
jacobi_results_dupe = deduped_max_acc_results.copy().loc[deduped_max_acc_results['strategy'] != 'mu_f']
jacobi_results_dupe.loc[:, 'filter'] = 'jacobi'
jacobi_results_dupe['test_acc'] = jacobi_results_dupe['test_acc'] * 100

final_accuracices = pd.concat([all_muf_filter_maxes_dupe, bernnet_results_dupe,
                               gprgnn_results_dupe, jacobi_results_dupe])

final_acc_table = final_accuracices.pivot(index=['dataset'], columns=['filter', 'strategy'], values='test_acc').round(3)
final_ci_table = (final_accuracices.pivot(index=['dataset'], columns=['filter', 'strategy'], values='ci') * 100).round(3) 

In [None]:
fig, axs = plt.subplots(figsize=(10, 10),
                        nrows=2,
                        ncols=2,
                        layout='constrained')

filters = ["BernNet", 'GPRGNN',  "JacobiConv"]

x = np.arange(len(filters))  
width = 0.23
width_fix = 0.25
multiplier = 0

is_legend = False

flattened_axs = list(chain.from_iterable(axs))

inset_ds_list = ['cora', 'citeseer', 'chameleon', 'squirrel']

for ds, ax in zip(inset_ds_list, flattened_axs):
    ds_accs = final_acc_table.loc[ds, ].reset_index().sort_values('filter')
    ds_cis = final_ci_table.loc[ds, ].reset_index().sort_values('filter')

    multiplier = 0

    for strat in ['none', 'uniform', 'mu_f']:

        offset = width_fix * multiplier

        name_mapped_strat = strategy_namemap[strat]

        rects = ax.bar(x + offset, 
                    ds_accs.loc[ds_accs['strategy']==strat, ds], 
                    width, 
                    label=name_mapped_strat if not is_legend else None,
                    edgecolor="white",
                    linewidth=0,
                    color=strategy_colormap[strat])
        err = ax.errorbar(x + offset,
                          ds_accs.loc[ds_accs['strategy']==strat, ds],
                          yerr=ds_cis.loc[ds_accs['strategy']==strat, ds],
                          linestyle='none',
                          capsize=12,
                          color='red',
                          linewidth=0.8,
                          capthick=0.8)
        multiplier += 1

    is_legend = True

    ax.grid()
    ax.set_xlabel(r"Filter", fontsize=12)
    ax.set_ylabel("Mean Accuracy (%)", fontsize=12)
    ax.set_xticks(x + width, filters)
    ax.tick_params(labelsize=12)
    ax.set_ylim(0, 100)
    ax.set_title(ds.capitalize(),
                 fontsize=16)

fig.legend(loc='upper center', ncol=3,
           bbox_to_anchor=(0.5, 1.04))

fig.savefig("final_accuracy_quadplot_comparison.pdf", bbox_inches='tight')