In [None]:
import torch
import numpy as np
import math

In [None]:
device = torch.device("cuda")

In [None]:
def get_memory_size(input, sparse=False):
    if sparse:
        return (input._values().element_size() * input._values().nelement() + 
               input._indices().element_size() * input._indices().nelement() + 8) * 1e-9
    return input.element_size() * input.nelement() * 1e-9 # in GB

def get_sparsity(input):
    return 1 - input._nnz()/input.nelement()

# Cheb Conv

In [None]:
from gechebnet.graph.graph import SE2GEGraph, RandomSubGraph
from gechebnet.model.convolution import ChebConv

In [None]:
xi, eps = .1, 20.
graph_lvl1 = SE2GEGraph(
    nx=8,
    ny=8,
    ntheta=3,
    K=8,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, sigmac: torch.exp(-sqdistc / sigmac),
)
graph_lvl2 = SE2GEGraph(
    nx=4,
    ny=4,
    ntheta=3,
    K=8,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, sigmac: torch.exp(-sqdistc / sigmac),
)
graph_lvl3 = SE2GEGraph(
    nx=2,
    ny=2,
    ntheta=3,
    K=8,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, sigmac: torch.exp(-sqdistc / sigmac),
)

In [None]:
sub_graph_lvl1 = RandomSubGraph(graph_lvl1)
sub_graph_lvl2 = RandomSubGraph(graph_lvl2)
sub_graph_lvl3 = RandomSubGraph(graph_lvl3)

In [None]:
cheb_conv = ChebConv(sub_graph_lvl1, 1, 1, 2, ).to(device)

In [None]:
x = torch.rand(1, 1, 8*8*3).to(device)

In [None]:
cheb_conv(x).shape

In [None]:
if hasattr(sub_graph, "laplacian"):
    del sub_graph.laplacian
if hasattr(sub_graph, "node_proj"):
    del sub_graph.node_proj
    
sub_graph.node_sampling(0.5)

In [None]:
x = sub_graph.project(x)
y = cheb_conv(x)

## Wide Group Equivariant ChebNet

In [None]:
from gechebnet.model.chebnet import WideGEChebNet

In [None]:
model = WideGEChebNet(in_channels=1, out_channels=10, R=2, graph_lvl1=sub_graph_lvl1, graph_lvl2=sub_graph_lvl2, graph_lvl3=sub_graph_lvl3, depth=8, widen_factor=2)
model = model.to(device)

In [None]:
model.capacity

In [None]:
x = torch.rand(1, 1, 8*8*3).to(device)
model(x)

In [None]:
if hasattr(sub_graph_lvl1, "laplacian"):
    del sub_graph_lvl1.laplacian
if hasattr(sub_graph_lvl1, "node_proj"):
    del sub_graph_lvl1.node_proj
    
sub_graph_lvl1.edge_sampling(0.5)

In [None]:
model(x)

In [None]:
model.graph.laplacian

In [None]:
model

## Wide Residual Group Equivariant ChebNet

In [None]:
from gechebnet.model.reschebnet import WideResGEChebNet

In [None]:
model = WideResGEChebNet(in_channels=1, out_channels=10, R=5, graph=se2_graph, depth=26, widen_factor=2)
model = model.to(device)
model

In [None]:
model.capacity

In [None]:
model(x)