In [57]:
import os
import torch
from torch.utils.data import Dataset
import numpy as np
import collections
import csv
import random
import pickle
from torch.utils.data import DataLoader
import dgl
import pandas as pd

class Subgraphs(Dataset):


    def __init__(self, root, mode, subgraph_list, subgraph2label, subgraph2center_node, nodes_aux, labels_aux):   

        self.subgraph2label = subgraph2label
        self.subgraph_list = subgraph_list
        self.subgraph2center_node = subgraph2center_node
        self.subgraph2nodes_aux  = nodes_aux
        self.subgraph2labels_aux  = labels_aux
        
        self.data = pd.read_csv(os.path.join(root, mode + '.csv'))  # csv path
    
    def __getitem__(self, index):
        
        return self.subgraph_list[self.data.iloc[index]['name']], self.subgraph2label[self.data.iloc[index]['name']], self.subgraph2center_node[self.data.iloc[index]['name']], self.subgraph2nodes_aux[self.data.iloc[index]['name']], self.subgraph2labels_aux[self.data.iloc[index]['name']]

    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        return len(self.data)
    
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels, center_nodes, nodes_auxill, labels_auxill = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    #print(nodes_auxill)
    #nodes_auxill = np.vstack(nodes_auxill)
    labels_auxill = np.hstack(labels_auxill)    
    return batched_graph, torch.LongTensor(labels), torch.LongTensor(center_nodes), nodes_auxill, torch.LongTensor(labels_auxill)


In [39]:
import dgl.function as fn
import torch
import torch.nn as nn


# Sends a message of node feature h.
msg = fn.copy_src(src='h', out='m')

def reduce(nodes):
    """Take an average over all neighbor node features hu and use it to
    overwrite the original node feature."""
    accum = torch.mean(nodes.mailbox['m'], 1)
    return {'h': accum}

class NodeApplyModule(nn.Module):
    """Update the node feature hv with ReLU(Whv+b)."""
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h' : h}

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        # Initialize the node features with h.
        g.ndata['h'] = feature
        g.update_all(msg, reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

In [40]:
import torch.nn.functional as F


class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()

        self.layers = nn.ModuleList([
            GCN(in_dim, hidden_dim, F.relu),
            GCN(hidden_dim, hidden_dim, F.relu)])
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g, to_fetch_center, to_fetch_aux):
        # to_fetch_aux: [[idx1, .., idx_n1], [idx1, .., idx_n2], [idx1, .., idx_n3]]
        h = g.in_degrees().view(-1, 1).float().to(device)
        for conv in self.layers:
            h = conv(g, h)
        g.ndata['h'] = h
        #print(h.shape)
        #hg = dgl.mean_nodes(g, 'h')
        #print(to_fetch)
        num_nodes_ = g.batch_num_nodes
        num_nodes_.insert(0, 0)
        offset = torch.cumsum(torch.LongTensor(num_nodes_), dim = 0)[:-1].to(device)
        hg = h[to_fetch_center + offset]
        offset_aux = np.hstack([np.array([offset[i]] * len(ego)) + np.array(ego) for i, ego in enumerate(to_fetch_aux)])
        h_aux = h[offset_aux]
        #print(hg.shape)
        #print(hg.shape)
        #print(h[0].shape)
        #hg = h[g.nodes[0].data['center_node'].detach().numpy()[0]]
        #print(hg.shape)
        return self.classify(hg), self.classify(h_aux)

In [44]:
from torch.utils.data import DataLoader
import torch.optim as optim
dataset = 'cora'
path = '/n/scratch2/kexinhuang/MGNN_Data/cora/fold1/'
path = '../../MGNN_Local_Data/cora/'

with open(path + 'list_subgraph.pkl', 'rb') as f:
    total_subgraph = pickle.load(f)
    
with open(path + 'label.pkl', 'rb') as f:
    info = pickle.load(f)

with open(path + 'center.pkl', 'rb') as f:
    center_node = pickle.load(f)

with open(path + 'nodes_aux.pkl', 'rb') as f:
    nodes_aux = pickle.load(f)
    
with open(path + 'labels_aux.pkl', 'rb') as f:
    labels_aux = pickle.load(f)   
    
trainset = Subgraphs(path, 'train', total_subgraph, info, center_node, nodes_aux, labels_aux)
#valset = Subgraphs(path, mode='val', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
#testset = Subgraphs(path, mode='test', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
    
data_loader = DataLoader(trainset, batch_size=64, shuffle=True,
                         collate_fn=collate)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [46]:
# Create model
model = Classifier(1, 64, 10)
model.to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

epoch_losses = []
for epoch in range(500):
    epoch_loss = 0
    center_loss = 0
    aux_loss = 0
    for iter, (bg, label, to_fetch, to_fetch_aux, labels_aux) in enumerate(data_loader):
        bg = bg.to(device)
        label = label.to(device)
        labels_aux = labels_aux.to(device)
        prediction_center, prediction_aux = model(bg, to_fetch, to_fetch_aux)
        #print(prediction.shape)
        loss1 = loss_func(prediction_center, label)
        loss2 = loss_func(prediction_aux, labels_aux)
        loss = loss1 + loss2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        center_loss += loss1.detach().item()
        aux_loss += loss2.detach().item()

    epoch_loss /= (iter + 1)
    center_loss /= (iter + 1)
    aux_loss /= (iter + 1)
    
    print('Epoch {}, loss_center {:.4f}, loss_aux {:.4f}, loss {:.4f}'.format(epoch, center_loss, aux_loss, epoch_loss))
    epoch_losses.append(epoch_loss)

Epoch 0, loss_center 2.3450, loss_aux 2.5646, loss 4.9096
Epoch 1, loss_center 2.0903, loss_aux 1.7203, loss 3.8105
Epoch 2, loss_center 1.9727, loss_aux 1.6371, loss 3.6099
Epoch 3, loss_center 1.8926, loss_aux 1.5713, loss 3.4639
Epoch 4, loss_center 1.8414, loss_aux 1.5345, loss 3.3759
Epoch 5, loss_center 1.8060, loss_aux 1.5000, loss 3.3060
Epoch 6, loss_center 1.7845, loss_aux 1.4977, loss 3.2822
Epoch 7, loss_center 1.7796, loss_aux 1.4818, loss 3.2614
Epoch 8, loss_center 1.7692, loss_aux 1.4731, loss 3.2423
Epoch 9, loss_center 1.7681, loss_aux 1.4682, loss 3.2363
Epoch 10, loss_center 1.7621, loss_aux 1.4638, loss 3.2259
Epoch 11, loss_center 1.7590, loss_aux 1.4636, loss 3.2227
Epoch 12, loss_center 1.7583, loss_aux 1.4629, loss 3.2212
Epoch 13, loss_center 1.7583, loss_aux 1.4595, loss 3.2178
Epoch 14, loss_center 1.7541, loss_aux 1.4611, loss 3.2153
Epoch 15, loss_center 1.7579, loss_aux 1.4588, loss 3.2166
Epoch 16, loss_center 1.7569, loss_aux 1.4541, loss 3.2111
Epoch 1

Epoch 139, loss_center 1.7203, loss_aux 1.3478, loss 3.0681
Epoch 140, loss_center 1.7193, loss_aux 1.3584, loss 3.0777
Epoch 141, loss_center 1.7166, loss_aux 1.3539, loss 3.0705
Epoch 142, loss_center 1.7164, loss_aux 1.3500, loss 3.0664
Epoch 143, loss_center 1.7163, loss_aux 1.3537, loss 3.0700
Epoch 144, loss_center 1.7185, loss_aux 1.3524, loss 3.0708
Epoch 145, loss_center 1.7194, loss_aux 1.3490, loss 3.0684
Epoch 146, loss_center 1.7183, loss_aux 1.3499, loss 3.0681
Epoch 147, loss_center 1.7231, loss_aux 1.3578, loss 3.0810
Epoch 148, loss_center 1.7148, loss_aux 1.3642, loss 3.0790
Epoch 149, loss_center 1.7185, loss_aux 1.3463, loss 3.0648
Epoch 150, loss_center 1.7155, loss_aux 1.3416, loss 3.0570
Epoch 151, loss_center 1.7169, loss_aux 1.3380, loss 3.0549
Epoch 152, loss_center 1.7160, loss_aux 1.3534, loss 3.0694
Epoch 153, loss_center 1.7224, loss_aux 1.3590, loss 3.0815
Epoch 154, loss_center 1.7173, loss_aux 1.3685, loss 3.0858
Epoch 155, loss_center 1.7132, loss_aux 

Epoch 276, loss_center 1.7143, loss_aux 1.3323, loss 3.0466
Epoch 277, loss_center 1.7082, loss_aux 1.3324, loss 3.0406
Epoch 278, loss_center 1.7102, loss_aux 1.3269, loss 3.0371
Epoch 279, loss_center 1.7069, loss_aux 1.3234, loss 3.0303
Epoch 280, loss_center 1.7109, loss_aux 1.3336, loss 3.0444
Epoch 281, loss_center 1.7088, loss_aux 1.3194, loss 3.0283
Epoch 282, loss_center 1.7088, loss_aux 1.3206, loss 3.0294
Epoch 283, loss_center 1.7070, loss_aux 1.3262, loss 3.0332
Epoch 284, loss_center 1.7102, loss_aux 1.3190, loss 3.0293
Epoch 285, loss_center 1.7122, loss_aux 1.3282, loss 3.0404
Epoch 286, loss_center 1.7140, loss_aux 1.3454, loss 3.0594
Epoch 287, loss_center 1.7115, loss_aux 1.3123, loss 3.0237
Epoch 288, loss_center 1.7109, loss_aux 1.3226, loss 3.0335
Epoch 289, loss_center 1.7084, loss_aux 1.3129, loss 3.0213
Epoch 290, loss_center 1.7125, loss_aux 1.3318, loss 3.0443
Epoch 291, loss_center 1.7106, loss_aux 1.3266, loss 3.0372
Epoch 292, loss_center 1.7094, loss_aux 

Epoch 413, loss_center 1.7054, loss_aux 1.3123, loss 3.0177
Epoch 414, loss_center 1.7013, loss_aux 1.3034, loss 3.0047
Epoch 415, loss_center 1.7034, loss_aux 1.3072, loss 3.0106
Epoch 416, loss_center 1.7015, loss_aux 1.3053, loss 3.0068
Epoch 417, loss_center 1.7026, loss_aux 1.3002, loss 3.0027
Epoch 418, loss_center 1.7038, loss_aux 1.3036, loss 3.0074
Epoch 419, loss_center 1.7020, loss_aux 1.2992, loss 3.0012
Epoch 420, loss_center 1.7031, loss_aux 1.3113, loss 3.0144
Epoch 421, loss_center 1.7047, loss_aux 1.3036, loss 3.0083
Epoch 422, loss_center 1.7080, loss_aux 1.3287, loss 3.0367
Epoch 423, loss_center 1.7034, loss_aux 1.3153, loss 3.0187
Epoch 424, loss_center 1.7020, loss_aux 1.3029, loss 3.0049
Epoch 425, loss_center 1.7076, loss_aux 1.3075, loss 3.0151
Epoch 426, loss_center 1.7035, loss_aux 1.3036, loss 3.0070
Epoch 427, loss_center 1.7034, loss_aux 1.3224, loss 3.0258
Epoch 428, loss_center 1.7013, loss_aux 1.2932, loss 2.9945
Epoch 429, loss_center 1.7048, loss_aux 

In [47]:
torch.save(model.state_dict(), 'model.pt')

In [48]:
model = Classifier(1, 64, 10)
model.load_state_dict(torch.load('model.pt'))

<All keys matched successfully>

In [58]:
path = '/n/scratch2/kexinhuang/MGNN_Data/cora/fold1/'
path = '../../MGNN_Local_Data/cora/'

valset = Subgraphs(path, 'val', total_subgraph, info, center_node, nodes_aux, labels_aux)
#testset = Subgraphs(path, mode='test', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
# Convert a list of tuples to two lists
test_X, test_Y, center_nodes, nodes_aux, labels_aux = map(list, zip(*valset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
test_bg.to(device)
test_Y.to(device)
center_nodes = torch.LongTensor(center_nodes).to(device)

probs_Y = torch.softmax(model(test_bg, center_nodes, nodes_aux), 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.detach().cpu().float()).sum().item() / len(test_Y) * 100))

ValueError: not enough values to unpack (expected 5, got 0)