In [2]:
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F

from models import LogReg
from utils import process

dataset = 'cora'

# training params
batch_size = 1
nb_epochs = 2000
patience = 100
lr = 0.001
l2_coef = 0.0
drop_prob = 0.5
hid_units = 256
sparse = True
nonlinearity = 'prelu' # special name to separate parameters

In [3]:
adj, features, labels, idx_train, idx_val, idx_test = process.load_data(dataset)
features, _ = process.preprocess_features(features)

nb_nodes = features.shape[0]
ft_size = features.shape[1]
nb_classes = labels.shape[1]

adj = process.normalize_adj(adj + sp.eye(adj.shape[0]))

if sparse:
    sp_adj = process.sparse_mx_to_torch_sparse_tensor(adj)
else:
    adj = (adj + sp.eye(adj.shape[0])).todense()

features = torch.FloatTensor(features[np.newaxis])
if not sparse:
    adj = torch.FloatTensor(adj[np.newaxis])
labels = torch.FloatTensor(labels[np.newaxis])
idx_train = torch.LongTensor(idx_train)
idx_val = torch.LongTensor(idx_val)
idx_test = torch.LongTensor(idx_test)

In [219]:
class Discriminator(nn.Module):
    def __init__(self, n_h):
        super(Discriminator, self).__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 8)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None):
        c_x = torch.unsqueeze(c, 1)
        c_x = c_x.expand_as(h_pl)

        sc_1 = torch.squeeze(self.f_k(h_pl, c_x))
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x))

        if s_bias1 is not None:
            sc_1 += s_bias1
        if s_bias2 is not None:
            sc_2 += s_bias2

        logits = torch.cat((sc_1, sc_2), 0)

        return logits

In [220]:
from layers import GCN, AvgReadout

class DGI(nn.Module):
    def __init__(self, n_in, n_h, activation):
        super(DGI, self).__init__()
        self.gcn = GCN(n_in, n_h, activation)
        self.gcn2 = GCN(n_h, n_h, activation)
        self.gcn3 = GCN(n_h, n_h, activation)
        
        self.read = AvgReadout()

        self.sigm = nn.Sigmoid()

        self.disc = Discriminator(n_h)

    def forward(self, seq1, seq2, adj, sparse, msk, samp_bias1, samp_bias2):
        h_1 = self.gcn(seq1, adj, sparse)
        h_1 = self.gcn2(h_1, adj, sparse)
        h_1 = self.gcn3(h_1, adj, sparse)

        c = self.read(h_1, msk)
        c = self.sigm(c)

        h_2 = self.gcn(seq2, adj, sparse)
        h_2 = self.gcn2(h_2, adj, sparse)
        h_2 = self.gcn3(h_2, adj, sparse)
        
        h_1 = F.dropout(h_1, drop_prob, training=self.training)
        h_2 = F.dropout(h_2, drop_prob, training=self.training)
        
        ret = self.disc(c, h_1, h_2, samp_bias1, samp_bias2)

        return ret

    # Detach the return variables
    def embed(self, seq, adj, sparse, msk):
        h_1 = self.gcn(seq, adj, sparse)
        h_1 = self.gcn2(h_1, adj, sparse)
        c = self.read(h_1, msk)

        return h_1.detach(), c.detach()

In [221]:
model = DGI(ft_size, hid_units, nonlinearity)
optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef)

if torch.cuda.is_available():
    print('Using CUDA')
    model.cuda()
    features = features.cuda()
    if sparse:
        sp_adj = sp_adj.cuda()
    else:
        adj = adj.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()
    
b_xent = nn.BCEWithLogitsLoss()
xent = nn.CrossEntropyLoss()
nll = nn.NLLLoss()
cnt_wait = 0
best = 1e9
best_t = 0

train_lbls = torch.argmax(labels[0, idx_train], dim=1)
val_lbls = torch.argmax(labels[0, idx_val], dim=1)
test_lbls = torch.argmax(labels[0, idx_test], dim=1)

Using CUDA


In [222]:
for epoch in range(nb_epochs):
    model.train()
    optimiser.zero_grad()

    idx = np.random.permutation(nb_nodes)
    shuf_fts = features[:, idx, :]

    lbl_1 = torch.zeros(nb_nodes, dtype=torch.long)
    lbl_2 = torch.ones(nb_nodes, dtype=torch.long)
    lbl = torch.cat((lbl_1, lbl_2), 0)

    if torch.cuda.is_available():
        shuf_fts = shuf_fts.cuda()
        lbl = lbl.cuda()
    
    logits = model(features, shuf_fts, sp_adj if sparse else adj, sparse, None, None, None) 
    
    # SSL Loss
    sup_loss = xent(logits[idx_train], train_lbls)/8.

    log = nn.functional.log_softmax(logits, dim=1)
    ul_logits = torch.cat((torch.sum(log[:,:-1], dim=1).reshape(-1,1),log[:,-1].reshape(-1,1)),1)
    ul_loss = nll(ul_logits, lbl)

    loss = sup_loss + ul_loss
    
    print('Loss: %.4f, UL: %.4f, Sup: %.4f' % (loss.item(), ul_loss.item(), sup_loss.item()))

    if loss < best:
        best = loss
        best_t = epoch
        cnt_wait = 0
        torch.save(model.state_dict(), 'best_dgi.pkl')
    else:
        cnt_wait += 1

    if cnt_wait == patience:
        print('Early stopping!')
        break

    loss.backward()
    optimiser.step()
    
    model.eval()
    logits = model(features, shuf_fts, sp_adj if sparse else adj, sparse, None, None, None) 
    log = nn.functional.log_softmax(logits, dim=1)
    preds = torch.argmax(log[:,:-1], dim=1)
    train_acc = torch.sum(preds[idx_train] == train_lbls).float() / train_lbls.shape[0]
    acc = torch.sum(preds[idx_test] == test_lbls).float() / test_lbls.shape[0]
    print("%d Test Acc: %.4f, Train Acc: %.4f\n" % (epoch, acc.item(), train_acc.item()))

Loss: 8.5778, UL: 8.3178, Sup: 0.2600
0 Test Acc: 0.4960, Train Acc: 0.6214

Loss: 8.5747, UL: 8.3162, Sup: 0.2585
1 Test Acc: 0.6390, Train Acc: 0.7357

Loss: 8.5703, UL: 8.3136, Sup: 0.2567
2 Test Acc: 0.5980, Train Acc: 0.8143

Loss: 8.5632, UL: 8.3073, Sup: 0.2559
3 Test Acc: 0.6130, Train Acc: 0.8500

Loss: 8.5518, UL: 8.2978, Sup: 0.2539
4 Test Acc: 0.6190, Train Acc: 0.8429

Loss: 8.5337, UL: 8.2820, Sup: 0.2517
5 Test Acc: 0.6690, Train Acc: 0.8714

Loss: 8.5109, UL: 8.2624, Sup: 0.2485
6 Test Acc: 0.6800, Train Acc: 0.8714

Loss: 8.4791, UL: 8.2332, Sup: 0.2459
7 Test Acc: 0.6920, Train Acc: 0.8929

Loss: 8.4430, UL: 8.2030, Sup: 0.2399
8 Test Acc: 0.6860, Train Acc: 0.8714

Loss: 8.4373, UL: 8.1900, Sup: 0.2473
9 Test Acc: 0.7140, Train Acc: 0.8786

Loss: 8.3908, UL: 8.1600, Sup: 0.2308
10 Test Acc: 0.7220, Train Acc: 0.9286

Loss: 8.3198, UL: 8.0883, Sup: 0.2316
11 Test Acc: 0.7440, Train Acc: 0.9214

Loss: 8.3177, UL: 8.0777, Sup: 0.2400
12 Test Acc: 0.7410, Train Acc: 0.92

KeyboardInterrupt: 

In [195]:
print('Loading {}th epoch'.format(best_t))
model.load_state_dict(torch.load('best_dgi.pkl'))

model.eval()
logits = model(features, shuf_fts, sp_adj if sparse else adj, sparse, None, None, None) 
log = nn.functional.log_softmax(logits, dim=1)


Loading 130th epoch


In [196]:
preds = torch.argmax(log, dim=1)

In [197]:
count = np.zeros(8, dtype=int)
for i in range(2708):
    count[preds[i]] += 1
print(count)

count = np.zeros(8, dtype=int)
for i in range(2708,5416):
    count[preds[i]] += 1
print(count)

[360 334 447 580 435 279 273   0]
[   8    3    5   18   18    8    0 2648]


In [198]:
preds = torch.argmax(log[:,:-1], dim=1)
acc = torch.sum(preds[idx_test] == test_lbls).float() / test_lbls.shape[0]
print("Test Acc", acc)

Test Acc tensor(0.7850, device='cuda:0')


In [199]:
print('Loading {}th epoch'.format(best_t))
model.load_state_dict(torch.load('best_dgi.pkl'))

logits = model(features, shuf_fts, sp_adj if sparse else adj, sparse, None, None, None) 
log = nn.functional.log_softmax(logits, dim=1)
preds = torch.argmax(log[:,:-1], dim=1)
acc = torch.sum(preds[idx_test] == test_lbls).float() / test_lbls.shape[0]
acc

Loading 130th epoch


tensor(0.7850, device='cuda:0')

In [200]:
print('Loading {}th epoch'.format(best_t))
model.load_state_dict(torch.load('best_dgi.pkl'))

embeds, _ = model.embed(features, sp_adj if sparse else adj, sparse, None)
train_embs = embeds[0, idx_train]
val_embs = embeds[0, idx_val]
test_embs = embeds[0, idx_test]

train_lbls = torch.argmax(labels[0, idx_train], dim=1)
val_lbls = torch.argmax(labels[0, idx_val], dim=1)
test_lbls = torch.argmax(labels[0, idx_test], dim=1)

tot = torch.zeros(1)
tot = tot.cuda()

accs = []

for _ in range(50):
    log = LogReg(hid_units, nb_classes)
    opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
    log.cuda()

    pat_steps = 0
    best_acc = torch.zeros(1)
    best_acc = best_acc.cuda()
    for _ in range(100):
        log.train()
        opt.zero_grad()

        logits = log(train_embs)
        loss = xent(logits, train_lbls)
        
        loss.backward()
        opt.step()

        
    tlogits = log(train_embs)
    tpreds = torch.argmax(tlogits, dim=1)
    train_acc = torch.sum(tpreds == train_lbls).float() / train_lbls.shape[0]

    logits = log(test_embs)
    preds = torch.argmax(logits, dim=1)
    acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
    print("Test Acc: %.4f, Train Acc: %.4f" % (acc.item(), train_acc.item()))
    accs.append(acc * 100)
    tot += acc

print('Average accuracy:', tot / 50)

accs = torch.stack(accs)
print(accs.mean())
print(accs.std())

Loading 130th epoch
Test Acc: 0.8090, Train Acc: 0.9286
Test Acc: 0.8110, Train Acc: 0.9286
Test Acc: 0.8070, Train Acc: 0.9214
Test Acc: 0.8120, Train Acc: 0.9286
Test Acc: 0.8110, Train Acc: 0.9214
Test Acc: 0.8150, Train Acc: 0.9286
Test Acc: 0.8110, Train Acc: 0.9286
Test Acc: 0.8130, Train Acc: 0.9286
Test Acc: 0.8130, Train Acc: 0.9286
Test Acc: 0.8090, Train Acc: 0.9214
Test Acc: 0.8100, Train Acc: 0.9286
Test Acc: 0.8090, Train Acc: 0.9214
Test Acc: 0.8140, Train Acc: 0.9286
Test Acc: 0.8140, Train Acc: 0.9214
Test Acc: 0.8140, Train Acc: 0.9214
Test Acc: 0.8150, Train Acc: 0.9286
Test Acc: 0.8150, Train Acc: 0.9214
Test Acc: 0.8100, Train Acc: 0.9214
Test Acc: 0.8100, Train Acc: 0.9286
Test Acc: 0.8150, Train Acc: 0.9214
Test Acc: 0.8080, Train Acc: 0.9286
Test Acc: 0.8160, Train Acc: 0.9286
Test Acc: 0.8140, Train Acc: 0.9214
Test Acc: 0.8130, Train Acc: 0.9286
Test Acc: 0.8080, Train Acc: 0.9214
Test Acc: 0.8100, Train Acc: 0.9286
Test Acc: 0.8130, Train Acc: 0.9286
Test Acc