In [None]:
import numpy as np
import numpy.linalg as lg
import networkx as nx
from matplotlib import pyplot as plt

import os
import sys
module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)
from utils.help_functions import remove_edges
from utils import check_permutation_matrix

In [None]:
n = 40
n_blocks = 4
block_size = int(n/n_blocks)
blocks = [block_size] * n_blocks
probs = [[0.70, 0.05, 0.05, 0.05],
         [0.05, 0.70, 0.05, 0.05],
         [0.05, 0.05, 0.70, 0.05],
         [0.05, 0.05, 0.05, 0.70]]

In [None]:
strategy_names = [
#     'GOT',-
    'L2',
    'L2-inv',
#     'random',
    'GW',
    'fGOT',
    'rrmw',
]
colors = {'GOT': 'red', 'GW': 'green', 'L2': 'blue', 'L2-inv': 'orange', 'random': 'black', 'fGOT': 'yellow',
         'rrmw': 'purple'}
results_folder = '../results/'

p_values = [0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]

w2_errors = {}
l2_errors = {}
l2_inv_errors = {}
gw_errors = {}
permutation_matrices = {}
seeds = range(0, 25)
for name in strategy_names:
    w2_error_matrix = np.array([np.loadtxt(f'{results_folder}/w2_error_{name}#{seed}.csv') for seed in seeds])
    l2_error_matrix = np.array([np.loadtxt(f'{results_folder}/l2_error_{name}#{seed}.csv') for seed in seeds])
    l2_inv_error_matrix = np.array([np.loadtxt(f'{results_folder}/l2_inv_error_{name}#{seed}.csv') for seed in seeds])
    gw_error_matrix = np.array([np.loadtxt(f'{results_folder}/gw_error_{name}#{seed}.csv') for seed in seeds])
    
    w2_errors[name] = np.mean(w2_error_matrix, axis=0)
    l2_errors[name] = np.mean(l2_error_matrix, axis=0)
    l2_inv_errors[name] = np.mean(l2_inv_error_matrix, axis=0)
    gw_errors[name] = np.mean(gw_error_matrix, axis=0)
    
#     for seed in seeds:
#         permutation_matrices[seed] = {}
        
#         rng = np.random.default_rng(seed=seed)
        
#         # Original graph
#         G1 = nx.stochastic_block_model(blocks, probs, seed=seed)
#         assert nx.is_connected(G1), 'G1 is not connected.'
#         communities = {}
#         for node in G1.nodes:
#             communities[node] = np.floor(node / block_size)
#         n = len(G1)
#         L1 = nx.laplacian_matrix(G1, range(n))
#         L1 = np.double(np.array(L1.todense()))

#         # Generate permutation matrix
#         idx = rng.permutation(n)
#         P_true = np.eye(n)
#         P_true = P_true[idx, :]
#         for p in p_values:
#             # Changed graph
#             G1_reduced = remove_edges(G1, communities, between_probability=p, within_probability=0.5, seed=rng)
#             L1_reduced = nx.laplacian_matrix(G1_reduced, range(n))
#             L1_reduced = np.double(np.array(L1_reduced.todense()))
#             L2 = P_true @ L1_reduced @ P_true.T
            
#             P = np.loadtxt(f'{results_folder}/permutation_{name}_{p}#{seed}.csv')
#             try:
#                 check_permutation_matrix(P)
#             except ValueError:
#                 print(f'{results_folder}/permutation_{name}_{p}#{seed}.csv')
#                 print(P.sum(axis=0))
#                 raise ValueError()
            
#             permutation_matrices[seed][p] = P
#             print(name, ':', lg.norm(P.T @ L2 @ P - L1, 'fro')**2, 'should be correct')
#             print(name, ':', np.loadtxt(f'{results_folder}/l2_error_{name}#{seed}.csv'))
#             print(name, ':', lg.norm(P @ L2 @ P.T - L1, 'fro'))
    
max_w2_error = np.max(list(w2_errors.values()))
max_l2_error = np.max(list(l2_errors.values()))
max_l2_inv_error = np.max(list(l2_inv_errors.values()))
max_gw_error = np.max(list(gw_errors.values()))

fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(16, 2.3), sharey=True)
ax1.set_ylim(0, 1.02)
ax1.set_title('Normalized L2 error')
ax2.set_title('Normalized L2 error (inv)')
ax3.set_title('Normalized GOT error')
ax4.set_title('Normalized GW error')
for name in strategy_names:
    l2_error = l2_errors[name] / max_l2_error
    ax1.plot(p_values, l2_error, label=name, marker='*', c=colors[name])
    l2_inv_error = l2_inv_errors[name] / max_l2_inv_error
    ax2.plot(p_values, l2_inv_error, label=name, marker='*', c=colors[name])
    w2_error = w2_errors[name] / max_w2_error
    ax3.plot(p_values, w2_error, label=name, marker='*', c=colors[name])
    gw_error = gw_errors[name] / max_gw_error
    ax4.plot(p_values, gw_error, label=name, marker='*', c=colors[name])
ax1.legend(loc='lower left')
plt.savefig('../plots/alignment_errors.pdf')
plt.show()