In [16]:
import torch
import torch.nn as nn
import lightning as pl



from train_bronze import LightningModel, get_dataset, get_config_for_dataset
from bronze_age.config import Config, DatasetEnum
from torch_geometric.loader import DataLoader

In [26]:
dataset = DatasetEnum.BA_2MOTIFS
config = get_config_for_dataset(dataset)
dataset = get_dataset(config)
train_loader = DataLoader(dataset, batch_size=128, shuffle=True)
model = LightningModel.load_from_checkpoint("lightning_logs/12/03/25 16:15 BA_2Motifs/version_6/checkpoints/epoch=180-step=1267.ckpt")
model.eval()

LightningModel(
  (model): BronzeAgeGNN(
    (input): BronzeAgeLayer(
      (f): Linear(in_features=1, out_features=6, bias=True)
      (non_linearity): GumbelSoftmax()
      (eval_non_linearity): GumbelSoftmax()
    )
    (output): BronzeAgeLayer(
      (f): MLP(
        (lins): ModuleList(
          (0): Linear(in_features=30, out_features=16, bias=True)
          (1): Linear(in_features=16, out_features=2, bias=True)
        )
        (bns): ModuleList(
          (0): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (non_linearity): LogSoftmax(dim=-1)
      (eval_non_linearity): LogSoftmax(dim=-1)
    )
    (stone_age): ModuleList(
      (0-3): 4 x BronzeAgeGNNLayer(6, 6)
    )
  )
  (train_accuracy): MulticlassAccuracy()
  (val_accuracy): MulticlassAccuracy()
  (test_accuracy): MulticlassAccuracy()
)

In [27]:
it = next(iter(train_loader))
torch.argmax(model.model(it.x, it.edge_index, batch = it.batch)[0], dim=-1)

tensor([0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1,
        0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
        0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0,
        0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0,
        1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1,
        1, 1, 0, 0, 1, 1, 0, 1])

In [28]:
it.y

tensor([0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1,
        0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
        0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0,
        0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0,
        1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1,
        1, 1, 0, 0, 1, 1, 0, 1])

In [29]:
dt = model.model.to_decision_tree(train_loader)

In [36]:
torch.allclose(model.model.input(it.x)[0], dt.input(it.x)[0])

True

In [38]:
x1_1 = model.model.input(it.x)[0]
x1_2 = dt.input(it.x)[0]

In [57]:
x2_1 = model.model.stone_age[0](x1_1, it.edge_index)[0]
x2_2 = dt.stone_age[0](x1_2, it.edge_index)[0]
(x2_1 - x2_2).abs().max()

tensor(5.9605e-08, grad_fn=<MaxBackward1>)

In [67]:
x2_1 

tensor([[0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<SubBackward0>)

In [68]:
x2_2

tensor([[0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.],
        ...,
        [0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.]])

In [None]:
x3_1 = model.model.stone_age[1](x2_1.round().float(), it.edge_index)[0]
x3_2 = dt.stone_age[1](x2_1, it.edge_index)[0]
(x3_1 - x3_2).abs().max()

tensor(1., grad_fn=<MaxBackward1>)

In [55]:
(x3_1 - x3_2).abs().max()

tensor(1., grad_fn=<MaxBackward1>)

In [30]:
dt(it.x, it.edge_index, batch = it.batch)

(tensor([[1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1.

In [None]:
decision_tree = model.model.to_decision_tree()

tensor([0])