In [1]:
from model import CCA_SSG, LogReg
from aug import random_aug
from dataset import load

import numpy as np
import torch as th
import torch.nn as nn

import warnings

warnings.filterwarnings('ignore')

from sklearn.metrics import roc_auc_score, average_precision_score
from util import mask_test_edges_dgl
import dgl
import torch.nn.functional as F
import pdb

Using backend: pytorch


In [2]:
# parser.add_argument('--gpu', type=int, default=0, help='GPU index.')
# parser.add_argument('--use_mlp', action='store_true', default=False, help='Use MLP instead of GNN')

In [3]:
def get_scores(edges_pos, edges_neg, adj_rec):
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    adj_rec = adj_rec.cpu()
    # Predict on test set of edges
    preds = []
    for e in edges_pos:
        preds.append(sigmoid(adj_rec[e[0], e[1]].item()))

    preds_neg = []
    for e in edges_neg:
        preds_neg.append(sigmoid(adj_rec[e[0], e[1]].data))

    preds_all = np.hstack([preds, preds_neg])
    labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds_neg))])
    roc_score = roc_auc_score(labels_all, preds_all)
    ap_score = average_precision_score(labels_all, preds_all)

    return roc_score, ap_score

In [4]:
graph, feat, labels, num_class, train_idx, val_idx, test_idx = load('cora')
adj_orig = graph.adj().to_dense()
train_edge_idx, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges_dgl(graph, adj_orig)

# create train graph
train_edge_idx = th.tensor(train_edge_idx)
train_graph = dgl.edge_subgraph(graph, train_edge_idx, preserve_nodes=True)

# add self loop
#train_graph = dgl.remove_self_loop(train_graph)
#train_graph = dgl.add_self_loop(train_graph)
#n_edges = train_graph.number_of_edges()
#adj = train_graph.adjacency_matrix().to_dense()

# normalization
#degs = train_graph.in_degrees().float()
#norm = th.pow(degs, -0.5)
#norm[th.isinf(norm)] = 0
#train_graph.ndata['norm'] = norm.unsqueeze(1)

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [5]:
# graph, feat, labels, num_class, train_idx, val_idx, test_idx = load('cora')

in_dim = feat.shape[1]

hid_dim = 512
out_dim = 512
n_layers = 2

model = CCA_SSG(in_dim, hid_dim, out_dim, n_layers, use_mlp=False)
lr1 = 1e-3 
wd1 = 0
optimizer = th.optim.Adam(model.parameters(), lr=lr1, weight_decay=wd1)

N = graph.number_of_nodes()

for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    
    dfr = 0.2
    der = 0.2

    graph1, feat1 = random_aug(graph, feat, 0.2, 0.2)
    graph2, feat2 = random_aug(graph, feat, 0.2, 0.2)

    graph1 = graph1.add_self_loop()
    graph2 = graph2.add_self_loop()

    z1, z2 = model(graph1, feat1, graph2, feat2)

    c = th.mm(z1.T, z2)
    c1 = th.mm(z1.T, z1)
    c2 = th.mm(z2.T, z2)

    c = c / N
    c1 = c1 / N
    c2 = c2 / N

    loss_inv = -th.diagonal(c).sum()
    iden = th.tensor(np.eye(c.shape[0]))
    loss_dec1 = (iden - c1).pow(2).sum()
    loss_dec2 = (iden - c2).pow(2).sum()
    
    lambd = 1e-3
    
    loss = loss_inv + lambd * (loss_dec1 + loss_dec2)

    loss.backward()
    optimizer.step()

    print('Epoch={:03d}, loss={:.4f}'.format(epoch, loss.item()))


Epoch=000, loss=-359.3125
Epoch=001, loss=-351.0300
Epoch=002, loss=-352.7984
Epoch=003, loss=-358.8241
Epoch=004, loss=-396.2486
Epoch=005, loss=-399.4687
Epoch=006, loss=-377.7650
Epoch=007, loss=-408.0907
Epoch=008, loss=-401.6025
Epoch=009, loss=-413.8253
Epoch=010, loss=-402.4824
Epoch=011, loss=-404.1632
Epoch=012, loss=-415.7970
Epoch=013, loss=-408.9606
Epoch=014, loss=-412.8656
Epoch=015, loss=-416.9501
Epoch=016, loss=-420.7227
Epoch=017, loss=-422.6843
Epoch=018, loss=-431.5959
Epoch=019, loss=-419.6221
Epoch=020, loss=-434.7444
Epoch=021, loss=-426.3523
Epoch=022, loss=-428.5444
Epoch=023, loss=-430.0735
Epoch=024, loss=-431.1600
Epoch=025, loss=-439.5212
Epoch=026, loss=-432.6340
Epoch=027, loss=-434.9827
Epoch=028, loss=-435.3034
Epoch=029, loss=-436.5374
Epoch=030, loss=-437.5219
Epoch=031, loss=-441.5498
Epoch=032, loss=-440.2833
Epoch=033, loss=-445.4271
Epoch=034, loss=-437.4046
Epoch=035, loss=-439.2170
Epoch=036, loss=-445.2807
Epoch=037, loss=-444.1143
Epoch=038, l

In [6]:
print("=== Evaluation ===")
graph = train_graph.remove_self_loop().add_self_loop()
adj = graph.adj().to_dense()

pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)

embeds = model.get_embedding(graph, feat)

# loss_fn = F.binary_cross_entropy
loss_fn = F.binary_cross_entropy_with_logits
output_activation = nn.Sigmoid()
logreg = LogReg(embeds.shape[1], adj.shape[1])

logits_temp = logreg(embeds)
logits = output_activation(th.mm(logits_temp, logits_temp.t()))

val_roc, val_ap = get_scores(val_edges, val_edges_false, logits)
test_roc, test_ap = get_scores(test_edges, test_edges_false, logits)
print(test_roc, test_ap)

=== Evaluation ===
0.9415646548819658 0.939592793268594


In [7]:
print("=== Evaluation ===")
graph = train_graph.remove_self_loop().add_self_loop()
adj = graph.adj().to_dense()

pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)

embeds = model.get_embedding(graph, feat)

''' Linear Evaluation '''
logreg = LogReg(embeds.shape[1], adj.shape[1])
lr2 = 1e-2
wd2 = 1e-4
opt = th.optim.Adam(logreg.parameters(), lr=lr2, weight_decay=wd2)

loss_fn = F.binary_cross_entropy
# loss_fn = F.binary_cross_entropy_with_logits
output_activation = nn.Sigmoid()

#nn.CrossEntropyLoss()

best_val_roc = 0
eval_roc = 0
best_val_ap = 0
eval_ap = 0
    
for epoch in range(2000):
    logreg.train()
    opt.zero_grad()
    logits_temp = logreg(embeds)
    logits = output_activation(th.mm(logits_temp, logits_temp.t()))
    
    # pdb.set_trace()
    loss = loss_fn(logits, adj)
    # loss = norm * loss_fn(logits, adj, pos_weight = pos_weight)
    loss.backward()
    opt.step()

    logreg.eval()
    with th.no_grad():
        val_roc, val_ap = get_scores(val_edges, val_edges_false, logits)
        test_roc, test_ap = get_scores(test_edges, test_edges_false, logits)

        if val_roc >= best_val_roc:
            best_val_roc = val_roc
            if test_roc > eval_roc:
                eval_roc = test_roc
        
        if val_ap >= best_val_ap:
            best_val_ap = val_ap
            if test_ap > eval_ap:
                eval_ap = test_ap

    print('Epoch:{}, val_ap:{:.4f}, val_roc:{:4f}, test_ap:{:4f}, test_roc:{:4f}'.format(epoch, val_ap, val_roc, test_ap, test_roc))
    print('Linear evaluation AP:{:.4f}'.format(eval_ap))
    print('Linear evaluation ROC:{:.4f}'.format(eval_roc))

=== Evaluation ===
Epoch:0, val_ap:0.9381, val_roc:0.944172, test_ap:0.947187, test_roc:0.946114
Linear evaluation AP:0.9472
Linear evaluation ROC:0.9461
Epoch:1, val_ap:0.6854, val_roc:0.666862, test_ap:0.683328, test_roc:0.671102
Linear evaluation AP:0.9472
Linear evaluation ROC:0.9461
Epoch:2, val_ap:0.6605, val_roc:0.640585, test_ap:0.656147, test_roc:0.645925
Linear evaluation AP:0.9472
Linear evaluation ROC:0.9461
Epoch:3, val_ap:0.6651, val_roc:0.645943, test_ap:0.660982, test_roc:0.651300
Linear evaluation AP:0.9472
Linear evaluation ROC:0.9461
Epoch:4, val_ap:0.6878, val_roc:0.669858, test_ap:0.685841, test_roc:0.673870
Linear evaluation AP:0.9472
Linear evaluation ROC:0.9461
Epoch:5, val_ap:0.7611, val_roc:0.746173, test_ap:0.762262, test_roc:0.741208
Linear evaluation AP:0.9472
Linear evaluation ROC:0.9461
Epoch:6, val_ap:0.8961, val_roc:0.890364, test_ap:0.904065, test_roc:0.888968
Linear evaluation AP:0.9472
Linear evaluation ROC:0.9461
Epoch:7, val_ap:0.8593, val_roc:0.84

Epoch:61, val_ap:0.9330, val_roc:0.946351, test_ap:0.946647, test_roc:0.954282
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:62, val_ap:0.9307, val_roc:0.944406, test_ap:0.944952, test_roc:0.952822
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:63, val_ap:0.9285, val_roc:0.942257, test_ap:0.943021, test_roc:0.951130
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:64, val_ap:0.9270, val_roc:0.940874, test_ap:0.941894, test_roc:0.950093
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:65, val_ap:0.9271, val_roc:0.940993, test_ap:0.942037, test_roc:0.950326
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:66, val_ap:0.9290, val_roc:0.942757, test_ap:0.943433, test_roc:0.951753
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:67, val_ap:0.9310, val_roc:0.944946, test_ap:0.945169, test_roc:0.953485
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:68, val_ap:0.9328, val_roc:0.946675, test_

Epoch:122, val_ap:0.8901, val_roc:0.913844, test_ap:0.906846, test_roc:0.924938
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:123, val_ap:0.8892, val_roc:0.913311, test_ap:0.906201, test_roc:0.924437
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:124, val_ap:0.8886, val_roc:0.912836, test_ap:0.905442, test_roc:0.923924
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:125, val_ap:0.8879, val_roc:0.912245, test_ap:0.904641, test_roc:0.923347
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:126, val_ap:0.8870, val_roc:0.911514, test_ap:0.903865, test_roc:0.922690
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:127, val_ap:0.8859, val_roc:0.910636, test_ap:0.903004, test_roc:0.921958
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:128, val_ap:0.8850, val_roc:0.909826, test_ap:0.902113, test_roc:0.921208
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:129, val_ap:0.8842, val_roc:0.90918

Epoch:182, val_ap:0.8384, val_roc:0.872901, test_ap:0.856831, test_roc:0.885363
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:183, val_ap:0.8378, val_roc:0.872304, test_ap:0.855881, test_roc:0.884628
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:184, val_ap:0.8369, val_roc:0.871609, test_ap:0.855150, test_roc:0.883945
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:185, val_ap:0.8360, val_roc:0.870867, test_ap:0.854418, test_roc:0.883287
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:186, val_ap:0.8353, val_roc:0.870266, test_ap:0.853638, test_roc:0.882623
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:187, val_ap:0.8346, val_roc:0.869625, test_ap:0.852777, test_roc:0.881912
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:188, val_ap:0.8339, val_roc:0.868883, test_ap:0.851971, test_roc:0.881231
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:189, val_ap:0.8331, val_roc:0.86823

Epoch:242, val_ap:0.7937, val_roc:0.835426, test_ap:0.810309, test_roc:0.845689
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:243, val_ap:0.7930, val_roc:0.834864, test_ap:0.809552, test_roc:0.845047
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:244, val_ap:0.7924, val_roc:0.834328, test_ap:0.808845, test_roc:0.844448
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:245, val_ap:0.7916, val_roc:0.833737, test_ap:0.808229, test_roc:0.843809
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:246, val_ap:0.7909, val_roc:0.833111, test_ap:0.807496, test_roc:0.843190
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:247, val_ap:0.7903, val_roc:0.832545, test_ap:0.806820, test_roc:0.842589
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:248, val_ap:0.7897, val_roc:0.832056, test_ap:0.806141, test_roc:0.841970
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:249, val_ap:0.7889, val_roc:0.83143

Epoch:302, val_ap:0.7584, val_roc:0.804363, test_ap:0.772794, test_roc:0.811326
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625
Epoch:303, val_ap:0.7579, val_roc:0.803848, test_ap:0.772254, test_roc:0.810742
Linear evaluation AP:0.9563
Linear evaluation ROC:0.9625


KeyboardInterrupt: 