In [None]:
from torchvision.transforms.functional import rotate

In [None]:
from gechebnet.engines.utils import prepare_batch

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

import numpy as np

In [None]:
def rotate_image(x, angles):
    x0 = x.clone()
    for a in angles:
        x = torch.cat((x, rotate(x0, a)), 0)   
    return x

In [None]:
def equivariance_error(x0, angle, chebconv, laplacian):
    x1 = rotate(x0, angle)
    x1, _ = prepare_batch((x1, torch.empty(1)), se2_graph, device)
    x1_hat = chebconv(x1, laplacian).contiguous().view(graph.nsym, 28, 28)
    
    x0, _ = prepare_batch((x0, torch.empty(1)), se2_graph, device)
    x0_hat = chebconv(x0, laplacian).contiguous().view(graph.nsym, 28, 28)
    x0_hat = rotate(x0_hat, angle)
    
    return (x1_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()

In [None]:
def imshow(img):
    fig = plt.figure(figsize=(20., 20.))
    
    img = img.permute(1, 2, 0)
    
    if torch.allclose(img[:,:,0], img[:,:,1]):
        plt.imshow(img[:,:,0])
    else:
        plt.imshow(img)
        
    plt.axis("off")
    plt.show()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST 

In [None]:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler

In [None]:
dataset = MNIST(
            "data",
            train=True,
            download=True,
            transform=Compose([ToTensor()]),
        )
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler)

In [None]:
x, y = next(iter(dataloader))

In [None]:
plt.imshow(x[0,0])

# Chebyschev convolutional layer

In [None]:
from gechebnet.graphs.graphs import SE2GEGraph
from gechebnet.nn.layers.convs import ChebConv
from gechebnet.liegroups.se2 import se2_uniform_sampling

In [None]:
xi = 3e-3
eps = 0.1

se2_sampling = se2_uniform_sampling(28, 28, 6)
se2_graph = SE2GEGraph(
    se2_sampling,
    K=16,
    sigmas=(1., eps, xi),
    path_to_graph="saved_graphs"
)

In [None]:
cheb_conv = ChebConv(1, 1, 4, se2_graph).to(device)

# Equivariance layers

In [None]:
x0, y0 = next(iter(dataloader))
x0 = rotate_image(x0, [30*i for i in range(1, 11)])
x, _ = prepare_batch((x0, y0), se2_graph, device)

In [None]:
with torch.no_grad():
    x_hat = pool(x).cpu()
    x_hat = x_hat.reshape(11, 1, 6, 14, 14).permute(2, 0, 1, 3, 4).reshape(-1, 1, 14, 14)

In [None]:
imshow(make_grid(x0, nrow=11, normalize=False))

In [None]:
imshow(make_grid(x_hat, nrow=11, normalize=False))

# ChebConv layer

In [None]:
from gechebnet.nn.layers.convs import ChebConv

In [None]:
conv = ChebConv(1, 1, 4, se2_graph)

In [None]:
x0, y0 = next(iter(dataloader))

In [None]:
x0_, _ = prepare_batch((x0, y0), se2_graph, None)
plt.imshow(x0[0,0])
plt.axis("off")

In [None]:
with torch.no_grad():
    x0_hat = conv(x0_)
plt.imshow(x0_hat.view(6, 28, 28).roll(4, 0)[0])
plt.axis("off")

In [None]:
plt.imshow(rotate(x0_hat.view(6, 28, 28).roll(3, 0), 90)[0])
plt.axis("off")

In [None]:
x1_, _ = prepare_batch((rotate(x0, 90), y0), se2_graph, None)
plt.imshow(x1_.view(6, 28, 28)[0])
plt.axis("off")

In [None]:
with torch.no_grad():
    x1_hat = conv(x1_)
plt.imshow(x1_hat.view(6, 28, 28)[0])
plt.axis("off")

In [None]:
err = []
angle = 5

for _ in range(500):
    conv = ChebConv(1, 1, 4, se2_graph)

    with torch.no_grad():
        x0, y0 = next(iter(dataloader))
        x0_, _ = prepare_batch((x0, y0), se2_graph, None)

        # transform then pool
        x1_, _ = prepare_batch((rotate(x0, angle), y0), se2_graph, None)
        
        x1_hat = conv(x1_).contiguous().view(1, 28, 28)

        # pool then transform 
        x0_hat = conv(x0_).contiguous().view(1, 28, 28)
        x0_hat = rotate(x0_hat, angle)    
        
        err.append(((x1_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()).item())

np.mean(err), np.std(err)

# Pooling layer

In [None]:
from gechebnet.nn.layers.pools import CubicPool

In [None]:
pool = CubicPool((1, 2), (6, 28, 28))

In [None]:
x0, y0 = next(iter(dataloader))

In [None]:
x0_, _ = prepare_batch((x0, y0), se2_graph, None)
plt.imshow(x0[0,0])
plt.axis("off")

In [None]:
x0_hat = pool(x0_)
plt.imshow(x0_hat.view(1, 1, 1)[0])
plt.axis("off")

In [None]:
plt.imshow(rotate(x0_hat.view(1, 1, 1), 90)[0])
plt.axis("off")

In [None]:
x1_, _ = prepare_batch((rotate(x0, 90), y0), se2_graph, None)
plt.imshow(x1_.view(6, 28, 28)[0])
plt.axis("off")

In [None]:
x1_hat = pool(x1_)
plt.imshow(x1_hat.view(1, 1, 1)[0])
plt.axis("off")

In [None]:
pool = CubicPool((1, 2), (6, 28, 28)).to(device)

err = []
angle = 45

for _ in range(100):
    with torch.no_grad():
        x0, y0 = next(iter(dataloader))

        # transform then pool
        x1_, _ = prepare_batch((rotate(x0, angle), y0), se2_graph, device)
        x1_hat = pool(x1_).contiguous().view(6, 14, 14)

        # pool then transform
        x0_, _ = prepare_batch((x0, y0), se2_graph, device)   
        x0_hat = pool(x0_).contiguous().view(6, 14, 14)
        x0_hat = rotate(x0_hat, angle)
        
        err.append(((x1_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()).item())

np.mean(err), np.std(err)

# Unpooling layer

In [None]:
from gechebnet.nn.layers.unpools import CubicUnpool

In [None]:
unpool = CubicUnpool((1, 2), (6, 28, 28))

In [None]:
x0, y0 = next(iter(dataloader))

In [None]:
x0_, _ = prepare_batch((x0, y0), se2_graph, None)
plt.imshow(x0[0,0])
plt.axis("off")

In [None]:
x0_hat = unpool(x0_)
plt.imshow(x0_hat.view(6, 56, 56)[0])
plt.axis("off")

In [None]:
plt.imshow(rotate(x0_hat.view(6, 56, 56), 90)[0])
plt.axis("off")

In [None]:
x1_, _ = prepare_batch((rotate(x0, 90), y0), se2_graph, None)
plt.imshow(x1_.view(6, 28, 28)[0])
plt.axis("off")

In [None]:
x1_hat = unpool(x1_)
plt.imshow(x1_hat.view(6, 56, 56)[0])
plt.axis("off")

In [None]:
unpool = CubicUnpool((2, 1), (6, 28, 28)).to(device)

err = []
angle = 30

for _ in range(1000):
    with torch.no_grad():
        x0, y0 = next(iter(dataloader))

        # transform then pool
        x1_, _ = prepare_batch((rotate(x0, angle), y0), se2_graph, device)
        x1_hat = unpool(x1_).contiguous().view(12, 28, 28)

        # pool then transform
        x0_, _ = prepare_batch((x0, y0), se2_graph, device)   
        x0_hat = unpool(x0_).contiguous().view(12, 28, 28)
        x0_hat = rotate(x0_hat, angle)
        
        err.append(((x1_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
se2_graph.node_theta.unique()