In [1]:
import os
from typing import Optional

import torch
import torch_geometric.transforms as T
import tqdm
from omegaconf import OmegaConf
from sklearn.decomposition import PCA
from torch_geometric.datasets import Planetoid

from cbsd.nn.builders import ConsistencyBasedLaplacianBuilder
from cbsd.utils.logging import print_config


In [2]:
root = os.path.dirname(os.getcwd())
path = os.path.join(root, "config/model/snn.yaml")
conf = OmegaConf.load(path)
print_config(conf)

In [10]:
dataset = Planetoid(
    root="data/Planetoid", name="pubmed", transform=T.NormalizeFeatures()
)[0]
x = dataset.x
edge_index = dataset.edge_index
num_nodes = edge_index.max().item() + 1
num_features = dataset.num_features
d = conf.sheaf_laplacian.d

# Apply PCA to x
pca = PCA(n_components=d * 32)
x = pca.fit_transform(x.detach().numpy())
x = torch.tensor(x, dtype=torch.float32).reshape(num_nodes, d, -1)

builder = ConsistencyBasedLaplacianBuilder(
    edge_index=edge_index, d=d, init="random"
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
builder = builder.to(device)

builder.train(x, 0.1, 100, log_every=1, reg="matrix", lambda_reg=0.2)

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.test.index
Processing...
Done!


Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1/100, Loss Total: 474.5012, Loss: 679.1887, Loss Reg: 344.2486
Epoch 2/100, Loss Total: 321.7773, Loss: 479.4453, Loss Reg: 308.8946
Epoch 3/100, Loss Total: 213.3175, Loss: 335.7481, Loss Reg: 276.4048
Epoch 4/100, Loss Total: 139.4621, Loss: 236.1779, Loss Reg: 247.4014
Epoch 5/100, Loss Total: 90.4066, Loss: 168.5474, Loss Reg: 222.1562
Epoch 6/100, Loss Total: 58.5650, Loss: 123.4483, Loss Reg: 200.9679
Epoch 7/100, Loss Total: 38.7057, Loss: 94.4088, Loss Reg: 184.1065
Epoch 8/100, Loss Total: 27.1551, Loss: 76.8575, Loss Reg: 171.6542
Epoch 9/100, Loss Total: 21.1609, Loss: 67.2934, Loss Reg: 163.3694
Epoch 10/100, Loss Total: 18.6177, Loss: 62.9422, Loss Reg: 158.6803
Epoch 11/100, Loss Total: 18.0315, Loss: 61.7410, Loss Reg: 156.8061
Epoch 12/100, Loss Total: 18.3850, Loss: 62.2091, Loss Reg: 156.9115
Epoch 13/100, Loss Total: 18.9517, Loss: 63.2465, Loss Reg: 158.2274
Epoch 14/100, Loss Total: 19.1810, Loss: 64.0079, Loss Reg: 160.1263
Epoch 15/100, Loss Total: 18.7041

In [7]:
torch.norm(builder.restriction_maps)

tensor(138.5013, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)

In [8]:
builder.restriction_maps

Parameter containing:
tensor([[[[ 0.0008,  0.0005],
          [-0.0002,  0.0009]],

         [[ 0.0006,  0.0029],
          [ 0.0004, -0.0011]]],


        [[[ 0.0003,  0.0011],
          [-0.0024, -0.0001]],

         [[ 0.0014, -0.0014],
          [-0.0022, -0.0031]]],


        [[[ 0.0002,  0.0010],
          [-0.0010,  0.0004]],

         [[ 0.0039, -0.0007],
          [-0.0050, -0.0048]]],


        ...,


        [[[-0.0033, -0.0014],
          [-0.0005,  0.0005]],

         [[-0.0003, -0.0019],
          [-0.0005, -0.0003]]],


        [[[-0.0002, -0.0009],
          [ 0.0005, -0.0026]],

         [[-0.0005, -0.0022],
          [-0.0003, -0.0014]]],


        [[[-0.0010,  0.0019],
          [-0.0014, -0.0032]],

         [[-0.0005,  0.0040],
          [-0.0023, -0.0012]]]], device='cuda:0', requires_grad=True)

In [9]:
# check if any of the restriction maps are all zeros
print(torch.any(torch.all(builder.restriction_maps == 0, dim=1)))
# check if any of the restriction maps are negative
print(torch.any(builder.restriction_maps < 0))
# print the negative values
print(builder.restriction_maps[builder.restriction_maps < 0].shape)
# print the positive values
print(builder.restriction_maps[builder.restriction_maps > 0].shape)


tensor(False, device='cuda:0')
tensor(True, device='cuda:0')
torch.Size([20341])
torch.Size([21883])
