In [1]:
import torch

import lightning as pl
from train import LightningModel, get_config_for_dataset
from bronze_age.config import LayerType

from bronze_age.datasets import DatasetEnum, get_dataset

dataset_enum = DatasetEnum.SIMPLE_SATURATION

In [2]:
config = get_config_for_dataset(dataset_enum)
config.layer_type = LayerType.BronzeAgeGeneralConcept

In [3]:
dataset = get_dataset(config)

In [4]:
model = LightningModel(num_classes=dataset.num_classes, num_node_features=dataset.num_node_features, config=config)
model

LightningModel(
  (model): StoneAgeGNN(
    (input): InputLayer(
      (lin1): Linear(in_features=3, out_features=3, bias=True)
    )
    (output): PoolingLayer(
      (lin2): MLP(
        (lins): ModuleList(
          (0): Linear(in_features=6, out_features=16, bias=True)
          (1): Linear(in_features=16, out_features=2, bias=True)
        )
        (bns): ModuleList(
          (0): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (stone_age): ModuleList(
      (0): BronzeAgeGNNLayerConceptReasoner(3, 3)
    )
  )
  (train_accuracy): MulticlassAccuracy()
  (val_accuracy): MulticlassAccuracy()
  (test_accuracy): MulticlassAccuracy()
)

In [5]:
dataset[0].x

tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        ...,
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.]])

In [6]:
with torch.no_grad():
    model.model.input.lin1.weight[:] = torch.eye(model.model.input.lin1.weight.shape[0])
    model.model.input.lin1.bias[:] = 0.0

model.eval()
model.model.input.forward(dataset[0].x)

tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        ...,
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.]], grad_fn=<ArgMaxBackward>)

In [7]:
model.model.stone_age[0]

BronzeAgeGNNLayerConceptReasoner(3, 3)

In [8]:
dataset[0].y == 0

tensor([True, True, True,  ..., True, True, True])

In [9]:
bounding_parameter = 10
threshold = 7
magical_idx = 3 + 1 * bounding_parameter + threshold  
with torch.no_grad():
    model.model.stone_age[0].concept_reasoner.filter_nn.weight[:] = -10000.0
    model.model.stone_age[0].concept_reasoner.filter_nn.weight[magical_idx, 0] = 10000.0
    model.model.stone_age[0].concept_reasoner.filter_nn.weight[magical_idx, 1] = 10000.0
    model.model.stone_age[0].concept_reasoner.filter_nn.weight[1, 1] = 10000.0
    model.model.stone_age[0].concept_reasoner.filter_nn.weight[1, 2] = 10000.0
    model.model.stone_age[0].concept_reasoner.sign_nn.weight[:] = 0
    model.model.stone_age[0].concept_reasoner.sign_nn.weight[magical_idx, 0] = 10000.0
    model.model.stone_age[0].concept_reasoner.sign_nn.weight[magical_idx, 1] = -10000.0
    model.model.stone_age[0].concept_reasoner.sign_nn.weight[1, 1] = -10000.0
    model.model.stone_age[0].concept_reasoner.sign_nn.weight[1, 2] = 10000.0
    

In [10]:
model.model.stone_age[0].concept_reasoner.filter_nn.weight

Parameter containing:
tensor([[-10000., -10000., -10000.],
        [-10000.,  10000.,  10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [ 10000.,  10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000., -10000., -10000.],
        [-10000.

In [11]:
from bronze_age.models.concept_reasoner import softselect
softselect(model.model.stone_age[0].concept_reasoner.filter_nn.weight, 0.1)

tensor([[0.2712, 0.2712, 0.2712],
        [0.0000, 1.0000, 1.0000],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [1.0000, 1.0000, 0.0000],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.2712, 0.2712, 0.2712],
        [0.271

In [12]:
X = torch.randn(10, 33)
X[:, magical_idx] = 0.0
X[::4, magical_idx] = 1.0
X[::3, 1] = 1.0

model.model.stone_age[0].concept_reasoner(X)

tensor([[1.0000, 0.0000, 1.0000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 0.0000, 1.0000],
        [1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000],
        [0.0000, 0.0000, 1.0000],
        [0.0000, 1.0000, 0.0000],
        [1.0000, 0.0000, 0.4169],
        [0.0000, 0.0000, 1.0000]], grad_fn=<SqueezeBackward1>)

In [13]:
model.model.input.forward(dataset[0].x) == dataset[0].x

tensor([[True, True, True],
        [True, True, True],
        [True, True, True],
        ...,
        [True, True, True],
        [True, True, True],
        [True, True, True]])

In [14]:
model.model.stone_age[0](dataset[0].x, dataset[0].edge_index)[:, 0], 1 - dataset[0].y

(tensor([3.0870e-07, 3.0870e-07, 3.0870e-07,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00], grad_fn=<SelectBackward0>),
 tensor([1, 1, 1,  ..., 1, 1, 1]))

In [15]:
with torch.no_grad():
    model.model.output.lin2.lins[0].weight[:] = 0
    model.model.output.lin2.lins[0].bias[:] = 0
    model.model.output.lin2.lins[0].weight[:2, -3:-1] = torch.eye(2)
    model.model.output.lin2.lins[0].weight[0, -1] = 1

model.model.output.lin2.lins[0].weight

Parameter containing:
tensor([[0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], requires_grad=True)

In [16]:
with torch.no_grad():
    model.model.output.lin2.lins[1].weight[:] = 0
    model.model.output.lin2.lins[1].bias[:] = 0
    model.model.output.lin2.lins[1].weight[:2, :2] = torch.eye(2)

model.model.output.lin2.lins[1].weight


Parameter containing:
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       requires_grad=True)

In [17]:
x = dataset[0].x
edge_index = dataset[0].edge_index

x = model.model.input(x.float())
xs = [x]
for layer in model.model.stone_age:
    x = layer(x, edge_index, explain=False)
    xs.append(x)

x_prime = torch.cat(xs, dim=1)
x_prime = model.model.output(x_prime)
x_prime

tensor([[-0.3133, -1.3133],
        [-0.3133, -1.3133],
        [-0.3133, -1.3133],
        ...,
        [-0.3133, -1.3133],
        [-0.3133, -1.3133],
        [-0.3133, -1.3133]], grad_fn=<LogSoftmaxBackward0>)

In [18]:
(xs[0] == dataset[0].x)

tensor([[True, True, True],
        [True, True, True],
        [True, True, True],
        ...,
        [True, True, True],
        [True, True, True],
        [True, True, True]])

In [19]:
xs[1].round()

tensor([[0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        ...,
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.]], grad_fn=<RoundBackward0>)

In [20]:
torch.argmax(model.model(dataset[0].x, dataset[0].edge_index), dim=1), dataset[0].y

(tensor([0, 0, 0,  ..., 0, 0, 0]), tensor([0, 0, 0,  ..., 0, 0, 0]))

In [21]:
import numpy
y_pred = torch.argmax(model.model(dataset[0].x, dataset[0].edge_index), dim=1)
y_true = dataset[0].y
numpy.mean((y_pred == y_true).numpy())

np.float64(1.0)

In [22]:
from torch_geometric.loader import DataLoader
trainer = pl.Trainer(max_epochs=100, accelerator='mps')

test_dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
trainer.test(model, dataloaders=test_dataloader)

Testing: |          | 0/? [00:00<?, ?it/s]

Layer  0
[{'class': 'y_0', 'explanation': 's_1_count>7', 'count': 594}, {'class': 'y_1', 'explanation': '~s_1 & ~s_1_count>7', 'count': 1396}, {'class': 'y_2', 'explanation': 's_1', 'count': 10}]
Layer  0
[{'class': 'y_0', 'explanation': 's_1_count>7', 'count': 594}, {'class': 'y_1', 'explanation': '~s_1 & ~s_1_count>7', 'count': 1396}, {'class': 'y_2', 'explanation': 's_1', 'count': 10}]
Layer  0
[{'class': 'y_0', 'explanation': 's_1_count>7', 'count': 594}, {'class': 'y_1', 'explanation': '~s_1 & ~s_1_count>7', 'count': 1396}, {'class': 'y_2', 'explanation': 's_1', 'count': 10}]
Layer  0
[{'class': 'y_0', 'explanation': 's_1_count>7', 'count': 594}, {'class': 'y_1', 'explanation': '~s_1 & ~s_1_count>7', 'count': 1396}, {'class': 'y_2', 'explanation': 's_1', 'count': 10}]
Layer  0
[{'class': 'y_0', 'explanation': 's_1_count>7', 'count': 594}, {'class': 'y_1', 'explanation': '~s_1 & ~s_1_count>7', 'count': 1396}, {'class': 'y_2', 'explanation': 's_1', 'count': 10}]
Layer  0
[{'class': 

[{'test_loss': 0.3139800727367401, 'test_acc': 1.0}]