In [None]:
import torch
import numpy as np
import math

#import pykeops
#pykeops.clean_pykeops()

In [None]:
device = torch.device("cuda")

In [None]:
from gechebnet.graph.graph import SE2GEGraph

In [None]:
xi, eps = 1., 1.
se2_graph = SE2GEGraph(
    nx=30,
    ny=30,
    ntheta=10,
    knn=4,
    sigmas=(xi / eps, xi, 1.0),
    weight_kernel=lambda sqdistc, sigmac: torch.exp(-sqdistc / sigmac),
    kappa=0.,
    device=torch.device("cuda")
)

In [None]:
laplacian = se2_graph.laplacian()

In [None]:
from gechebnet.graph.signal_processing import get_laplacian
from gechebnet.graph.utils import remove_directed_edges

In [None]:
from gechebnet.graph.utils import code_edges

In [None]:
def sparsify_laplacian(laplacian, kappa, on="edges"):
    if not on in {"edges", "nodes"}:
        raise ValueError(f"{on} is not a valid value for on: must be 'edges' or 'nodes'.")
    
    num_nodes = laplacian.size(0)
    
    # number of edges corresponds to non zero values minus the number of diagonal elements
    num_edges = laplacian._nnz() - num_nodes
    
    edge_index = laplacian._indices()
    edge_weight = -laplacian._values()
    mask = edge_index[0] != edge_index[1] # mask corresponding to non-diagonal elements 
    
    if on=="edges":
        return sparsify_edges(edge_index[:, mask], edge_weight[mask], num_nodes, num_edges, kappa)
    else:
        node_index = edge_index[0].unique()
        return sparsify_nodes(edge_index[:, mask], edge_weight[mask], node_index, num_nodes, num_edges, kappa)

def sparsify_edges(edge_index, edge_weight, num_nodes, num_edges, kappa):
    edge_code = code_edges(edge_index, edge_weight, num_nodes)
    unique, inverse = edge_code.unique(return_inverse=True)
    
    num_samples = math.ceil((1 - kappa) * unique.size(0))  # num edges to keep
    probabilities = unique - unique.floor()
    random_sampling = torch.multinomial(probabilities, num_samples)

    mask = torch.tensor([False]*num_edges)
    for eidx in unique[random_sampling]:
        mask += edge_code == eidx
        
    return get_laplacian(edge_index[:, mask], edge_weight[mask], num_nodes)

def sparsify_nodes(edge_index, edge_weight, node_index, num_nodes, num_edges, kappa):
    num_samples = math.floor(kappa * num_nodes)  # num nodes to drop
    random_sampling = torch.multinomial(torch.ones(num_nodes), num_samples)
    
    mask = torch.tensor([False]*num_edges)
    for nidx in node_index[random_sampling]:
        mask += (edge_index[0] == nidx) + (edge_index[1] == nidx)
        
    return get_laplacian(edge_index[:, ~mask], edge_weight[~mask], num_nodes)

In [None]:
sparse_laplacian = sparsify_laplacian(laplacian, 0.75, on="edges")
sparse_laplacian

In [None]:
get_laplacian(se2_graph.laplacian())

In [None]:
W_norm = torch.sparse.FloatTensor(sparse_index, sparse_weight / divide[sparse_index[1]], torch.Size((8, 8)))
W_norm = torch.sparse.FloatTensor(edge_index, edge_weight, torch.Size((num_nodes, num_nodes)))
I = sparse_tensor_diag(num_nodes)
return (I - W_norm).to(device)

In [None]:
sparse_weight / divide[sparse_index[1]]

In [None]:
sparse_weight[sparse_]

In [None]:
S = torch.sparse.sum(sparse_L, dim=0)

divide = torch.zeros(8)
divide[S.coalesce().indices()] = S.coalesce().values()
sparse_L / divide

In [None]:
sparse_L, S

In [None]:
index_row = sparse_L.coalesce().indices()[1]


In [None]:
laplacian = torch.sparse.FloatTensor(torch.randint(0, 100, (2, 1000)), torch.rand(1000), torch.Size((100, 100)))
sparse_laplacian = torch.sparse.FloatTensor(torch.randint(0, 100, (2, 500)), torch.rand(500), torch.Size((100, 100)))

In [None]:
x0 = torch.rand(100,1)

In [None]:
%timeit torch.mm(laplacian, x0)

In [None]:
%timeit torch.mm(sparse_laplacian, x0)

In [None]:
def get_sparse_laplacian(laplacian, on="edges"):
    

## ChebNet

In [None]:
from gechebnet.model.chebnet import GEChebNet

In [None]:
K = 5
in_channels = 3
out_channels = 10
hidden_channels = [16, 17, 19, 23]

In [None]:
model = GEChebNet(se2_graph, K, in_channels, hidden_channels, out_channels, device=device)
model

In [None]:
model = model.to(device)

In [None]:
x = torch.rand(16, 3, 600, device=device)
x.shape

In [None]:
y = model(x)
y.shape

## ResChebNet

In [None]:
from gechebnet.model.reschebnet import ResGEChebNet

In [None]:
K = 10
in_channels = 3
out_channels = 10
hidden_channels = [[16, 16, 16], [16, 32, 32], [32, 64, 64]]

In [None]:
model = ResGEChebNet(se2_graph, K, in_channels, hidden_channels, out_channels, device=device)
model

In [None]:
model.capacity

In [None]:
model = model.to(device)

In [None]:
x = torch.rand(16, 3, 600, device=device)
x.shape

In [None]:
y = model(x)
y.shape

In [None]:
from torch import nn
m = nn.Identity()
m(x, torch.rand(2))

In [None]:
import torch

In [None]:
num_nodes = 5400
num_edges = 5400*32
edge_index = torch.randint(0, num_nodes, (2, num_edges))
node_index = torch.arange(num_nodes).unsqueeze(1)

In [None]:
(node_index.repeat(1, num_edges) == edge_index[0]).sum(dim=1).min() > 0

In [None]:
random sparse nodes