In [2]:
import dgl
%matplotlib inline
import networkx as nx 
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import os
os.chdir('../graphwave/')

import matplotlib.pyplot as plt
import graphwave
from graphwave.shapes import build_graph

import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

import torch

In [None]:
class graph_data(Dataset):


    def __init__(self, root, mode, subgraph2label, k_shot, n_qry, batchsz = 100):   

        self.subgraph2label = subgraph2label
        
        self.data = pd.read_csv(os.path.join(root, mode + '.csv'))  # csv path
        self.k_shot = k_shot
        self.n_qry = n_qry
        self.batchsz = batchsz
        
        self.labels = np.unique(self.data.label.values)
        self.labels_dict = dict(zip(list(self.labels), list(range(len(self.labels)))))
        
    def __getitem__(self, index):
        support = []
        query = []
        
        for i in self.labels:
            df_labels = self.data[self.data.label == i].reset_index(drop = True)
            support = support + list(df_labels.sample(n = self.k_shot)['name'].values)
            query = query + list(df_labels[~df_labels.name.isin(support)].sample(n = self.n_qry)['name'].values)
            
        return support, [self.labels_dict[self.subgraph2label[i]] for i in support], query, [self.labels_dict[self.subgraph2label[i]] for i in query]

    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        return self.batchsz
    
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    
    spt_idx, l_spt, qry_idx, l_qry = map(list, zip(*samples))
    return torch.LongTensor(spt_idx), torch.LongTensor(l_spt), torch.LongTensor(qry_idx), torch.LongTensor(l_qry)

In [20]:
import torch.nn as nn
import torch.nn.functional as F

# Define the message and reduce function
# NOTE: We ignore the GCN's normalization constant c_ij for this tutorial.
def gcn_message(edges):
    # The argument is a batch of edges.
    # This computes a (batch of) message called 'msg' using the source node's feature 'h'.
    return {'msg' : edges.src['h']}

def gcn_reduce(nodes):
    # The argument is a batch of nodes.
    # This computes the new 'h' features by summing received 'msg' in each node's mailbox.
    return {'h' : torch.sum(nodes.mailbox['msg'], dim=1)}

# Define the GCNLayer module
class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)

    def forward(self, g, inputs):
        # g is the graph and the inputs is the input node features
        # first set the node features
        g.ndata['h'] = inputs
        # trigger message passing on all edges
        g.send(g.edges(), gcn_message)
        # trigger aggregation at all nodes
        g.recv(g.nodes(), gcn_reduce)
        # get the result node features
        h = g.ndata.pop('h')
        # perform linear transformation
        return self.linear(h)

In [21]:
gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')

class NodeApplyModule(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        if self.activation is not None:
            h = self.activation(h)
        return {'h' : h}
    
class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        g.ndata['h'] = feature
        g.update_all(gcn_msg, gcn_reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

In [22]:
class Net(nn.Module):
    def __init__(self, in_feats, hid_dim, out_feats):
        super(Net, self).__init__()
        self.gcn1 = GCN(in_feats, hid_dim, F.relu)
        self.gcn2 = GCN(hid_dim, out_feats, None)

    def forward(self, g, features):
        x = self.gcn1(g, features)
        x = self.gcn2(g, x)
        return x
    

Net(
  (gcn1): GCN(
    (apply_mod): NodeApplyModule(
      (linear): Linear(in_features=1433, out_features=16, bias=True)
    )
  )
  (gcn2): GCN(
    (apply_mod): NodeApplyModule(
      (linear): Linear(in_features=16, out_features=7, bias=True)
    )
  )
)


In [16]:
# multiple graph setting 

task_num = 4
k_shot = 1
n_qry = 6

path = '../data/multiple_graph/BA/META_LABEL/'
fold_n = 1
adjs = np.load(path + 'graphs_adj.npy', allow_pickle=True)
dgl_Gs = []

for i in range(adjs.shape[0]):
    adj = adjs[i]
    G = nx.from_numpy_matrix(adj)
    S = dgl.DGLGraph()
    S.from_networkx(G)
    dgl_Gs.append(S)
    
    
path = path + 'fold'+str(fold_n)+'/'
trainset = graph_data(path, 'train', info, k_shot, n_qry, 1000)
valset = graph_data(path, 'val', info, k_shot, n_qry, 100)
testset = graph_data(path, 'test', info, k_shot, n_qry, 100)

data_loader_test = DataLoader(testset, batch_size=1, shuffle=True, collate_fn=collate)
data_loader_val = DataLoader(valset, batch_size=1, shuffle=True, collate_fn=collate)
data_loader_train = DataLoader(trainset, batch_size=task_num, shuffle=True, collate_fn=collate)


In [47]:
net = Net(1, 64, np.unique(labels).shape[0])
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

    
for iter, (bg, label, to_fetch, bg_qry, label_qry, to_fetch_qry) in enumerate(data_loader_train):
    
    
    for graph_idx, S in enumerate(dgl_Gs):    
        features = S.in_degrees().view(-1, 1).float().to(device)    
        logits = net(S, features, idx)
        
        
        
    # we save the logits for visualization later
    all_logits.append(logits.detach())
    logp = F.log_softmax(logits, 1)
    # we only compute loss for labeled nodes
    loss = F.nll_loss(logp[labeled_nodes], labels_train)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))


Epoch 0 | Loss: 2.3168
Epoch 1 | Loss: 2.1854
Epoch 2 | Loss: 2.0515
Epoch 3 | Loss: 1.9082
Epoch 4 | Loss: 1.7647
Epoch 5 | Loss: 1.6287
Epoch 6 | Loss: 1.4966
Epoch 7 | Loss: 1.3618
Epoch 8 | Loss: 1.2255
Epoch 9 | Loss: 1.0934
Epoch 10 | Loss: 0.9700
Epoch 11 | Loss: 0.8562
Epoch 12 | Loss: 0.7535
Epoch 13 | Loss: 0.6632
Epoch 14 | Loss: 0.5845
Epoch 15 | Loss: 0.5139
Epoch 16 | Loss: 0.4495
Epoch 17 | Loss: 0.3933
Epoch 18 | Loss: 0.3482
Epoch 19 | Loss: 0.3133
Epoch 20 | Loss: 0.2853
Epoch 21 | Loss: 0.2632
Epoch 22 | Loss: 0.2456
Epoch 23 | Loss: 0.2296
Epoch 24 | Loss: 0.2139
Epoch 25 | Loss: 0.1989
Epoch 26 | Loss: 0.1850
Epoch 27 | Loss: 0.1718
Epoch 28 | Loss: 0.1596
Epoch 29 | Loss: 0.1497
Epoch 30 | Loss: 0.1424
Epoch 31 | Loss: 0.1371
Epoch 32 | Loss: 0.1324
Epoch 33 | Loss: 0.1274
Epoch 34 | Loss: 0.1218
Epoch 35 | Loss: 0.1162
Epoch 36 | Loss: 0.1114
Epoch 37 | Loss: 0.1080
Epoch 38 | Loss: 0.1057
Epoch 39 | Loss: 0.1037
Epoch 40 | Loss: 0.1015
Epoch 41 | Loss: 0.0988
Ep

In [48]:
net.eval()
logits = net(S, inputs)
# we save the logits for visualization later
logp = F.log_softmax(logits, 1)
# we only compute loss for labeled nodes
#loss = F.nll_loss(logp[labeled_nodes], labels_train)

In [49]:
argmax_Y = torch.max(logp[unlabelled_nodes], 1)[1]

In [50]:
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_label == argmax_Y.float()).sum().item() / len(test_label) * 100))

Accuracy of argmax predictions on the test set: 44.751381%
