In [1]:
from model import CCA_SSG, LogReg
from aug import random_aug
from dataset import load

import numpy as np
import torch as th
import torch.nn as nn

import warnings

warnings.filterwarnings('ignore')

from sklearn.metrics import roc_auc_score, average_precision_score
from util import mask_test_edges_dgl
import dgl
import torch.nn.functional as F
import pdb

Using backend: pytorch


In [2]:
# parser.add_argument('--gpu', type=int, default=0, help='GPU index.')
# parser.add_argument('--use_mlp', action='store_true', default=False, help='Use MLP instead of GNN')

In [3]:
def compute_loss_para(adj):
    pos_weight = ((adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum())
    norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
    weight_mask = adj.view(-1) == 1
    weight_tensor = th.ones(weight_mask.size(0))
    weight_tensor[weight_mask] = pos_weight
    return weight_tensor, norm

In [4]:
def get_scores(edges_pos, edges_neg, adj_rec):
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    adj_rec = adj_rec.cpu()
    # Predict on test set of edges
    preds = []
    for e in edges_pos:
        preds.append(sigmoid(adj_rec[e[0], e[1]].item()))

    preds_neg = []
    for e in edges_neg:
        preds_neg.append(sigmoid(adj_rec[e[0], e[1]].data))

    preds_all = np.hstack([preds, preds_neg])
    labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds_neg))])
    roc_score = roc_auc_score(labels_all, preds_all)
    ap_score = average_precision_score(labels_all, preds_all)

    return roc_score, ap_score

In [5]:
graph, feat, labels, num_class, train_idx, val_idx, test_idx = load('cora')
adj_orig = graph.adj().to_dense()
train_edge_idx, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges_dgl(graph, adj_orig)

# create train graph
train_edge_idx = th.tensor(train_edge_idx)
train_graph = dgl.edge_subgraph(graph, train_edge_idx, preserve_nodes=True)

# add self loop
#train_graph = dgl.remove_self_loop(train_graph)
#train_graph = dgl.add_self_loop(train_graph)
#n_edges = train_graph.number_of_edges()
#adj = train_graph.adjacency_matrix().to_dense()

# normalization
#degs = train_graph.in_degrees().float()
#norm = th.pow(degs, -0.5)
#norm[th.isinf(norm)] = 0
#train_graph.ndata['norm'] = norm.unsqueeze(1)

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [6]:
# graph, feat, labels, num_class, train_idx, val_idx, test_idx = load('cora')

in_dim = feat.shape[1]

hid_dim = 512
out_dim = 512
n_layers = 2

model = CCA_SSG(in_dim, hid_dim, out_dim, n_layers, use_mlp=False)
lr1 = 1e-3 
wd1 = 0
optimizer = th.optim.Adam(model.parameters(), lr=lr1, weight_decay=wd1)

N = graph.number_of_nodes()

for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    
    dfr = 0.2
    der = 0.2

    graph1, feat1 = random_aug(graph, feat, 0.2, 0.2)
    graph2, feat2 = random_aug(graph, feat, 0.2, 0.2)

    graph1 = graph1.add_self_loop()
    graph2 = graph2.add_self_loop()

    z1, z2 = model(graph1, feat1, graph2, feat2)

    c = th.mm(z1.T, z2)
    c1 = th.mm(z1.T, z1)
    c2 = th.mm(z2.T, z2)

    c = c / N
    c1 = c1 / N
    c2 = c2 / N

    loss_inv = -th.diagonal(c).sum()
    iden = th.tensor(np.eye(c.shape[0]))
    loss_dec1 = (iden - c1).pow(2).sum()
    loss_dec2 = (iden - c2).pow(2).sum()
    
    lambd = 1e-3
    
    loss = loss_inv + lambd * (loss_dec1 + loss_dec2)

    loss.backward()
    optimizer.step()

    print('Epoch={:03d}, loss={:.4f}'.format(epoch, loss.item()))


Epoch=000, loss=-335.8493
Epoch=001, loss=-342.0265
Epoch=002, loss=-362.4157
Epoch=003, loss=-382.1460
Epoch=004, loss=-396.3875
Epoch=005, loss=-381.8439
Epoch=006, loss=-388.4731
Epoch=007, loss=-387.4223
Epoch=008, loss=-404.9112
Epoch=009, loss=-405.2298
Epoch=010, loss=-410.5345
Epoch=011, loss=-419.2475
Epoch=012, loss=-416.3677
Epoch=013, loss=-412.4417
Epoch=014, loss=-414.8666
Epoch=015, loss=-416.0368
Epoch=016, loss=-412.0549
Epoch=017, loss=-417.2782
Epoch=018, loss=-418.3943
Epoch=019, loss=-432.0712
Epoch=020, loss=-420.9087
Epoch=021, loss=-427.2425
Epoch=022, loss=-424.4462
Epoch=023, loss=-433.9296
Epoch=024, loss=-431.0464
Epoch=025, loss=-433.8957
Epoch=026, loss=-433.3944
Epoch=027, loss=-433.0483
Epoch=028, loss=-442.7434
Epoch=029, loss=-436.7130
Epoch=030, loss=-435.5874
Epoch=031, loss=-437.5592
Epoch=032, loss=-440.6771
Epoch=033, loss=-445.4746
Epoch=034, loss=-442.1152
Epoch=035, loss=-436.9419
Epoch=036, loss=-445.9366
Epoch=037, loss=-443.0281
Epoch=038, l

In [7]:
print("=== Evaluation ===")
graph = train_graph.remove_self_loop().add_self_loop()
adj = graph.adj().to_dense()

weight_tensor, norm = compute_loss_para(adj)

embeds = model.get_embedding(graph, feat)

# loss_fn = F.binary_cross_entropy
loss_fn = F.binary_cross_entropy
output_activation = nn.Sigmoid()
logreg = LogReg(embeds.shape[1], adj.shape[1])

logits_temp = logreg(embeds)
logits = output_activation(th.mm(logits_temp, logits_temp.t()))

val_roc, val_ap = get_scores(val_edges, val_edges_false, logits)
test_roc, test_ap = get_scores(test_edges, test_edges_false, logits)
print(test_roc, test_ap)

=== Evaluation ===
0.9342036342400215 0.9343107150425899


In [None]:
print("=== Evaluation ===")
graph = train_graph.remove_self_loop().add_self_loop()
adj = graph.adj().to_dense()

weight_tensor, norm = compute_loss_para(adj)

embeds = model.get_embedding(graph, feat)

''' Linear Evaluation '''
logreg = LogReg(embeds.shape[1], adj.shape[1])
lr2 = 1e-2
wd2 = 1e-4
opt = th.optim.Adam(logreg.parameters(), lr=lr2, weight_decay=wd2)

loss_fn = F.binary_cross_entropy
output_activation = nn.Sigmoid()

best_val_roc = 0
eval_roc = 0
best_val_ap = 0
eval_ap = 0
    
for epoch in range(2000):
    logreg.train()
    opt.zero_grad()
    logits_temp = logreg(embeds)
    logits = output_activation(th.mm(logits_temp, logits_temp.t()))
    
    # pdb.set_trace()
    loss = norm*loss_fn(logits.view(-1), adj.view(-1), weight = weight_tensor)
    loss.backward()
    opt.step()

    logreg.eval()
    with th.no_grad():
        val_roc, val_ap = get_scores(val_edges, val_edges_false, logits)
        test_roc, test_ap = get_scores(test_edges, test_edges_false, logits)

        if val_roc >= best_val_roc:
            best_val_roc = val_roc
            if test_roc > eval_roc:
                eval_roc = test_roc
        
        if val_ap >= best_val_ap:
            best_val_ap = val_ap
            if test_ap > eval_ap:
                eval_ap = test_ap

    print('Epoch:{}, val_ap:{:.4f}, val_roc:{:4f}, test_ap:{:4f}, test_roc:{:4f}'.format(epoch, val_ap, val_roc, test_ap, test_roc))
    print('Linear evaluation AP:{:.4f}'.format(eval_ap))
    print('Linear evaluation ROC:{:.4f}'.format(eval_roc))

=== Evaluation ===
Epoch:0, val_ap:0.9409, val_roc:0.942404, test_ap:0.935086, test_roc:0.936019
Linear evaluation AP:0.9351
Linear evaluation ROC:0.9360
Epoch:1, val_ap:0.6876, val_roc:0.663431, test_ap:0.693164, test_roc:0.665078
Linear evaluation AP:0.9351
Linear evaluation ROC:0.9360
Epoch:2, val_ap:0.6700, val_roc:0.635526, test_ap:0.678623, test_roc:0.641258
Linear evaluation AP:0.9351
Linear evaluation ROC:0.9360
Epoch:3, val_ap:0.6900, val_roc:0.658134, test_ap:0.695163, test_roc:0.661009
Linear evaluation AP:0.9351
Linear evaluation ROC:0.9360
Epoch:4, val_ap:0.7310, val_roc:0.705915, test_ap:0.734049, test_roc:0.704682
Linear evaluation AP:0.9351
Linear evaluation ROC:0.9360
Epoch:5, val_ap:0.7968, val_roc:0.776656, test_ap:0.798646, test_roc:0.773840
Linear evaluation AP:0.9351
Linear evaluation ROC:0.9360
Epoch:6, val_ap:0.8477, val_roc:0.832880, test_ap:0.855743, test_roc:0.839316
Linear evaluation AP:0.9351
Linear evaluation ROC:0.9360
Epoch:7, val_ap:0.8705, val_roc:0.86

Epoch:61, val_ap:0.9591, val_roc:0.963119, test_ap:0.954320, test_roc:0.959889
Linear evaluation AP:0.9542
Linear evaluation ROC:0.9598
Epoch:62, val_ap:0.9591, val_roc:0.963101, test_ap:0.954299, test_roc:0.959898
Linear evaluation AP:0.9542
Linear evaluation ROC:0.9598
Epoch:63, val_ap:0.9591, val_roc:0.963137, test_ap:0.954279, test_roc:0.959902
Linear evaluation AP:0.9542
Linear evaluation ROC:0.9598
Epoch:64, val_ap:0.9591, val_roc:0.963155, test_ap:0.954238, test_roc:0.959916
Linear evaluation AP:0.9542
Linear evaluation ROC:0.9598
Epoch:65, val_ap:0.9592, val_roc:0.963180, test_ap:0.954233, test_roc:0.959930
Linear evaluation AP:0.9542
Linear evaluation ROC:0.9598
Epoch:66, val_ap:0.9592, val_roc:0.963223, test_ap:0.954213, test_roc:0.959929
Linear evaluation AP:0.9542
Linear evaluation ROC:0.9598
Epoch:67, val_ap:0.9593, val_roc:0.963302, test_ap:0.954205, test_roc:0.959943
Linear evaluation AP:0.9542
Linear evaluation ROC:0.9598
Epoch:68, val_ap:0.9594, val_roc:0.963364, test_