In [1]:
#%load_ext autoreload
#%autoreload 2
import torch

from torch import nn

from bronze_age.models.bronze_age import BronzeAgeGNN
from bronze_age.models.stone_age import StoneAgeGNN


from bronze_age.config import Config, BronzeConfig, DatasetEnum, LayerTypeBronze, LayerType, NetworkType, AggregationMode, NonLinearity

from train import get_config_for_dataset, get_dataset

In [2]:
dataset = DatasetEnum.INFECTION

config = get_config_for_dataset(dataset, layer_type=LayerType.StoneAge)
dataset = get_dataset(config)

In [3]:
def map_config(config: Config) -> BronzeConfig:
    config_dict = config.__dict__.copy()
    config_dict['aggregation_mode'] = AggregationMode.BRONZE_AGE
    if config_dict['layer_type'] == LayerType.StoneAge:
        layer_type = LayerTypeBronze.MLP if config_dict['network'] == NetworkType.MLP else LayerTypeBronze.LINEAR
        config_dict['aggregation_mode'] = AggregationMode.STONE_AGE
        non_linearity = NonLinearity.GUMBEL_SOFTMAX
    elif config_dict['layer_type'] == LayerType.BronzeAge:
        layer_type = LayerTypeBronze.MLP if config_dict['network'] == NetworkType.MLP else LayerTypeBronze.LINEAR
        config_dict['aggregation_mode'] = AggregationMode.BRONZE_AGE
        non_linearity = NonLinearity.GUMBEL_SOFTMAX
    elif config_dict['layer_type'] == LayerType.BronzeAgeConcept:
        layer_type = LayerTypeBronze.DEEP_CONCEPT_REASONER
        non_linearity = None if not config.use_one_hot_output else NonLinearity.DIFFERENTIABLE_ARGMAX
    elif config_dict['layer_type'] == LayerType.BronzeAgeGeneralConcept:
        layer_type = LayerTypeBronze.GLOBAL_DEEP_CONCEPT_REASONER
        non_linearity = None if not config.use_one_hot_output else NonLinearity.DIFFERENTIABLE_ARGMAX
    
    if config.use_one_hot_output and config_dict['aggregation_mode'] == AggregationMode.BRONZE_AGE:
        config_dict['aggregation_mode'] = AggregationMode.BRONZE_AGE_ROUNDED
    if config.one_hot_evaluation:
        eval_non_linearity = NonLinearity.DIFFERENTIABLE_ARGMAX
    else:
        eval_non_linearity = None
    config_dict['nonlinearity'] = non_linearity
    config_dict['evaluation_nonlinearity'] = eval_non_linearity
    del config_dict['use_one_hot_output']
    del config_dict['one_hot_evaluation']
    config_dict['layer_type'] = layer_type
    del config_dict['network']

    return BronzeConfig(**config_dict)

In [4]:
def get_models(config: Config):
    return StoneAgeGNN(dataset.num_node_features, dataset.num_classes, config), BronzeAgeGNN(dataset.num_node_features, dataset.num_classes, map_config(config))

In [5]:
def key_mapping_stone_to_bronze(key):
    #key = key.replace(".lin2.", ".f.lins.")
    #key = key.replace(".lin1.", ".f.lin.")
    key = key.replace("input.lin1", "input.f")
    key = key.replace("output.lin1", "output.f")
    key = key.replace("input.lin2", "input.f")
    key = key.replace("output.lin2", "output.f")
    key = key.replace("input.concept_reasoner", "input.f.concept_reasoner")
    key = key.replace("output.concept_reasoner", "output.f.concept_reasoner")
    key = key.replace("input.concept_context_generator", "input.f.concept_context_generator")
    key = key.replace("output.concept_context_generator", "output.f.concept_context_generator")
    key = key.replace(".linear_softmax.lin1.", ".layer.f.")
    key = key.replace(".reasoning_module.", ".layer.f.")
    return key


In [6]:
from itertools import product
from functools import lru_cache
layer_types = LayerType
network_types = NetworkType
skip_connection = [False, True]
use_one_hot = [False, True]
datasets = DatasetEnum

# todo add one_hot
@lru_cache(maxsize=None)
def _get_dataset(dataset_):
    return get_dataset(get_config_for_dataset(dataset_))
for dataset_, layer_type, network_type, use_skip, use_one_hot_output in product(datasets, layer_types, network_types, skip_connection, use_one_hot):
    if dataset_ in [DatasetEnum.CORA, DatasetEnum.CITESEER, DatasetEnum.PUBMED, DatasetEnum.OGBA, DatasetEnum.OGB_MOLHIV, DatasetEnum.OGB_PPA, DatasetEnum.OGB_CODE2]:
        continue
    print(f"Layer type: {layer_type}, Network type: {network_type}, Use skip connection: {use_skip}, Dataset: {dataset_}")
    use_one_hot_output = use_one_hot_output and layer_types in [LayerType.BronzeAgeConcept, LayerType.BronzeAgeGeneralConcept]
    config = get_config_for_dataset(dataset_, layer_type=layer_type, network=network_type, skip_connection=use_skip, use_one_hot_output=use_one_hot_output)
    dataset = _get_dataset(dataset_)
    stone_age, bronze_age = get_models(config)
    stone_age.eval()
    bronze_age.eval()
    bronze_age.load_state_dict({key_mapping_stone_to_bronze(key):val for (key, val) in stone_age.state_dict().items()}, strict=True)
    #print(config)
    #print(stone_age)
    #print(bronze_age)
    num_params_stone = sum(p.numel() for p in stone_age.parameters())
    num_params_bronze = sum(p.numel() for p in bronze_age.parameters())
    #print(f"Stone Age: {num_params_stone} parameters and Bronze Age: {num_params_bronze} parameters")
    assert num_params_stone == num_params_bronze
    assert torch.allclose(stone_age(dataset[0].x, dataset[0].edge_index), bronze_age(dataset[0].x, dataset[0].edge_index)[0], atol=1e-6)
    #print(f"Stone Age: {num_params_stone} parameters")
    #print(f"Bronze Age: {num_params_bronze} parameters")

Layer type: stone-age, Network type: linear, Use skip connection: False, Dataset: MUTAG
Layer type: stone-age, Network type: linear, Use skip connection: False, Dataset: MUTAG
Layer type: stone-age, Network type: linear, Use skip connection: True, Dataset: MUTAG
Layer type: stone-age, Network type: linear, Use skip connection: True, Dataset: MUTAG
Layer type: stone-age, Network type: mlp, Use skip connection: False, Dataset: MUTAG
Layer type: stone-age, Network type: mlp, Use skip connection: False, Dataset: MUTAG
Layer type: stone-age, Network type: mlp, Use skip connection: True, Dataset: MUTAG
Layer type: stone-age, Network type: mlp, Use skip connection: True, Dataset: MUTAG
Layer type: bronze-age, Network type: linear, Use skip connection: False, Dataset: MUTAG
Layer type: bronze-age, Network type: linear, Use skip connection: False, Dataset: MUTAG
Layer type: bronze-age, Network type: linear, Use skip connection: True, Dataset: MUTAG
Layer type: bronze-age, Network type: linear, 

  self.data, self.slices = torch.load(self.processed_paths[0])


Layer type: bronze-age-concept, Network type: linear, Use skip connection: False, Dataset: MUTAG
Layer type: bronze-age-concept, Network type: linear, Use skip connection: True, Dataset: MUTAG
Layer type: bronze-age-concept, Network type: linear, Use skip connection: True, Dataset: MUTAG
Layer type: bronze-age-concept, Network type: mlp, Use skip connection: False, Dataset: MUTAG
Layer type: bronze-age-concept, Network type: mlp, Use skip connection: False, Dataset: MUTAG
Layer type: bronze-age-concept, Network type: mlp, Use skip connection: True, Dataset: MUTAG
Layer type: bronze-age-concept, Network type: mlp, Use skip connection: True, Dataset: MUTAG
Layer type: bronze-age-general-concept, Network type: linear, Use skip connection: False, Dataset: MUTAG
Layer type: bronze-age-general-concept, Network type: linear, Use skip connection: False, Dataset: MUTAG
Layer type: bronze-age-general-concept, Network type: linear, Use skip connection: True, Dataset: MUTAG
Layer type: bronze-age-

  self.data, self.slices = torch.load(self.processed_paths[0])


Layer type: stone-age, Network type: linear, Use skip connection: False, Dataset: BA_2Motifs
Layer type: stone-age, Network type: linear, Use skip connection: True, Dataset: BA_2Motifs
Layer type: stone-age, Network type: linear, Use skip connection: True, Dataset: BA_2Motifs
Layer type: stone-age, Network type: mlp, Use skip connection: False, Dataset: BA_2Motifs
Layer type: stone-age, Network type: mlp, Use skip connection: False, Dataset: BA_2Motifs
Layer type: stone-age, Network type: mlp, Use skip connection: True, Dataset: BA_2Motifs
Layer type: stone-age, Network type: mlp, Use skip connection: True, Dataset: BA_2Motifs
Layer type: bronze-age, Network type: linear, Use skip connection: False, Dataset: BA_2Motifs
Layer type: bronze-age, Network type: linear, Use skip connection: False, Dataset: BA_2Motifs
Layer type: bronze-age, Network type: linear, Use skip connection: True, Dataset: BA_2Motifs
Layer type: bronze-age, Network type: linear, Use skip connection: True, Dataset: BA

In [7]:
#model1, model2 = get_models(get_config_for_dataset(DatasetEnum.MUTAG, layer_type=LayerType.BronzeAgeGeneralConcept, network=NetworkType.LINEAR, skip_connection=True))
model1, model2 = stone_age, bronze_age

In [8]:
model1_state_dict = {key_mapping_stone_to_bronze(k):v for (k, v) in model1.state_dict().items()}

In [9]:
model2.load_state_dict(model1_state_dict, strict=True)
model1.eval()
model2.eval()

BronzeAgeGNN(
  (input): BronzeAgeLayer(
    (f): GlobalConceptReasonerModule(
      (concept_reasoner): GlobalConceptReasoningLayer(
        (filter_nn): Embedding(3, 3)
        (sign_nn): Embedding(3, 3)
      )
    )
    (non_linearity): Identity()
    (eval_non_linearity): Identity()
  )
  (output): BronzeAgeLayer(
    (f): GlobalConceptReasonerModule(
      (concept_reasoner): GlobalConceptReasoningLayer(
        (filter_nn): Embedding(6, 2)
        (sign_nn): Embedding(6, 2)
      )
    )
    (non_linearity): Identity()
    (eval_non_linearity): Identity()
  )
  (stone_age): ModuleList(
    (0): BronzeAgeGNNLayer(3, 3)
  )
)

In [10]:
model1(dataset[0].x, dataset[0].edge_index)

tensor([[0.6704, 0.4177],
        [0.6704, 0.4177],
        [0.6704, 0.4177],
        ...,
        [0.6704, 0.4177],
        [0.6704, 0.4177],
        [0.6704, 0.4177]], grad_fn=<SqueezeBackward1>)

In [11]:
model1.stone_age[0](inp, dataset[0].edge_index)

NameError: name 'inp' is not defined

In [None]:
for key in model1.stone_age[0].state_dict():
    assert torch.allclose(model1.stone_age[0].state_dict()[key], model2.stone_age[0].state_dict()[key_mapping_stone_to_bronze(f".{key}")[1:]]) 

In [None]:
model1.stone_age[0]

In [None]:
model2.stone_age[0]

In [None]:
list(model2.stone_age[0].named_parameters())

In [None]:
model2.stone_age[0](inp, dataset[0].edge_index)


In [None]:
inp = model1.input(dataset[0].x)
model1.stone_age[0](inp, dataset[0].edge_index) - model2.stone_age[0](inp, dataset[0].edge_index)[0]

In [None]:
model2.stone_age[0]

In [None]:
model1.stone_age[0]

In [None]:
model1(dataset[0].x, dataset[0].edge_index).mean(axis=-2)

In [None]:
model2(dataset[0].x, dataset[0].edge_index)

In [24]:
import numpy as np

def linear_combo_features(input_data, state_size):
    """Calculates the pairwise differences between the features and appends them to the input data"""

    difference_features = (
        input_data[np.newaxis, :, :state_size, None]
        > input_data[:, np.newaxis, :state_size]
    )
    difference_features = (
        input_data[:, :state_size, None]
        > input_data[:, np.newaxis, :state_size]
    )
    print(difference_features)
    difference_features = difference_features.reshape(len(input_data), -1).astype(int)
    return np.concatenate((input_data, difference_features), axis=1)

states = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
state_size = 3
neighbors = np.array([[2, 1, 0], [2, 2, 1], [0, 1, 1]])
states = np.concatenate((neighbors, states), axis=1)
linear_combo_features(states, state_size)[:, 2*state_size:]


[[[False  True  True]
  [False False  True]
  [False False False]]

 [[False False  True]
  [False False  True]
  [False False False]]

 [[False False False]
  [ True False False]
  [ True False False]]]


array([[0, 1, 1, 0, 0, 1, 0, 0, 0],
       [0, 0, 1, 0, 0, 1, 0, 0, 0],
       [0, 0, 0, 1, 0, 0, 1, 0, 0]])

In [25]:
neighbors[..., None, :] < neighbors[..., :, None]

array([[[False, False, False],
        [ True, False, False],
        [ True,  True, False]],

       [[False, False, False],
        [False, False, False],
        [ True,  True, False]],

       [[False,  True,  True],
        [False, False, False],
        [False, False, False]]])