In [1]:
import os
import torch
from torch.utils.data import Dataset
import numpy as np
import collections
import csv
import random
import pickle
from torch.utils.data import DataLoader
import dgl
import pandas as pd

class Subgraphs(Dataset):


    def __init__(self, root, mode, path_s = 'list_subgraph.pkl', path_l = 'label.pkl'):
       
        # load subgraph list 
        with open(os.path.join(root, path_s), 'rb') as f:
            subgraph_list = pickle.load(f)

        with open(os.path.join(root, path_l), 'rb') as f:
            subgraph2label = pickle.load(f)    

        self.subgraph2label = subgraph2label
        self.subgraph_list = subgraph_list

        self.data = pd.read_csv(os.path.join(root, mode + '.csv'))  # csv path
    
    def __getitem__(self, index):
        #print(index)
        #print(self.data.iloc[index]['name'])
        #print(self.subgraph_list[self.data.iloc[index]['name']])
        return self.subgraph_list[self.data.iloc[index]['name']], self.subgraph2label[self.data.iloc[index]['name']] 

    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        return len(self.data)
    
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.LongTensor(labels)


In [2]:
import dgl.function as fn
import torch
import torch.nn as nn


# Sends a message of node feature h.
msg = fn.copy_src(src='h', out='m')

def reduce(nodes):
    """Take an average over all neighbor node features hu and use it to
    overwrite the original node feature."""
    accum = torch.mean(nodes.mailbox['m'], 1)
    return {'h': accum}

class NodeApplyModule(nn.Module):
    """Update the node feature hv with ReLU(Whv+b)."""
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h' : h}

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        # Initialize the node features with h.
        g.ndata['h'] = feature
        g.update_all(msg, reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

In [3]:
import torch.nn.functional as F


class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()

        self.layers = nn.ModuleList([
            GCN(in_dim, hidden_dim, F.relu),
            GCN(hidden_dim, hidden_dim, F.relu)])
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # For undirected graphs, in_degree is the same as
        # out_degree.
        h = g.in_degrees().view(-1, 1).float().to(device)
        for conv in self.layers:
            h = conv(g, h)
        g.ndata['h'] = h
        #print(h.shape)
        hg = dgl.mean_nodes(g, 'h')
        #print(hg.shape)
        #print(h[0].shape)
        #hg = h[g.nodes[0].data['center_node'].detach().numpy()[0]]
        #print(hg.shape)
        return self.classify(hg)

In [4]:
from torch.utils.data import DataLoader
import torch.optim as optim
dataset = 'cora'
path = '/n/scratch2/kexinhuang/MGNN_Data/cora/fold1/'
trainset = Subgraphs(path, mode='train', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
#valset = Subgraphs(path, mode='val', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
#testset = Subgraphs(path, mode='test', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
    
data_loader = DataLoader(trainset, batch_size=64, shuffle=True,
                         collate_fn=collate)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# Create model
model = Classifier(1, 256, 10)
model.to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.008)
model.train()

epoch_losses = []
for epoch in range(500):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        bg = bg.to(device)
        label = label.to(device)
        prediction = model(bg)
        #print(prediction.shape)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

Epoch 0, loss 2.8846
Epoch 1, loss 1.7318
Epoch 2, loss 1.6287
Epoch 3, loss 1.6157
Epoch 4, loss 1.6048
Epoch 5, loss 1.6101
Epoch 6, loss 1.6240
Epoch 7, loss 1.6189
Epoch 8, loss 1.6109
Epoch 9, loss 1.5873
Epoch 10, loss 1.6093
Epoch 11, loss 1.5878
Epoch 12, loss 1.5957
Epoch 13, loss 1.5973
Epoch 14, loss 1.5883
Epoch 15, loss 1.5971
Epoch 16, loss 1.5953
Epoch 17, loss 1.5784
Epoch 18, loss 1.5815
Epoch 19, loss 1.5819
Epoch 20, loss 1.5781
Epoch 21, loss 1.5667
Epoch 22, loss 1.5728
Epoch 23, loss 1.5732
Epoch 24, loss 1.5545
Epoch 25, loss 1.5633
Epoch 26, loss 1.5550
Epoch 27, loss 1.5671
Epoch 28, loss 1.5571
Epoch 29, loss 1.5434
Epoch 30, loss 1.5489
Epoch 31, loss 1.5466
Epoch 32, loss 1.5563
Epoch 33, loss 1.5519
Epoch 34, loss 1.5413
Epoch 35, loss 1.5337
Epoch 36, loss 1.5340
Epoch 37, loss 1.5359
Epoch 38, loss 1.5335
Epoch 39, loss 1.5480
Epoch 40, loss 1.5317
Epoch 41, loss 1.5297
Epoch 42, loss 1.5354
Epoch 43, loss 1.5067
Epoch 44, loss 1.5116
Epoch 45, loss 1.507

Epoch 361, loss 1.3530
Epoch 362, loss 1.3490
Epoch 363, loss 1.3533
Epoch 364, loss 1.3561
Epoch 365, loss 1.3500
Epoch 366, loss 1.3572
Epoch 367, loss 1.3761
Epoch 368, loss 1.3582
Epoch 369, loss 1.3551
Epoch 370, loss 1.3531
Epoch 371, loss 1.3451
Epoch 372, loss 1.3502
Epoch 373, loss 1.3532
Epoch 374, loss 1.3585
Epoch 375, loss 1.3602
Epoch 376, loss 1.3638
Epoch 377, loss 1.3612
Epoch 378, loss 1.3649
Epoch 379, loss 1.3552
Epoch 380, loss 1.3733
Epoch 381, loss 1.3513
Epoch 382, loss 1.3628
Epoch 383, loss 1.3629
Epoch 384, loss 1.3497
Epoch 385, loss 1.3571
Epoch 386, loss 1.3762
Epoch 387, loss 1.3477
Epoch 388, loss 1.3499
Epoch 389, loss 1.3674
Epoch 390, loss 1.3437
Epoch 391, loss 1.3672
Epoch 392, loss 1.3541
Epoch 393, loss 1.3502
Epoch 394, loss 1.3573
Epoch 395, loss 1.3408
Epoch 396, loss 1.3532
Epoch 397, loss 1.3509
Epoch 398, loss 1.3605
Epoch 399, loss 1.3579
Epoch 400, loss 1.3439
Epoch 401, loss 1.3469
Epoch 402, loss 1.3500
Epoch 403, loss 1.3730
Epoch 404, 

In [6]:
torch.save(model.state_dict(), 'model.pt')

In [4]:
model = Classifier(1, 256, 10)
model.load_state_dict(torch.load('model.pt'))

<All keys matched successfully>

In [5]:
path = '/n/scratch2/kexinhuang/MGNN_Data/cora/fold1/'
valset = Subgraphs(path, mode='val', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
#testset = Subgraphs(path, mode='test', path_s = 'list_subgraph.pkl', path_l = 'label.pkl')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*valset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
test_bg.to(device)
test_Y.to(device)
probs_Y = torch.softmax(model(test_bg), 1)
#sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
#print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
#    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))

RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'other'

In [13]:
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.detach().cpu().float()).sum().item() / len(test_Y) * 100))

Accuracy of argmax predictions on the test set: 53.874539%
