In [None]:
import pykeops
pykeops.clean_pykeops() 

In [None]:
import torch
import numpy as np

from torchvision.transforms.functional import rotate

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

In [None]:
from gechebnet.data.dataloader import get_test_equivariance_data_loader
from gechebnet.graph.graph import HyperCubeGraph
from gechebnet.graph.plot import visualize_graph, visualize_neighborhood

from gechebnet.model.convolution import ChebConv

In [None]:
def prepare_batch(batch, graph, res_rot, device):
    x_org, y = batch
    x = torch.zeros(res_rot, *x_org.shape[1:])
    for i, angle in enumerate(np.linspace(0, 360, res_rot, endpoint=False)):
        x[i] = rotate(x_org, angle)
        
    B, C, H, W = x.shape  # (B, C, H, W)

    if H != graph.nx2 or W != graph.nx1:
        raise ValueError("Dimension incompatibility between graph and data")

    x = x.unsqueeze(2).expand(B, C, graph.nx3, graph.nx2, graph.nx1).reshape(B, C, -1)  # (B, C, L*H*W)

    return x.to(device), y.to(device)

In [None]:
def plot_equivariance(output):
    B, _, L, H, W = output.shape
    images = output.permute(2, 0, 1, 3, 4).contiguous().view(B * L, H, W)
    fig = plt.figure(figsize=(20., 20.))
    grid = ImageGrid(fig, 111,  # similar to subplot(111)
                     nrows_ncols=(L, B),  # creates 2x2 grid of axes
                     axes_pad=0.1,  # pad between axes in inch.
                     )

    for (ax, im) in zip(grid, images):
        # Iterating over the grid returns the Axes.
        ax.imshow(im)

    plt.show()

In [None]:
XI = .1
EPS = 1.0

DATASET = "MNIST"
DATA_PATH = "data"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
XI = 0.01
EPS = 0.1

B = 12          # num rotations
H, W = 28, 28 
C= 1
L = 6           # num layers   
K = 15          # Chebyschev's approximation's order

In [None]:
graph = HyperCubeGraph(
            grid_size=(H, W),
            nx3=L,
            sigmas=(XI / EPS, XI, 1.0),
            weight_comp_device=DEVICE,
            knn=27
        )

cheb_conv_layer = ChebConv(graph, 1, 1, K, laplacian_device=DEVICE).to(DEVICE)

In [None]:
test_loader, _, _ = get_test_equivariance_data_loader(DATASET, batch_size=1)
batch = next(iter(test_loader))
x, y = prepare_batch(batch, graph, B, DEVICE)

In [None]:
with torch.no_grad():
    X_out = cheb_conv_layer(x).contiguous().view(B, C, L, H, W).cpu()

In [None]:
plot_equivariance(X_out)