In [None]:
import torch
import torch_sparse
import torch_scatter
import torchvisio

torch.__version__, torch_sparse.__version__, torch_scatter.__version__

In [None]:
from gechebnet.graph.graph import HyperCubeGraph
from gechebnet.graph.plot import visualize_graph, visualize_neighborhood
from gechebnet.graph.utils import is_undirected

In [None]:
from torch_sparse import coalesce, transpose

# Define anisotropy parameters

In [None]:
xi = .01
eps = .1

# Create a graph

In [None]:
graph = HyperCubeGraph(
    grid_size=(5,5),
    nx3=5,
    sigmas=(xi/eps, xi, 1.),
    weight_sigma=1.0,
    knn=26,
    weight_comp_device=torch.device("cuda")
)

In [None]:
is_undirected(graph.edge_index, graph.edge_weight, graph.num_nodes)

# Visualize graph's nodes

In [None]:
fig = visualize_graph(graph)

# Visualize neighborhood

In [None]:
fig = visualize_neighborhood(graph, 211)

In [None]:
fig = visualize_neighborhood(graph, 2)

# One can choose an other weight kernel

In [None]:
# Laplacian kernel
graph = HyperCubeGraph(
    grid_size=(10,10),
    equiv_axis_size=10,
    sigmas=(xi/eps, xi, 1.),
    weight_kernel="laplacian",
    weight_comp_device=torch.device("cuda")
)

fig = visualize_neighborhood(graph, graph.centroid_index)

In [None]:
# Cauchy kernel
graph = HyperCubeGraph(
    grid_size=(10,10),
    equiv_axis_size=10,
    sigmas=(xi/eps, xi, 1.),
    weight_kernel="cauchy",
    weight_comp_device=torch.device("cuda")
)

fig = visualize_neighborhood(graph, graph.centroid_index)

# Compress graph

In [None]:
alpha = 96/28
xi_ = xi/(alpha**2)
eps_ = eps

## Node compression

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=6,
                       static_compression=("node", 0.5),
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi_/eps_, xi_, 1.))

In [None]:
fig = visualize_graph(graph_data)

In [None]:
fig = visualize_weight_fields(graph_data)

## Edge compression

In [None]:
graph_data = GraphData(grid_size=(96, 96), 
                       num_layers=6,
                       static_compression=("edge", 0.5),
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi_/eps_, xi_, 1.))

In [None]:
fig = visualize_graph(graph_data)

In [None]:
fig = visualize_weight_fields(graph_data)

# Image embedding

In [None]:
from gechebnet.data.dataloader import get_data_list_mnist, get_data_list_rotated_mnist, get_data_list_stl10

## MNIST

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=3,
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi/eps, xi, 1.))

In [None]:
processed_path = "../../data/MNIST/processed"
data_list = get_data_list_mnist(graph_data, processed_path, train=True)

In [None]:
fig = visualize_samples(data_list)

## Rotated MNIST

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=3,
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi/eps, xi, 1.))

In [None]:
processed_path = "../../data/RotatedMNIST/processed"
data_list = get_data_list_rotated_mnist(graph_data, processed_path, train=True)

In [None]:
fig = visualize_samples(data_list)

## STL10

In [None]:
graph_data = GraphData(grid_size=(96, 96), 
                       num_layers=3,
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       static_compression=("node", 0.5),
                       sigmas=(xi_/eps_, xi_, 1.))

In [None]:
fig = visualize_weight_fields(graph_data)

In [None]:
processed_path = "../../data/stl10/processed"
data_list = get_data_list_stl10(graph_data, processed_path, train=True)

In [None]:
fig = visualize_samples(data_list)