In [None]:
from gechebnet.graph.graph import HyperCubeGraph
from gechebnet.graph.plot import visualize_graph, visualize_neighborhood
from gechebnet.graph.utils import is_undirected

import torch

import matplotlib.pyplot as plt

# Define anisotropy parameters

In [None]:
def compute_weight(sqdist, weight_kernel, weight_sigma=1.):
    if weight_kernel == "gaussian":
        kernel = lambda sqdistc: torch.exp(-sqdistc / weight_sigma ** 2)
    elif weight_kernel == "laplacian":
        kernel = lambda sqdistc: torch.exp(-torch.sqrt(sqdistc) / weight_sigma)
    elif weight_kernel == "cauchy":
        kernel = lambda sqdistc: 1 / (1 + sqdistc / weight_sigma ** 2)
        
    return kernel(sqdist)

In [None]:
sqdist = torch.arange(0, 5, 0.01)
_ = plt.plot(sqdist, compute_weight(sqdist, "gaussian"))
_ = plt.plot(sqdist, compute_weight(sqdist, "cauchy"))
_ = plt.plot(sqdist, compute_weight(sqdist, "laplacian"))
_ = plt.legend(["gaussian", "cauchy", "laplacian"])

# Create a graph

In [None]:
weight_sigma = 1.
eps, xi = 1., 1.

MIN_KNN = 2
MULT_KNN = 2
POOLING_SIZE = 2

DEVICE = torch.device("cuda")

NX1, NX2, NX3 = 20, 20, 20

for exp_knn in torch.arange(4):

    graph_1 = HyperCubeGraph(
        grid_size=(NX1, NX2),
        nx3=NX3,
        knn=int(MIN_KNN * MULT_KNN ** exp_knn * POOLING_SIZE ** 4),
        sigmas=(xi / eps, xi, 1.0),
        weight_comp_device=DEVICE,
    )

    wandb.log({f"graph_1_nodes": graph_1.num_nodes, f"graph_1_edges": graph_1.num_edges})

    graph_2 = HyperCubeGraph(
        grid_size=(NX1 // POOLING_SIZE, NX2 // POOLING_SIZE),
        nx3=NX3,
        knn=int(MIN_KNN * MULT_KNN ** exp_knn * POOLING_SIZE ** 2),
        sigmas=(xi / eps, xi, 1.0),
        weight_comp_device=DEVICE,
    )

    wandb.log({f"graph_2_nodes": graph_2.num_nodes, f"graph_2_edges": graph_2.num_edges})

    graph_3 = HyperCubeGraph(
        grid_size=(NX1 // POOLING_SIZE // POOLING_SIZE, NX2 // POOLING_SIZE // POOLING_SIZE),
        nx3=NX3,
        knn=int(MIN_KNN * MULT_KNN ** exp_knn),
        sigmas=(xi / eps, xi, 1.0),
        weight_comp_device=DEVICE,
    )

In [None]:
fig = visualize_neighborhood(graph, )

In [None]:
graph.edge_weight[1]

In [None]:
graph = HyperCubeGraph(
    grid_size=(5,5),
    nx3=1,
    sigmas=(4*xi/eps, 4*xi, 1.),
    weight_sigma=1.84,
    weight_kernel="cauchy",
    knn=26,
    weight_comp_device=torch.device("cuda")
)

In [None]:
graph.edge_weight[0]

In [None]:
fig = visualize_neighborhood(graph, 0)

In [None]:
is_undirected(graph.edge_index, graph.edge_weight, graph.num_nodes)

# Visualize graph's nodes

In [None]:
fig = visualize_graph(graph)

# Visualize neighborhood

In [None]:
fig = visualize_neighborhood(graph, 0)

In [None]:
fig = visualize_neighborhood(graph, 2)

# One can choose an other weight kernel

In [None]:
# Laplacian kernel
graph = HyperCubeGraph(
    grid_size=(10,10),
    equiv_axis_size=10,
    sigmas=(xi/eps, xi, 1.),
    weight_kernel="laplacian",
    weight_comp_device=torch.device("cuda")
)

fig = visualize_neighborhood(graph, graph.centroid_index)

In [None]:
# Cauchy kernel
graph = HyperCubeGraph(
    grid_size=(10,10),
    equiv_axis_size=10,
    sigmas=(xi/eps, xi, 1.),
    weight_kernel="cauchy",
    weight_comp_device=torch.device("cuda")
)

fig = visualize_neighborhood(graph, graph.centroid_index)

# Compress graph

In [None]:
alpha = 96/28
xi_ = xi/(alpha**2)
eps_ = eps

## Node compression

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=6,
                       static_compression=("node", 0.5),
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi_/eps_, xi_, 1.))

In [None]:
fig = visualize_graph(graph_data)

In [None]:
fig = visualize_weight_fields(graph_data)

## Edge compression

In [None]:
graph_data = GraphData(grid_size=(96, 96), 
                       num_layers=6,
                       static_compression=("edge", 0.5),
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi_/eps_, xi_, 1.))

In [None]:
fig = visualize_graph(graph_data)

In [None]:
fig = visualize_weight_fields(graph_data)

# Image embedding

In [None]:
from gechebnet.data.dataloader import get_data_list_mnist, get_data_list_rotated_mnist, get_data_list_stl10

## MNIST

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=3,
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi/eps, xi, 1.))

In [None]:
processed_path = "../../data/MNIST/processed"
data_list = get_data_list_mnist(graph_data, processed_path, train=True)

In [None]:
fig = visualize_samples(data_list)

## Rotated MNIST

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=3,
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi/eps, xi, 1.))

In [None]:
processed_path = "../../data/RotatedMNIST/processed"
data_list = get_data_list_rotated_mnist(graph_data, processed_path, train=True)

In [None]:
fig = visualize_samples(data_list)

## STL10

In [None]:
graph_data = GraphData(grid_size=(96, 96), 
                       num_layers=3,
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       static_compression=("node", 0.5),
                       sigmas=(xi_/eps_, xi_, 1.))

In [None]:
fig = visualize_weight_fields(graph_data)

In [None]:
processed_path = "../../data/stl10/processed"
data_list = get_data_list_stl10(graph_data, processed_path, train=True)

In [None]:
fig = visualize_samples(data_list)