In [1]:
from model import CCA_SSG, LogReg
from model2 import TOCCA_link
from aug import random_aug
from dataset import load

import numpy as np
import torch as th
import torch.nn as nn

import warnings

warnings.filterwarnings('ignore')

from sklearn.metrics import roc_auc_score, average_precision_score
from util import mask_test_edges_dgl
import dgl
import torch.nn.functional as F
import pdb

Using backend: pytorch


In [2]:
# parser.add_argument('--gpu', type=int, default=0, help='GPU index.')
# parser.add_argument('--use_mlp', action='store_true', default=False, help='Use MLP instead of GNN')

In [3]:
def get_scores(edges_pos, edges_neg, adj_rec):
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    adj_rec = adj_rec.cpu()
    # Predict on test set of edges
    preds = []
    for e in edges_pos:
        preds.append(sigmoid(adj_rec[e[0], e[1]].item()))

    preds_neg = []
    for e in edges_neg:
        preds_neg.append(sigmoid(adj_rec[e[0], e[1]].data))

    preds_all = np.hstack([preds, preds_neg])
    labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds_neg))])
    roc_score = roc_auc_score(labels_all, preds_all)
    ap_score = average_precision_score(labels_all, preds_all)

    return roc_score, ap_score

In [4]:
graph, feat, labels, num_class, train_idx, val_idx, test_idx = load('cora')
adj_orig = graph.adj().to_dense()
train_edge_idx, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges_dgl(graph, adj_orig)

# create train graph
train_edge_idx = th.tensor(train_edge_idx)
train_graph = dgl.edge_subgraph(graph, train_edge_idx, preserve_nodes=True)

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [6]:
graph = train_graph.remove_self_loop().add_self_loop()
adj = graph.adj().to_dense()

in_dim = feat.shape[1]

hid_dim = 512
out_dim = 512
n_layers = 2
n_adj = adj.shape[1]

loss_fn = F.binary_cross_entropy
output_activation = nn.Sigmoid()

best_val_roc = 0
eval_roc = 0
best_val_ap = 0
eval_ap = 0

model = TOCCA_link(in_dim, hid_dim, out_dim, n_adj, n_layers, use_mlp=False)
lr1 = 1e-3
wd1 = 0
wd2 = 1e-4
optimizer = th.optim.Adam(model.parameters(), lr=lr1, weight_decay=wd2)

N = graph.number_of_nodes()


for epoch in range(100):
    model.train()
    optimizer.zero_grad()

    dfr = 0.2
    der = 0.2

    graph1, feat1 = random_aug(graph, feat, 0.2, 0.2)
    graph2, feat2 = random_aug(graph, feat, 0.2, 0.2)

    graph1 = graph1.add_self_loop()
    graph2 = graph2.add_self_loop()

    z1, z2, logits_temp = model(graph1, feat1, graph2, feat2)
    
    logits = output_activation(logits_temp)
    
    c = th.mm(z1.T, z2)
    c1 = th.mm(z1.T, z1)
    c2 = th.mm(z2.T, z2)

    c = c / N
    c1 = c1 / N
    c2 = c2 / N

    loss_inv = -th.diagonal(c).sum()
    iden = th.tensor(np.eye(c.shape[0]))
    loss_dec1 = (iden - c1).pow(2).sum()
    loss_dec2 = (iden - c2).pow(2).sum()

    lambd1 = 10^2
    lambd2 = 1e-3

    loss_task = loss_fn(logits, adj)

    loss = loss_task + lambd1 * loss_inv + lambd2 * (loss_dec1 + loss_dec2)

    loss.backward()
    optimizer.step()

    model.eval()
    with th.no_grad():
        val_roc, val_ap = get_scores(val_edges, val_edges_false, logits)
        test_roc, test_ap = get_scores(test_edges, test_edges_false, logits)

        if val_roc >= best_val_roc:
            best_val_roc = val_roc
            if test_roc > eval_roc:
                eval_roc = test_roc
        
        if val_ap >= best_val_ap:
            best_val_ap = val_ap
            if test_ap > eval_ap:
                eval_ap = test_ap

    print('Epoch:{}, val_ap:{:.4f}, val_roc:{:4f}, test_ap:{:4f}, test_roc:{:4f}'.format(epoch, val_ap, val_roc, test_ap, test_roc))
    print('Linear evaluation AP:{:.4f}'.format(eval_ap))
    print('Linear evaluation ROC:{:.4f}'.format(eval_roc))

Epoch:0, val_ap:0.4956, val_roc:0.488775, test_ap:0.500527, test_roc:0.488114
Linear evaluation AP:0.5005
Linear evaluation ROC:0.4881
Epoch:1, val_ap:0.5057, val_roc:0.490363, test_ap:0.513020, test_roc:0.504792
Linear evaluation AP:0.5130
Linear evaluation ROC:0.5048
Epoch:2, val_ap:0.5291, val_roc:0.510328, test_ap:0.512731, test_roc:0.503181
Linear evaluation AP:0.5130
Linear evaluation ROC:0.5048
Epoch:3, val_ap:0.5249, val_roc:0.509662, test_ap:0.515626, test_roc:0.510948
Linear evaluation AP:0.5130
Linear evaluation ROC:0.5048
Epoch:4, val_ap:0.5462, val_roc:0.537484, test_ap:0.525902, test_roc:0.532856
Linear evaluation AP:0.5259
Linear evaluation ROC:0.5329
Epoch:5, val_ap:0.5694, val_roc:0.551865, test_ap:0.537680, test_roc:0.531210
Linear evaluation AP:0.5377
Linear evaluation ROC:0.5329
Epoch:6, val_ap:0.5905, val_roc:0.571424, test_ap:0.547057, test_roc:0.542999
Linear evaluation AP:0.5471
Linear evaluation ROC:0.5430
Epoch:7, val_ap:0.5940, val_roc:0.571125, test_ap:0.572

Epoch:61, val_ap:0.8452, val_roc:0.817041, test_ap:0.856480, test_roc:0.834621
Linear evaluation AP:0.8592
Linear evaluation ROC:0.8338
Epoch:62, val_ap:0.8497, val_roc:0.820606, test_ap:0.857354, test_roc:0.835763
Linear evaluation AP:0.8592
Linear evaluation ROC:0.8338
Epoch:63, val_ap:0.8437, val_roc:0.812007, test_ap:0.860798, test_roc:0.839649
Linear evaluation AP:0.8592
Linear evaluation ROC:0.8338
Epoch:64, val_ap:0.8455, val_roc:0.816145, test_ap:0.861594, test_roc:0.839017
Linear evaluation AP:0.8592
Linear evaluation ROC:0.8338
Epoch:65, val_ap:0.8442, val_roc:0.816667, test_ap:0.862021, test_roc:0.841958
Linear evaluation AP:0.8592
Linear evaluation ROC:0.8338
Epoch:66, val_ap:0.8418, val_roc:0.811827, test_ap:0.861711, test_roc:0.840238
Linear evaluation AP:0.8592
Linear evaluation ROC:0.8338
Epoch:67, val_ap:0.8433, val_roc:0.816004, test_ap:0.860565, test_roc:0.840569
Linear evaluation AP:0.8592
Linear evaluation ROC:0.8338
Epoch:68, val_ap:0.8435, val_roc:0.814758, test_