In [None]:
import torch
import numpy as np
import math

In [None]:
device = torch.device("cuda")

In [None]:
def get_memory_size(input, sparse=False):
    if sparse:
        return (input._values().element_size() * input._values().nelement() + 
               input._indices().element_size() * input._indices().nelement() + 8) * 1e-9
    return input.element_size() * input.nelement() * 1e-9 # in GB

def get_sparsity(input):
    return 1 - input._nnz()/input.nelement()

# Cheb Conv

In [None]:
from gechebnet.graph.graph import SE2GEGraph, RandomSubGraph
from gechebnet.model.convolution import ChebConv

In [None]:
xi, eps = 1., 1.
graph = SE2GEGraph(
    nx=96,
    ny=96,
    ntheta=6,
    knn=16,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, sigmac: torch.exp(-sqdistc / sigmac),
)

In [None]:
sub_graph = RandomSubGraph(graph)

In [None]:
cheb_conv = ChebConv(sub_graph, 1, 1, 2, ).to(device)

In [None]:
x = torch.rand(8, 1, 96*96*6).to(device)

In [None]:
if hasattr(sub_graph, "laplacian"):
    del sub_graph.laplacian

In [None]:
x = sub_graph.project(x)
y = cheb_conv(x)

In [None]:
if hasattr(sub_graph, "laplacian"):
    del sub_graph.laplacian
    
sub_graph.edge_sampling(0.5)

In [None]:
x = sub_graph.project(x)
y = cheb_conv(x)

In [None]:
if hasattr(sub_graph, "laplacian"):
    del sub_graph.laplacian
if hasattr(sub_graph, "node_proj"):
    del sub_graph.node_proj
    
sub_graph.node_sampling(0.5)

In [None]:
x = sub_graph.project(x)
y = cheb_conv(x)

## Wide Group Equivariant ChebNet

In [None]:
from gechebnet.model.chebnet import WideGEChebNet

In [None]:
model = WideGEChebNet(in_channels=1, out_channels=10, K=2, graph=se2_graph, depth=8, widen_factor=2)
model = model.to(device)

In [None]:
model.capacity

In [None]:
model(x)

In [None]:
model.graph.laplacian

In [None]:
se2_graph.set_sparse_laplacian(on="edges", rate=0.4, norm=True, device=device)

In [None]:
model.graph.laplacian

In [None]:
model

## Wide Residual Group Equivariant ChebNet

In [None]:
from gechebnet.model.reschebnet import WideResGEChebNet

In [None]:
model = WideResGEChebNet(in_channels=1, out_channels=10, K=5, graph=se2_graph, depth=26, widen_factor=2)
model = model.to(device)
model

In [None]:
model.capacity

In [None]:
model(x)