In [None]:
import math
import os

import torch
import wandb
from gechebnet.data.dataloader import get_train_val_data_loaders
from gechebnet.engine.engine import create_supervised_evaluator, create_supervised_trainer
from gechebnet.engine.utils import prepare_batch, wandb_log
from gechebnet.graph.graph import HyperCubeGraph
from gechebnet.model.chebnet import ChebNet
from gechebnet.model.optimizer import get_optimizer
from gechebnet.utils import random_choice
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events
from ignite.metrics import Accuracy, Loss

from torch.nn import NLLLoss
from torch.nn.functional import nll_loss

DATA_PATH = "data"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATASET = "MNIST"
VAL_RATIO = 0.2
NX1, NX2 = (28, 28)

IN_CHANNELS = 1
OUT_CHANNELS = 10
HIDDEN_CHANNELS = 20

EPOCHS = 20


def get_model(nx3, knn, eps, xi, weight_sigma, weight_kernel, K, pooling):
    graphs = [
        HyperCubeGraph(
            grid_size=(NX1, NX2),
            nx3=nx3,
            weight_kernel=weight_kernel,
            weight_sigma=weight_sigma,
            knn=knn,
            sigmas=(xi / eps, xi, 1.0),
            weight_comp_device=DEVICE,
        ),
        HyperCubeGraph(
            grid_size=(NX1 // 2, NX2 // 2),
            nx3=nx3,
            weight_kernel=weight_kernel,
            weight_sigma=weight_sigma,
            knn=knn,
            sigmas=(xi / eps, xi, 1.0),
            weight_comp_device=DEVICE,
        ),
        HyperCubeGraph(
            grid_size=(NX1 // 2 // 2, NX2 // 2 // 2),
            nx3=nx3,
            weight_kernel=weight_kernel,
            weight_sigma=weight_sigma,
            knn=knn,
            sigmas=(xi / eps, xi, 1.0),
            weight_comp_device=DEVICE,
        ),
    ]

    model = ChebNet(graphs, K, IN_CHANNELS, OUT_CHANNELS, HIDDEN_CHANNELS, laplacian_device=DEVICE, pooling=pooling)
    #while model.capacity < NUM_PARAMS:
    #    hidden_channels += 1
    #    model = ChebNet(graphs, K, IN_CHANNELS, OUT_CHANNELS, hidden_channels, laplacian_device=DEVICE, pooling=pooling)

    print(model.capacity)

    return model.to(DEVICE)


In [None]:
import numpy as np
lst = [0, 1, 0, 1]
np.std(lst)

In [None]:
MIN_KNN = 2

for knn_exp in [0, 1, 2, 3, 4]:
    print(MIN_KNN * 2**knn_exp)
    print(MIN_KNN * 2**knn_exp*4)
    print(MIN_KNN * 2**knn_exp*16)
    print("\n")

In [None]:
24/(7*7*4)

In [None]:
import torch
node_index = torch.arange(10000)
node_index.dtype

In [None]:
batch_size = 4
xi = 0.01
eps = 0.1
K = 20
knn = 27
learning_rate = 1e-3
nx3 = 6
optimizer = "adam"
weight_sigma = 1.
weight_decay = 0
weight_kernel = "gaussian"
pooling = "max"
    
train_loader, val_loader = get_train_val_data_loaders(DATASET, batch_size=batch_size, val_ratio=VAL_RATIO, data_path=DATA_PATH)

model = get_model(nx3, knn, eps, xi, weight_sigma, weight_kernel, K, pooling)

optimizer = get_optimizer(model, optimizer, learning_rate, weight_decay)

loss_fn = nll_loss
metrics = {"val_mnist_acc": Accuracy(), "val_mnist_loss": Loss(loss_fn)}

# create ignite's engines
trainer = create_supervised_trainer(
    L=nx3, model=model, optimizer=optimizer, loss_fn=loss_fn, device=DEVICE, prepare_batch=prepare_batch
)
ProgressBar(persist=False, desc="Training").attach(trainer)

ProgressBar(persist=False, desc="Training").attach(trainer)

evaluator = create_supervised_evaluator(L=nx3, model=model, metrics=metrics, device=DEVICE, prepare_batch=prepare_batch)

# track training with wandb
_ = trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, evaluator, val_loader)

# save best model

trainer.run(train_loader, max_epochs=EPOCHS)

In [None]:
import torch.nn as nn

m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
# input is of size N x C = 3 x 5
input = torch.randn(3, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4])
output = loss(m(input), target)

In [None]:
a = torch.rand(10)
a.dtype

In [None]:
from gechebnet.model.convolution import cheb_conv, ChebConv
from gechebnet.model.chebnet import ChebNet

from gechebnet.graph.plot import visualize_graph, visualize_neighborhood


In [None]:
model = ChebNet(
[HyperCubeGraph(
            grid_size=(2, 2),
            nx3=2,
            weight_kernel="gaussian",
            weight_sigma=1.,
            knn=2,
            sigmas=(xi / eps, xi, 1.0),
            weight_comp_device=DEVICE,
        )],
    10, 1, 10, 20
)

model.capacity

In [None]:
xi = 0.01
eps = 0.1

graph_1 = HyperCubeGraph(
            grid_size=(28, 28),
            nx3=6,
            weight_kernel="gaussian",
            weight_sigma=1.,
            knn=26,
            sigmas=(xi / eps, xi, 1.0),
            weight_comp_device=DEVICE,
        )

In [None]:
graph_1.num_edges

In [None]:
graph_2 = HyperCubeGraph(
            grid_size=(14, 14),
            nx3=6,
            weight_kernel="gaussian",
            weight_sigma=1.,
            knn=26,
            sigmas=(xi/ eps, xi, 1.0),  # adapt the metric kernel to the size of the graph
            weight_comp_device=DEVICE,
        )

In [None]:
graph_2.num_edges

In [None]:
graph_3 = HyperCubeGraph(
            grid_size=(7, 7),
            nx3=6,
            weight_kernel="gaussian",
            weight_sigma=1.,
            knn=26,
            sigmas=(xi/ eps, xi, 1.0),  # adapt the metric kernel to the size of the graph
            weight_comp_device=DEVICE,
        )

In [None]:
fig = visualize_neighborhood(graph_1, 0)

In [None]:
fig = visualize_neighborhood(graph_2, 0)

In [None]:
fig = visualize_neighborhood(graph_3, 0)

In [None]:
import math
from typing import Optional, Tuple, TypeVar

import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor, RandomRotation, RandomHorizontalFlip, RandomVerticalFlip

MNIST_MEAN, MNIST_STD = (0.1307,), (0.3081,)

In [None]:
transformation = [RandomRotation(degrees=180), ToTensor(), Normalize(MNIST_MEAN, MNIST_STD)]
transformation = [RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5), ToTensor(), Normalize(MNIST_MEAN, MNIST_STD)]


dataset = MNIST(
    "data", train=False, download=True, transform=Compose(transformation)
)

dataloader = DataLoader(dataset, batch_size=8)

In [None]:
batch = next(iter(dataloader))

In [None]:
plt.imshow(batch[0][0].squeeze())

In [None]:
plt.imshow(batch[0][1].squeeze())

In [None]:
plt.imshow(batch[0][2].squeeze())

In [None]:
from gechebnet.data.dataloader import get_test_equivariance_data_loader

import matplotlib.pyplot as plt


In [None]:
d1, d2, d3 = get_test_equivariance_data_loader("MNIST")

In [None]:
batch = next(iter(d1))
plt.imshow(batch[0][1].squeeze())

In [None]:
batch = next(iter(d2))
plt.imshow(batch[0][1].squeeze())

In [None]:
batch = next(iter(d3))
plt.imshow(batch[0][1].squeeze())

In [None]:
"batch_size": {"distribution": "q_log_uniform", "min": math.log(8), "max": math.log(256)},
"eps": {"distribution": "log_uniform", "min": math.log(0.1), "max": math.log(1.0)},
"K": {"distribution": "q_log_uniform", "min": math.log(2), "max": math.log(64)},
"knn": {"distribution": "categorical", "values": [2, 4, 8, 16, 32]},
"learning_rate": {"distribution": "log_uniform", "min": math.log(1e-5), "max": math.log(0.1)},
"nx3": {"distribution": "int_uniform", "min": 3, "max": 9},
"pooling": {"distribution": "categorical", "values": ["max", "avg"]},


In [None]:
import random
import math

In [None]:
# graph
nx3 = random.randint(3, 12)
eps = math.exp(random.uniform(math.log(0.1), math.log(1.0)))
xi = math.exp(random.uniform(math.log(1e-2), math.log(1.0)))
knn = random.choice([2, 4, 8, 16, 32])
weight_kernel = random.choice(["cauchy", "gaussian", "laplacian"])
weight_sigma = math.exp(random.uniform(math.log(0.25), math.log(10)))

# network
K = random.choice([2, 4, 8, 16, 32, 64])
pooling = random.choice(["max", "avg"])

# training
batch_size = random.choice([8, 16, 32, 64, 128, 256])
learning_rate = math.exp(random.uniform(math.log(1e-5), math.log(0.1)))
weight_decay = math.exp(random.uniform(math.log(1e-7), math.log(1e-2)))

In [None]:
nx3, eps, xi, knn, weight_kernel, weight_sigma

In [None]:
import math
import random

import numpy as np
import torch
from gechebnet.graph.graph import HyperCubeGraph

NX1 = 28
NX2 = 28
NX3 = 8

POOLING_SIZE = 2

DEVICE = torch.device("cuda")

import time

NUM_ITER = 1


def build_graphs(knn):

    eps = math.exp(random.uniform(math.log(0.1), math.log(1.0)))
    xi = math.exp(random.uniform(math.log(1e-2), math.log(1.0)))

    print((xi, xi * eps, 1.0))

    times = []
    print(f"KNN = {int(knn * POOLING_SIZE ** 4)} and V = {NX1*NX2*NX3}")
    for _ in range(NUM_ITER):
        start = time.time()
        graph_1 = HyperCubeGraph(
            grid_size=(NX1, NX2),
            nx3=NX3,
            knn=int(knn * POOLING_SIZE ** 4),
            weight_comp_device=DEVICE,
            sigmas=(xi, xi * eps, 1.0),
        )
        end = time.time()
        if graph_1.num_nodes > graph_1.num_edges:
            print("Value Error: an error occured during the construction of the graph")
        times.append(end - start)
    print(f"time: mean {np.mean(times)} std {np.std(times)}")

    times = []
    print(f"KNN = {int(knn * POOLING_SIZE ** 2)} and V = {(NX1//POOLING_SIZE)*(NX2//POOLING_SIZE)*NX3}")
    for _ in range(NUM_ITER):
        start = time.time()
        graph_2 = HyperCubeGraph(
            grid_size=(NX1 // POOLING_SIZE, NX2 // POOLING_SIZE),
            nx3=NX3,
            knn=int(knn * POOLING_SIZE ** 2),
            weight_comp_device=DEVICE,
            sigmas=(xi / eps, xi, 1.0),
        )
        end = time.time()
        if graph_2.num_nodes > graph_2.num_edges:
            print("Value Error: an error occured during the construction of the graph")
        times.append(end - start)
    print(f"time: mean {np.mean(times)} std {np.std(times)}")

    times = []
    print(f"KNN = {int(knn)} and V = {(NX1//POOLING_SIZE//POOLING_SIZE)*(NX2//POOLING_SIZE//POOLING_SIZE)*NX3}")
    for _ in range(NUM_ITER):
        start = time.time()
        graph_3 = HyperCubeGraph(
            grid_size=(NX1 // POOLING_SIZE // POOLING_SIZE, NX2 // POOLING_SIZE // POOLING_SIZE),
            nx3=NX3,
            knn=int(knn),
            weight_comp_device=DEVICE,
            sigmas=(xi / eps, xi, 1.0),
        )
        end = time.time()
        if graph_3.num_nodes > graph_3.num_edges:
            print("Value Error: an error occured during the construction of the graph")
        times.append(end - start)
    print(f"time: mean {np.mean(times)} std {np.std(times)}")


for knn in [2, 4, 8, 16, 32]:
    build_graphs(knn)


In [None]:
import math
import os
import random

import pykeops
import torch
import wandb
from gechebnet.data.dataloader import get_train_val_data_loaders
from gechebnet.engine.engine import create_supervised_evaluator, create_supervised_trainer
from gechebnet.engine.utils import prepare_batch, wandb_log
from gechebnet.graph.graph import HyperCubeGraph
from gechebnet.model.chebnet import ChebNet
from gechebnet.model.optimizer import get_optimizer
from gechebnet.utils import random_choice
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events
from ignite.metrics import Accuracy, Loss
from torch.nn import NLLLoss
from torch.nn.functional import nll_loss

DEVICE = torch.device("cuda")

DATASET_NAME = "MNIST"  # STL10
VAL_RATIO = 0.2
NX1, NX2 = (28, 28)

IN_CHANNELS = 1
OUT_CHANNELS = 10
HIDDEN_CHANNELS = 20
POOLING_SIZE = 2

EPOCHS = 20
OPTIMIZER = "adam"

NUM_ITER = 10


def build_sweep_config():
    sweep_config = {"method": "bayes", "metric": {"name": "validation_accuracy", "goal": "maximize"}}

    sweep_config["parameters"] = {
        "batch_size": {"distribution": "q_log_uniform", "min": math.log(8), "max": math.log(256)},
        "eps": {"distribution": "log_uniform", "min": math.log(0.1), "max": math.log(1.0)},
        "K": {"distribution": "q_log_uniform", "min": math.log(2), "max": math.log(64)},
        "knn": {"distribution": "categorical", "values": [2, 4, 8, 16, 32]},
        "learning_rate": {"distribution": "log_uniform", "min": math.log(1e-5), "max": math.log(0.1)},
        "nx3": {"distribution": "int_uniform", "min": 3, "max": 9},
        "pooling": {"distribution": "categorical", "values": ["max", "avg"]},
        "weight_sigma": {"distribution": "uniform", "min": 0.25, "max": 8.0},
        "weight_decay": {"distribution": "log_uniform", "min": math.log(1e-6), "max": math.log(1e-3)},
        "weight_kernel": {"distribution": "categorical", "values": ["cauchy", "gaussian", "laplacian"]},
        "xi": {"distribution": "log_uniform", "min": math.log(1e-2), "max": math.log(1.0)},
    }

    return sweep_config


def get_model(nx3, knn, eps, xi, weight_sigma, weight_kernel, K, pooling):

    print("nx3", nx3, type(nx3))
    print("knn", knn, type(knn))
    print("eps", eps, type(eps))
    print("xi", xi, type(xi))
    print("weight_sigma", weight_sigma, type(weight_sigma))
    print("weight_kernel", weight_kernel, type(weight_kernel))
    print("K", K, type(K))
    print("pooling", pooling, type(pooling))

    print("NX1, NX2", NX1, NX2)
    print("NX1 // POOLING_SIZE, NX2 // POOLING_SIZE", NX1 // POOLING_SIZE, NX2 // POOLING_SIZE)
    print(
        "NX1 // POOLING_SIZE // POOLING_SIZE, NX2 // POOLING_SIZE// POOLING_SIZE",
        NX1 // POOLING_SIZE // POOLING_SIZE,
        NX2 // POOLING_SIZE // POOLING_SIZE,
    )

    # Different graphs are for successive pooling layers

    graph_1 = HyperCubeGraph(
        grid_size=(NX1, NX2),
        nx3=nx3,
        knn=int(knn * POOLING_SIZE ** 4),
        sigmas=(xi/eps, xi, 1.0),
        weight_comp_device=DEVICE,
        weight_sigma=weight_sigma,
        weight_kernel=weight_kernel
    )
    if graph_1.num_nodes > graph_1.num_edges:
        raise ValueError(f"An error occured during the computation of the graph")
    wandb.log({f"graph_1_nodes": graph_1.num_nodes, f"graph_1_edges": graph_1.num_edges})

    graph_2 = HyperCubeGraph(
        grid_size=(NX1 // POOLING_SIZE, NX2 // POOLING_SIZE),
        nx3=nx3,
        knn=int(knn * POOLING_SIZE ** 2),
        sigmas=(xi/eps, xi, 1.0),
        weight_comp_device=DEVICE,
        weight_sigma=weight_sigma,
        weight_kernel=weight_kernel
    )
    if graph_2.num_nodes > graph_2.num_edges:
        raise ValueError(f"An error occured during the computation of the graph")
    wandb.log({f"graph_2_nodes": graph_2.num_nodes, f"graph_2_edges": graph_2.num_edges})

    graph_3 = HyperCubeGraph(
        grid_size=(NX1 // POOLING_SIZE // POOLING_SIZE, NX2 // POOLING_SIZE // POOLING_SIZE),
        nx3=nx3,
        knn=int(knn * POOLING_SIZE ** 4),
        sigmas=(xi/eps, xi, 1.0),
        weight_comp_device=DEVICE,
        weight_sigma=weight_sigma,
        weight_kernel=weight_kernel
    )
    if graph_3.num_nodes > graph_3.num_edges:
        raise ValueError(f"An error occured during the computation of the graph")
    wandb.log({f"graph_3_nodes": graph_3.num_nodes, f"graph_3_edges": graph_3.num_edges})

    model = ChebNet(
        (graph_1, graph_2, graph_3),
        K,
        IN_CHANNELS,
        OUT_CHANNELS,
        HIDDEN_CHANNELS,
        laplacian_device=DEVICE,
        pooling=pooling,
    )

    wandb.log({"capacity": model.capacity})

    return model.to(DEVICE)


def train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config):
        config = wandb.config

        # Model and optimizer
        model = get_model(
            config.nx3,
            config.knn,
            config.eps,
            config.xi,
            config.weight_sigma,
            config.weight_kernel,
            config.K,
            config.pooling,
        )

#         optimizer = get_optimizer(model, OPTIMIZER, config.learning_rate, config.weight_decay)
#         loss_fn = nll_loss

#         # Trainer and evaluator(s) engines
#         trainer = create_supervised_trainer(
#             L=config.nx3,
#             model=model,
#             optimizer=optimizer,
#             loss_fn=loss_fn,
#             device=DEVICE,
#             prepare_batch=prepare_batch,
#         )
#         ProgressBar(persist=False, desc="Training").attach(trainer)

#         metrics = {"validation_accuracy": Accuracy(), "validation_loss": Loss(loss_fn)}

#         evaluator = create_supervised_evaluator(
#             L=config.nx3, model=model, metrics=metrics, device=DEVICE, prepare_batch=prepare_batch
#         )
#         ProgressBar(persist=False, desc="Evaluation").attach(evaluator)

#         train_loader, val_loader = get_train_val_data_loaders(
#             DATASET_NAME, batch_size=config.batch_size, val_ratio=VAL_RATIO, data_path=DATA_PATH
#         )

#         # Performance tracking with wandb
#         trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, evaluator, val_loader)

#         trainer.run(train_loader, max_epochs=EPOCHS)

sweep_config = build_sweep_config()
sweep_id = wandb.sweep(sweep_config, project="gechebnet")
wandb.agent(sweep_id, train, count=50)

In [None]:
!python3 ~/Documents/thesis/GroupEquivariantChebNets/scripts/test_compiled_graphs.py