In [1]:
import torch_geometric.datasets as datasets

mutag = datasets.TUDataset(root='data', name='MUTAG')

In [2]:
from gnnexplain.nn.gcn import GraphGCN
from lightning import LightningModule

model = GraphGCN.load_from_checkpoint('checkpoints/MUTAG-epoch=99.ckpt').cpu()
isinstance(model, LightningModule)

True

In [3]:
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch

loader = DataLoader(mutag, batch_size=40000, shuffle=False)
batch = loader.__iter__().__next__().cpu()
l = batch.to_data_list()
train, test = Batch.from_data_list(l[:150]), Batch.from_data_list(l[150:])

In [None]:
(model(train).argmax(1) == train.y).float().mean()

tensor([[-2.0703e+00, -1.3485e-01],
        [-2.1291e-01, -1.6515e+00],
        [-9.3181e-03, -4.6805e+00],
        [-4.7084e+00, -9.0601e-03],
        [-2.7363e-02, -3.6122e+00],
        [-3.2908e+00, -3.7935e-02],
        [-2.9723e-01, -1.3582e+00],
        [-7.2213e+00, -7.3108e-04],
        [-3.7045e-03, -5.6001e+00],
        [-5.4620e+00, -4.2542e-03],
        [-3.7743e-01, -1.1572e+00],
        [-3.7567e+00, -2.3637e-02],
        [-2.6330e+00, -7.4577e-02],
        [-7.4722e-03, -4.9003e+00],
        [-4.3226e+00, -1.3354e-02],
        [-3.9297e+00, -1.9844e-02],
        [-2.8610e-05, -1.0464e+01],
        [-7.1189e+00, -8.1006e-04],
        [-1.2061e-02, -4.4238e+00],
        [-8.5642e-01, -5.5283e-01],
        [-3.0665e+00, -4.7703e-02],
        [-3.3553e+00, -3.5522e-02],
        [-8.3312e-01, -5.7039e-01],
        [-2.0873e+00, -1.3241e-01],
        [-7.1533e+00, -7.8254e-04],
        [-5.5942e-02, -2.9113e+00],
        [-2.2095e+00, -1.1626e-01],
        [-2.1566e+00, -1.229

In [11]:
from gnnexplain.model.gtree import *
import optuna

optuna.logging.set_verbosity(optuna.logging.WARNING)
opt = Optimizer(lmb=1e-4, max_ccp_alpha=1e-3, n_trials=100)
expl = opt.optimize(train, model)

Best trial: 0. Best value: 0.772367:   4%|▍         | 4/100 [00:00<00:13,  7.32it/s]

Best trial: 78. Best value: 0.8991: 100%|██████████| 100/100 [00:18<00:00,  5.33it/s] 


In [12]:
expl.accuracy(train)

tensor(0.8733)

In [13]:
expl.accuracy(test)

tensor(0.8158)

In [14]:
%matplotlib inline
import matplotlib.pyplot as plt
fig, axs = plt.subplots(nrows = len(expl.layer) + 1, figsize=(10, 4 * len(expl.layer) + 4))
for k in range(len(expl.layer)):
    expl.layer[k].plot(axs[k], k)
expl.out_layer.plot(axs[-1], k+1)
plt.savefig(f'figures/MUTAG_{expl.accuracy(test):.0%}.png')
plt.close()