In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')  # enable import from src/

In [65]:
from pathlib import Path

import torch
import matplotlib.pyplot as plt
from dgl.dataloading import GraphDataLoader

from src.dataset import InstanceDataset
from src.net import InstanceGCN

In [6]:
instances_fps = list(Path('../data/raw/').glob('97_*.json'))

ds = InstanceDataset(instances_fps[:5], sols_dir='../data/interim/old/')

ds[0]

(Graph(num_nodes={'con': 14040, 'soc': 291, 'var': 4656},
       num_edges={('con', 'c2s', 'soc'): 775, ('con', 'c2v', 'var'): 163371, ('soc', 's2c', 'con'): 775, ('var', 'v2c', 'con'): 163371},
       metagraph=[('con', 'soc', 'c2s'), ('con', 'var', 'c2v'), ('soc', 'con', 's2c'), ('var', 'con', 'v2c')]),
 (array([[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00, ...,
           8.67361738e-18, -4.72552566e+00, -1.70118924e+01],
         [ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00, ...,
           0.00000000e+00, -4.72552566e+00, -1.70118924e+01],
         [ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00, ...,
           0.00000000e+00, -4.72552566e+00, -1.70118924e+01],
         ...,
         [ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00, ...,
           2.35636636e-02, -4.98637250e+00, -1.79509410e+01],
         [ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00, ...,
           2.76755405e-02, -4.98637250e+00, -1.79509410e+01],
         [ 0.00000000e+00,  0.00000000e

In [14]:
net = InstanceGCN(1, readout_op=None)

net

InstanceGCN(
  (soc_emb): Sequential(
    (0): Linear(in_features=1, out_features=10, bias=True)
    (1): ReLU()
  )
  (var_emb): Sequential(
    (0): Linear(in_features=1, out_features=10, bias=True)
    (1): ReLU()
  )
  (con_emb): Sequential(
    (0): Linear(in_features=2, out_features=10, bias=True)
    (1): ReLU()
  )
  (convs): Sequential(
    (0): HeteroGraphConv(
      (mods): ModuleDict(
        (v2c): GraphConv(in=10, out=10, normalization=both, activation=None)
        (s2c): GraphConv(in=10, out=10, normalization=both, activation=None)
        (c2v): GraphConv(in=10, out=10, normalization=both, activation=None)
        (c2s): GraphConv(in=10, out=10, normalization=both, activation=None)
      )
    )
    (1): HeteroGraphConv(
      (mods): ModuleDict(
        (v2c): GraphConv(in=10, out=10, normalization=both, activation=None)
        (s2c): GraphConv(in=10, out=10, normalization=both, activation=None)
        (c2v): GraphConv(in=10, out=10, normalization=both, activation=N

In [64]:
torch.arange(10)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [62]:
g, (sols, objs) = ds[3]

y = torch.from_numpy(sols)[:,:2522]
w = torch.from_numpy(objs)
w = torch.softmax(w / w.max(), 0)

logits = net(g)
logits = logits.repeat((y.shape[0], 1))

# bce = torch.nn.BCEWithLogitsLoss(w)
bce = torch.nn.BCEWithLogitsLoss(reduction='none')
loss = bce(logits, y)

loss = w @ loss.sum(-1)

loss

tensor(1913.8787, dtype=torch.float64, grad_fn=<DotBackward0>)

In [66]:
dl = GraphDataLoader(ds, batch_size=1)

dl

<dgl.dataloading.dataloader.GraphDataLoader at 0x7efc2db30fd0>

In [70]:
g, (y, w) = next(iter(dl))
g, y, w

(Graph(num_nodes={'con': 14040, 'soc': 291, 'var': 4656},
       num_edges={('con', 'c2s', 'soc'): 775, ('con', 'c2v', 'var'): 163371, ('soc', 's2c', 'con'): 775, ('var', 'v2c', 'con'): 163371},
       metagraph=[('con', 'soc', 'c2s'), ('con', 'var', 'c2v'), ('soc', 'con', 's2c'), ('var', 'con', 'v2c')]),
 tensor([[[ 1.0000e+00,  1.0000e+00,  1.0000e+00,  ...,  8.6736e-18,
           -4.7255e+00, -1.7012e+01],
          [ 1.0000e+00,  1.0000e+00,  1.0000e+00,  ...,  0.0000e+00,
           -4.7255e+00, -1.7012e+01],
          [ 1.0000e+00,  1.0000e+00,  1.0000e+00,  ...,  0.0000e+00,
           -4.7255e+00, -1.7012e+01],
          ...,
          [ 1.0000e+00,  1.0000e+00,  1.0000e+00,  ...,  2.3564e-02,
           -4.9864e+00, -1.7951e+01],
          [ 1.0000e+00,  1.0000e+00,  1.0000e+00,  ...,  2.7676e-02,
           -4.9864e+00, -1.7951e+01],
          [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ...,  1.7994e-02,
           -4.9188e+00, -1.7708e+01]]], dtype=torch.float64),
 tensor([[15