In [None]:
import torch
import numpy as np

from torchvision.transforms.functional import rotate

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

In [None]:
from gechebnet.data.dataloader import get_test_equivariance_dataloaders
from gechebnet.graph.graph import SE2GEGraph
from gechebnet.graph.plot import visualize_graph, visualize_neighborhood
from gechebnet.engine.utils import prepare_batch
from gechebnet.model.convolution import ChebConv

In [None]:
def rotate_image(x, angles):
    x0 = x.clone()
    for a in angles:
        x = torch.cat((x, rotate(x0, a)), 0)   
    return x

In [None]:
def equivariance_error(x0, angle, chebconv, laplacian):
    x1 = rotate(x0, angle)
    x1, _ = prepare_batch((x1, torch.empty(1)), se2_graph, device)
    x1_hat = chebconv(x1, laplacian).contiguous().view(graph.nsym, 28, 28)
    
    x0, _ = prepare_batch((x0, torch.empty(1)), se2_graph, device)
    x0_hat = chebconv(x0, laplacian).contiguous().view(graph.nsym, 28, 28)
    x0_hat = rotate(x0_hat, angle)
    
    return (x1_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()

In [None]:
def plot_equivariance(input):
    B, _, L, H, W = input.shape
    images = input.permute(2, 0, 1, 3, 4).contiguous().view(B * L, H, W)
    fig = plt.figure(figsize=(20., 20.))

    grid = ImageGrid(fig, 111, nrows_ncols=(L, B), axes_pad=0.1, share_all=True)  

    grid[0].get_yaxis().set_ticks([])
    grid[0].get_xaxis().set_ticks([])

    for (ax, im) in zip(grid, images):
        # Iterating over the grid returns the Axes.
        ax.imshow(im)
    
    plt.show()


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
xi = 0.01
eps = 0.1
se2_graph = SE2GEGraph(
            nx=28,
            ny=28,
            ntheta=6, # theta = -pi, -2pi/3, -pi/3, 0, pi/3, 2pi/3
            sigmas=(xi / eps, xi, 1.0),
            knn=16
        )

se2_graph.set_laplacian(norm=True, device=device)

In [None]:
cheb_conv = ChebConv(1, 1, 6).to(device)

In [None]:
test_loader, _, _ = get_test_equivariance_dataloaders("mnist", batch_size=1, data_path="data")

In [None]:
x0, y0 = next(iter(test_loader))
plt.imshow(x0[0,0])

In [None]:
x0.min()

In [None]:
with torch.no_grad():
    x0, y0 = next(iter(test_loader))
    #plt.imshow(rotate(x0, 90)[0,0])
    #plt.axis("off")

    x90, _ = prepare_batch((rotate(x0, 90), y0), se2_graph, device)
    x90_hat = cheb_conv(x90, se2_graph.laplacian).contiguous().view(se2_graph.nsym, 28, 28)
    #plt.imshow(x90_hat[0].cpu())
    #plt.axis("off")


    x0, _ = prepare_batch((x0, y0), se2_graph, device)    
    
    x0_hat = cheb_conv(x0, se2_graph.laplacian).contiguous().view(se2_graph.nsym, 28, 28)
    plt.imshow(x0_hat[3].cpu())
    plt.axis("off")
    
    x0_hat = rotate(torch.roll(x0_hat, 3, 0), 90)
    #plt.imshow(x0_hat[0].cpu())
    #plt.axis("off")
    print(f"equivariance error : {(x90_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()}")

In [None]:
with torch.no_grad():
    x0, y0 = next(iter(test_loader))
    #plt.imshow(rotate(x0, 60)[0,0])
    #plt.axis("off")

    x60, _ = prepare_batch((rotate(x0, 60), y0), se2_graph, device)
    x60_hat = cheb_conv(x60, se2_graph.laplacian).contiguous().view(se2_graph.nsym, 28, 28)
    plt.imshow(x60_hat[0].cpu())
    plt.axis("off")


    x0, _ = prepare_batch((x0, y0), se2_graph, device)    
    
    x0_hat = cheb_conv(x0, se2_graph.laplacian).contiguous().view(se2_graph.nsym, 28, 28)
    #plt.imshow(x0_hat[2].cpu())
    #plt.axis("off")
    
    x0_hat = rotate(torch.roll(x0_hat, 2, 0), 60)
    #plt.imshow(x0_hat[2].cpu())
    #plt.axis("off")
    print(f"equivariance error : {(x60_hat - x0_hat).pow(2).sum() / x0_hat.pow(2).sum()}")

In [None]:
_ = plt.imshow(x0_hat[3].cpu())

In [None]:
x, y = next(iter(test_loader))
x = rotate_image(x, [30*i for i in range(1, 11)])
x, _ = prepare_batch((x, y), se2_graph, device)

In [None]:
with torch.no_grad():
    x_hat = cheb_conv(x, se2_graph.laplacian).contiguous().view(-1, 1, 6, 28, 28).cpu()

In [None]:
plot_equivariance(x_hat)