In [1]:

import lightning
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold
from torch_geometric.data import DataLoader

from bronze_age.config import Config, NetworkType
from bronze_age.datasets import DatasetEnum, get_dataset
from bronze_age.models.stone_age import StoneAgeGNN as BronzeAgeGNN

from stone_age.models.stone_age_gnn import StoneAgeGNN

In [2]:
#  config = Config(dataset=DatasetEnum.REDDIT_BINARY, data_dir='downloads', temperature=1.0, alpha=1.0, beta=1.0, dropout=0.0, use_batch_norm=True, network=NetworkType.MLP, hidden_units=16, state_size=5, num_layers=2, skip_connection=True, use_pooling=True, bounding_parameter=1000, batch_size=128, learning_rate=0.01)

alpha = 1.0
beta = 1.0
in_channels = 1# 1433
out_channels = 2 #7
bounding_parameter = 1000
state_size = 5
num_layers = 2
gumbel = True
softmax_temp = 1.0
network = 'mlp'
use_pooling = True
skip_connection = True
use_batch_norm = True
hidden_units = 16
dropout = 0.0
torch.manual_seed(0)
model1 = StoneAgeGNN(in_channels,
    out_channels,
    bounding_parameter,
    state_size,
    num_layers = num_layers,
    gumbel = gumbel,
    softmax_temp= softmax_temp,
    network= network,
    use_pooling =use_pooling,
    skip_connection = skip_connection,
    use_batch_norm = use_batch_norm,
    hidden_units= hidden_units,
    dropout= dropout)
config = Config(
    dataset=DatasetEnum.REDDIT_BINARY,
    data_dir='downloads',
    temperature=softmax_temp,
    alpha=alpha ,
    beta=beta,
    dropout=dropout,
    use_batch_norm=use_batch_norm,
    network=NetworkType.MLP,
    hidden_units=hidden_units,
    state_size=state_size,
    num_layers=num_layers,
    skip_connection=skip_connection,
    use_pooling=use_pooling,
    bounding_parameter=bounding_parameter,
    batch_size=32
)
torch.manual_seed(0)
model2 = BronzeAgeGNN(in_channels, out_channels, config)

num_trainable_params1 = sum(p.numel() for p in model1.parameters() if p.requires_grad)
num_trainable_params2 = sum(p.numel() for p in model2.parameters() if p.requires_grad)

print(num_trainable_params1, num_trainable_params2)

model2.load_state_dict(model1.state_dict())
model1.set_beta(beta)
#model1.set_argmax(True)
#model1.eval()
#model2.eval()

918 918


In [7]:
print(model1)

StoneAgeGNN(
  (input): InputLayer(
    (lin1): Linear(in_features=1, out_features=5, bias=True)
  )
  (output): PoolingLayer(
    (lin2): MLP(
      (lins): ModuleList(
        (0): Linear(in_features=15, out_features=16, bias=True)
        (1): Linear(in_features=16, out_features=2, bias=True)
      )
      (bns): ModuleList(
        (0): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (stone_age): ModuleList(
    (0-1): 2 x StoneAgeGNNLayer()
  )
)


In [None]:
print(model2)

In [3]:
dataset = get_dataset(config)
data_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)



In [4]:
dataset.num_node_features, dataset.num_classes

(1, 2)

In [5]:
for it in data_loader:
    torch.manual_seed(0)
    out1 = model1.forward(it.x, it.edge_index, it.batch)
    torch.manual_seed(0)
    out2 = model2.forward(it.x, it.edge_index, it.batch)
    assert (out1 == out2).all()

In [6]:
len(dataset)

2000