In [1]:
import os
import torch
from torch.utils.data import Dataset
import numpy as np
import collections
import csv
import random
import pickle
from torch.utils.data import DataLoader
import dgl
import pandas as pd

class Subgraphs(Dataset):


    def __init__(self, root, mode, subgraph_list, subgraph2label, subgraph2center_node):   

        self.subgraph2label = subgraph2label
        self.subgraph_list = subgraph_list
        self.subgraph2center_node = subgraph2center_node
        
        self.data = pd.read_csv(os.path.join(root, mode + '.csv'))  # csv path
    
    def __getitem__(self, index):
        
        return self.subgraph_list[self.data.iloc[index]['name']], self.subgraph2label[self.data.iloc[index]['name']], self.subgraph2center_node[self.data.iloc[index]['name']] 

    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        return len(self.data)
    
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels, center_nodes = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.LongTensor(labels), torch.LongTensor(center_nodes)


In [2]:
import dgl.function as fn
import torch
import torch.nn as nn


# Sends a message of node feature h.
msg = fn.copy_src(src='h', out='m')

def reduce(nodes):
    """Take an average over all neighbor node features hu and use it to
    overwrite the original node feature."""
    accum = torch.mean(nodes.mailbox['m'], 1)
    return {'h': accum}

class NodeApplyModule(nn.Module):
    """Update the node feature hv with ReLU(Whv+b)."""
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h' : h}

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        # Initialize the node features with h.
        g.ndata['h'] = feature
        g.update_all(msg, reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

In [34]:
import torch.nn.functional as F


class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()

        self.layers = nn.ModuleList([
            GCN(in_dim, hidden_dim, F.relu),
            GCN(hidden_dim, hidden_dim, F.relu)])
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g, to_fetch, features):
        # For undirected graphs, in_degree is the same as
        # out_degree.
        h = g.in_degrees().view(-1, 1).float().to(device)
        #h = torch.tensor([1.]*g.number_of_nodes()).view(-1, 1).float()

        #h = features.float()
        for conv in self.layers:
            h = conv(g, h)
        g.ndata['h'] = h
        #print(h.shape)
        #hg = dgl.mean_nodes(g, 'h')
        #print(to_fetch)
        num_nodes_ = g.batch_num_nodes
        num_nodes_.insert(0, 0)
        offset = torch.cumsum(torch.LongTensor(num_nodes_), dim = 0)[:-1].to(device)
        hg = h[to_fetch + offset]
        #print(hg.shape)
        #print(hg.shape)
        #print(h[0].shape)
        #hg = h[g.nodes[0].data['center_node'].detach().numpy()[0]]
        #print(hg.shape)
        return self.classify(hg)

In [82]:
from torch.utils.data import DataLoader
import torch.optim as optim
path = '../data/synthetic_setting1_test/'
with open(path + 'list_subgraph.pkl', 'rb') as f:
    total_subgraph = pickle.load(f)
    
with open(path + 'label.pkl', 'rb') as f:
    info = pickle.load(f)

with open(path + 'center.pkl', 'rb') as f:
    center_node = pickle.load(f)
    
trainset = Subgraphs(path, 'train', total_subgraph, info, center_node)
#valset = Subgraphs(path, mode='val', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
#testset = Subgraphs(path, mode='test', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
    
data_loader = DataLoader(trainset, batch_size=64, shuffle=True,
                         collate_fn=collate)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Create model
model = Classifier(1, 256, 10)
model.to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)
model.train()

epoch_losses = []
for epoch in range(200):
    epoch_loss = 0
    for iter, (bg, label, to_fetch) in enumerate(data_loader):
        bg = bg.to(device)
        label = label.to(device)
        features = bg.ndata['h']
        prediction = model(bg, to_fetch, features)
        #print(prediction.shape)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    if epoch%25 ==0:
        print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

Epoch 0, loss 2.2608
Epoch 25, loss 0.6104
Epoch 50, loss 0.5129
Epoch 75, loss 0.4532
Epoch 100, loss 0.4324


In [None]:
path = '../data/synthetic_setting1_test/'
valset = Subgraphs(path, 'test', total_subgraph, info, center_node)
#testset = Subgraphs(path, mode='test', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
# Convert a list of tuples to two lists
test_X, test_Y, center_nodes = map(list, zip(*valset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
test_bg.to(device)
test_Y.to(device)
center_nodes = torch.LongTensor(center_nodes).to(device)

probs_Y = torch.softmax(model(test_bg, center_nodes, test_bg.ndata['h']), 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.detach().cpu().float()).sum().item() / len(test_Y) * 100))

In [73]:
torch.save(model.state_dict(), 'model.pt')

In [74]:
model = Classifier(1, 256, 10)
model.load_state_dict(torch.load('model.pt'))

<All keys matched successfully>