In [1]:
import torch
import math
from ogb.nodeproppred import PygNodePropPredDataset
import torch_geometric.transforms as T
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_sparse import SparseTensor

from sketch import CountSketch, TensorSketch

In [2]:
dataset = PygNodePropPredDataset(name='ogbn-arxiv', transform=T.ToSparseTensor(), root="../dataset")
data = dataset[0]
conv_mat = gcn_norm(data.adj_t.to_symmetric(), edge_weight=None,
                    num_nodes=data.x.size(0), improved=False,
                    add_self_loops=True, dtype=torch.float32)
nf_mat = data.x

item = torch.arange(math.ceil(data.x.size(0) * 0.1), dtype=torch.long)
conv_mat = conv_mat.saint_subgraph(item)[0]
nf_mat = nf_mat[item, :]

in_dim = conv_mat.size(0)
out_dim = int(in_dim * 0.1)

count_sketch = CountSketch(in_dim, out_dim)
conv_sketch = count_sketch(count_sketch(conv_mat).t()).t()

In [3]:
def top_k_sparsifying(conv_sketch, top_k):
    conv_sketch_sparsified = torch.topk(conv_sketch, k=top_k, dim=-1)
    conv_sketch_sparsified = SparseTensor(row=torch.arange(conv_sketch.size(0), dtype=torch.long).repeat_interleave(top_k), 
                                          col=conv_sketch_sparsified.indices.flatten(), value=conv_sketch_sparsified.values.flatten(),
                                          sparse_sizes=(conv_sketch.size(0), conv_sketch.size(1)), is_sorted=False)
    return conv_sketch_sparsified

In [4]:
for top_k in range(1, 17):
    conv_sparsified = top_k_sparsifying(conv_sketch, top_k).to_dense()
    print(f'{top_k}: {torch.norm(conv_sketch - conv_sparsified)/torch.norm(conv_sketch)}')

1: 0.13234886527061462
2: 0.10862584412097931
3: 0.10012906044721603
4: 0.09636086970567703
5: 0.09451214224100113
6: 0.09352024644613266
7: 0.0929645299911499
8: 0.09264307469129562
9: 0.092458575963974
10: 0.09234696626663208
11: 0.09227889776229858
12: 0.09223224222660065
13: 0.0921991616487503
14: 0.09217651188373566
15: 0.09216027706861496
16: 0.09214799106121063


In [5]:
conv_sparsified = top_k_sparsifying(conv_sketch, top_k=4)

In [6]:
conv_sparsified @ nf_mat

tensor([[-0.1198, -0.0719, -0.1841,  ...,  0.2934, -0.3095, -0.2924],
        [-0.3808, -0.2476, -0.9778,  ...,  0.3168, -0.9381, -0.8327],
        [-0.1992, -0.0633, -0.4452,  ...,  0.2870,  0.2768, -0.3381],
        ...,
        [-0.3385,  0.4881, -0.6637,  ...,  0.3804, -0.5179, -0.3545],
        [ 0.3610,  0.2201, -0.3786,  ...,  0.1228,  0.1520, -0.1758],
        [-0.1291, -0.1395, -0.5764,  ...,  0.0950, -0.3974, -0.3890]])