In [None]:
import matplotlib.pyplot as plt

import torch

In [None]:
from gechebnet.graphs.graphs import SE2GEGraph, SO3GEGraph
from gechebnet.graphs.viz import visualize_graph_signal 


In [None]:
from gechebnet.liegroups.so3 import so3_uniform_sampling

In [None]:
from gechebnet.liegroups.se2 import se2_uniform_sampling
import matplotlib.cm as cm
import numpy as np

In [None]:
def plot_eigenspace(graph, indices):
    L = graph.size[0]
    M = np.prod(graph.size[1:])
    K = len(indices)

    _, eigenvec = graph.get_eigen_space()
    eigenvec = torch.from_numpy(eigenvec)
    
    fig = plt.figure(figsize=(4*K, 4*L))
    
    X, Y, Z = graph.cartesian_pos()
    
    for i, k in enumerate(indices):
        for l in range(L):
            ax = fig.add_subplot(L, K, l * K + i + 1)
            ax.scatter(X[l*M:(l+1)*M], Y[l*M:(l+1)*M], c=eigenvec[l*M:(l+1)*M, k], cmap=cm.PiYG)
            ax.axis("off")
            
    fig.tight_layout()


In [None]:
def plot_eigenvalues(graph, indices):
    
    eigenval, _ = graph.get_eigen_space()
    
    fig = plt.figure(figsize=(10, 2))
    
    plt.scatter(indices, eigenval[indices], s=20, c="firebrick")
    plt.xlabel(r"$k$")
    plt.ylabel(r"$\lambda_k$")
    plt.xlim(min(indices)-1, max(indices)+1)
    
    fig.tight_layout()

# Projective line bundle of the SE(2) group

In [None]:
eps, xi = 1., 1. #0.1, 2.048 / (28 ** 2)

In [None]:
se2_sampling = se2_uniform_sampling(28,28,1)

In [None]:
se2_graph = SE2GEGraph(
    se2_sampling,
    K=8,
    sigmas=(1., eps, xi),
    path_to_graph="saved_graphs"
)

In [None]:
plot_eigenspace(se2_graph,[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])

In [None]:
plot_eigenvalues(se2_graph, np.arange(30))

## Eigenspace

In [None]:
eigenval, eigenvec = se2_graph.get_eigen_space()
eigenval = torch.from_numpy(eigenval)
eigenvec = torch.from_numpy(eigenvec)

In [None]:
fig = plt.figure(figsize=(40, 20))

_ = plt.scatter(torch.arange(50), eigenval[:50])
_ = plt.xlabel("index")
_ = plt.ylabel("eigenvalue")

In [None]:
import matplotlib.cm as cm

num_vecs = 8
num_layers = 6

fig = plt.figure(figsize=plt.figaspect(1.))

fig, axes = plt.subplots(num_layers, num_vecs, squeeze=False, figsize=(num_vecs*4, num_layers*4), projection='3d')

for k in range(num_vecs):
    for l in range(num_layers):
        axes[l][k].scatter(se2_graph.node_x[l*784:(l+1)*784], se2_graph.node_y[l*784:(l+1)*784], c=eigenvec[l*784:(l+1)*784,k+1], cmap=cm.PiYG)
        axes[l][k].axis("off")

plt.show()
print(eigenval[1:9])

In [None]:
visualize_graph_signal(se2_graph, eigenvec[:,1])

In [None]:
visualize_graph_signal(se2_graph, eigenvec[:,2], title=fr"$\lambda_{2}={eigenval[2].item()}$")

In [None]:
visualize_graph_signal(se2_graph, eigenvec[:,3], title=fr"$\lambda_{3}={eigenval[3].item()}$")

In [None]:
visualize_graph_signal(se2_graph, eigenvec[:,4], title=fr"$\lambda_{4}={eigenval[4].item()}$")

In [None]:
visualize_graph_signal(se2_graph, eigenvec[:,5], title=fr"$\lambda_{5}={eigenval[5].item()}$")

# Projective line bundle of the SO(3) group

In [None]:
eps, xi = 1., 1.

In [None]:
path_to_sampling = "/home/hugo/Documents/thesis/GroupEquivariantChebNets/gechebnet/liegroups/sampling"
so3_sampling = so3_uniform_sampling(path_to_sampling, 3, 1)

In [None]:
so3_graph = SO3GEGraph(
    so3_sampling,
    K=16,
    sigmas=(1., eps, xi),
    path_to_graph="saved_graphs"
)

In [None]:
eigenval, eigenvec = so3_graph.get_eigen_space()
eigenval = torch.from_numpy(eigenval)
eigenvec = torch.from_numpy(eigenvec)

In [None]:
so3_graph.edge_weight

In [None]:
_ = plt.scatter(torch.arange(30), eigenval[:30])
_ = plt.xlabel("index")
_ = plt.ylabel("eigenvalue")

In [None]:
import matplotlib.cm as cm

num_vecs = 8
num_layers = 6

fig, axes = plt.subplots(num_layers, num_vecs, squeeze=False, figsize=(num_vecs*4, num_layers*4))

for k in range(num_vecs):
    for l in range(num_layers):
        axes[l][k].scatter(so3_graph.node_x[l*784:(l+1)*784], so3_graph.node_y[l*784:(l+1)*784], c=eigenvec[l*784:(l+1)*784,k+1], cmap=cm.PiYG)

plt.show()

In [None]:
so3_graph.num_nodes

## Eigenspace

In [None]:
eigenval, eigenvec = so3_graph.eigen_space

In [None]:
eigenval = torch.from_numpy(eigenval)
eigenvec = torch.from_numpy(eigenvec)

### Frequencies

In [None]:
_ = plt.scatter(torch.arange(so3_graph.num_nodes), eigenval, s=1)
_ = plt.xlabel("index")
_ = plt.ylabel("eigenvalue")

### Fourier basis

In [None]:
_ = plt.scatter(torch.arange(20), eigenval[:20])
_ = plt.xlabel("index")
_ = plt.ylabel("eigenvalue")

In [None]:
visualize_graph_signal(so3_graph, eigenvec[:,0], title=fr"$\lambda_{0}={eigenval[0].item()}$")

In [None]:
visualize_graph_signal(so3_graph, eigenvec[:,1], title=fr"$\lambda_{1}={eigenval[1].item()}$")

In [None]:
visualize_graph_signal(so3_graph, eigenvec[:,2], title=fr"$\lambda_{2}={eigenval[2].item()}$")

In [None]:
visualize_graph_signal(so3_graph, eigenvec[:,3], title=fr"$\lambda_{3}={eigenval[3].item()}$")

In [None]:
visualize_graph_signal(so3_graph, eigenvec[:,4], title=fr"$\lambda_{4}={eigenval[4].item()}$")

In [None]:
visualize_graph_signal(so3_graph, eigenvec[:,5], title=fr"$\lambda_{5}={eigenval[5].item()}$")