In [None]:
from gechebnet.graph.graph import GraphData
from gechebnet.graph.utils import GaussianKernel, CauchyKernel
from gechebnet.graph.plot import visualize_graph, visualize_weight_fields, visualize_samples, visualize_weight_field

# Define anisotropy parameters

In [None]:
xi = 0.05
eps = 0.1

# Create a graph

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=9,
                       self_loop=True, 
                       sigmas=(xi/eps, xi, 1.))

# Visualize graph's nodes

In [None]:
fig = visualize_graph(graph_data)

# Plot weight fields

In [None]:
fig = visualize_weight_fields(graph_data)

# Plot specific weight field

In [None]:
node_index = int(graph_data.nx1 / 2) + int(graph_data.nx2 / 2) * graph_data.nx1 + 2 * graph_data.nx1 * graph_data.nx2
fig = visualize_weight_field(graph_data, node_index)

# Use a different weights' kernel

## Gaussian Kernel

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=9,
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi/eps, xi, 1.),
                      )

In [None]:
fig = visualize_weight_fields(graph_data)

## Cauchy Kernel

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=9,
                       self_loop=True, 
                       weight_kernel=CauchyKernel(0.3, 1.0),
                       sigmas=(xi/eps, xi, 1.))

In [None]:
fig = visualize_weight_fields(graph_data)

# Compress graph

In [None]:
alpha = 96/28
xi_ = xi/(alpha**2)
eps_ = eps

## Node compression

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=6,
                       static_compression=("node", 0.5),
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi_/eps_, xi_, 1.))

In [None]:
fig = visualize_graph(graph_data)

In [None]:
fig = visualize_weight_fields(graph_data)

## Edge compression

In [None]:
graph_data = GraphData(grid_size=(96, 96), 
                       num_layers=6,
                       static_compression=("edge", 0.5),
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi_/eps_, xi_, 1.))

In [None]:
fig = visualize_graph(graph_data)

In [None]:
fig = visualize_weight_fields(graph_data)

# Image embedding

In [None]:
from gechebnet.data.dataloader import get_data_list_mnist, get_data_list_rotated_mnist, get_data_list_stl10

## MNIST

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=3,
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi/eps, xi, 1.))

In [None]:
processed_path = "../../data/MNIST/processed"
data_list = get_data_list_mnist(graph_data, processed_path, train=True)

In [None]:
fig = visualize_samples(data_list)

## Rotated MNIST

In [None]:
graph_data = GraphData(grid_size=(28, 28), 
                       num_layers=3,
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       sigmas=(xi/eps, xi, 1.))

In [None]:
processed_path = "../../data/RotatedMNIST/processed"
data_list = get_data_list_rotated_mnist(graph_data, processed_path, train=True)

In [None]:
fig = visualize_samples(data_list)

## STL10

In [None]:
graph_data = GraphData(grid_size=(96, 96), 
                       num_layers=3,
                       self_loop=True, 
                       weight_kernel=GaussianKernel(0.3, 1.0),
                       static_compression=("node", 0.5),
                       sigmas=(xi_/eps_, xi_, 1.))

In [None]:
fig = visualize_weight_fields(graph_data)

In [None]:
processed_path = "../../data/stl10/processed"
data_list = get_data_list_stl10(graph_data, processed_path, train=True)

In [None]:
fig = visualize_samples(data_list)