## PyTorch GVAE
基于 PyTorch 实现图变分自编码器 （Graph Variational AutoEncoder，GAE）。

### 导入基本包

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

### GCN encoder layer

In [2]:
class GCN(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, dropout=0., act=F.relu):
        super(GCN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.act = act
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)

    def forward(self, input, adj):
        input = F.dropout(input, self.dropout, self.training)
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        output = self.act(output)
        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
model = GCN(64,32)
model.__repr__

<bound method GCN.__repr__ of GCN (64 -> 32)>

### Inner Product Decoder

In [3]:
class InnerProductDecoder(nn.Module):
    """Decoder for using inner product for prediction."""

    def __init__(self, dropout, act=torch.sigmoid):
        super(InnerProductDecoder, self).__init__()
        self.dropout = dropout
        self.act = act

    def forward(self, z):
        z = F.dropout(z, self.dropout, training=self.training)
        adj = self.act(torch.mm(z, z.t()))
        return adj

### GVAE 模型

In [4]:
class GVAE(nn.Module):
    def __init__(self, input_feat_dim, hidden_dim1, hidden_dim2, dropout):
        super(GVAE, self).__init__()
        self.gc1 = GCN(input_feat_dim, hidden_dim1, dropout, act=F.relu)
        self.gc2 = GCN(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
        self.gc3 = GCN(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
        self.dc = InnerProductDecoder(dropout, act=lambda x: x)

    def encode(self, x, adj):
        hidden1 = self.gc1(x, adj)
        return self.gc2(hidden1, adj), self.gc3(hidden1, adj)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x, adj):
        mu, logvar = self.encode(x, adj)
        z = self.reparameterize(mu, logvar)
        return self.dc(z), mu, logvar

### 损失函数定义

In [5]:
def loss_function(preds, labels, mu, logvar, n_nodes, norm, pos_weight):
    cost = norm * F.binary_cross_entropy_with_logits(preds, labels, pos_weight=pos_weight)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 / n_nodes * torch.mean(torch.sum(
        1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 1))
    return cost + KLD

### Train

In [6]:
# 超参数
model = 'GVAE'
seed = 42
epochs = 200
hidden1 = 32
hidden2 = 16
lr = 0.01
dropout = 0.
dataset = 'cora'

In [7]:
from tqdm import tqdm
from tqdm import trange

import time
import numpy as np
import scipy.sparse as sp
import networkx as nx
import pickle as pkl

# utils

def load_data(dataset):
    # load the data: x, tx, allx, graph
    names = ['x', 'tx', 'allx', 'graph']
    objects = []
    
    pbar = trange(len(names))
    for i in pbar:
        pbar.set_description("Load data {}".format("ind.{}.{}".format(dataset, names[i])))
        with open("./data/Cora/raw/ind.{}.{}".format(dataset, names[i]), 'rb') as rf:
            data = pkl.load(rf, encoding='latin1')
            objects.append(data)
        # objects.append(pkl.load(open("data/ind.{}.{}".format(dataset, names[i]), 'rb'), encoding='latin1'))
    x, tx, allx, graph = tuple(objects)
    
    def parse_index_file(filename):
        index = []
        for line in open(filename, 'rb'):
            index.append(int(line.strip()))
        return index

    test_idx_reorder = parse_index_file("./data/Cora/raw/ind.{}.test.index".format(dataset))
    test_idx_range = np.sort(test_idx_reorder)
#     if dataset == 'citeseer':
#         # Fix citeseer dataset (there are some isolated nodes in the graph)
#         # Find isolated nodes, add them as zero-vecs into the right position
#         test_idx_range_full = range(
#             min(test_idx_reorder), max(test_idx_reorder) + 1)
#         tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
#         tx_extended[test_idx_range - min(test_idx_range), :] = tx
#         tx = tx_extended

    features = sp.vstack((allx, tx)).tolil() # 转为链表稀疏矩阵lil加快访问速度
    features[test_idx_reorder, :] = features[test_idx_range, :]
    features = torch.FloatTensor(np.array(features.todense()))
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

    return adj, features


def sparse_to_tuple(sparse_mx):
    if not sp.isspmatrix_coo(sparse_mx):
        sparse_mx = sparse_mx.tocoo()
    coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
    values = sparse_mx.data
    shape = sparse_mx.shape
    return coords, values, shape


def mask_test_edges(adj):
    # Function to build test set with 10% positive links
    # NOTE: Splits are randomized and results might slightly deviate from reported numbers in the paper.
    # TODO: Clean up.

    # Remove diagonal elements
    adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
    adj.eliminate_zeros()
    # Check that diag is zero:
    assert np.diag(adj.todense()).sum() == 0

    adj_triu = sp.triu(adj)
    adj_tuple = sparse_to_tuple(adj_triu)
    edges = adj_tuple[0]
    edges_all = sparse_to_tuple(adj)[0]
    num_test = int(np.floor(edges.shape[0] / 10.))
    num_val = int(np.floor(edges.shape[0] / 20.))

    all_edge_idx = list(range(edges.shape[0]))
    np.random.shuffle(all_edge_idx)
    val_edge_idx = all_edge_idx[:num_val]
    test_edge_idx = all_edge_idx[num_val:(num_val + num_test)]
    test_edges = edges[test_edge_idx]
    val_edges = edges[val_edge_idx]
    train_edges = np.delete(edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0)

    def ismember(a, b, tol=5):
        rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)
        return np.any(rows_close)

    test_edges_false = []
    while len(test_edges_false) < len(test_edges):
        idx_i = np.random.randint(0, adj.shape[0])
        idx_j = np.random.randint(0, adj.shape[0])
        if idx_i == idx_j:
            continue
        if ismember([idx_i, idx_j], edges_all):
            continue
        if test_edges_false:
            if ismember([idx_j, idx_i], np.array(test_edges_false)):
                continue
            if ismember([idx_i, idx_j], np.array(test_edges_false)):
                continue
        test_edges_false.append([idx_i, idx_j])

    val_edges_false = []
    while len(val_edges_false) < len(val_edges):
        idx_i = np.random.randint(0, adj.shape[0])
        idx_j = np.random.randint(0, adj.shape[0])
        if idx_i == idx_j:
            continue
        if ismember([idx_i, idx_j], train_edges):
            continue
        if ismember([idx_j, idx_i], train_edges):
            continue
        if ismember([idx_i, idx_j], val_edges):
            continue
        if ismember([idx_j, idx_i], val_edges):
            continue
        if val_edges_false:
            if ismember([idx_j, idx_i], np.array(val_edges_false)):
                continue
            if ismember([idx_i, idx_j], np.array(val_edges_false)):
                continue
        val_edges_false.append([idx_i, idx_j])

    assert ~ismember(test_edges_false, edges_all)
    assert ~ismember(val_edges_false, edges_all)
    assert ~ismember(val_edges, train_edges)
    assert ~ismember(test_edges, train_edges)
    assert ~ismember(val_edges, test_edges)

    data = np.ones(train_edges.shape[0])

    # Re-build adj matrix
    adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape)
    adj_train = adj_train + adj_train.T

    # NOTE: these edge lists only contain single direction of edge!
    return adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false


def preprocess_graph(adj):
    adj = sp.coo_matrix(adj)
    adj_ = adj + sp.eye(adj.shape[0])
    rowsum = np.array(adj_.sum(1))
    degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
    adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
    # return sparse_to_tuple(adj_normalized)
    return sparse_mx_to_torch_sparse_tensor(adj_normalized)


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


from sklearn.metrics import roc_auc_score, average_precision_score


def get_roc_score(emb, adj_orig, edges_pos, edges_neg):
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    # Predict on test set of edges
    adj_rec = np.dot(emb, emb.T)
    preds = []
    pos = []
    for e in edges_pos:
        preds.append(sigmoid(adj_rec[e[0], e[1]]))
        pos.append(adj_orig[e[0], e[1]])

    preds_neg = []
    neg = []
    for e in edges_neg:
        preds_neg.append(sigmoid(adj_rec[e[0], e[1]]))
        neg.append(adj_orig[e[0], e[1]])

    preds_all = np.hstack([preds, preds_neg])
    labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds))])
    roc_score = roc_auc_score(labels_all, preds_all)
    ap_score = average_precision_score(labels_all, preds_all)

    return roc_score, ap_score

In [8]:
adj, features = load_data('cora')
n_nodes, feat_dim = features.shape
print(n_nodes, feat_dim)

adj_orig = adj
adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
adj_orig.eliminate_zeros() # 不为0的元素

adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
adj = adj_train

# Some preprocessing
adj_norm = preprocess_graph(adj)
adj_label = adj_train + sp.eye(adj_train.shape[0])
# adj_label = sparse_to_tuple(adj_label)
adj_label = torch.FloatTensor(adj_label.toarray())

pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
pos_weight = torch.tensor(pos_weight)

norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)

model = GVAE(feat_dim, hidden1, hidden2, dropout)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

hidden_emb = None
for epoch in range(epochs):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    recovered, mu, logvar = model(features, adj_norm)
    loss = loss_function(preds=recovered, labels=adj_label,
                         mu=mu, logvar=logvar, n_nodes=n_nodes,
                         norm=norm, pos_weight=pos_weight)
    loss.backward()
    optimizer.step()

    hidden_emb = mu.data.numpy()
    roc_curr, ap_curr = get_roc_score(hidden_emb, adj_orig, val_edges, val_edges_false)
    
    if (epoch+1) % 10 == 0:
            print ("Epoch[{}/{}], Loss: {:.4f}, Val_ap: {:.4f}" 
                   .format(epoch+1, epochs, loss.item(), ap_curr))

print("Optimization Finished!")

roc_score, ap_score = get_roc_score(hidden_emb, adj_orig, test_edges, test_edges_false)
print('Test ROC score: ' + str(roc_score))
print('Test AP score: ' + str(ap_score))

Load data ind.cora.graph: 100%|██████████| 4/4 [00:00<00:00, 570.03it/s]


2708 1433
Epoch[10/200], Loss: 0.6952, Val_ap: 0.7553
Epoch[20/200], Loss: 0.5651, Val_ap: 0.8352
Epoch[30/200], Loss: 0.5069, Val_ap: 0.8874
Epoch[40/200], Loss: 0.4816, Val_ap: 0.8912
Epoch[50/200], Loss: 0.4664, Val_ap: 0.8934
Epoch[60/200], Loss: 0.4591, Val_ap: 0.8981
Epoch[70/200], Loss: 0.4526, Val_ap: 0.9012
Epoch[80/200], Loss: 0.4474, Val_ap: 0.9069
Epoch[90/200], Loss: 0.4427, Val_ap: 0.9121
Epoch[100/200], Loss: 0.4401, Val_ap: 0.9110
Epoch[110/200], Loss: 0.4385, Val_ap: 0.9112
Epoch[120/200], Loss: 0.4369, Val_ap: 0.9116
Epoch[130/200], Loss: 0.4355, Val_ap: 0.9099
Epoch[140/200], Loss: 0.4338, Val_ap: 0.9052
Epoch[150/200], Loss: 0.4322, Val_ap: 0.9062
Epoch[160/200], Loss: 0.4309, Val_ap: 0.9064
Epoch[170/200], Loss: 0.4301, Val_ap: 0.9059
Epoch[180/200], Loss: 0.4288, Val_ap: 0.9047
Epoch[190/200], Loss: 0.4277, Val_ap: 0.9048
Epoch[200/200], Loss: 0.4269, Val_ap: 0.9035
Optimization Finished!
Test ROC score: 0.9100634071342927
Test AP score: 0.9247575722131322
