# Neural network layers

In [None]:
import torch
from torchvision.transforms.functional import affine
import numpy as np
import matplotlib.cm as cm


In [None]:
from gechebnet.graphs.graphs import SE2GEGraph, R2GEGraph

In [None]:
def get_signal(B, C, H, W, L=1, cheb=False):
    if cheb:
        x = torch.rand(B, C, 1, H * W)
        x = x.expand(-1, -1, L, -1)
        return x.reshape(B, C, -1)
    
    return torch.rand(B, C, H, W)

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig = plt.figure(figsize=(10, 10))

X, Y, Z = se2_graph.cartesian_pos()
signal = get_signal(16, 3, 28, 28, 6, True)

ax = fig.add_subplot(1, 1, 1, projection="3d")
ax.scatter(X, Y, Z, c=signal[0].permute(1,0), cmap=cm.PiYG)
ax.axis("off")
fig.tight_layout()

## Conv layers

In [None]:
from gechebnet.nn.layers.convs import ChebConv

In [None]:
xi = 0.0026
eps = 0.1

se2_graph = SE2GEGraph(
    [28, 28, 6],
    K=8,
    sigmas=(1., eps, xi),
    path_to_graph="saved_graphs"
)

In [None]:
xi = 1.
eps = 1.

r2_graph = R2GEGraph(
    [28, 28, 1],
    K=8,
    sigmas=(1., eps, xi),
    path_to_graph="saved_graphs"
)

In [None]:
x = get_x(16, 3, 28, 28, 6, True)
conv = ChebConv(3, 32, 4, se2_graph)
conv(x).shape

### Equivariance

In [None]:
err = []
tx, ty, theta = 0, 0, 60

for _ in range(100):
    conv = ChebConv(3, 16, 4, se2_graph)
    with torch.no_grad():
        x = get_x(16, 3, 28, 28, 6, True)
        # transform then conv
        y1 = conv(affine(x.view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1))
        # conv then transform 
        y2 = affine(conv(x).view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 16, 6, 28, 28).roll(2, 2).view(16, 16, -1)
        err.append(((y1 - y2).pow(2).sum() / y2.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
err = []
tx, ty, theta = 0, 0, 60

for _ in range(100):
    conv = ChebConv(3, 16, 4, r2_graph)
    with torch.no_grad():
        #x0, y0 = next(iter(dataloader))
        x = get_x(16, 3, 28, 28, 1, True)
        # transform then conv
        y1 = conv(affine(x.view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1))
        # conv then transform 
        y2 = affine(conv(x).view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 16, 1, 28, 28).view(16, 16, -1)
        err.append(((y1 - y2).pow(2).sum() / y2.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
err = []
tx, ty, theta = 3, 3, 0

for _ in range(100):
    conv = torch.nn.Conv2d(3, 16, 3)
    with torch.no_grad():
        x = get_x(16, 3, 28, 28)
        # transform then conv
        y1 = conv(affine(x, theta, (tx, ty), 1., (0., 0.)))
        # conv then transform 
        y2 = affine(conv(x), theta, (tx, ty), 1., (0., 0.))
        err.append(((y1 - y2).pow(2).sum() / y2.pow(2).sum()).item())

np.mean(err), np.std(err)

## Filters

In [None]:
from torchvision.utils import make_grid

In [None]:
#conv = ChebConv(1, 1, 4, se2_graph)
conv = model.conv

In [None]:
conv.weight

In [None]:
def show(img):
    plt.imshow(img.permute(1,2,0), cmap=cm.viridis)
    plt.axis("off")

In [None]:
input = torch.zeros(6, 1, 6*784)
for l in range(6):
    input[l, :, 406 + l * 784] = 1.

with torch.no_grad():
    output = conv(input).reshape(6, 16, 6, 28, 28)[:,14,:,:,:]
    output = output.reshape(-1, 1, 28, 28)
        
output = ((output - output.min())/(output.max() - output.min()))
output.shape

In [None]:
show(make_grid(output, padding=1, nrow=6))

In [None]:
import torch

from gechebnet.graphs.graphs import SE2GEGraph, RandomSubGraph
from gechebnet.nn.models.chebnets import WideResSE2GEChebNet

device = torch.device("cpu")

graph = SE2GEGraph(
    [28, 28, 6],
    K=8,
    sigmas=(1.0, 0.1, 2.048 / (28 ** 2)),
    path_to_graph="saved_graphs",
)
# we use random sub graphs to evaluate the effect of edges and nodes' sampling
sub_graph = RandomSubGraph(graph)

# Loads group equivariant Chebnet
model = WideResSE2GEChebNet(
    in_channels=1,
    out_channels=10,
    kernel_size=4,
    graph_lvl0=sub_graph,
    depth=8,
    widen_factor=2,
).to(device)

In [None]:
model.load_state_dict(torch.load("models/model_19.pt"))

In [None]:
conv = model.conv

In [None]:
plot_filters(graph, conv)

In [None]:
def plot_filters(graph, filter):
    L = graph.size[-1]
    M = np.prod(graph.size[:2])

    
    input = torch.zeros(L, 1, L * M)
    for l in range(L):
        input[l, :, 406 * (l+1)] = 1.
        
    with torch.no_grad():
        output = filter(input)
        
    print(output.shape)
            
    fig = plt.figure(figsize=(4*L, 4*L))
    
    X, Y, Z = graph.cartesian_pos()
    
    for i in range(L):
        for j in range(L):
            ax = fig.add_subplot(L, L, i * L + j + 1)
            ax.scatter(X[i*M:(i+1)*M], Y[i*M:(i+1)*M], c=output[j, 0, i*M:(i+1)*M], cmap=cm.viridis)
            ax.axis("off")
            
    fig.tight_layout()

In [None]:
conv = ChebConv(1, 1, 4, se2_graph)

In [None]:
plot_filters(se2_graph, conv)

# Pooling layers

In [None]:
from gechebnet.nn.layers.pools import SE2SpatialPool, SE2OrientationPool, GlobalPool

In [None]:
x = get_x(16, 3, 28, 28, 6, True)

In [None]:
err = []
tx, ty, theta = 2, 2, 90

for _ in range(100):
    pool = SE2SpatialPool(2, (28,28,6), "max")
    with torch.no_grad():
        x = get_x(16, 3, 28, 28, 6, True)
        # transform then conv
        y1 = pool(affine(x.view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1))
        # conv then transform 
        y2 = affine(pool(x).view(-1, 14, 14), theta, (tx//2, ty//2), 1., (0., 0.)).view(16, 3, 6, 14, 14).roll(3, 2).view(16, 3, -1)
        err.append(((y1 - y2).pow(2).sum() / y2.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
err = []
tx, ty, theta = 0, 2, 0

for _ in range(100):
    pool = SE2SpatialPool(2, (28,28,6), "rand")
    with torch.no_grad():
        x = get_x(16, 3, 28, 28, 6, True)
        # transform then conv
        y1 = pool(affine(x.view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1))
        # conv then transform 
        y2 = affine(pool(x).view(-1, 14, 14), theta, (tx//2, ty//2), 1., (0., 0.)).view(16, 3, 6, 14, 14).roll(0, 2).view(16, 3, -1)
        err.append(((y1 - y2).pow(2).sum() / y2.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
err = []
tx, ty, theta = 0, 0, 90

for _ in range(100):
    pool = SE2SpatialPool(2, (28,28,6), "avg")
    with torch.no_grad():
        x = get_x(16, 3, 28, 28, 6, True)
        # transform then conv
        y1 = pool(affine(x.view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1))
        # conv then transform 
        y2 = affine(pool(x).view(-1, 14, 14), theta, (tx//2, ty//2), 1., (0., 0.)).view(16, 3, 6, 14, 14).roll(3, 2).view(16, 3, -1)
        err.append(((y1 - y2).pow(2).sum() / y2.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
err = []
tx, ty, theta = 0, 0, 90

for _ in range(100):
    pool = SE2OrientationPool(6, (28,28,6), "max")
    with torch.no_grad():
        x = get_x(16, 3, 28, 28, 6, True)
        # transform then conv
        y1 = pool(affine(x.view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1))
        # conv then transform 
        y2 = affine(pool(x).view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1)
        err.append(((y1 - y2).pow(2).sum() / y2.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
err = []
tx, ty, theta = 0, 0, 90

for _ in range(100):
    pool = SE2OrientationPool(6, (28,28,6), "rand")
    with torch.no_grad():
        x = get_x(16, 3, 28, 28, 6, True)
        # transform then conv
        y1 = pool(affine(x.view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1))
        # conv then transform 
        y2 = affine(pool(x).view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1)
        err.append(((y1 - y2).pow(2).sum() / y2.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
x = torch.rand(1, 1, 1, 2, 2).expand(-1, -1, 2, -1, -1).reshape(1, 1, -1)
x

In [None]:
pool = SE2OrientationPool(2, (2,2,2), "rand")

In [None]:
pool(x)

In [None]:
err = []
tx, ty, theta = 0, 0, 90

for _ in range(100):
    pool = SE2OrientationPool(6, (28,28,6), "avg")
    with torch.no_grad():
        x = get_x(16, 3, 28, 28, 6, True)
        # transform then conv
        y1 = pool(affine(x.view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1))
        # conv then transform 
        y2 = affine(pool(x).view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1)
        err.append(((y1 - y2).pow(2).sum() / y2.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
pool = SE2SpatialPool(2, (28,28,6), "max")

In [None]:
pool(x)

In [None]:
from gechebnet.engines.utils import prepare_batch

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

import numpy as np

In [None]:
def rotate_image(x, angles):
    x0 = x.clone()
    for a in angles:
        x = torch.cat((x, rotate(x0, a)), 0)   
    return x

In [None]:
def equivariance_error(x0, angle, chebconv, laplacian):
    x1 = rotate(x0, angle)
    x1, _ = prepare_batch((x1, torch.empty(1)), se2_graph, device)
    x1_hat = chebconv(x1, laplacian).contiguous().view(graph.nsym, 28, 28)
    
    x0, _ = prepare_batch((x0, torch.empty(1)), se2_graph, device)
    x0_hat = chebconv(x0, laplacian).contiguous().view(graph.nsym, 28, 28)
    x0_hat = rotate(x0_hat, angle)
    
    return (x1_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()

In [None]:
def imshow(img):
    fig = plt.figure(figsize=(20., 20.))
    
    img = img.permute(1, 2, 0)
    
    if torch.allclose(img[:,:,0], img[:,:,1]):
        plt.imshow(img[:,:,0])
    else:
        plt.imshow(img)
        
    plt.axis("off")
    plt.show()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST 

In [None]:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler

In [None]:
dataset = MNIST(
            "data",
            train=True,
            download=True,
            transform=Compose([ToTensor()]),
        )
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler)

In [None]:
x, y = next(iter(dataloader))

In [None]:
plt.imshow(x[0,0])

In [None]:
x_ = affine(x, 45, (10, 10), 1., (0., 0.))

In [None]:
plt.imshow(x_[0,0])

In [None]:
conv = torch.nn.Conv2d(3, 16, 3)

In [None]:
x = get_x()
x.min(), x.max()

In [None]:
plt.imshow(x[0])

In [None]:
with torch.no_grad():
    y = conv(affine(x.view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(16, 3, -1)).view(-1, 6, 28, 28)

In [None]:
plt.imshow(y[0,0])

In [None]:
with torch.no_grad():
    y = affine(conv(x).view(-1, 28, 28), theta, (tx, ty), 1., (0., 0.)).view(-1, 6, 28, 28)

In [None]:
plt.imshow(y[0,3])

# Equivariance layers

In [None]:
def prepare_batch(batch):
    x, y = batch

    B, C, *_ = x.shape
    x = x.reshape(B, C, -1).unsqueeze(2)
    x = x.expand(-1, -1, 1, -1)
    x = x.reshape(B, C, -1)
    return x, y

In [None]:
x0, y0 = next(iter(dataloader))
x0 = rotate_image(x0, [30*i for i in range(1, 11)])
x, _ = prepare_batch((x0, y0), se2_graph, device)

In [None]:
with torch.no_grad():
    x_hat = pool(x).cpu()
    x_hat = x_hat.reshape(11, 1, 6, 14, 14).permute(2, 0, 1, 3, 4).reshape(-1, 1, 14, 14)

In [None]:
imshow(make_grid(x0, nrow=11, normalize=False))

In [None]:
imshow(make_grid(x_hat, nrow=11, normalize=False))

In [None]:
from torch.nn import Conv2d

In [None]:
conv = Conv2d(1, 1, 3, padding=1)

# ChebConv layer

In [None]:
from gechebnet.nn.layers.convs import ChebConv

In [None]:
conv = ChebConv(1, 1, 4, se2_graph)
conv = Conv2d(1, 1, 3, padding=1)

In [None]:
x, y, theta = 3, 0, 0

In [None]:
# original image
x0, y0 = next(iter(dataloader))
plt.imshow(x0[0,0])
plt.axis("off")

In [None]:
# transforms ...
x1_, _ = prepare_batch((affine(x0, theta, (x, y), 1., (0., 0.)), y0))
plt.imshow(x1_.contiguous().view(1, 28, 28)[0])
plt.axis("off")

In [None]:
# ... convolves
with torch.no_grad():
    x1_hat = conv(x1_)
plt.imshow(x1_hat.contiguous().view(1, 28, 28)[0])
plt.axis("off")

In [None]:
# convolves ...
x0_, _ = prepare_batch((x0, y0))
with torch.no_grad():
    x0_hat = conv(x0_).contiguous().view(1, 28, 28)
    
plt.imshow(x0_hat[0])
plt.axis("off")

In [None]:
# ... transforms
plt.imshow(affine(x0_hat, theta, (x, y), 1., (0., 0.)).roll(2, 0)[0])
plt.axis("off")

In [None]:
def prepare_batch(batch):
    return batch

In [None]:
err = []
x, y, theta = 0, 0, 60

for _ in range(100):
    conv = ChebConv(1, 1, 4, se2_graph)
    #conv = Conv2d(1, 1, 3, padding=1)
    with torch.no_grad():
        x0, y0 = next(iter(dataloader))

        # transform then conv
        x1_, _ = prepare_batch((affine(x0, theta, (x, y), 1., (0., 0.)), y0))
        x1_hat = conv(x1_).contiguous().view(1, 28, 28)

        # conv then transform 
        x0_, _ = prepare_batch((x0, y0))
        x0_hat = conv(x0_).contiguous().view(1, 28, 28)
        x0_hat = affine(x0_hat, theta, (x, y), 1., (0., 0.)).roll(0, 0)
        
        err.append(((x1_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
se2_graph.node_theta

In [None]:
x = torch.rand(28, 28)
y = torch.rand(28, 28)
(x - y).pow(2).sum() / y.pow(2).sum()

# Pooling layer

In [None]:
from gechebnet.nn.layers.pools import CubicPool

In [None]:
pool = CubicPool((1, 2), (6, 28, 28))

In [None]:
x0, y0 = next(iter(dataloader))

In [None]:
x0_, _ = prepare_batch((x0, y0), se2_graph, None)
plt.imshow(x0[0,0])
plt.axis("off")

In [None]:
x0_hat = pool(x0_)
plt.imshow(x0_hat.view(1, 1, 1)[0])
plt.axis("off")

In [None]:
plt.imshow(rotate(x0_hat.view(1, 1, 1), 90)[0])
plt.axis("off")

In [None]:
x1_, _ = prepare_batch((rotate(x0, 90), y0), se2_graph, None)
plt.imshow(x1_.view(6, 28, 28)[0])
plt.axis("off")

In [None]:
x1_hat = pool(x1_)
plt.imshow(x1_hat.view(1, 1, 1)[0])
plt.axis("off")

In [None]:
pool = CubicPool((1, 2), (6, 28, 28)).to(device)

err = []
angle = 45

for _ in range(100):
    with torch.no_grad():
        x0, y0 = next(iter(dataloader))

        # transform then pool
        x1_, _ = prepare_batch((rotate(x0, angle), y0), se2_graph, device)
        x1_hat = pool(x1_).contiguous().view(6, 14, 14)

        # pool then transform
        x0_, _ = prepare_batch((x0, y0), se2_graph, device)   
        x0_hat = pool(x0_).contiguous().view(6, 14, 14)
        x0_hat = rotate(x0_hat, angle)
        
        err.append(((x1_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()).item())

np.mean(err), np.std(err)

# Unpooling layer

In [None]:
from gechebnet.nn.layers.unpools import CubicUnpool

In [None]:
unpool = CubicUnpool((1, 2), (6, 28, 28))

In [None]:
x0, y0 = next(iter(dataloader))

In [None]:
x0_, _ = prepare_batch((x0, y0), se2_graph, None)
plt.imshow(x0[0,0])
plt.axis("off")

In [None]:
x0_hat = unpool(x0_)
plt.imshow(x0_hat.view(6, 56, 56)[0])
plt.axis("off")

In [None]:
plt.imshow(rotate(x0_hat.view(6, 56, 56), 90)[0])
plt.axis("off")

In [None]:
x1_, _ = prepare_batch((rotate(x0, 90), y0), se2_graph, None)
plt.imshow(x1_.view(6, 28, 28)[0])
plt.axis("off")

In [None]:
x1_hat = unpool(x1_)
plt.imshow(x1_hat.view(6, 56, 56)[0])
plt.axis("off")

In [None]:
unpool = CubicUnpool((2, 1), (6, 28, 28)).to(device)

err = []
angle = 30

for _ in range(1000):
    with torch.no_grad():
        x0, y0 = next(iter(dataloader))

        # transform then pool
        x1_, _ = prepare_batch((rotate(x0, angle), y0), se2_graph, device)
        x1_hat = unpool(x1_).contiguous().view(12, 28, 28)

        # pool then transform
        x0_, _ = prepare_batch((x0, y0), se2_graph, device)   
        x0_hat = unpool(x0_).contiguous().view(12, 28, 28)
        x0_hat = rotate(x0_hat, angle)
        
        err.append(((x1_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
se2_graph.node_theta.unique()