In [1]:
import dgl
from dgl.data import TreeGridDataset
import torch
import torch_geometric
import torch_geometric.transforms as T

In [2]:
import argparse
import time
import easydict
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

In [3]:
import torch_geometric.utils
from torch_geometric.utils.convert import from_networkx
from torch_geometric.logging import log
import os
import pandas as pd
import glob
import pickle

In [4]:
dataset = TreeGridDataset()
g = dataset[0]

Done loading data from cached files.


In [5]:
#Download file Tree_Grids.pkl from the dataset in https://github.com/Graph-and-Geometric-Learning/D4Explainer. Tree_Grids.pkl is required for the train/val/test splits.

In [6]:
with open('Tree_Grids.pkl', 'rb') as fin:
    adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix  = pickle.load(fin)

In [7]:
len(adj)

1231

In [8]:
g.ndata["train_mask"] = torch.tensor(train_mask)
g.ndata["val_mask"] = torch.tensor(val_mask)
g.ndata["test_mask"] = torch.tensor(test_mask)

In [9]:
data = torch_geometric.utils.from_dgl(g)
data.x = data.feat
data.y = data.label
data.pop('feat')
data.pop('__orig__')
data.pop('label')
#data

tensor([1, 1, 1,  ..., 0, 0, 0])

In [10]:
x = torch.tensor([1.0,0.0])
data.x = x.repeat(data.x.shape[0],1)

In [11]:
data.x

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        ...,
        [1., 0.],
        [1., 0.],
        [1., 0.]])

In [12]:
data.x.shape

torch.Size([1231, 2])

In [13]:
data

Data(edge_index=[2, 1705], train_mask=[1231], val_mask=[1231], test_mask=[1231], x=[1231, 2], y=[1231])

In [14]:
parser = argparse.ArgumentParser()
args = easydict.EasyDict({
    "dataset": 'TreeGrid',
    #"batch_size": 128,
    # "hidden_channels": 64,
    # "lr": 0.0005,
    "epochs": 2000,
})

In [15]:
device = 'cpu'

In [16]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels,
                             )
        self.conv2 = GCNConv(hidden_channels, hidden_channels,
                             )
        self.conv3 = GCNConv(hidden_channels, out_channels,
                             )

    def forward(self, x, edge_index, edge_weight=None):
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv1(x, edge_index, edge_weight).relu()
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv2(x, edge_index, edge_weight).relu()
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv3(x, edge_index, edge_weight)
        return x
device = 'cpu'
model = GCN(
    in_channels=data.x.shape[1],
    hidden_channels=32,
    out_channels=2,
).to(device)

optimizer = torch.optim.Adam([
    dict(params=model.conv1.parameters(), weight_decay=5e-4),
    dict(params=model.conv2.parameters(), weight_decay=0),
    dict(params=model.conv3.parameters(), weight_decay=0)
], lr=0.01)  # Only perform weight-decay on first convolution.


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    #train_idx = data.y != -1 
    #loss = F.cross_entropy(out[train_idx], data.y[train_idx])
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x, data.edge_index).argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs


best_val_acc = test_acc = 0
start_patience = patience = 100
times = []
for epoch in range(1, 2000 + 1):
    start = time.time()
    loss = train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        #best_val_acc = val_acc
        test_acc = tmp_test_acc
    if epoch%10==0:
        log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)
    times.append(time.time() - start)

    if (val_acc>best_val_acc):
        print('saving....')
        patience = start_patience
        best_val_acc = val_acc
        print('best acc is', best_val_acc)

        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(model.state_dict(), '../checkpoint/TreeGrids_gcn.pth')
    else:
        patience -= 1
        
    if patience <= 0:
        print('Stopping training as validation accuracy did not improve '
              f'for {start_patience} epochs')
        break   
        #torch.save(model.cpu(), './checkpoint/dblp_gcn.pt')
print(f'Median time per epoch: {torch.tensor(times).median():.4f}s')

saving....
best acc is 0.6178861788617886
Epoch: 010, Loss: 0.6701, Train: 0.5752, Val: 0.6179, Test: 0.6290
Epoch: 020, Loss: 0.6497, Train: 0.5752, Val: 0.6179, Test: 0.6290
saving....
best acc is 0.6504065040650406
saving....
best acc is 0.7073170731707317
saving....
best acc is 0.7560975609756098
Epoch: 030, Loss: 0.6437, Train: 0.7398, Val: 0.7561, Test: 0.7097
Epoch: 040, Loss: 0.6372, Train: 0.5030, Val: 0.5122, Test: 0.7097
Epoch: 050, Loss: 0.6433, Train: 0.4299, Val: 0.3902, Test: 0.7097
Epoch: 060, Loss: 0.6441, Train: 0.4248, Val: 0.3821, Test: 0.7097
Epoch: 070, Loss: 0.6374, Train: 0.4695, Val: 0.4553, Test: 0.7097
Epoch: 080, Loss: 0.6281, Train: 0.4278, Val: 0.3821, Test: 0.7097
Epoch: 090, Loss: 0.6453, Train: 0.4868, Val: 0.4715, Test: 0.7097
Epoch: 100, Loss: 0.6414, Train: 0.4868, Val: 0.4715, Test: 0.7097
Epoch: 110, Loss: 0.6407, Train: 0.4837, Val: 0.4715, Test: 0.7097
Epoch: 120, Loss: 0.6275, Train: 0.4817, Val: 0.4715, Test: 0.7097
Stopping training as validat