In [1]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from deeprobust.graph.defense import GCN
from deeprobust.graph.global_attack import MetaApprox, Metattack
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset
import argparse

from config import *
from utils import *
from metrics import *
###
import tensorflow as tf
import time
from models import GCN_dropedge


 # Settings
len_tmp = labels.shape[0]



import os
os.environ["CUDA_VISIBLE_DEVICES"]="4"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = 'cpu'
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if device != 'cpu':
    torch.cuda.manual_seed(args.seed)

data = Dataset(root='/tmp/', name=args.dataset, setting='nettack')
adj, features, labels = data.adj, data.features, data.labels
print(features.shape)
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
idx_unlabeled = np.union1d(idx_val, idx_test)

ptb_rate = 0.01
perturbations = int(ptb_rate * (adj.sum()//2))
adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
print(features.shape)

def test(adj):
    ''' test on GCN '''

    # adj = normalize_adj_tensor(adj)
    gcn = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=256,
        dropout=0.0, with_relu=True, with_bias=False, weight_decay=0, device=device)
    gcn = gcn.to(device)
    gcn.fit(features, adj, labels, idx_train) # train without model picking
    # gcn.fit(features, adj, labels, idx_train, idx_val) # train with validation model picking
    output = gcn.output.cpu()
    loss_test = F.nll_loss(output[idx_test], labels[idx_test])
    acc_test = accuracy(output[idx_test], labels[idx_test])
    print("Test set results:",
          "loss= {:.4f}".format(loss_test.item()),
          "accuracy= {:.4f}".format(acc_test.item()))

    return acc_test.item()

KeyboardInterrupt: 

In [None]:
train_mask = np.zeros(len_tmp)
train_mask[idx_train] = 1

val_mask = np.zeros(len_tmp)
val_mask[idx_val] = 1

test_mask = np.zeros(len_tmp)
test_mask[idx_test] = 1

optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)

#tuple_adj = sparse_to_tuple(adj.tocoo())
features_tensor = tf.convert_to_tensor(features,dtype=dtype)
#adj_tensor = tf.SparseTensor(*tuple_adj)

labels_tmp = F.one_hot(labels)


y_train_tensor = tf.convert_to_tensor(labels_tmp,dtype=dtype)
train_mask_tensor = tf.convert_to_tensor(train_mask)
y_test_tensor = tf.convert_to_tensor(labels_tmp,dtype=dtype)
test_mask_tensor = tf.convert_to_tensor(test_mask)
y_val_tensor = tf.convert_to_tensor(labels_tmp,dtype=dtype)
val_mask_tensor = tf.convert_to_tensor(val_mask)

best_test_acc = 0
best_val_acc_trail = 0
best_val_loss = 10000

# Setup Surrogate model

surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,weight_decay=0.0,
                                nhid=256, dropout=0, with_relu=True, with_bias=False, device='cpu').to('cpu')
surrogate.fit(features, adj, labels, idx_train, idx_val, patience=100)
# Setup Attack Model
model = Metattack(surrogate, nnodes=adj.shape[0], feature_shape=features.shape,
        attack_structure=True, attack_features=False, device='cpu', lambda_=0).to('cpu')
# Attack

model.attack(features, adj, labels, idx_train, idx_unlabeled, n_perturbations=perturbations, ll_constraint=False)
modified_adj = model.modified_adj




Perturbing graph:   0%|          | 0/50 [00:00<?, ?it/s]

GCN loss on unlabled data: 0.49987319111824036
GCN acc on unlabled data: 0.8569512740277156
attack loss: 0.2617972493171692


Perturbing graph:   2%|▏         | 1/50 [00:09<07:39,  9.39s/it]

GCN loss on unlabled data: 0.4988565742969513
GCN acc on unlabled data: 0.8582923558337059
attack loss: 0.2586739659309387


Perturbing graph:   4%|▍         | 2/50 [00:18<07:22,  9.21s/it]

GCN loss on unlabled data: 0.5092517137527466
GCN acc on unlabled data: 0.8524810013410818
attack loss: 0.26840662956237793


Perturbing graph:   6%|▌         | 3/50 [00:25<06:24,  8.19s/it]

GCN loss on unlabled data: 0.5231226086616516
GCN acc on unlabled data: 0.8511399195350916
attack loss: 0.27679237723350525


Perturbing graph:   8%|▊         | 4/50 [00:32<05:52,  7.67s/it]

GCN loss on unlabled data: 0.5226290225982666
GCN acc on unlabled data: 0.8506928922664283
attack loss: 0.27846354246139526


Perturbing graph:  10%|█         | 5/50 [00:41<06:05,  8.11s/it]

GCN loss on unlabled data: 0.5252869129180908
GCN acc on unlabled data: 0.8529280286097451
attack loss: 0.276610404253006


Perturbing graph:  12%|█▏        | 6/50 [00:50<06:12,  8.46s/it]

GCN loss on unlabled data: 0.530955970287323
GCN acc on unlabled data: 0.8497988377291015
attack loss: 0.2848435938358307


Perturbing graph:  14%|█▍        | 7/50 [00:59<06:18,  8.81s/it]

GCN loss on unlabled data: 0.528327465057373
GCN acc on unlabled data: 0.8511399195350916
attack loss: 0.2821969985961914


Perturbing graph:  16%|█▌        | 8/50 [01:51<15:36, 22.30s/it]

GCN loss on unlabled data: 0.5324933528900146
GCN acc on unlabled data: 0.851586946803755
attack loss: 0.2877604067325592


Perturbing graph:  18%|█▊        | 9/50 [02:33<19:29, 28.53s/it]

GCN loss on unlabled data: 0.534060537815094
GCN acc on unlabled data: 0.8520339740724184
attack loss: 0.2940361499786377


Perturbing graph:  20%|██        | 10/50 [03:01<18:55, 28.38s/it]

GCN loss on unlabled data: 0.5442187786102295
GCN acc on unlabled data: 0.8475637013857845
attack loss: 0.30321604013442993


Perturbing graph:  22%|██▏       | 11/50 [03:45<21:29, 33.06s/it]

GCN loss on unlabled data: 0.5432174801826477
GCN acc on unlabled data: 0.8462226195797944
attack loss: 0.2983416020870209


Perturbing graph:  24%|██▍       | 12/50 [04:25<22:21, 35.31s/it]

GCN loss on unlabled data: 0.5399670600891113
GCN acc on unlabled data: 0.845775592311131
attack loss: 0.30239349603652954


Perturbing graph:  26%|██▌       | 13/50 [05:13<24:09, 39.18s/it]

GCN loss on unlabled data: 0.5439217686653137
GCN acc on unlabled data: 0.8448815377738041
attack loss: 0.30126509070396423


Perturbing graph:  28%|██▊       | 14/50 [05:47<22:35, 37.66s/it]

GCN loss on unlabled data: 0.5565455555915833
GCN acc on unlabled data: 0.8462226195797944
attack loss: 0.31246116757392883


Perturbing graph:  30%|███       | 15/50 [06:33<23:18, 39.97s/it]

GCN loss on unlabled data: 0.5540516972541809
GCN acc on unlabled data: 0.8426464014304873
attack loss: 0.3116917908191681


Perturbing graph:  32%|███▏      | 16/50 [07:09<22:02, 38.89s/it]

GCN loss on unlabled data: 0.548148512840271
GCN acc on unlabled data: 0.8462226195797944
attack loss: 0.3143134117126465


Perturbing graph:  34%|███▍      | 17/50 [07:55<22:36, 41.10s/it]

GCN loss on unlabled data: 0.5547254681587219
GCN acc on unlabled data: 0.8426464014304873
attack loss: 0.31832852959632874


Perturbing graph:  36%|███▌      | 18/50 [08:25<20:08, 37.78s/it]

GCN loss on unlabled data: 0.5566093921661377
GCN acc on unlabled data: 0.8413053196244971
attack loss: 0.32247859239578247


Perturbing graph:  38%|███▊      | 19/50 [09:08<20:16, 39.24s/it]

GCN loss on unlabled data: 0.5526655316352844
GCN acc on unlabled data: 0.8413053196244971
attack loss: 0.3246949017047882


Perturbing graph:  40%|████      | 20/50 [09:47<19:36, 39.23s/it]

GCN loss on unlabled data: 0.5628324747085571
GCN acc on unlabled data: 0.8408582923558338
attack loss: 0.33063891530036926


Perturbing graph:  42%|████▏     | 21/50 [10:24<18:33, 38.41s/it]

GCN loss on unlabled data: 0.5697439312934875
GCN acc on unlabled data: 0.8426464014304873
attack loss: 0.33770951628685


Perturbing graph:  44%|████▍     | 22/50 [10:48<16:00, 34.32s/it]

GCN loss on unlabled data: 0.5702192783355713
GCN acc on unlabled data: 0.8426464014304873
attack loss: 0.3394838571548462


Perturbing graph:  46%|████▌     | 23/50 [11:00<12:25, 27.60s/it]

GCN loss on unlabled data: 0.5683210492134094
GCN acc on unlabled data: 0.8426464014304873
attack loss: 0.33527594804763794


Perturbing graph:  48%|████▊     | 24/50 [11:12<09:53, 22.85s/it]

GCN loss on unlabled data: 0.5815107226371765
GCN acc on unlabled data: 0.8413053196244971
attack loss: 0.34328493475914


Perturbing graph:  50%|█████     | 25/50 [12:02<12:54, 30.96s/it]

GCN loss on unlabled data: 0.5752321481704712
GCN acc on unlabled data: 0.8413053196244971
attack loss: 0.3489215075969696


Perturbing graph:  52%|█████▏    | 26/50 [12:21<11:00, 27.54s/it]

GCN loss on unlabled data: 0.5788125991821289
GCN acc on unlabled data: 0.8399642378185069
attack loss: 0.35380011796951294


Perturbing graph:  54%|█████▍    | 27/50 [12:45<10:09, 26.48s/it]

GCN loss on unlabled data: 0.5735634565353394
GCN acc on unlabled data: 0.8408582923558338
attack loss: 0.34825676679611206


Perturbing graph:  56%|█████▌    | 28/50 [13:38<12:34, 34.29s/it]

GCN loss on unlabled data: 0.5853952765464783
GCN acc on unlabled data: 0.8372820742065266
attack loss: 0.36046600341796875


Perturbing graph:  58%|█████▊    | 29/50 [14:16<12:26, 35.54s/it]

GCN loss on unlabled data: 0.5855368971824646
GCN acc on unlabled data: 0.8395172105498435
attack loss: 0.36060014367103577


Perturbing graph:  60%|██████    | 30/50 [15:02<12:48, 38.41s/it]

GCN loss on unlabled data: 0.5914331078529358
GCN acc on unlabled data: 0.8372820742065266
attack loss: 0.3635765314102173


Perturbing graph:  62%|██████▏   | 31/50 [15:49<13:03, 41.25s/it]

GCN loss on unlabled data: 0.5798417925834656
GCN acc on unlabled data: 0.8408582923558338
attack loss: 0.35571563243865967


Perturbing graph:  64%|██████▍   | 32/50 [16:32<12:31, 41.74s/it]

GCN loss on unlabled data: 0.5876798629760742
GCN acc on unlabled data: 0.8417523468931605
attack loss: 0.3673661947250366


Perturbing graph:  66%|██████▌   | 33/50 [17:17<12:02, 42.52s/it]

GCN loss on unlabled data: 0.5787569284439087
GCN acc on unlabled data: 0.8421993741618239
attack loss: 0.36202114820480347


Perturbing graph:  68%|██████▊   | 34/50 [17:53<10:49, 40.59s/it]

GCN loss on unlabled data: 0.5903906226158142
GCN acc on unlabled data: 0.8399642378185069
attack loss: 0.37824878096580505


Perturbing graph:  70%|███████   | 35/50 [18:31<09:59, 39.94s/it]

: 

: 

In [None]:
modified_adj=sp.csr_array(modified_adj.int())

tuple_adj = sparse_to_tuple(modified_adj.tocoo())
adj_tensor = tf.SparseTensor(*tuple_adj)


7


In [None]:

# Settings
optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)
model = GCN_dropedge(input_dim=features.shape[1], output_dim=labels.max().item()+1, adj=adj_tensor)


best_test_acc = 0
best_val_acc = 0
best_val_loss = 10000


curr_step = 0
for epoch in range(args.epochs):

    with tf.GradientTape() as tape:
        output = model.call((features_tensor),training=True)
        cross_loss = masked_softmax_cross_entropy(output, y_train_tensor,train_mask_tensor)
        lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in model.trainable_variables])
        loss = cross_loss #+ args.weight_decay*lossL2
        grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    output = model.call((features_tensor), training=False)
    train_acc = masked_accuracy(output, y_train_tensor,train_mask_tensor)
    val_acc  = masked_accuracy(output, y_val_tensor,val_mask_tensor)
    val_loss = masked_softmax_cross_entropy(output, y_val_tensor, val_mask_tensor)
    test_acc  = masked_accuracy(output, y_test_tensor,test_mask_tensor)

    if val_acc > best_val_acc:
        curr_step = 0
        best_test_acc = test_acc
        best_val_acc = val_acc
        best_val_loss= val_loss
        # Print results

    else:
        curr_step +=1
    if curr_step > args.early_stop:
        print("Early stopping...")
        break

    print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cross_loss),"val_loss=", "{:.5f}".format(val_loss),
    "train_acc=", "{:.5f}".format(val_acc), "val_acc=", "{:.5f}".format(val_acc),
    "test_acc=", "{:.5f}".format(best_test_acc))

0.0
Epoch: 0001 train_loss= 1.93524 val_loss= 1.89642 train_acc= 0.41365 val_acc= 0.41365 test_acc= 0.38984
Epoch: 0002 train_loss= 1.87406 val_loss= 1.85383 train_acc= 0.49398 val_acc= 0.49398 test_acc= 0.47485
Epoch: 0003 train_loss= 1.81402 val_loss= 1.81144 train_acc= 0.51004 val_acc= 0.51004 test_acc= 0.49799
Epoch: 0004 train_loss= 1.75464 val_loss= 1.76905 train_acc= 0.50602 val_acc= 0.50602 test_acc= 0.49799
Epoch: 0005 train_loss= 1.69535 val_loss= 1.72631 train_acc= 0.50602 val_acc= 0.50602 test_acc= 0.49799
Epoch: 0006 train_loss= 1.63565 val_loss= 1.68287 train_acc= 0.50201 val_acc= 0.50201 test_acc= 0.49799
Epoch: 0007 train_loss= 1.57523 val_loss= 1.63880 train_acc= 0.51004 val_acc= 0.51004 test_acc= 0.49799
Epoch: 0008 train_loss= 1.51408 val_loss= 1.59434 train_acc= 0.52610 val_acc= 0.52610 test_acc= 0.56891
Epoch: 0009 train_loss= 1.45231 val_loss= 1.54956 train_acc= 0.55020 val_acc= 0.55020 test_acc= 0.58400
Epoch: 0010 train_loss= 1.39007 val_loss= 1.50448 train_acc=