In [None]:
import torch

import numpy as np

from torchvision.transforms.functional import rotate

In [None]:
from gechebnet.data.dataloader import get_test_data_loader
from gechebnet.graph.graph import HyperCubeGraph
from gechebnet.graph.plot import visualize_graph, visualize_neighborhood

from gechebnet.model.convolution import ChebConv

In [None]:
NX1, NX2 = 2,2
XI = .1
EPS = 1.0

DATASET = "MNIST"
DATA_PATH = "data"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

In [None]:
XI = 0.01
EPS = 0.1

graph = HyperCubeGraph(
            grid_size=(28, 28),
            nx3=6,
            sigmas=(XI / EPS, XI, 1.0),
            weight_comp_device=DEVICE,
            knn=27
        )

In [None]:
test_loader = get_test_data_loader(DATASET, batch_size=1)

In [None]:
batch = next(iter(test_loader))

In [None]:
cheb_conv_layer = ChebConv(graph, 1, 1, 5, laplacian_device=DEVICE).to(DEVICE)

In [None]:
from gechebnet.engine.utils import prepare_batch

In [None]:
x, y = batch

In [None]:
from matplotlib import pyplot as plt

In [None]:
def prepare_batch(batch, graph, res_rot, device):
    x_org, y = batch
    x = torch.zeros(res_rot, *x_org.shape[1:])
    for i, angle in enumerate(np.linspace(0, 360, res_rot, endpoint=False)):
        x[i] = rotate(x_org, angle)
        
    B, C, H, W = x.shape  # (B, C, H, W)

    if H != graph.nx2 or W != graph.nx1:
        raise ValueError("Dimension incompatibility between graph and data")

    x = x.unsqueeze(2).expand(B, C, graph.nx3, graph.nx2, graph.nx1).reshape(B, C, -1)  # (B, C, L*H*W)

    return x.to(device), y.to(device)

In [None]:
x, y = prepare_batch(batch, graph, 8, DEVICE)

In [None]:
with torch.no_grad():
    x_out = cheb_conv_layer(x)

In [None]:
_ = visualize_graph(graph, x[7].cpu())

In [None]:
_ = visualize_graph(graph, x_out[1].clip(0.,1.).cpu())