# Overview
This notebook aims to show a minimal reproducing example to use our trained models for inference.
Here, we only evaluate our model from the *first* run, while the results reported in the main paper are averaged across all executed runs, e.g., 10 runs for `ogbg-molhiv`.

### Import libraries

In [1]:
import os
import numpy as np
import torch
import json

In [2]:
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from torch_geometric.datasets import ZINC

In [3]:
from torch_geometric.data import DataLoader
from torch_geometric.transforms import RemoveIsolatedNodes

In [4]:
import sys
sys.path.append("..")

In [5]:
from hypercomplex.undirectional.models import PHMSkipConnectAdd as UPH_SC_ADD

In [6]:
# OGB framework
from train_hiv import test_validate as test_validate_hiv
from train_pcba import test_validate as test_validate_pcba

In [7]:
# Customized
from train_zinc import test_validate as test_validate_zinc
from train_zinc import Evaluator as zinc_Evaluator

In [8]:
def print_model_infos(model):
    mp_layers = [mp * model.phm_dim for mp in model.mp_layers]
    dn_layers = [mp * model.phm_dim for mp in model.downstream_layers]
    print(f"Model consists of {model.get_number_of_params_()} trainable parameters.")
    print(f"PHC-dim: {model.phm_dim}")
    print(f"Message passing layers: {mp_layers}.")
    print(f"Downstream layers: {dn_layers}.")
    return None

In [9]:
DEVICE = "cuda:0"
PRE_TRAFO = False # for dataset class

### `ogbg-molhiv`

In [10]:
os.listdir("hiv/experiment1/")

['params.json',
 'run_1',
 'run_10',
 'run_2',
 'run_3',
 'run_4',
 'run_5',
 'run_6',
 'run_7',
 'run_8',
 'run_9',
 'run.log']

In [11]:
with open("hiv/experiment1/params.json", "r") as f:
    run_dict = json.load(f)

In [12]:
run_dict

{'device': 0,
 'nworkers': 0,
 'pin_memory': 'True',
 'batch_size': 128,
 'save_dir': 'hiv/experiment1',
 'n_runs': 10,
 'seed': 0,
 'pooling': 'softattention',
 'type': 'undirectional-phm-sc-add',
 'phm_dim': 4,
 'learn_phm': 'True',
 'unique_phm': 'False',
 'init': 'phm',
 'input_embed_dim': 200,
 'embed_combine': 'sum',
 'full_encoder': 'True',
 'mp_units': '200,200',
 'mp_norm': 'naive-batch-norm',
 'mlp_mp': 'True',
 'dropout_mpnn': '0.3,0.3',
 'same_dropout': 'False',
 'bias': 'True',
 'd_units': '128,32',
 'd_bn': 'naive-batch-norm',
 'dropout_dn': '0.3,0.1',
 'activation': 'relu',
 'aggr_msg': 'softmax',
 'aggr_node': 'softmax',
 'msg_scale': 'False',
 'real_trafo': 'linear',
 'epochs': 50,
 'lr': 0.001,
 'patience': 5,
 'factor': 0.75,
 'weightdecay': 0.1,
 'regularization': 2,
 'grad_clipping': 2.0,
 'log_weights': 'False',
 'msg_encoder': 'identity'}

### load dataset and loaders

In [13]:
dname="ogbg-molhiv"
dataset = PygGraphPropPredDataset(name=dname, root="dataset", transform=None)
evaluator = Evaluator(name=dname)
split_idx = dataset.get_idx_split()
transform = RemoveIsolatedNodes()  # will be applied in the test_validate function

In [14]:
test_data = dataset[split_idx["test"]]
print("Test split sample size: ", len(test_data))
test_loader = DataLoader(test_data, batch_size=run_dict["batch_size"], drop_last=False,
                         shuffle=False, num_workers=0, pin_memory=False)

Test split sample size:  4113


In [15]:
model = torch.load("hiv/experiment1/run_1/model.pt").to(DEVICE)
if "sc_type" not in run_dict.keys():
    model.sc_type = "first" # include attribute, as this model was trained with a different implementation
else:
    model.sc_type = run_dict["sc_type"]
model = model.eval()
print_model_infos(model)

Model consists of 110909 trainable parameters.
PHC-dim: 4
Message passing layers: [200, 200].
Downstream layers: [128, 32].


In [16]:
test_metrics_hiv = test_validate_hiv(model, DEVICE, transform, test_loader, evaluator)

In [17]:
test_metrics_hiv

{'loss': 0.11866279581909034, 'rocauc': 0.7960370034183742}

### `ogbg-molpcba`


In [18]:
dname="ogbg-molpcba"
dataset = PygGraphPropPredDataset(name=dname, root="dataset", transform=None)
evaluator = Evaluator(name=dname)
split_idx = dataset.get_idx_split()
transform = RemoveIsolatedNodes()  # will be applied in the test_validate function

In [19]:
with open("pcba/experiment1/params.json", "r") as f:
    run_dict = json.load(f)

In [20]:
test_data = dataset[split_idx["test"]]
print("Test split sample size: ", len(test_data))
test_loader = DataLoader(test_data, batch_size=run_dict["batch_size"], drop_last=False,
                         shuffle=False, num_workers=0, pin_memory=False)

Test split sample size:  43793


In [21]:
model = torch.load("pcba/experiment1/run_1/model.pt").to(DEVICE)
if "sc_type" not in run_dict.keys():
    model.sc_type = "first" # include attribute, as this model was trained with a different implementation
else:
    model.sc_type = run_dict["sc_type"]
model = model.eval()
model = model.eval()
print_model_infos(model)

Model consists of 1690328 trainable parameters.
PHC-dim: 2
Message passing layers: [512, 512, 512, 512, 512, 512, 512].
Downstream layers: [768, 256].


In [22]:
test_metrics_pcba = test_validate_pcba(model, DEVICE, transform, test_loader, evaluator)

In [23]:
test_metrics_pcba

{'loss': 0.04663223533899669, 'ap': 0.29484338917596925}

### `ZINC`


In [24]:
with open("zinc/experiment1/params.json", "r") as f:
    run_dict = json.load(f)

In [25]:
path = "dataset/ZINC"

In [26]:
test_data = ZINC(path, subset=True, split='test')
evaluator = zinc_Evaluator()

test_loader = DataLoader(test_data, batch_size=run_dict["batch_size"], drop_last=False,
                         shuffle=False, num_workers=0)
print("Test split sample size: ", len(test_data))

Test split sample size:  1000


In [27]:
model = torch.load("zinc/experiment1/run_1/model.pt").to(DEVICE)
if "sc_type" not in run_dict.keys():
    model.sc_type = "first" # include attribute, as this model was trained with a different implementation
else:
    model.sc_type = run_dict["sc_type"]
model = model.eval()
model = model.eval()
print_model_infos(model)

Model consists of 106291 trainable parameters.
PHC-dim: 5
Message passing layers: [200, 200, 200, 200].
Downstream layers: [180, 80].


In [28]:
test_metrics_zinc = test_validate_zinc(model, DEVICE, transform, test_loader, evaluator)

In [29]:
test_metrics_zinc

{'loss': 0.18928215515613556, 'mae': 0.1892821490764618}