# Neural network layers

In this tutorial, we introduce three kinds of neural network layers we are using in this thesis. The Chebyshev convolutional layer is a spectral method and has a diffusion effect on a original signal. The pooling and unpooling layer are used to modify an image resolution, by down-sampling and reduction (pooling) or by up-sampling and expansion (unpooling).

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm

import numpy as np
import torch

In [None]:
from gechebnet.graphs.graphs import SE2GEGraph, SO3GEGraph, R2GEGraph, S2GEGraph
from gechebnet.nn.layers.convs import ChebConv
from gechebnet.nn.layers.pools import SE2SpatialPool, SO3SpatialPool
from gechebnet.nn.layers.unpools import SE2SpatialUnpool, SO3SpatialUnpool
from gechebnet.utils.utils import delta_kronecker

In [None]:
def plot_signal(graph, signal, size):
    M, L = size
        
    fig = plt.figure(figsize=(5*L, 5))
    
    X, Y, Z = graph.cartesian_pos()
    vm = signal.abs().max()
    
    for l in range(L):
        ax = fig.add_subplot(1, L, l + 1, projection="3d")
        ax.scatter(X[l*M:(l+1)*M], Y[l*M:(l+1)*M], Z[l*M:(l+1)*M], c=signal[...,l*M:(l+1)*M], cmap=cm.PiYG, vmin=-vm, vmax=vm)
        ax.axis("off")
            
    fig.tight_layout()

## SE(2) Group Manifold Graph

### Convolutional layers

In [None]:
se2_graph = SE2GEGraph(
    [28, 28, 6],
    K=16,
    sigmas=(1., 0.1, 0.0026),
    path_to_graph="saved_graphs"
)

In [None]:
in_channels = 1
out_channels = 1
kernel_size = 4
conv = ChebConv(in_channels, out_channels, kernel_size, se2_graph)

In [None]:
with torch.no_grad():
    input = delta_kronecker((1, 1, 28*28*6), (0,0,406))
    output = conv(input)

In [None]:
plot_signal(se2_graph, output, (784,6))

### Pooling and unpooling layers

In [None]:
up_se2_graph = R2GEGraph(
    [40, 40, 1],
    K=8,
    sigmas=(1., 1., 1.),
    path_to_graph="saved_graphs"
)

se2_graph = R2GEGraph(
    [20, 20, 1],
    K=8,
    sigmas=(1., 1., 1.),
    path_to_graph="saved_graphs"
)

down_se2_graph = R2GEGraph(
    [10, 10, 1],
    K=8,
    sigmas=(1., 1., 1.),
    path_to_graph="saved_graphs"
)

In [None]:
input = torch.rand(20*20*1)

In [None]:
pool = SE2SpatialPool(2, (20, 20, 1), "rand")
unpool = SE2SpatialUnpool(2, (20, 20, 1), "rand")

In [None]:
plot_signal(se2_graph, input, (20*20,1))

In [None]:
with torch.no_grad():
    plot_signal(down_se2_graph, pool(input), (10*10,1))

In [None]:
with torch.no_grad():
    plot_signal(up_se2_graph, unpool(input), (40*40,1))

## SO(3) Group Manifold Graph

### Convolutional layer

In [None]:
so3_graph = SO3GEGraph(
    size=[642, 6],
    K=32,
    sigmas=(1., .1, 10.0 / 642),
    path_to_graph="saved_graphs"
)

In [None]:
in_channels = 1
out_channels = 1
kernel_size = 4
conv = ChebConv(in_channels, out_channels, kernel_size, so3_graph)

In [None]:
with torch.no_grad():
    input = delta_kronecker((1, 1, 642*6), (0,0,143))
    output = conv(input)

In [None]:
plot_signal(so3_graph, output, (642,6))

### Pooling and unpooling layers

In [None]:
up_so3_graph = S2GEGraph(
    size=[2562, 1],
    K=8,
    sigmas=(1., 1., 1.),
    path_to_graph="saved_graphs"
)

so3_graph = S2GEGraph(
    [642, 1],
    K=8,
    sigmas=(1., 1., 1.),
    path_to_graph="saved_graphs"
)

down_so3_graph = S2GEGraph(
    [162, 1],
    K=8,
    sigmas=(1., 1., 1.),
    path_to_graph="saved_graphs"
)

In [None]:
input = torch.rand(642*1)

In [None]:
pool = SO3SpatialPool(2, (642, 1), "max")
unpool = SO3SpatialUnpool(2, (642, 1), "avg")

In [None]:
plot_signal(so3_graph, input, (642,1))

In [None]:
with torch.no_grad():
    plot_signal(down_so3_graph, pool(input), (162,1))

In [None]:
with torch.no_grad():
    plot_signal(up_so3_graph, unpool(input), (2562,1))

In [None]:
from gechebnet.graphs.graphs import R2GEGraph, RandomSubGraph, SE2GEGraph
from gechebnet.nn.models.chebnets import WideResSE2GEChebNet

import torch

import matplotlib.pyplot as plt

In [None]:
edge_index = torch.tensor([[0, 0, 1, 1, 2, 2], [1, 2, 0, 2, 0, 1]]) 
edge_weight = torch.tensor([0.5, 0.4, 0.2, 0.3, 0.4, 0.3])

In [None]:
edge_matrix = torch.sparse.FloatTensor(edge_index, edge_weight, torch.Size((3, 3))).to_dense()

In [None]:
mask = edge_matrix.t() == edge_matrix
mask

In [None]:
mask1 = edge_matrix < 0.3 
mask2 = edge_matrix.t() == edge_matrix
mask1 & mask2

In [None]:
matrix = torch.zeros_like(edge_matrix)
matrix[mask] = edge_matrix[mask]
matrix = matrix.to_sparse()

In [None]:
matrix.to_sparse().coalesce()

In [None]:
matrix.indices()

In [None]:
from gechebnet.graphs.graphs import R2GEGraph, RandomSubGraph, SE2GEGraph
from gechebnet.nn.models.chebnets import WideResSE2GEChebNet

import torch

import matplotlib.pyplot as plt

device = torch.device("cuda")

graph = SE2GEGraph(
        [28, 28, 6],
        K=16,
        sigmas=(1.0, 0.1, 0.02),
        path_to_graph="saved_graphs",
    )

#graph = R2GEGraph(
#        [28, 28, 1],
#        K=8,
#        sigmas=(1.0, 1., 1.),
#        path_to_graph="saved_graphs",
#    )

# we use random sub graphs to evaluate the effect of edges and nodes' sampling
sub_graph = RandomSubGraph(graph)

# Loads group equivariant Chebnet
model = WideResSE2GEChebNet(
    in_channels=1,
    out_channels=10,
    kernel_size=4,
    graph_lvl0=sub_graph,
    res_depth=2,
    widen_factor=8,
)

In [None]:
graph.edge_sqdist.unique()

In [None]:
graph.num_edges, graph.num_nodes

In [None]:
image = torch.rand(28, 28)
rot_image = torch.rot90(image)
vflip_image = torch.flip(image, [0])
hflip_image = torch.flip(image, [1])
hvflip_image = torch.flip(image, [0, 1])

In [None]:
plt.imshow(image)

In [None]:
plt.imshow(rot_image)

In [None]:
plt.imshow(hvflip_image)

In [None]:
image = torch.rand(28, 28)
rot_image = torch.rot90(image)
vflip_image = torch.flip(image, [0])
hflip_image = torch.flip(image, [1])
hvflip_image = torch.flip(image, [0, 1])

In [None]:
input = image.unsqueeze(0)
input = input.reshape(1, -1).unsqueeze(1).expand(-1, 6, -1).reshape(1, 1, -1)
model(input)

In [None]:
rot_input = rot_image.unsqueeze(0)
rot_input = rot_input.reshape(1, -1).unsqueeze(1).expand(-1, 6, -1).reshape(1, 1, -1)
model(rot_input)

In [None]:
hflip_input = hflip_image.unsqueeze(0)
hflip_input = hflip_input.reshape(1, -1).unsqueeze(1).expand(-1, 6, -1).reshape(1, 1, -1)
model(hflip_input)

In [None]:
vflip_input = vflip_image.unsqueeze(0)
vflip_input = vflip_input.reshape(1, -1).unsqueeze(1).expand(-1, 6, -1).reshape(1, 1, -1)
model(vflip_input)

In [None]:
def permutation_matrix(graph):
    P = torch.zeros(graph.num_nodes, graph.num_nodes)
    indices = torch.roll((torch.rot90(graph.node_index.reshape(*graph.size[::-1]), dims=(2,1))), shifts=3, dims=0).flatten()
    #indices = torch.rot90(graph.node_index.reshape(*graph.size[::-1]), dims=(2,1)).flatten()
    P[indices, graph.node_index] = 1
    return P

In [None]:
graph = SE2GEGraph(
        [10, 10, 6],
        K=16,
        sigmas=(1.0, 0.1, 0.02),
        path_to_graph="saved_graphs",
    )


In [None]:
graph = R2GEGraph(
        [4, 4, 1],
        K=8,
        sigmas=(1.0, 1., 1.),
        path_to_graph="saved_graphs",
    )

In [None]:
mask = graph.edge_index[0] == 0
graph.edge_sqdist[mask], graph.edge_index[:, mask] 

In [None]:
graph.edge_sqdist, graph.edge_index

In [None]:
graph.edge_sqdist.median()

In [None]:
graph.node_x[graph.edge_index[1, mask]], graph.node_y[graph.edge_index[1, mask]]

In [None]:
graph.node_index.float() @ P

In [None]:
W = torch.sparse.FloatTensor(graph.edge_index, graph.edge_weight, torch.Size((graph.num_nodes, graph.num_nodes)))

In [None]:
P = permutation_matrix(graph)

In [None]:
W = torch.sparse.FloatTensor(graph.edge_index, graph.edge_weight, torch.Size((graph.num_nodes, graph.num_nodes)))
torch.allclose(P.t() @ W.to_dense() @ P, W.to_dense())

In [None]:
L = graph.get_laplacian()
torch.allclose(P.t() @ L.to_dense() @ P, L.to_dense())

In [None]:
graph.get_laplacian().to_dense()

In [None]:
P.t()@P

In [None]:
L = graph.get_laplacian()

#P.t()@graph.laplacian@P

In [None]:
P.t() @ L.to_dense() @ P

In [None]:
L.to_dense()

In [None]:
from gechebnet.geometry.se import se2_riemannian_sqdist, se2_matrix
import math

In [None]:
Gg = se2_matrix(torch.tensor([0.]), torch.tensor([0.]), torch.tensor([0.]))
Gh = se2_matrix(torch.tensor([5.]), torch.tensor([0.]), torch.tensor([0.]))
Re = torch.diag(torch.tensor([1, 0.1, 0.02]))
se2_riemannian_sqdist(Gg, Gh, Re)

In [None]:
Gg = se2_matrix(torch.tensor([0.]), torch.tensor([0.]), torch.tensor([math.pi/2]))
Gh = se2_matrix(torch.tensor([0.]), torch.tensor([5.]), torch.tensor([math.pi/2]))
Re = torch.diag(torch.tensor([1, 0.1, 0.02]))
se2_riemannian_sqdist(Gg, Gh, Re)