In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import sqlite3

from matplotlib import pyplot as plt

import os
import sys
module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)

from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
strategy_names = [
    'GOT',
#     'got-original-l',
    'L2',
#     'L2-inv',
#     'GOT-original',
    'random',
#     'true'
    'GW',
#     'fGOT',
#     'rrmw',
]
colors = {'GOT': 'red', 'GOT-original': 'red', 'GW': 'green', 'L2': 'blue', 'L2-inv': 'orange', 'random': 'black', 'fGOT': 'yellow',
         'rrmw': 'purple', 'true': 'purple', 'got-original-l': 'yellow'}
results_folder = '../results/07-01/'

p_values = [0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]

w2_errors = {}
l2_errors = {}
l2_inv_errors = {}
gw_errors = {}
seeds = list(range(0, 25))

empty_seeds = []
for seed in seeds:
    try:
        np.loadtxt(f'{results_folder}/w2_error_got#{seed}.csv')
        np.loadtxt(f'{results_folder}/w2_error_l2#{seed}.csv')
    except (OSError, FileNotFoundError):
        empty_seeds.append(seed)
        
print(f'Remove seeds {empty_seeds}')
for seed in empty_seeds:
    seeds.remove(seed)

for name in strategy_names:
    w2_error_matrix = np.array([np.loadtxt(f'{results_folder}/w2_error_{name}#{seed}.csv') for seed in seeds])
    l2_error_matrix = np.array([np.loadtxt(f'{results_folder}/l2_error_{name}#{seed}.csv') for seed in seeds])
    gw_error_matrix = np.array([np.loadtxt(f'{results_folder}/gw_error_{name}#{seed}.csv') for seed in seeds])
    
    w2_errors[name] = np.mean(w2_error_matrix, axis=0)
    l2_errors[name] = np.mean(l2_error_matrix, axis=0)
    gw_errors[name] = np.mean(gw_error_matrix, axis=0)
    
max_w2_error = np.max(list(w2_errors.values()))
max_l2_error = np.max(list(l2_errors.values()))
max_gw_error = np.max(list(gw_errors.values()))

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 2.3), sharey=True)
ax1.set_ylim(0, 1.02)
ax1.set_title('Normalized L2 error')
ax2.set_title('Normalized GOT error')
ax3.set_title('Normalized GW error')
for name in strategy_names:
    l2_error = l2_errors[name] / max_l2_error
    ax1.plot(p_values, l2_error, label=name, marker='*', c=colors[name])
    w2_error = w2_errors[name] / max_w2_error
    ax2.plot(p_values, w2_error, label=name, marker='*', c=colors[name])
    gw_error = gw_errors[name] / max_gw_error
    ax3.plot(p_values, gw_error, label=name, marker='*', c=colors[name])
ax1.legend(loc='lower left')

fig.suptitle(results_folder, fontsize=16, y=1.1)
# plt.savefig('../plots/alignment_errors.pdf', bbox_inches='tight')
plt.show()

## Original GOT Experiments

In [None]:
strategies = [
    {'name': 'GOT', 'color': 'red'},
#     {'name': 'GOT-original', 'color': 'cyan'},
    {'name': 'L2', 'color': 'blue'},
    {'name': 'random', 'color': 'black'},
    {'name': 'gw', 'color': 'green'},
#     {'name': 'fgot', 'color': 'red', 'display_name': 'fGOT', 'marker': '+'},
#     {'name': 'rrmw', 'color': 'purple'},
#     {'name': 'ipfp', 'color': 'orange'},
]

p_values = [0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]

w2_errors = {}
l2_errors = {}
gw_errors = {}
seeds = list(range(0, 25))

connection = sqlite3.connect(f'../results/got_alignment_07005.db')
cursor = connection.cursor()

for strategy in strategies:
    name = strategy['name']
    w2_errors[name] = []
    l2_errors[name] = []
    gw_errors[name] = []
    for p in p_values:
        cursor.execute(f"SELECT W2_LOSS, L2_LOSS, GW_LOSS FROM alignment WHERE STRATEGY='{name}' and p={p};")
        results = cursor.fetchall()
        assert len(results) > 0, f"No results for strategy '{name}'."
        if len(results) < 25:
            print(f"Only {len(results)} successful seeds for strategy '{name}'.")
        w2_errors[name].append(np.mean(results, axis=0)[0])
        l2_errors[name].append(np.mean(results, axis=0)[1])
        gw_errors[name].append(np.mean(results, axis=0)[2])
cursor.close()
connection.close()

max_w2_error = np.max(list(w2_errors.values()))
max_l2_error = np.max(list(l2_errors.values()))
max_gw_error = np.max(list(gw_errors.values()))

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 2.3), sharey=True)
ax1.set_ylim(0, 1.02)
ax1.set_title('Normalized L2 error')
ax1.set_xlabel('Edge removal probability')
ax2.set_title('Normalized GOT error')
ax2.set_xlabel('Edge removal probability')
ax3.set_title('Normalized GW error')
ax3.set_xlabel('Edge removal probability')
for strategy in strategies:
    name = strategy['name']
    display_name = strategy['display_name'] if 'display_name' in strategy else name
    marker = strategy['marker'] if 'marker' in strategy else '*'
    l2_error = l2_errors[name] / max_l2_error
    ax1.plot(p_values, l2_error, label=display_name, marker=marker, c=strategy['color'], lw=1)
    w2_error = w2_errors[name] / max_w2_error
    ax2.plot(p_values, w2_error, label=display_name, marker=marker, c=strategy['color'], lw=1)
    gw_error = gw_errors[name] / max_gw_error
    ax3.plot(p_values, gw_error, label=display_name, marker=marker, c=strategy['color'], lw=1)
ax1.legend(loc='lower left')

plt.savefig('../plots/alignment_errors_got.pdf', bbox_inches='tight')
plt.show()

In [None]:
connection = sqlite3.connect(f'../results/results_got.db')
cursor = connection.cursor()
cursor.execute(f"SELECT STRATEGY, SEED, P, L2_LOSS, W2_LOSS, GW_LOSS FROM alignment;")
results = cursor.fetchall()
cursor.close()
connection.close()

strategies = [
    {'name': 'GOT', 'color': 'red'},
#     {'name': 'GOT-original', 'color': 'cyan'},
    {'name': 'L2', 'color': 'blue'},
    {'name': 'random', 'color': 'black'},
    {'name': 'GW', 'color': 'green'},
#     {'name': 'fgot', 'color': 'red', 'display_name': 'fGOT', 'marker': '+'},
#     {'name': 'rrmw', 'color': 'purple'},
#     {'name': 'ipfp', 'color': 'orange'},
]

strategy_name = {
    'GOT': 'GOT',
    'random': 'random',
    'Pgot': '$g(L) = L^{\dagger/2}$',
    'Pgw': 'GW',
    'gw': 'GW',
    'PstoH': '$g(L) = L^{\dagger/2}$ stochastic',
    'PLsq': '$g(L) = L^2$',
    'P_nv2': '$g(L) = L^2$ stochastic',
    'L2': 'GOT L2',
}

data = pd.DataFrame(results, columns=['strategy', 'seed', 'p', 'L2 loss', 'GOT error', 'GW error'])
display_names = [strategy_name[name] for name in data['strategy']]
data['name'] = display_names
for name in np.unique(data['strategy']):
    print(f"{strategy_name[name]} : {int(len(data[data['strategy']==name])/10)} repetitions")

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 2.3), sharey=False)
ax1.set_title('Normalized L2 error')
ax1.set_xlabel('Edge removal probability')
ax2.set_title('Normalized GOT error')
ax2.set_xlabel('Edge removal probability')
ax3.set_title('Normalized GW error')
ax3.set_xlabel('Edge removal probability')
sns.lineplot(x='p', y='L2 loss', hue='name', markers=True, dashes=False, style='name', data=data, ax=ax1, errorbar=None)
sns.lineplot(x='p', y='GOT error', hue='name', markers=True, dashes=False, style='name', data=data, ax=ax2, errorbar=None)
sns.lineplot(x='p', y='GW error', hue='name', markers=True, dashes=False, style='name', data=data, ax=ax3, errorbar=None)
ax1.legend(loc='lower left')
plt.show()

## fGOT results

In [None]:
connection = sqlite3.connect(f'../results/results_fgot.db')
cursor = connection.cursor()
cursor.execute(f"SELECT STRATEGY, SEED, P*100, L2_LOSS, W2_LOSS FROM alignment where strategy!='gw';")
results = cursor.fetchall()
cursor.close()
connection.close()

strategy_name = {
    'GOT': 'GOT',
    'random': 'random',
    'Pgot': '$g(L) = L^{\dagger/2}$',
    'Pgw': 'GW',
    'gw': 'GW',
    'PstoH': '$g(L) = L^{\dagger/2}$ stochastic',
    'PLsq': '$g(L) = L^2$',
    'P_nv2': '$g(L) = L^2$ stochastic',
    'L2': 'GOT L2',
}

data = pd.DataFrame(results, columns=['strategy', 'seed', 'p', 'L2 error', 'GOT error'])
display_names = [strategy_name[name] for name in data['strategy']]
data['name'] = display_names
for name in np.unique(data['strategy']):
    print(f"{strategy_name[name]} : {int(len(data[data['strategy']==name])/10)} repetitions")

# Plot L2 error
plt.figure(figsize=(18,6))
sns.lineplot(x='p', y='L2 error', hue='name', markers=True, dashes=False, style='name', data=data)
plt.ylabel('L2 distance', fontsize=20)
plt.xlabel('Graph size', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(prop={"size":20})
plt.savefig('../plots/alignment_errors_fgot.pdf', bbox_inches='tight')
plt.show()

# Plot GOT error
plt.figure(figsize=(18,6))
sns.lineplot(x='p', y='GOT error', hue='name', markers=True, dashes=False, style='name', data=data)
plt.ylabel('GOT distance', fontsize=20)
plt.xlabel('Graph size', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(prop={"size":20})
# plt.savefig('../plots/alignment_errors_fgot.pdf', bbox_inches='tight')
plt.show()

In [None]:
import pickle
from sklearn.model_selection import train_test_split
from sklearn.metrics import zero_one_loss

accuracies = []
for seed in [0, 1, 2]:
    # Load distance matrix
    distances = np.ones((100, 100))
    for i in range(100):
        for j in range(100):
            a = np.loadtxt(f"../results/fgot_distances/{i}-{j}#{seed}.csv")
            if a:
                distances[i,j] = a
            else:
                print(i, j)
                distances[i,j] = np.inf
#             distances.append(np.loadtxt(f"../results/fgot_distances/row{row}#0.csv"))
    distances = np.array(distances)
    
    # Load graph data set
    n_graphs = 100
    path = "../data/ENZYMES/enzymes.pkl"
    with open(path, 'rb') as file:
        graphs = pickle.load(file)
    y = np.array([G.graph['label'] for G in graphs])
    X_train, X_test, y_train, y_test = train_test_split(graphs, y, test_size=n_graphs, random_state=seed)
    
    # Compute nearest neighbors and accuracy
    nearest_neighbors = np.argmin(distances, axis=1)
    y_pred = y_train[nearest_neighbors]
    accuracy = n_graphs - zero_one_loss(y_test, y_pred, normalize=False)
    accuracies.append(accuracy)

accuracies = np.array(accuracies)
print(f'Mean: {accuracies.mean()}')
print(f'Std.: {accuracies.std()}')

In [None]:
connection = sqlite3.connect(f'../results/results_fgot.db')
cursor = connection.cursor()
cursor.execute(f"SELECT * FROM classification;")
results = cursor.fetchall()
cursor.close()
connection.close()
data = pd.DataFrame(results, columns=['strategy', 'data set', 'seed', 'accuracy'])

accuracies = data['accuracy']

average_performance = np.mean(data['accuracy'])
print(accuracies.mean())
print(accuracies.std())