# Group Equivariant Graph Tutorial

## Introduction

In this tutorial, we present the different functionnalities of our implementation of the group equivariant graph. Such a graph is a k-NN graph representing the structure of a group, a set equipped with a binary operation called the group product. Each vertex of the graph represents an element of the group and the edge's weights corresponds to the Riemannian distance between these element. 

In [None]:
import torch
import numpy as np
import math
import matplotlib.pyplot as plt

#import pykeops
#pykeops.clean_pykeops()

In [None]:
from gechebnet.graph.visualization import visualize_graph_neighborhood, visualize_graph

## Package

The `gechebnet` package implements such a graph and have the basic functions to interact with them. For the most part, it uses PyTorch for tensor operations and PyKeops to achieve a memory and time efficient graph construction.

## Graph class

The main class of the package is the `Graph` class. It stores a graph in a very efficient way using `node_index`, `edge_index` and `edge_weight` attributes. Very easily, we can get important properties of a graph with `num_nodes` or `num_edges`. Other method, properties and attributes are available, we will see them later. However, this class just gives the main structure of a graph, the group equivariant graph inherits from this class.

An important remark is that our graphs are necessarily symmetric, that's why we refer the `knn` parameter as the **maximum** number of neighbours of a vertex.

In [None]:
from gechebnet.graph.graph import SE2GEGraph, SO3GEGraph

## SE(2) Group Equivariant Graph

The SE(2) group equivariant graph is the first group equivariant graph we have implemented. To create such a graph, different variables have to be speficied:
- `nx`: the discretization on the x axis, i.e. the number of points to uniformly sample in the x direction.
- `ny`: the discretization on the y axis, i.e. the number of points to uniformly sample in the y direction.
- `nsym`: the discretization on the symmetry axis, i.e. the number of points to uniformly sample in the orientation domain.
- `knn`: the maximum number of neighbors of each vertex. 
- `sigmas`: the anisotropy's parameters for the computation of the Riemannian distance between vertices.
- `weight_kernel`: the weight kernel, that is a function taking as input the squared Riemannian distance and returning a weight.
- `kappa`: the edge's compression rate, the rate of edges to drop during random compression.

In [None]:
eps, xi = .1, 10.

In [None]:
se2_graph = SE2GEGraph(
    nx=30,
    ny=30,
    ntheta=10,
    K=16,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, tc: torch.exp(-sqdistc / (4*tc)),
)

In [None]:
visualize_graph(se2_graph)

In [None]:
visualize_graph_neighborhood(se2_graph, 3000)

In [None]:
visualize_graph_neighborhood(se2_graph, 455)

In [None]:
se2_graph.node_x[455], se2_graph.node_y[455], se2_graph.node_theta[455]

In [None]:
f"The SE(2) graph has {se2_graph.num_nodes} vertices and {se2_graph.num_edges} edges"

## SO(3) Group Equivariant Graph

The SO(3) group equivariant graph is the second group equivariant graph we have implemented. To create such a graph, different variables have to be speficied:
- `nsamples`: the discretization on the sphere, i.e. the number of points to uniformly sample on the sphere.
- `nalpha`: the discretization on the alpha axis, i.e. the number of points to uniformly sample in the alpha direction.
- `knn`: the maximum number of neighbors of each vertex. 
- `sigmas`: the anisotropy's parameters for the computation of the Riemannian distance between vertices.
- `weight_kernel`: the weight kernel, that is a function taking as input the squared Riemannian distance and returning a weight.
- `kappa`: the edge's compression rate, the rate of edges to drop during random compression.

In [None]:
eps, xi = 0.1, 0.5

In [None]:
so3_graph = SO3GEGraph(
    polyhedron="icosahedron",
    level=2,
    nalpha=6,
    K=16,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, tc: torch.exp(-sqdistc / (4*tc)),
)

In [None]:
f"The SO(3) graph has {so3_graph.num_nodes} vertices and {so3_graph.num_edges} edges"

In [None]:
visualize_graph(so3_graph)

In [None]:
visualize_graph_neighborhood(so3_graph, 162*4)

In [None]:
972/6

In [None]:
so3_graph.node_alpha[162*3], so3_graph.node_beta[162*3], so3_graph.node_gamma[162*3]