In [None]:
import math
import os

import torch
import wandb
from gechebnet.data.dataloader import get_train_val_data_loaders
from gechebnet.engine.engine import create_supervised_evaluator, create_supervised_trainer
from gechebnet.engine.utils import prepare_batch, wandb_log
from gechebnet.graph.graph import HyperCubeGraph
from gechebnet.model.chebnet import ChebNet
from gechebnet.model.optimizer import get_optimizer
from gechebnet.utils import random_choice
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events
from ignite.metrics import Accuracy, Loss

from torch.nn import NLLLoss
from torch.nn.functional import nll_loss

DATA_PATH = "data"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATASET = "MNIST"
VAL_RATIO = 0.2
NX1, NX2 = (28, 28)

IN_CHANNELS = 1
OUT_CHANNELS = 10
HIDDEN_CHANNELS = 20

EPOCHS = 20


def get_model(nx3, knn, eps, xi, weight_sigma, weight_kernel, K, pooling):
    graphs = [
        HyperCubeGraph(
            grid_size=(NX1, NX2),
            nx3=nx3,
            weight_kernel=weight_kernel,
            weight_sigma=weight_sigma,
            knn=knn,
            sigmas=(xi / eps, xi, 1.0),
            weight_comp_device=DEVICE,
        ),
        HyperCubeGraph(
            grid_size=(NX1 // 2, NX2 // 2),
            nx3=nx3,
            weight_kernel=weight_kernel,
            weight_sigma=weight_sigma,
            knn=knn,
            sigmas=(xi / eps, xi, 1.0),
            weight_comp_device=DEVICE,
        ),
        HyperCubeGraph(
            grid_size=(NX1 // 2 // 2, NX2 // 2 // 2),
            nx3=nx3,
            weight_kernel=weight_kernel,
            weight_sigma=weight_sigma,
            knn=knn,
            sigmas=(xi / eps, xi, 1.0),
            weight_comp_device=DEVICE,
        ),
    ]

    model = ChebNet(graphs, K, IN_CHANNELS, OUT_CHANNELS, HIDDEN_CHANNELS, laplacian_device=DEVICE, pooling=pooling)
    #while model.capacity < NUM_PARAMS:
    #    hidden_channels += 1
    #    model = ChebNet(graphs, K, IN_CHANNELS, OUT_CHANNELS, hidden_channels, laplacian_device=DEVICE, pooling=pooling)

    print(model.capacity)

    return model.to(DEVICE)


In [None]:
batch_size = 4
xi = 0.01
eps = 0.1
K = 20
knn = 27
learning_rate = 1e-3
nx3 = 6
optimizer = "adam"
weight_sigma = 1.
weight_decay = 0
weight_kernel = "gaussian"
pooling = "max"
    
train_loader, val_loader = get_train_val_data_loaders(DATASET, batch_size=batch_size, val_ratio=VAL_RATIO, data_path=DATA_PATH)

model = get_model(nx3, knn, eps, xi, weight_sigma, weight_kernel, K, pooling)

optimizer = get_optimizer(model, optimizer, learning_rate, weight_decay)

loss_fn = nll_loss
metrics = {"val_mnist_acc": Accuracy(), "val_mnist_loss": Loss(loss_fn)}

# create ignite's engines
trainer = create_supervised_trainer(
    L=nx3, model=model, optimizer=optimizer, loss_fn=loss_fn, device=DEVICE, prepare_batch=prepare_batch
)
ProgressBar(persist=False, desc="Training").attach(trainer)

ProgressBar(persist=False, desc="Training").attach(trainer)

evaluator = create_supervised_evaluator(L=nx3, model=model, metrics=metrics, device=DEVICE, prepare_batch=prepare_batch)

# track training with wandb
_ = trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, evaluator, val_loader)

# save best model

trainer.run(train_loader, max_epochs=EPOCHS)

In [None]:
import torch.nn as nn

m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
# input is of size N x C = 3 x 5
input = torch.randn(3, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4])
output = loss(m(input), target)

In [None]:
input

In [None]:
a = torch.rand(10)
a.dtype

In [None]:
from gechebnet.model.convolution import cheb_conv, ChebConv
from gechebnet.model.chebnet import ChebNet

In [None]:
model = ChebNet(
[HyperCubeGraph(
            grid_size=(2, 2),
            nx3=2,
            weight_kernel="gaussian",
            weight_sigma=1.,
            knn=2,
            sigmas=(xi / eps, xi, 1.0),
            weight_comp_device=DEVICE,
        )],
    10, 1, 10, 20
)

model.capacity

In [None]:
xi = 0.01
eps = 0.1

graph = HyperCubeGraph(
            grid_size=(2, 2),
            nx3=2,
            weight_kernel="gaussian",
            weight_sigma=1.,
            knn=2,
            sigmas=(xi / eps, xi, 1.0),
            weight_comp_device=DEVICE,
        )

In [None]:
graph.laplacian

In [None]:
x = torch.ones(1, 2, 8)
x

In [None]:
w = torch.rand(2, 2, 3)
w

In [None]:
cheb_conv(graph.laplacian, x, w).permute(1, 2, 0)

In [None]:
import math

In [None]:
math.exp(-1/(2*(0.1)**2))

In [None]:
C = ChebConv(graph, 2, 3, 1)

In [None]:
C.weight

In [None]:
C.bias

In [None]:
C.laplacian

In [None]:
x = torch.ones(1, 2, 8)
C(x)

In [None]:
x = torch.rand(1, 1, 2, 2)
x


In [None]:

B, C, H, W = x.shape  # (B, C, H, W)

x = x.unsqueeze(2).expand(B, C, 2, 2, 2).reshape(B, C, -1)  # (B, C, L*H*W)
x

In [None]:
x.reshape(B, C, -1, H, W)