# Experiment 1

Checking whether a heap-like graph is actually a heap using a GGNN

In [1]:
%cd ..
%pwd
from experiments.heaps import *
import wandb

/home/tekne/Oxford/ATML/Project/ggs-nn-model


In [2]:
config = wandb.config
config.n_train = 1000
config.n_test = 1000
config.epochs = 200
config.lr = 0.01
config.optimizer = 'adam'
config.min_len = 1
config.max_len = 32
config.batch_size = 500
config.p_heap = 0.5

In [3]:
from torch_geometric.data import DataLoader

data = make_heap_test_gnn_datapoints(
    n = config.n_train + config.n_test,
    p_heap = config.p_heap,
    min_len = config.min_len,
    max_len = config.max_len,
)

training_data = DataLoader(data[:config.n_train], batch_size=config.batch_size)
testing_data = DataLoader(data[config.n_train:], batch_size=config.batch_size)

In [4]:
print("Training batches:")
for data in training_data:
    print(data)

print("Test batches:")
for data in testing_data:
    print(data)

Training batches:
Batch(batch=[8069], edge_index=[2, 7569], x=[8069, 1], y=[500])
Batch(batch=[8237], edge_index=[2, 7737], x=[8237, 1], y=[500])
Test batches:
Batch(batch=[7828], edge_index=[2, 7328], x=[7828, 1], y=[500])
Batch(batch=[8333], edge_index=[2, 7833], x=[8333, 1], y=[500])


In [5]:
from ggnns.graph_level_ggnn import GraphLevelGGNN

NUM_CLASSES = 2

config.num_layers=3
config.hidden_size=20

model = GraphLevelGGNN(
    annotation_size=1,
    num_layers=config.num_layers,
    gate_nn=nn.Linear(2 * 1 + config.hidden_size, 1),
    hidden_size=config.hidden_size,
    final_layer=nn.Linear(2 * 1 + config.hidden_size, NUM_CLASSES),
).cuda()

In [6]:
opt = torch.optim.Adam(model.parameters(), lr=config.lr)

In [7]:
from experiments.utils import train

loss = torch.nn.CrossEntropyLoss()

results = train(
    model=model,
    opt=opt,
    training_data=training_data,
    testing_data=testing_data,
    criterion=lambda out, y: loss(out, y),
    checker=lambda out, y: ((torch.argmax(out, dim=-1) == y).sum(), y.shape[0]),
    epochs=config.epochs
)

100%|██████████| 1/1 [00:00<00:00,  3.47it/s]
