In [14]:
import sys
sys.path.append("/nfs/homedirs/reiffers/consistency-based-sheaf-diffusion")


import os
from typing import Optional

import torch
import torch_geometric.transforms as T
import tqdm
from omegaconf import OmegaConf
from sklearn.decomposition import PCA
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms.laplacian_lambda_max import LaplacianLambdaMax
from cbsd.config import instantiate_datamodule



from cbsd.nn.builders import ConsistencyBasedLaplacianBuilder
from cbsd.utils.logging import print_config


In [15]:
root = os.path.dirname(os.getcwd())
path = os.path.join(root, "config/model/snn.yaml")
conf = OmegaConf.load(path)
print_config(conf)

In [23]:
 # Load data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataconfig = OmegaConf.load(
    "/nfs/homedirs/reiffers/consistency-based-sheaf-diffusion/config/data/texas.yaml"
)

datamodule = instantiate_datamodule(dataconfig)
datamodule.prepare_data()
datamodule.setup()

lambda_max = LaplacianLambdaMax(normalization="sym", is_undirected=True)
datamodule.edge_attr = None
lambda_max(datamodule)
print(f'Laplacian Lambda max: {datamodule.lambda_max}')

# Build Laplacian
laplacianconfig = OmegaConf.load(
    "/nfs/homedirs/reiffers/consistency-based-sheaf-diffusion/config/model/snn.yaml"
).sheaf_laplacian

builder = ConsistencyBasedLaplacianBuilder(
            edge_index=datamodule.edge_index.to(device), config=laplacianconfig
        )
x = builder.dim_reduction(x=datamodule.x, d=laplacianconfig.d) # TODO: Have the same dimensionality reduction before passing x to the SNN model

print(device)
builder = builder.to(device)

builder.train(x, 0.1, 60, log_every=1, reg="matrix", lambda_reg=0.1, normalize=True)
sheaf_laplacian = builder.build_from_maps()

Laplacian Lambda max: 1.9376229047775269
cuda


Training Progress:   0%|          | 0/60 [00:00<?, ?it/s]

Epoch 1/60, Loss Total: -2.8215, Loss: 1.3655, Loss Reg: 40.5049
Epoch 2/60, Loss Total: -3.5046, Loss: 1.2200, Loss Reg: 46.0253
Epoch 3/60, Loss Total: -4.1981, Loss: 1.1044, Loss Reg: 51.9205
Epoch 4/60, Loss Total: -4.8841, Loss: 1.0282, Loss Reg: 58.0946
Epoch 5/60, Loss Total: -5.5593, Loss: 0.9866, Loss Reg: 64.4725
Epoch 6/60, Loss Total: -6.2325, Loss: 0.9635, Loss Reg: 70.9960
Epoch 7/60, Loss Total: -6.9101, Loss: 0.9477, Loss Reg: 77.6302
Epoch 8/60, Loss Total: -7.5936, Loss: 0.9355, Loss Reg: 84.3549
Epoch 9/60, Loss Total: -8.2826, Loss: 0.9257, Loss Reg: 91.1570
Epoch 10/60, Loss Total: -8.9767, Loss: 0.9177, Loss Reg: 98.0266
Epoch 11/60, Loss Total: -9.6759, Loss: 0.9109, Loss Reg: 104.9563
Epoch 12/60, Loss Total: -10.3799, Loss: 0.9044, Loss Reg: 111.9390
Epoch 13/60, Loss Total: -11.0887, Loss: 0.8979, Loss Reg: 118.9679
Epoch 14/60, Loss Total: -11.8015, Loss: 0.8915, Loss Reg: 126.0376
Epoch 15/60, Loss Total: -12.5173, Loss: 0.8856, Loss Reg: 133.1432
Epoch 16/6

In [26]:
torch.norm(builder.restriction_maps)

tensor(466.5425, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)

In [27]:
builder.restriction_maps

Parameter containing:
tensor([[[ 0.1290,  2.9599,  3.0699],
         [ 0.2805,  6.2656,  6.3858],
         [ 0.2732,  6.0859,  6.2135]],

        [[ 0.2116,  5.7300,  5.6388],
         [ 0.2406,  6.4314,  6.3623],
         [ 0.2450,  5.9558,  6.0650]],

        [[-4.2210,  6.7986,  6.6499],
         [-4.4442,  7.0115,  6.9490],
         [-4.5615,  7.2073,  6.9247]],

        ...,

        [[ 6.7115,  6.7746,  6.6561],
         [ 6.4921,  6.7083,  6.7597],
         [ 6.9384,  6.8358,  6.8788]],

        [[ 4.9577,  6.6583,  6.5325],
         [ 5.2582,  7.1173,  6.8028],
         [ 5.2588,  7.1564,  6.6976]],

        [[ 5.6797,  1.7797,  6.4920],
         [ 5.3583,  1.5888,  5.8168],
         [ 5.6211,  1.7715,  6.5033]]], device='cuda:0', requires_grad=True)

In [28]:
# check if any of the restriction maps are all zeros
print(torch.any(torch.all(builder.restriction_maps == 0, dim=1)))
# check if any of the restriction maps are negative
print(torch.any(builder.restriction_maps < 0))
# print the negative values
print(builder.restriction_maps[builder.restriction_maps < 0].shape)
# print the positive values
print(builder.restriction_maps[builder.restriction_maps > 0].shape)
# check if any of the restriction maps are infinite
print(torch.any(torch.all(torch.isfinite(builder.restriction_maps))))


tensor(False, device='cuda:0')
tensor(True, device='cuda:0')
torch.Size([141])
torch.Size([4881])
tensor(True, device='cuda:0')
