In [2]:
from model import CCA_SSG, LogReg
from aug import random_aug
from dataset import load

import numpy as np
import torch as th
import torch.nn as nn

import warnings

warnings.filterwarnings('ignore')

from sklearn.metrics import roc_auc_score, average_precision_score
from util import mask_test_edges_dgl

Using backend: pytorch


In [3]:
# parser.add_argument('--gpu', type=int, default=0, help='GPU index.')
# parser.add_argument('--use_mlp', action='store_true', default=False, help='Use MLP instead of GNN')

In [6]:
graph, feat, labels, num_class, train_idx, val_idx, test_idx = load('cora')
graph = graph.remove_self_loop().add_self_loop()
in_dim = feat.shape[1]

hid_dim = 512
out_dim = 512
n_layers = 2

model = CCA_SSG(in_dim, hid_dim, out_dim, n_layers, use_mlp=False)
lr1 = 1e-3 
wd1 = 0
optimizer = th.optim.Adam(model.parameters(), lr=lr1, weight_decay=wd1)

N = graph.number_of_nodes()

for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    
    dfr = 0.2
    der = 0.2

    graph1, feat1 = random_aug(graph, feat, 0.2, 0.2)
    graph2, feat2 = random_aug(graph, feat, 0.2, 0.2)

    graph1 = graph1.add_self_loop()
    graph2 = graph2.add_self_loop()

    z1, z2 = model(graph1, feat1, graph2, feat2)

    c = th.mm(z1.T, z2)
    c1 = th.mm(z1.T, z1)
    c2 = th.mm(z2.T, z2)

    c = c / N
    c1 = c1 / N
    c2 = c2 / N

    loss_inv = -th.diagonal(c).sum()
    iden = th.tensor(np.eye(c.shape[0]))
    loss_dec1 = (iden - c1).pow(2).sum()
    loss_dec2 = (iden - c2).pow(2).sum()
    
    lambd = 1e-3
    
    loss = loss_inv + lambd * (loss_dec1 + loss_dec2)

    loss.backward()
    optimizer.step()

    print('Epoch={:03d}, loss={:.4f}'.format(epoch, loss.item()))


  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Epoch=000, loss=-333.1239
Epoch=001, loss=-362.0690
Epoch=002, loss=-338.8452
Epoch=003, loss=-375.5197
Epoch=004, loss=-378.3668
Epoch=005, loss=-393.6665
Epoch=006, loss=-404.0918
Epoch=007, loss=-407.9352
Epoch=008, loss=-392.5743
Epoch=009, loss=-410.6282
Epoch=010, loss=-418.6497
Epoch=011, loss=-420.0819
Epoch=012, loss=-419.2513
Epoch=013, loss=-417.7541
Epoch=014, loss=-417.9207
Epoch=015, loss=-423.5932
Epoch=016, loss=-425.0707
Epoch=017, loss=-419.4870
Epoch=018, loss=-426.1390
Epoch=019, loss=-417.1960
Epoch=020, loss=-434.0430
Epoch=021, loss=-433.9163
Epoch=022, loss=-433.7112
Epoch=023, loss=-431.5808
Epoch=024, loss=-426.4957
Epoch=025, loss=-438.7745
Epoch=026, loss=-434.8502
Epoch=027, loss=-432.5161
Epoch=028, loss=-441.5172
Epoch=029, loss=-445.0157
Epoch=030, loss=-447.1457
Epoch=031, l

In [7]:
print("=== Evaluation ===")
graph = graph.remove_self_loop().add_self_loop()

embeds = model.get_embedding(graph, feat)

train_embs = embeds[train_idx]
val_embs = embeds[val_idx]
test_embs = embeds[test_idx]

label = labels

train_labels = label[train_idx]
val_labels = label[val_idx]
test_labels = label[test_idx]

train_feat = feat[train_idx]
val_feat = feat[val_idx]
test_feat = feat[test_idx] 

''' Linear Evaluation '''
logreg = LogReg(train_embs.shape[1], num_class)
lr2 = 1e-2
wd2 = 1e-4
opt = th.optim.Adam(logreg.parameters(), lr=lr2, weight_decay=wd2)

loss_fn = nn.CrossEntropyLoss()

best_val_acc = 0
eval_acc = 0
    
for epoch in range(2000):
    logreg.train()
    opt.zero_grad()
    logits = logreg(train_embs)
    preds = th.argmax(logits, dim=1)
    train_acc = th.sum(preds == train_labels).float() / train_labels.shape[0]
    loss = loss_fn(logits, train_labels)
    loss.backward()
    opt.step()

    logreg.eval()
    with th.no_grad():
        val_logits = logreg(val_embs)
        test_logits = logreg(test_embs)

        val_preds = th.argmax(val_logits, dim=1)
        test_preds = th.argmax(test_logits, dim=1)

        val_acc = th.sum(val_preds == val_labels).float() / val_labels.shape[0]
        test_acc = th.sum(test_preds == test_labels).float() / test_labels.shape[0]

        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            if test_acc > eval_acc:
                eval_acc = test_acc

    print('Epoch:{}, train_acc:{:.4f}, val_acc:{:4f}, test_acc:{:4f}'.format(epoch, train_acc, val_acc, test_acc))
    print('Linear evaluation accuracy:{:.4f}'.format(eval_acc))

=== Evaluation ===
Epoch:0, train_acc:0.1071, val_acc:0.308000, test_acc:0.326000
Linear evaluation accuracy:0.3260
Epoch:1, train_acc:0.5143, val_acc:0.648000, test_acc:0.663000
Linear evaluation accuracy:0.6630
Epoch:2, train_acc:0.8071, val_acc:0.736000, test_acc:0.759000
Linear evaluation accuracy:0.7590
Epoch:3, train_acc:0.8643, val_acc:0.756000, test_acc:0.780000
Linear evaluation accuracy:0.7800
Epoch:4, train_acc:0.8786, val_acc:0.778000, test_acc:0.789000
Linear evaluation accuracy:0.7890
Epoch:5, train_acc:0.8857, val_acc:0.782000, test_acc:0.791000
Linear evaluation accuracy:0.7910
Epoch:6, train_acc:0.9000, val_acc:0.786000, test_acc:0.799000
Linear evaluation accuracy:0.7990
Epoch:7, train_acc:0.8929, val_acc:0.790000, test_acc:0.808000
Linear evaluation accuracy:0.8080
Epoch:8, train_acc:0.8929, val_acc:0.788000, test_acc:0.812000
Linear evaluation accuracy:0.8080
Epoch:9, train_acc:0.8857, val_acc:0.788000, test_acc:0.815000
Linear evaluation accuracy:0.8080
Epoch:10, t

Epoch:84, train_acc:0.9143, val_acc:0.808000, test_acc:0.835000
Linear evaluation accuracy:0.8380
Epoch:85, train_acc:0.9143, val_acc:0.808000, test_acc:0.835000
Linear evaluation accuracy:0.8380
Epoch:86, train_acc:0.9143, val_acc:0.808000, test_acc:0.836000
Linear evaluation accuracy:0.8380
Epoch:87, train_acc:0.9143, val_acc:0.808000, test_acc:0.836000
Linear evaluation accuracy:0.8380
Epoch:88, train_acc:0.9143, val_acc:0.806000, test_acc:0.836000
Linear evaluation accuracy:0.8380
Epoch:89, train_acc:0.9143, val_acc:0.806000, test_acc:0.836000
Linear evaluation accuracy:0.8380
Epoch:90, train_acc:0.9143, val_acc:0.808000, test_acc:0.835000
Linear evaluation accuracy:0.8380
Epoch:91, train_acc:0.9143, val_acc:0.808000, test_acc:0.835000
Linear evaluation accuracy:0.8380
Epoch:92, train_acc:0.9143, val_acc:0.806000, test_acc:0.836000
Linear evaluation accuracy:0.8380
Epoch:93, train_acc:0.9143, val_acc:0.806000, test_acc:0.837000
Linear evaluation accuracy:0.8380
Epoch:94, train_acc:

Epoch:246, train_acc:0.9429, val_acc:0.808000, test_acc:0.841000
Linear evaluation accuracy:0.8410
Epoch:247, train_acc:0.9429, val_acc:0.808000, test_acc:0.841000
Linear evaluation accuracy:0.8410
Epoch:248, train_acc:0.9429, val_acc:0.808000, test_acc:0.841000
Linear evaluation accuracy:0.8410
Epoch:249, train_acc:0.9429, val_acc:0.806000, test_acc:0.841000
Linear evaluation accuracy:0.8410
Epoch:250, train_acc:0.9429, val_acc:0.806000, test_acc:0.841000
Linear evaluation accuracy:0.8410
Epoch:251, train_acc:0.9429, val_acc:0.806000, test_acc:0.841000
Linear evaluation accuracy:0.8410
Epoch:252, train_acc:0.9429, val_acc:0.806000, test_acc:0.841000
Linear evaluation accuracy:0.8410
Epoch:253, train_acc:0.9429, val_acc:0.804000, test_acc:0.841000
Linear evaluation accuracy:0.8410
Epoch:254, train_acc:0.9429, val_acc:0.804000, test_acc:0.841000
Linear evaluation accuracy:0.8410
Epoch:255, train_acc:0.9429, val_acc:0.800000, test_acc:0.841000
Linear evaluation accuracy:0.8410
Epoch:256,

Epoch:409, train_acc:0.9571, val_acc:0.800000, test_acc:0.840000
Linear evaluation accuracy:0.8410
Epoch:410, train_acc:0.9571, val_acc:0.800000, test_acc:0.840000
Linear evaluation accuracy:0.8410
Epoch:411, train_acc:0.9571, val_acc:0.800000, test_acc:0.840000
Linear evaluation accuracy:0.8410
Epoch:412, train_acc:0.9571, val_acc:0.800000, test_acc:0.840000
Linear evaluation accuracy:0.8410
Epoch:413, train_acc:0.9571, val_acc:0.800000, test_acc:0.840000
Linear evaluation accuracy:0.8410
Epoch:414, train_acc:0.9571, val_acc:0.800000, test_acc:0.840000
Linear evaluation accuracy:0.8410
Epoch:415, train_acc:0.9571, val_acc:0.800000, test_acc:0.840000
Linear evaluation accuracy:0.8410
Epoch:416, train_acc:0.9571, val_acc:0.800000, test_acc:0.840000
Linear evaluation accuracy:0.8410
Epoch:417, train_acc:0.9571, val_acc:0.800000, test_acc:0.840000
Linear evaluation accuracy:0.8410
Epoch:418, train_acc:0.9571, val_acc:0.800000, test_acc:0.840000
Linear evaluation accuracy:0.8410
Epoch:419,

Epoch:568, train_acc:0.9571, val_acc:0.796000, test_acc:0.844000
Linear evaluation accuracy:0.8410
Epoch:569, train_acc:0.9571, val_acc:0.796000, test_acc:0.844000
Linear evaluation accuracy:0.8410
Epoch:570, train_acc:0.9571, val_acc:0.796000, test_acc:0.844000
Linear evaluation accuracy:0.8410
Epoch:571, train_acc:0.9571, val_acc:0.796000, test_acc:0.844000
Linear evaluation accuracy:0.8410
Epoch:572, train_acc:0.9571, val_acc:0.796000, test_acc:0.844000
Linear evaluation accuracy:0.8410
Epoch:573, train_acc:0.9571, val_acc:0.796000, test_acc:0.844000
Linear evaluation accuracy:0.8410
Epoch:574, train_acc:0.9571, val_acc:0.796000, test_acc:0.844000
Linear evaluation accuracy:0.8410
Epoch:575, train_acc:0.9571, val_acc:0.796000, test_acc:0.844000
Linear evaluation accuracy:0.8410
Epoch:576, train_acc:0.9571, val_acc:0.796000, test_acc:0.844000
Linear evaluation accuracy:0.8410
Epoch:577, train_acc:0.9571, val_acc:0.796000, test_acc:0.844000
Linear evaluation accuracy:0.8410
Epoch:578,

Epoch:724, train_acc:0.9571, val_acc:0.796000, test_acc:0.846000
Linear evaluation accuracy:0.8410
Epoch:725, train_acc:0.9571, val_acc:0.796000, test_acc:0.846000
Linear evaluation accuracy:0.8410
Epoch:726, train_acc:0.9571, val_acc:0.796000, test_acc:0.846000
Linear evaluation accuracy:0.8410
Epoch:727, train_acc:0.9571, val_acc:0.796000, test_acc:0.846000
Linear evaluation accuracy:0.8410
Epoch:728, train_acc:0.9571, val_acc:0.796000, test_acc:0.846000
Linear evaluation accuracy:0.8410
Epoch:729, train_acc:0.9571, val_acc:0.796000, test_acc:0.846000
Linear evaluation accuracy:0.8410
Epoch:730, train_acc:0.9571, val_acc:0.796000, test_acc:0.846000
Linear evaluation accuracy:0.8410
Epoch:731, train_acc:0.9571, val_acc:0.796000, test_acc:0.846000
Linear evaluation accuracy:0.8410
Epoch:732, train_acc:0.9571, val_acc:0.796000, test_acc:0.846000
Linear evaluation accuracy:0.8410
Epoch:733, train_acc:0.9571, val_acc:0.796000, test_acc:0.846000
Linear evaluation accuracy:0.8410
Epoch:734,

Epoch:878, train_acc:0.9571, val_acc:0.796000, test_acc:0.847000
Linear evaluation accuracy:0.8410
Epoch:879, train_acc:0.9571, val_acc:0.796000, test_acc:0.847000
Linear evaluation accuracy:0.8410
Epoch:880, train_acc:0.9571, val_acc:0.796000, test_acc:0.847000
Linear evaluation accuracy:0.8410
Epoch:881, train_acc:0.9571, val_acc:0.796000, test_acc:0.847000
Linear evaluation accuracy:0.8410
Epoch:882, train_acc:0.9571, val_acc:0.796000, test_acc:0.847000
Linear evaluation accuracy:0.8410
Epoch:883, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:884, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:885, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:886, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:887, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:888,

Epoch:1048, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1049, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1050, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1051, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1052, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1053, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1054, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1055, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1056, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1057, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410


Epoch:1217, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1218, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1219, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1220, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1221, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1222, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1223, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1224, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1225, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1226, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410


Epoch:1379, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1380, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1381, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1382, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1383, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1384, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1385, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1386, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1387, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1388, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410


Epoch:1543, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1544, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1545, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1546, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1547, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1548, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1549, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1550, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1551, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1552, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410


Epoch:1704, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1705, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1706, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1707, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1708, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1709, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1710, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1711, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1712, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1713, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410


Epoch:1843, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1844, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1845, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1846, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1847, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1848, train_acc:0.9571, val_acc:0.796000, test_acc:0.849000
Linear evaluation accuracy:0.8410
Epoch:1849, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1850, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1851, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
Epoch:1852, train_acc:0.9571, val_acc:0.796000, test_acc:0.848000
Linear evaluation accuracy:0.8410
