In [None]:
import torch
import numpy as np
import math

In [None]:
from gechebnet.model.utils import ResidualBlock


In [None]:
m = ResidualBlock

In [None]:
device = torch.device("cuda")

In [None]:
def get_memory_size(input, sparse=False):
    if sparse:
        return (input._values().element_size() * input._values().nelement() + 
               input._indices().element_size() * input._indices().nelement() + 8) * 1e-9
    return input.element_size() * input.nelement() * 1e-9 # in GB

def get_sparsity(input):
    return 1 - input._nnz()/input.nelement()

# Graph

In [None]:
from gechebnet.graph.graph import SE2GEGraph

In [None]:
xi, eps = 1., 1.
se2_graph = SE2GEGraph(
    nx=30,
    ny=30,
    ntheta=6,
    knn=16,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, sigmac: torch.exp(-sqdistc / sigmac),
)

In [None]:
se2_graph.set_laplacian(norm=True, device=device)

In [None]:
get_memory_size(se2_graph.laplacian, sparse=True)

In [None]:
get_sparsity(se2_graph.laplacian)

# Data

In [None]:
x = torch.rand(15, 3, se2_graph.num_nodes, device=device)

In [None]:
get_memory_size(x)

## Wide Group Equivariant ChebNet

In [None]:
from gechebnet.model.chebnet import WideGEChebNet

In [None]:
model = WideGEChebNet(in_channels=1, out_channels=10, K=2, graph=se2_graph, depth=8, widen_factor=2)
model = model.to(device)

In [None]:
model.capacity

In [None]:
model(x)

## Wide Residual Group Equivariant ChebNet

In [None]:
from gechebnet.model.reschebnet import WideResGEChebNet

In [None]:
model = WideResGEChebNet(in_channels=1, out_channels=10, K=5, graph=se2_graph, depth=14, widen_factor=4)
model = model.to(device)
model

In [None]:
model.capacity

In [None]:
model(x)