In [68]:
import os
import torch
from torch.utils.data import Dataset
import numpy as np
import collections
import csv
import random
import pickle
from torch.utils.data import DataLoader
import dgl
import pandas as pd

class Subgraphs(Dataset):


    def __init__(self, root, mode, path_s = 'list_subgraph_graphwave.pkl', path_l = 'label_graphwave.pkl'):
       
        # load subgraph list 
        with open(os.path.join(root, path_s), 'rb') as f:
            subgraph_list = pickle.load(f)

        with open(os.path.join(root, path_l), 'rb') as f:
            subgraph2label = pickle.load(f)    

        self.subgraph2label = subgraph2label
        self.subgraph_list = subgraph_list

        self.data = pd.read_csv(os.path.join(root, mode + '.csv'))  # csv path
    
    def __getitem__(self, index):
        #print(index)
        #print(self.data.iloc[index]['name'])
        #print(self.subgraph_list[self.data.iloc[index]['name']])
        return self.subgraph_list[self.data.iloc[index]['name']], self.subgraph2label[self.data.iloc[index]['name']] 

    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        return len(self.data)
    
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.LongTensor(labels)


In [69]:
import dgl.function as fn
import torch
import torch.nn as nn


# Sends a message of node feature h.
msg = fn.copy_src(src='h', out='m')

def reduce(nodes):
    """Take an average over all neighbor node features hu and use it to
    overwrite the original node feature."""
    accum = torch.mean(nodes.mailbox['m'], 1)
    return {'h': accum}

class NodeApplyModule(nn.Module):
    """Update the node feature hv with ReLU(Whv+b)."""
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h' : h}

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        # Initialize the node features with h.
        g.ndata['h'] = feature
        g.update_all(msg, reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

In [95]:
import torch.nn.functional as F


class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()

        self.layers = nn.ModuleList([
            GCN(in_dim, hidden_dim, F.relu),
            GCN(hidden_dim, hidden_dim, F.relu)])
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # For undirected graphs, in_degree is the same as
        # out_degree.
        h = g.in_degrees().view(-1, 1).float()
        for conv in self.layers:
            h = conv(g, h)
        g.ndata['h'] = h
        #print(h.shape)
        hg = dgl.mean_nodes(g, 'h')
        #print(hg.shape)
        #print(h[0].shape)
        #hg = h[g.nodes[0].data['center_node'].detach().numpy()[0]]
        #print(hg.shape)
        return self.classify(hg)

In [101]:
from torch.utils.data import DataLoader
import torch.optim as optim
dataset = 'cora'
path = '../data/fake_data/fold1/'
trainset = Subgraphs(path, mode='train_'+dataset, path_s = 'list_subgraph_'+dataset+'.pkl', path_l = 'label_'+dataset+'.pkl')
valset = Subgraphs(path, mode='val_'+dataset, path_s = 'list_subgraph_'+dataset+'.pkl', path_l = 'label_'+dataset+'.pkl')
testset = Subgraphs(path, mode='test_'+dataset, path_s = 'list_subgraph_'+dataset+'.pkl', path_l = 'label_'+dataset+'.pkl')
    
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
                         collate_fn=collate)

# Create model
model = Classifier(1, 256, 10)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)
model.train()

epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        #print(prediction.shape)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

Epoch 0, loss 1.8910
Epoch 1, loss 1.6278
Epoch 2, loss 1.6018
Epoch 3, loss 1.6225
Epoch 4, loss 1.6125
Epoch 5, loss 1.5957
Epoch 6, loss 1.6131
Epoch 7, loss 1.6014
Epoch 8, loss 1.5998
Epoch 9, loss 1.5885
Epoch 10, loss 1.5924
Epoch 11, loss 1.5991
Epoch 12, loss 1.5763
Epoch 13, loss 1.5859
Epoch 14, loss 1.5942
Epoch 15, loss 1.5711
Epoch 16, loss 1.5698
Epoch 17, loss 1.5617
Epoch 18, loss 1.5635
Epoch 19, loss 1.5782
Epoch 20, loss 1.5655
Epoch 21, loss 1.5698
Epoch 22, loss 1.5506
Epoch 23, loss 1.5551
Epoch 24, loss 1.5457
Epoch 25, loss 1.5558
Epoch 26, loss 1.5594
Epoch 27, loss 1.5418
Epoch 28, loss 1.5326
Epoch 29, loss 1.5469
Epoch 30, loss 1.5345
Epoch 31, loss 1.5295
Epoch 32, loss 1.5358
Epoch 33, loss 1.5249
Epoch 34, loss 1.5298
Epoch 35, loss 1.5183
Epoch 36, loss 1.5254
Epoch 37, loss 1.5316
Epoch 38, loss 1.5173
Epoch 39, loss 1.5007
Epoch 40, loss 1.5017
Epoch 41, loss 1.5034
Epoch 42, loss 1.5062
Epoch 43, loss 1.4968
Epoch 44, loss 1.4973
Epoch 45, loss 1.488

In [102]:
model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
#sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
#print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
#    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))

Accuracy of argmax predictions on the test set: 48.339483%
