In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

import numpy as np

In [None]:
def rotate_image(x, angles):
    x0 = x.clone()
    for a in angles:
        x = torch.cat((x, rotate(x0, a)), 0)   
    return x

In [None]:
def equivariance_error(x0, angle, chebconv, laplacian):
    x1 = rotate(x0, angle)
    x1, _ = prepare_batch((x1, torch.empty(1)), se2_graph, device)
    x1_hat = chebconv(x1, laplacian).contiguous().view(graph.nsym, 28, 28)
    
    x0, _ = prepare_batch((x0, torch.empty(1)), se2_graph, device)
    x0_hat = chebconv(x0, laplacian).contiguous().view(graph.nsym, 28, 28)
    x0_hat = rotate(x0_hat, angle)
    
    return (x1_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()

In [None]:
def imshow(img):
    fig = plt.figure(figsize=(20., 20.))
    
    img = img.permute(1, 2, 0)
    
    if torch.allclose(img[:,:,0], img[:,:,1]):
        plt.imshow(img[:,:,0])
    else:
        plt.imshow(img)
        
    plt.axis("off")
    plt.show()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST 

In [None]:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler

In [None]:
dataset = MNIST(
            "data",
            train=True,
            download=True,
            transform=Compose([ToTensor()]),
        )
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler)

# Chebyschev convolutional layer

In [None]:
from gechebnet.graph.graph import SE2GEGraph
from gechebnet.model.convolution import ChebConv

In [None]:
xi = 1.
eps = 1.
se2_graph = SE2GEGraph(
            nx=28,
            ny=28,
            ntheta=1, # theta = -pi/2, -pi/3, -pi/6, 0, pi/6, pi/3
            sigmas=(xi / eps, xi, 1.0),
            knn=16
        )

In [None]:
cheb_conv = ChebConv(se2_graph, 1, 1, 6).to(device)

# Equivariance error

In [None]:
from torchvision.transforms.functional import rotate

In [None]:
from gechebnet.engine.utils import prepare_batch

In [None]:
err = []
for _ in range(10):
    with torch.no_grad():
        x0, y0 = next(iter(dataloader))
        #plt.imshow(rotate(x0, 90)[0,0])
        #plt.axis("off")

        x90, _ = prepare_batch((rotate(x0, 92), y0), se2_graph, device)
        x90_hat = cheb_conv(x90).contiguous().view(se2_graph.nsym, 28, 28)
        #plt.imshow(x90_hat[0].cpu())
        #plt.axis("off")

        x0, _ = prepare_batch((x0, y0), se2_graph, device)    

        x0_hat = cheb_conv(x0).contiguous().view(se2_graph.nsym, 28, 28)
        #plt.imshow(x0_hat[3].cpu())
        #plt.axis("off")

        x0_hat = rotate(torch.roll(x0_hat, 3, 0), 92)
        #plt.imshow(x0_hat[0].cpu())
        #plt.axis("off")
        
        err.append(((x90_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
err = []
for _ in range(10):
    with torch.no_grad():
        x0, y0 = next(iter(dataloader))
        #plt.imshow(rotate(x0, 90)[0,0])
        #plt.axis("off")

        x95, _ = prepare_batch((rotate(x0, 95), y0), se2_graph, device)
        x95_hat = cheb_conv(x95).contiguous().view(se2_graph.nsym, 28, 28)
        #plt.imshow(x90_hat[0].cpu())
        #plt.axis("off")

        x0, _ = prepare_batch((x0, y0), se2_graph, device)    

        x0_hat = cheb_conv(x0).contiguous().view(se2_graph.nsym, 28, 28)
        #plt.imshow(x0_hat[3].cpu())
        #plt.axis("off")

        x0_hat = rotate(torch.roll(x0_hat, 3, 0), 95)
        #plt.imshow(x0_hat[0].cpu())
        #plt.axis("off")
        
        err.append(((x95_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()).item())

np.mean(err), np.std(err)

In [None]:
eq_errs = []

for _ in range(10):
    with torch.no_grad():
        x0, y0 = next(iter(dataloader))
        #plt.imshow(rotate(x0, 60)[0,0])
        #plt.axis("off")

        x60, _ = prepare_batch((rotate(x0, 60), y0), se2_graph, device)
        x60_hat = cheb_conv(x60).contiguous().view(se2_graph.nsym, 28, 28)
        #plt.imshow(x60_hat[0].cpu())
        #plt.axis("off")

        x0, _ = prepare_batch((x0, y0), se2_graph, device)    

        x0_hat = cheb_conv(x0).contiguous().view(se2_graph.nsym, 28, 28)
        #plt.imshow(x0_hat[2].cpu())
        #plt.axis("off")

        x0_hat = rotate(torch.roll(x0_hat, 4, 0), 60)
        #plt.imshow(x0_hat[0].cpu())
        #plt.axis("off")
        
        err = (x60_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()
        eq_errs.append(err.item())
        
np.mean(eq_errs), np.std(eq_errs)

In [None]:
conv2d = torch.nn.Conv2d(1, 1, 3)

In [None]:
eq_errs = []

for _ in range(100):
    with torch.no_grad():
        x0, y0 = next(iter(dataloader))
        x90_hat = conv2d(rotate(x0, 95))
        x0_hat = rotate(conv2d(x0), 95)
        err = (x90_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()
        eq_errs.append(err.item())
    
np.mean(eq_errs), np.std(eq_errs)

# Equivariance layers

In [None]:
x0, y0 = next(iter(dataloader))
x0 = rotate_image(x0, [30*i for i in range(1, 11)])
x, _ = prepare_batch((x0, y0), se2_graph, device)

In [None]:
with torch.no_grad():
    x_hat = cheb_conv(x).cpu()
    x_hat = x_hat.reshape(11, 1, 6, 28, 28).permute(2, 0, 1, 3, 4).reshape(-1, 1, 28, 28)

In [None]:
imshow(make_grid(x0, nrow=11, normalize=True))

In [None]:
imshow(make_grid(x_hat, nrow=11, normalize=True))