In [1]:
import dgl
from dgl.data import BAShapeDataset
import torch
import torch_geometric
import torch_geometric.transforms as T

In [2]:
import argparse
import time
import torch
import easydict
import torch_geometric
import random
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
import networkx as nx
import matplotlib.pyplot as plt

In [3]:
# This code works in Python 3.10.6
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch_geometric.utils
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
import torch.nn.functional as F
from torch_geometric.datasets.dblp import DBLP
from torch_geometric.nn import GCNConv
import time
from torch_geometric.logging import log
import os
from collections import Counter

In [4]:
dataset = BAShapeDataset()

Done loading data from cached files.


In [5]:
g = dataset[0]

In [6]:
g

Graph(num_nodes=700, num_edges=2055,
      ndata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'__orig__': Scheme(shape=(), dtype=torch.int64)})

In [7]:
label = g.ndata['label']
len(label)

700

In [8]:
n_nodes = g.number_of_nodes()
n_nodes

700

In [9]:
#Download file BA_shapes.pkl from the dataset in https://github.com/Graph-and-Geometric-Learning/D4Explainer. BA_shapes.pkl is required for the train/val/test splits.

In [10]:
import pickle
with open('BA_shapes.pkl', 'rb') as fin:
    adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix  = pickle.load(fin)

In [11]:
g.ndata["train_mask"] = torch.tensor(train_mask)
g.ndata["val_mask"] = torch.tensor(val_mask)
g.ndata["test_mask"] = torch.tensor(test_mask)

In [12]:
g

Graph(num_nodes=700, num_edges=2055,
      ndata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={'__orig__': Scheme(shape=(), dtype=torch.int64)})

In [13]:
data = torch_geometric.utils.from_dgl(g)
data

Data(edge_index=[2, 2055], feat=[700, 1], label=[700], train_mask=[700], val_mask=[700], test_mask=[700], __orig__=[2055])

In [14]:
#https://stackoverflow.com/questions/4406501/change-the-name-of-a-key-in-dictionary
data.x = data.pop('feat')
data.y = data.pop('label')

In [15]:
data.pop('__orig__')

tensor([2008, 1840, 1812,  ..., 1542, 1538, 1543])

In [16]:
x = torch.tensor([1.0,0.0,0.0,0.0])
data.x = x.repeat(data.x.shape[0],1)

In [17]:
data.x

tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        ...,
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])

In [18]:
data.x.shape

torch.Size([700, 4])

In [19]:
data

Data(edge_index=[2, 2055], train_mask=[700], val_mask=[700], test_mask=[700], x=[700, 4], y=[700])

In [20]:
parser = argparse.ArgumentParser()
args = easydict.EasyDict({
    "dataset": 'BAShapes',
    #"batch_size": 128,
    #"hidden_channels": 64,
    #"lr": 0.0005,
    "epochs": 2000,
})

In [21]:
device = 'cpu'

In [22]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels,
                             )
        self.conv2 = GCNConv(hidden_channels, hidden_channels,
                             )
        self.conv3 = GCNConv(hidden_channels, out_channels,
                             )

    def forward(self, x, edge_index, edge_weight=None):
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv1(x, edge_index, edge_weight).relu()
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv2(x, edge_index, edge_weight).relu()
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv3(x, edge_index, edge_weight)
        return x

device = 'cpu'
model = GCN(
    in_channels=data.x.shape[1],
    hidden_channels=32,
    out_channels=4,
).to(device)

optimizer = torch.optim.Adam([
    dict(params=model.conv1.parameters(), weight_decay=5e-4),
    dict(params=model.conv2.parameters(), weight_decay=0),
    dict(params=model.conv3.parameters(), weight_decay=0)
], lr=0.01)  # Only perform weight-decay on first convolution.


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x, data.edge_index).argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs


best_val_acc = test_acc = 0
start_patience = patience = 100
times = []
for epoch in range(1, 2000 + 1):
    start = time.time()
    loss = train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        test_acc = tmp_test_acc
    if epoch%10==0:
        log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)
    times.append(time.time() - start)

    if (val_acc>best_val_acc):
        print('saving....')
        patience = start_patience
        best_val_acc = val_acc
        print('best acc is', best_val_acc)

        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(model.state_dict(), '../checkpoint/BA_shapes_gcn.pth')
    else:
        patience -= 1
        
    if patience <= 0:
        print('Stopping training as validation accuracy did not improve '
              f'for {start_patience} epochs')
        break   
        
print(f'Median time per epoch: {torch.tensor(times).median():.4f}s')

saving....
best acc is 0.14285714285714285
saving....
best acc is 0.5428571428571428
Epoch: 010, Loss: 1.3050, Train: 0.4143, Val: 0.5429, Test: 0.4286
Epoch: 020, Loss: 1.2379, Train: 0.4143, Val: 0.5429, Test: 0.4286
Epoch: 030, Loss: 1.2111, Train: 0.4143, Val: 0.5429, Test: 0.4286
Epoch: 040, Loss: 1.2198, Train: 0.4143, Val: 0.5429, Test: 0.4286
Epoch: 050, Loss: 1.2029, Train: 0.4143, Val: 0.5429, Test: 0.4286
Epoch: 060, Loss: 1.1347, Train: 0.4143, Val: 0.5429, Test: 0.4286
Epoch: 070, Loss: 1.1344, Train: 0.4143, Val: 0.5429, Test: 0.4286
Epoch: 080, Loss: 1.1384, Train: 0.4143, Val: 0.5429, Test: 0.4286
Epoch: 090, Loss: 1.0370, Train: 0.4143, Val: 0.5429, Test: 0.4286
Epoch: 100, Loss: 0.9655, Train: 0.4143, Val: 0.5429, Test: 0.4286
Stopping training as validation accuracy did not improve for 100 epochs
Median time per epoch: 0.0028s
