# Group manifold graphs


In this tutorial, we introduce the notion of group manifold graph, a discretization of a Riemannian manifold. At the moment, four manifolds are available: the translation group $\mathbb{R}^2$, the roto-translation group $SE(2)$, the 3d rotation group $SO(3)$ and the 1-sphere $S(2)$.

We define such a graph as following:
- the vertices corresponds to **uniformly sampled** elements on the manifold,
- the edges connects each vertex to its **K nearest neighbors**, w.r.t an **anisotropic riemannian distance**,
- the edges' weights are computed by a **gaussian weight kernel** applied on the riemannian distance between vertices.

In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm

## Create a graph manifold

In [None]:
from gechebnet.graphs.graphs import SE2GEGraph, SO3GEGraph, S2GEGraph, R2GEGraph, RandomSubGraph

In [None]:
r2_graph = R2GEGraph(
    size=[28, 28, 1],
    K=8,
    sigmas=(1., 1., 1.),
    path_to_graph="saved_graphs"
)

In [None]:
se2_graph = SE2GEGraph(
    size=[28, 28, 6],
    K=32,
    sigmas=(1., 0.1, 2.048 / (28 ** 2)),
    path_to_graph="saved_graphs"
)

In [None]:
s2_graph = S2GEGraph(
    size=[642, 1],
    K=8,
    sigmas=(1., 1., 1.),
    path_to_graph="saved_graphs"
)

In [None]:
so3_graph = SO3GEGraph(
    size=[642, 6],
    K=32,
    sigmas=(1., .1, 10/642),
    path_to_graph="saved_graphs"
)

## Get informations

In [None]:
s2_graph.is_connected

In [None]:
s2_graph.is_undirected

In [None]:
s2_graph.manifold

In [None]:
s2_graph.num_nodes

In [None]:
s2_graph.num_edges # number of directed edges

In [None]:
s2_graph.node_index[:10]

In [None]:
s2_graph.node_attributes

In [None]:
s2_graph.node_beta[:10], s2_graph.node_gamma[:10]

In [None]:
s2_graph.edge_index[:10] # dim 0 is source, dim 1 is target

In [None]:
s2_graph.edge_weight[:10] # dim 0 is source, dim 1 is target

In [None]:
s2_graph.edge_sqdist[:10] # dim 0 is source, dim 1 is target

In [None]:
s2_graph.neighborhood(9) # neighbors index, edges' weights and squared riemannian distance

### Static visualization

In [None]:
def plot_graph(graph, size):
    M, L = size

    fig = plt.figure(figsize=(5*L, 5))
    
    X, Y, Z = graph.cartesian_pos()

    for l in range(L):
        ax = fig.add_subplot(1, L, l + 1, projection="3d")
        ax.scatter(X[l*M:(l+1)*M], Y[l*M:(l+1)*M], Z[l*M:(l+1)*M], c="firebrick")
        ax.axis("off")

    fig.tight_layout()

def plot_graph_neighborhood(graph, index, size):
    M, L = size

    fig = plt.figure(figsize=(5, 5))
    
    X, Y, Z = graph.cartesian_pos()

    neighbors_indices, neighbors_weights, _ = graph.neighborhood(index)
    weights = torch.zeros(graph.num_nodes)
    weights[neighbors_indices] = neighbors_weights
    for l in range(L):
        ax = fig.add_subplot(L, 1, l + 1, projection="3d")
        ax.scatter(X[l*M:(l+1)*M], Y[l*M:(l+1)*M], Z[l*M:(l+1)*M], c=weights[l*M:(l+1)*M], cmap=cm.PuRd)
        ax.axis("off")

    fig.tight_layout()

In [None]:
plot_graph(s2_graph, [642, 1])

In [None]:
plot_graph_neighborhood(s2_graph, 406, [642, 1])

### Dynamic visualization

In [None]:
from gechebnet.graphs.viz import visualize_graph, visualize_graph_neighborhood, visualize_graph_signal

In [None]:
visualize_graph(so3_graph)

In [None]:
so3_graph = SO3GEGraph(
    size=[642, 6],
    K=16,
    sigmas=(1., .1, 10/642),
    path_to_graph="saved_graphs"
)

In [None]:
visualize_graph_neighborhood(so3_graph, 0)

In [None]:
signal = torch.rand(s2_graph.num_nodes)
visualize_graph_signal(s2_graph, signal)

## Random sub graph

In [None]:
random_sub_graph = RandomSubGraph(s2_graph)
random_sub_graph.num_nodes, random_sub_graph.num_edges

In [None]:
random_sub_graph.reinit()
random_sub_graph.edge_sampling(0.5)
random_sub_graph.num_nodes, random_sub_graph.num_edges

In [None]:
random_sub_graph.reinit()
random_sub_graph.node_sampling(0.5)
random_sub_graph.num_nodes, random_sub_graph.num_edges