In [None]:
import torch
import numpy as np
import math

In [None]:
device = torch.device("cuda")

In [None]:
def get_memory_size(input, sparse=False):
    if sparse:
        return (input._values().element_size() * input._values().nelement() + 
               input._indices().element_size() * input._indices().nelement() + 8) * 1e-9
    return input.element_size() * input.nelement() * 1e-9 # in GB

def get_sparsity(input):
    return 1 - input._nnz()/input.nelement()

# Chebyschev convolution

In [None]:
from gechebnet.graphs.graphs import SE2GEGraph, RandomSubGraph
from gechebnet.nn.layers.convs import ChebConv
from gechebnet.liegroups.se2 import se2_uniform_sampling
from gechebnet.nn.models.chebnets import WideResGEChebNet

In [None]:
K = 8
kernel_size = 4
xi, eps = 2.048 / (28 ** 2), 0.1
ntheta = 6

In [None]:
uniform_sampling = se2_uniform_sampling(28, 28, ntheta)
graph = SE2GEGraph(
    uniform_sampling,
    K=K,
    sigmas=(1.0, eps, xi),
    path_to_graph="saved_graphs",
)
sub_graph = RandomSubGraph(graph)

# Loads group equivariant Chebnet
model = WideResGEChebNet(
    in_channels=1,
    out_channels=10,
    kernel_size=kernel_size,
    pool=None,
    graph_lvl0=sub_graph,
    graph_lvl1=None,
    graph_lvl2=None,
    depth=8,
    widen_factor=2,
)


In [None]:
model.load_state_dict(torch.load("models/models/model_17.pt"))

In [None]:
model.block0.layers[0].conv1

In [None]:
impulse = torch.zeros(1, 64, 6*28*28)
impulse[:,:,100] = 1.

In [None]:
from gechebnet.utils.utils import delta_kronecker
import matplotlib.cm as cm
import matplotlib.pyplot as plt

In [None]:
impulse = delta_kronecker((1, 1, 6*28*28), (0,0, 3 * 28 * 28))

In [None]:
out = model.conv(impulse)

In [None]:
def plot_filters(graph, filter, indices, in_channels):
    L = graph.size[0]
    M = np.prod(graph.size[1:])
    K = len(indices)

    input = torch.zeros(1, in_channels, graph.num_nodes)
    input[:,:,100] = 1.
    
    with torch.no_grad():
        output = filter(input)
    
    fig = plt.figure(figsize=(4*K, 4*L))
    
    X, Y, Z = graph.cartesian_pos()
    
    for i, k in enumerate(indices):
        for l in range(L):
            ax = fig.add_subplot(L, K, l * K + i + 1)
            ax.scatter(X[l*M:(l+1)*M], Y[l*M:(l+1)*M], c=output[0, k, l*M:(l+1)*M], cmap=cm.PiYG)
            ax.axis("off")
            
    fig.tight_layout()


In [None]:
plot_filters(sub_graph, model.conv, np.arange(12), 1)

In [None]:
plot_filters(sub_graph, model.block2, np.arange(12), 16)

In [None]:
plot_filters(sub_graph, model.block1, np.arange(12), 32)

In [None]:
plot_filters(sub_graph, model.block0, np.arange(12), 64)

In [None]:
min(dx ** 2, dy ** 2)/dz**2

In [None]:
eps = 0.1
dz ** 2 / (9 * (dx ** 2)), dz ** 2 / (dx ** 2)

In [None]:
dz ** 2 / (9 * (dx ** 2) + 10 * dy ** 2)

In [None]:
dz ** 2 /  max(dx**2, dy**2)

In [None]:
eps, xi = .1, 20.

uniform_sampling = se2_uniform_sampling(3, 3, 6)
graph_lvl0 = SE2GEGraph(
    uniform_sampling
    16,
    (xi / eps, xi, 1.0),
    "saved_graphs"    
)
graph_lvl4 = SE2GEGraph(
    nx=4,
    ny=4,
    ntheta=3,
    K=8,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, sigmac: torch.exp(-sqdistc / sigmac),
)
graph_lvl3 = SE2GEGraph(
    nx=2,
    ny=2,
    ntheta=3,
    K=8,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, sigmac: torch.exp(-sqdistc / sigmac),
)
graph_lvl2 = SE2GEGraph(
    nx=2,
    ny=2,
    ntheta=3,
    K=8,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, sigmac: torch.exp(-sqdistc / sigmac),
)
graph_lvl1 = SE2GEGraph(
    nx=2,
    ny=2,
    ntheta=3,
    K=8,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, sigmac: torch.exp(-sqdistc / sigmac),
)
graph_lvl0 = SE2GEGraph(
    nx=3,
    ny=3,
    ntheta=6,
    K=16,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, sigmac: torch.exp(-sqdistc / sigmac),
)

In [None]:
sub_graph_lvl1 = RandomSubGraph(graph_lvl1)
sub_graph_lvl2 = RandomSubGraph(graph_lvl2)
sub_graph_lvl3 = RandomSubGraph(graph_lvl3)

In [None]:
cheb_conv = ChebConv(sub_graph_lvl1, 1, 1, 2, ).to(device)

In [None]:
x = torch.rand(1, 1, 8*8*3).to(device)

In [None]:
cheb_conv(x).shape

In [None]:
if hasattr(sub_graph, "laplacian"):
    del sub_graph.laplacian
if hasattr(sub_graph, "node_proj"):
    del sub_graph.node_proj
    
sub_graph.node_sampling(0.5)

In [None]:
x = sub_graph.project(x)
y = cheb_conv(x)

## Wide Residual Group Equivariant ChebNet

In [None]:
from gechebnet.model.reschebnet import WideResGEChebNet

In [None]:
model = WideResGEChebNet(in_channels=1, out_channels=10, R=5, graph=se2_graph, depth=26, widen_factor=2)
model = model.to(device)
model

In [None]:
capacity(model)

In [None]:
model(x)