In [192]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from google.colab import drive
import networkx as nx
import copy
import random


device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    gpu_info = !nvidia-smi
    gpu_info = '\n'.join(gpu_info)
    if gpu_info.find('failed') >= 0:
        print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
        print('and then re-execute this cell.')
    else:
        print(gpu_info)
print('device :',device)
print('torch.version :',torch.__version__)

device : cpu
torch.version : 1.7.0+cu101


In [193]:
drive.mount('/content/drive')
folder_dir = '/content/drive/My Drive/cora'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [194]:
class Cora(object):
    
    def __init__(self, base_dir):
        data = np.genfromtxt(base_dir + '/cora/cora.content', dtype=np.str)
        

        self.category_list = list(set(data[:, -1]))
        self.cat2lab = {category:label for label, category in enumerate(self.category_list)}
        self.lab2cat = {label:category for label, category in enumerate(self.category_list)}
        
        self.paper_list = np.array(data[:, 0], dtype=np.int)
        self.paper2node = {paper_id:node_id for node_id, paper_id in enumerate(self.paper_list)}
        self.node2paper = {node_id:paper_id for node_id, paper_id in enumerate(self.paper_list)}


        self.node_feature = np.array(data[:, 1:-1], dtype=int)
        self.node_label = [self.cat2lab[category] for category in data[:, -1]]
        
        self.edge = np.genfromtxt(base_dir + '/cora/cora.cites',   dtype=np.int)
        
        self.node_connection_list = [[] for _ in range(data.shape[0])]

        for i, (cited, citing) in enumerate(self.edge):
            cited_node = self.paper2node[cited]
            citing_node = self.paper2node[citing]
            self.node_connection_list[cited_node].append(citing_node)
            

    def get_data(self):
        node_feature = np.copy(self.node_feature)
        node_label = np.copy(self.node_label)
        node_connection_list = [node_list for node_list in self.node_connection_list]
        return node_feature, node_label, node_connection_list



def to_undirected_connection_list(src):
    tar = copy.deepcopy(src)
    for i, node_list in enumerate(src):
        for node in node_list:
            tar[node].append(i)
        tar[i].append(i) # self loof
       
    for i in range(len(src)):
        tar[i] = sorted(list(set(tar[i])))
        
    return tar

cora = Cora(base_dir='/content/drive/My Drive')
node_feature, node_label, node_connection_list = cora.get_data()
undirected_connection_list = to_undirected_connection_list(node_connection_list)

In [195]:
def choice(neighbor_list, num_samples):
    neighbor_list = copy.deepcopy(neighbor_list)
    if len(neighbor_list) < num_samples:
        return neighbor_list
        
    choiced = []
    for _ in range(num_samples):
        if len(neighbor_list) == 0:
            break
        v = random.choice(neighbor_list)
        choiced.append(v)
        neighbor_list.remove(v)
    return choiced


def sample(target_nodes, edges, num_samples):

    layer_nodes = [copy.deepcopy(target_nodes)]
    next_mapping_list = []

    for k, n_k in enumerate(num_samples):    

        src_nodes = layer_nodes[-1]
        
        tar_nodes = copy.deepcopy(src_nodes)
        
        for u in src_nodes:
            tar_nodes += choice(edges[u], n_k)

        tar_nodes = list(set(tar_nodes))

        src2tar_map = [[] for _ in src_nodes]

        for src_idx, u in enumerate(src_nodes):
            for tar_idx, v in enumerate(tar_nodes):
                if u in edges[v]:
                    src2tar_map[src_idx].append(tar_idx)
        
        layer_nodes.append(tar_nodes)
        next_mapping_list.append(src2tar_map)


    first_layer_nodes = layer_nodes.pop()
    
    layer_nodes.reverse()
    next_mapping_list.reverse()
        
    return first_layer_nodes, layer_nodes, next_mapping_list


In [196]:
B0, Bn, ptrs = sample([4, 3], undirected_connection_list, [25, 25])
print(B0)
print(Bn)
print(ptrs[0])
print(ptrs[1])

[3, 4, 197, 483, 295, 552, 611, 170, 749, 333, 463, 1741, 633, 564, 565, 601, 250, 477]
[[3, 4, 197, 170, 463, 601], [4, 3]]
[[0, 2, 10, 15], [1, 7], [0, 2, 4, 8, 12], [1, 7, 16], [0, 3, 9, 10, 13, 17], [0, 5, 6, 11, 14, 15]]
[[1, 3], [0, 2, 4, 5]]


In [197]:
class GraphSAGE(nn.Module):

    def __init__(self):
        super().__init__()

        self.net = nn.ModuleList([
            nn.Sequential(nn.Linear(in_features=2866, out_features=16, bias=False),
                          nn.ReLU(),
                          nn.BatchNorm1d(16),
                          nn.Dropout(0.5)),
            nn.Linear(in_features=32, out_features=7, bias=False)
        ])


    def forward(self, feature, B, ptrs):
        # B는 len(B[k]) 때문에 필요 
        x = feature
        for k in range(len(ptrs)):

            next_x = []

            for i in range(len(B[k])):
                h_u = x[i]
                h_nu = self.aggregate(x[ptrs[k][i], :])
                next_x.append(torch.cat([h_u, h_nu], dim=0))
    
            x = torch.stack(next_x, dim=0)
            x = self.net[k](x)
            if k < len(ptrs) - 1:
                x = x / (x.norm(dim=1, keepdim=True) + 1e-6)

        return x


    def aggregate(self, x):
        return x.mean(dim=0)



X = torch.Tensor(node_feature.astype(np.float))
X = X / X.norm(dim=1, keepdim=True)
Y = torch.LongTensor(node_label)
model = GraphSAGE()
solver = optim.Adam(model.parameters(),  lr=1e-3)


In [198]:
batch_size = 16

idx_train = np.arange(0, 140)

for epoch in range(10000):
    batch_index = np.random.choice(idx_train, batch_size, replace=False).tolist()


    first_layer_nodes, layer_nodes, next_mapping_list = sample(batch_index, undirected_connection_list, [10, 25])


    solver.zero_grad()
    pred = model(X[first_layer_nodes], layer_nodes, next_mapping_list)

    loss = F.cross_entropy(pred, Y[batch_index], reduction='mean')

    loss.backward()

    solver.step()
    if epoch % 1000 == 999:
        print(loss.item())

0.5760753154754639
0.22811304032802582
0.11270955950021744
0.07461027055978775
0.059563785791397095
0.05269928649067879
0.03213858604431152
0.017506344243884087
0.006018992979079485
0.005394180305302143


In [199]:

model.eval()

acc = 0.0
for i in range(140):
    first_layer_nodes, layer_nodes, next_mapping_list = sample([i], undirected_connection_list, [25, 25])
    pred = model(X[first_layer_nodes], layer_nodes, next_mapping_list)
    acc += (torch.argmax(pred, dim=1) == Y[i]).item()
print('train accuracy : ', acc/140)

acc = 0.0
for i in range(140, 2708):
    first_layer_nodes, layer_nodes, next_mapping_list = sample([i], undirected_connection_list, [25, 25])
    pred = model(X[first_layer_nodes], layer_nodes, next_mapping_list)
    acc += (torch.argmax(pred, dim=1) == Y[i]).item()
print('test accuracy : ', acc/(2708 - 140))





train accuracy :  1.0
test accuracy :  0.7348130841121495
