In [1]:
import os
import torch
from torch.utils.data import Dataset
import numpy as np
import collections
import csv
import random
import pickle
from torch.utils.data import DataLoader
import dgl
import pandas as pd

class Subgraphs(Dataset):


    def __init__(self, root, mode, subgraph_list, subgraph2label, subgraph2center_node):   

        self.subgraph2label = subgraph2label
        self.subgraph_list = subgraph_list
        self.subgraph2center_node = subgraph2center_node
        
        self.data = pd.read_csv(os.path.join(root, mode + '.csv'))  # csv path
    
    def __getitem__(self, index):
        
        return self.subgraph_list[self.data.iloc[index]['name']], self.subgraph2label[self.data.iloc[index]['name']], self.subgraph2center_node[self.data.iloc[index]['name']] 

    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        return len(self.data)
    
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels, center_nodes = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.LongTensor(labels), torch.LongTensor(center_nodes)


In [2]:
import dgl.function as fn
import torch
import torch.nn as nn


# Sends a message of node feature h.
msg = fn.copy_src(src='h', out='m')

def reduce(nodes):
    """Take an average over all neighbor node features hu and use it to
    overwrite the original node feature."""
    accum = torch.mean(nodes.mailbox['m'], 1)
    return {'h': accum}

class NodeApplyModule(nn.Module):
    """Update the node feature hv with ReLU(Whv+b)."""
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h' : h}

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        # Initialize the node features with h.
        g.ndata['h'] = feature
        g.update_all(msg, reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

In [11]:
import torch.nn.functional as F


class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()

        self.layers = nn.ModuleList([
            GCN(in_dim, hidden_dim, F.relu),
            GCN(hidden_dim, hidden_dim, F.relu)])
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g, to_fetch):
        # For undirected graphs, in_degree is the same as
        # out_degree.
        h = g.in_degrees().view(-1, 1).float().to(device)
        for conv in self.layers:
            h = conv(g, h)
        g.ndata['h'] = h
        #print(h.shape)
        #hg = dgl.mean_nodes(g, 'h')
        #print(to_fetch)
        num_nodes_ = g.batch_num_nodes
        num_nodes_.insert(0, 0)
        offset = torch.cumsum(torch.LongTensor(num_nodes_), dim = 0)[:-1].cuda()
        hg = h[to_fetch + offset]
        #print(hg.shape)
        #print(hg.shape)
        #print(h[0].shape)
        #hg = h[g.nodes[0].data['center_node'].detach().numpy()[0]]
        #print(hg.shape)
        return self.classify(hg)

In [4]:
from torch.utils.data import DataLoader
import torch.optim as optim
dataset = 'cora'
path = '/n/scratch2/kexinhuang/MGNN_Data/cora/fold1/'

with open(path + 'list_subgraph.pkl', 'rb') as f:
    total_subgraph = pickle.load(f)
    
with open(path + 'label.pkl', 'rb') as f:
    info = pickle.load(f)

with open(path + 'center.pkl', 'rb') as f:
    center_node = pickle.load(f)
    
trainset = Subgraphs(path, 'train', total_subgraph, info, center_node)
#valset = Subgraphs(path, mode='val', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
#testset = Subgraphs(path, mode='test', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
    
data_loader = DataLoader(trainset, batch_size=64, shuffle=True,
                         collate_fn=collate)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
# Create model
model = Classifier(1, 64, 10)
model.to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)
model.train()

epoch_losses = []
for epoch in range(500):
    epoch_loss = 0
    for iter, (bg, label, to_fetch) in enumerate(data_loader):
        bg = bg.to(device)
        label = label.to(device)
        prediction = model(bg, to_fetch)
        #print(prediction.shape)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

Epoch 0, loss 2.0257
Epoch 1, loss 1.7696
Epoch 2, loss 1.7480
Epoch 3, loss 1.7450
Epoch 4, loss 1.7273
Epoch 5, loss 1.7328
Epoch 6, loss 1.7157
Epoch 7, loss 1.7156
Epoch 8, loss 1.7227
Epoch 9, loss 1.7131
Epoch 10, loss 1.7069
Epoch 11, loss 1.7088
Epoch 12, loss 1.7095
Epoch 13, loss 1.7066
Epoch 14, loss 1.7060
Epoch 15, loss 1.6995
Epoch 16, loss 1.6950
Epoch 17, loss 1.6975
Epoch 18, loss 1.6972
Epoch 19, loss 1.6914
Epoch 20, loss 1.6876
Epoch 21, loss 1.6892
Epoch 22, loss 1.6886
Epoch 23, loss 1.7038
Epoch 24, loss 1.6942
Epoch 25, loss 1.6867
Epoch 26, loss 1.6941
Epoch 27, loss 1.6844
Epoch 28, loss 1.6911
Epoch 29, loss 1.6871
Epoch 30, loss 1.6835
Epoch 31, loss 1.6933
Epoch 32, loss 1.6849
Epoch 33, loss 1.6948
Epoch 34, loss 1.6792
Epoch 35, loss 1.6855
Epoch 36, loss 1.6776
Epoch 37, loss 1.6718
Epoch 38, loss 1.6829
Epoch 39, loss 1.6838
Epoch 40, loss 1.6776
Epoch 41, loss 1.6736
Epoch 42, loss 1.6776
Epoch 43, loss 1.6712
Epoch 44, loss 1.6759
Epoch 45, loss 1.685

Epoch 361, loss 1.5906
Epoch 362, loss 1.5877
Epoch 363, loss 1.5903
Epoch 364, loss 1.5856
Epoch 365, loss 1.5818
Epoch 366, loss 1.5985
Epoch 367, loss 1.5947
Epoch 368, loss 1.5898
Epoch 369, loss 1.5833
Epoch 370, loss 1.5833
Epoch 371, loss 1.5830
Epoch 372, loss 1.5815
Epoch 373, loss 1.5885
Epoch 374, loss 1.5958
Epoch 375, loss 1.5817
Epoch 376, loss 1.5878
Epoch 377, loss 1.5854
Epoch 378, loss 1.5900
Epoch 379, loss 1.5825
Epoch 380, loss 1.5774
Epoch 381, loss 1.5831
Epoch 382, loss 1.5865
Epoch 383, loss 1.5876
Epoch 384, loss 1.5777
Epoch 385, loss 1.5806
Epoch 386, loss 1.5875
Epoch 387, loss 1.5938
Epoch 388, loss 1.5868
Epoch 389, loss 1.5810
Epoch 390, loss 1.5766
Epoch 391, loss 1.5856
Epoch 392, loss 1.5822
Epoch 393, loss 1.5854
Epoch 394, loss 1.5817
Epoch 395, loss 1.5800
Epoch 396, loss 1.5781
Epoch 397, loss 1.5805
Epoch 398, loss 1.5783
Epoch 399, loss 1.5887
Epoch 400, loss 1.5778
Epoch 401, loss 1.5808
Epoch 402, loss 1.5814
Epoch 403, loss 1.5794
Epoch 404, 

In [8]:
torch.save(model.state_dict(), 'model.pt')

In [12]:
model = Classifier(1, 64, 10)
model.load_state_dict(torch.load('model.pt'))

<All keys matched successfully>

In [13]:
path = '/n/scratch2/kexinhuang/MGNN_Data/cora/fold1/'
valset = Subgraphs(path, 'val', total_subgraph, info, center_node)
#testset = Subgraphs(path, mode='test', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
# Convert a list of tuples to two lists
test_X, test_Y, center_nodes = map(list, zip(*valset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
test_bg.to(device)
test_Y.to(device)
center_nodes = torch.LongTensor(center_nodes).to(device)

probs_Y = torch.softmax(model(test_bg, center_nodes), 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.detach().cpu().float()).sum().item() / len(test_Y) * 100))

Accuracy of argmax predictions on the test set: 39.114391%
