In [None]:
import numpy as np
import torch
from torch_geometric.data import Batch
from torch_geometric.utils import to_networkx

from fs_grl.data.io_utils import load_data
import networkx as nx
import scipy
import matplotlib.pyplot as plt

# Input

In [None]:
graph = to_networkx(data_list[0], to_undirected=True)

In [None]:
graph = nx.generators.tutte_graph()
nx.draw(graph)

# Eigendecomposition

In [None]:
laplacian = nx.laplacian_matrix(graph).asfptype()
eigenvals, eigenvecs = scipy.sparse.linalg.eigsh(laplacian, k=2, which="SM")
principal_eigenvec = eigenvecs.transpose()[1]

In [None]:
print(principal_eigenvec)
print(f'norm: {np.linalg.norm(principal_eigenvec)}')

In [None]:
node_color = value2color(principal_eigenvec, min_value=principal_eigenvec.min(), max_value=principal_eigenvec.max())
nx.draw(graph, node_color=node_color, pos=pos)

# Positional encodings

In [None]:
K = 5

In [None]:
positional_encodings = []

for k in range(1, K + 1):
    kth_pos_enc = np.cos(principal_eigenvec * 2 * np.pi * k)
    positional_encodings.append(kth_pos_enc)

positional_encodings = np.stack(positional_encodings, axis=1)

In [None]:
positional_encodings

# Visualization

In [None]:
pos_enc = positional_encodings

In [None]:
def value2color(values, min_value, max_value):
    """
    Normalize between 0 and 1 and convert to colormap
    """
    values_norm = values - min_value
    values_norm /= max_value
    return plt.cm.RdBu(values_norm)

In [None]:
global_min = pos_enc[0][0]
global_max = pos_enc[0][0]

for k in range(K):
    for el in pos_enc[:, k]:
        if el < global_min:
            global_min = el
        if el > global_max:
            global_max = el

print(global_min, global_max)

In [None]:
node_colors = {k: value2color(pos_enc[:, k], global_min, global_max) for k in range(K)}

In [None]:
pos = nx.spring_layout(graph)
for k in range(K):
    nx.draw(graph, node_color=node_colors[k], pos=pos)
    plt.show()
