In [15]:
import sys

sys.path.append("nettack-master")

In [97]:
import json
import pickle
import random
from glob import glob
from itertools import repeat
from pathlib import Path
from uuid import uuid4

import numpy as np
import tensorflow as tf
from joblib import Parallel, delayed
from tqdm import tqdm

In [4]:
tf.test.is_gpu_available()

True

In [5]:
from nettack import GCN
from nettack import nettack as ntk
from nettack import utils




In [6]:
def select_subset_of_keys(dict_, *args):
    return dict((k, dict_[k]) for k in args)

In [7]:
def prepare_surrogate_model_params():
    _A_obs, _X_obs, _z_obs = utils.load_npz("./nettack-master/data/cora.npz")
    _A_obs = _A_obs + _A_obs.T
    _A_obs[_A_obs > 1] = 1
    lcc = utils.largest_connected_components(_A_obs)

    _A_obs = _A_obs[lcc][:, lcc]

    assert np.abs(_A_obs - _A_obs.T).sum() == 0, "Input graph is not symmetric"
    assert (
        _A_obs.max() == 1 and len(np.unique(_A_obs[_A_obs.nonzero()].A1)) == 1
    ), "Graph must be unweighted"
    assert _A_obs.sum(0).A1.min() > 0, "Graph contains singleton nodes"

    _X_obs = _X_obs[lcc].astype("float32")
    _z_obs = _z_obs[lcc]
    _N = _A_obs.shape[0]
    _K = _z_obs.max() + 1
    _Z_obs = np.eye(_K)[_z_obs]
    _An = utils.preprocess_graph(_A_obs)
    sizes = [16, _K]
    degrees = _A_obs.sum(0).A1

    seed = 420
    unlabeled_share = 0.8
    val_share = 0.1
    train_share = 1 - unlabeled_share - val_share
    np.random.seed(seed)

    split_train, split_val, split_unlabeled = utils.train_val_test_split_tabular(
        np.arange(_N),
        train_size=train_share,
        val_size=val_share,
        test_size=unlabeled_share,
        stratify=_z_obs,
    )

    return locals()

In [8]:
surrogate_params = prepare_surrogate_model_params()

Selecting 1 largest connected components


In [9]:
np.random.seed()  # reseed again since prepare_surrogate_model_params set the seed to something

In [10]:
models_weights_path = Path("./weights/")

In [115]:
pickle.dump(surrogate_params, open('surrogate_params.pickle', 'wb'))

In [11]:
def train_surrogate_and_save_weights(surrogate_params, gpu_id):
    with tf.device(f"/gpu:{gpu_id}"):
        surrogate_model = GCN.GCN(
            surrogate_params["sizes"],
            surrogate_params["_An"],
            surrogate_params["_X_obs"],
            with_relu=False,
            name="surrogate",
            gpu_id=0,
        )
        surrogate_model.train(
            surrogate_params["split_train"],
            surrogate_params["split_val"],
            surrogate_params["_Z_obs"],
        )
        W1 = surrogate_model.W1.eval(session=surrogate_model.session)
        W2 = surrogate_model.W2.eval(session=surrogate_model.session)

        surrogate_model.session.close()  # let tf free memory

        path = models_weights_path / str(uuid4())
        np.savez(str(path), W1=W1, W2=W2)
        return path

In [13]:
for _ in tqdm(range(200)):
    train_surrogate_and_save_weights(surrogate_params, 0)

  0%|          | 0/100 [00:00<?, ?it/s]

converged after 52 iterations


  2%|▏         | 2/100 [00:02<02:10,  1.33s/it]

converged after 49 iterations


  3%|▎         | 3/100 [00:03<01:55,  1.19s/it]

converged after 48 iterations


  4%|▍         | 4/100 [00:04<01:50,  1.15s/it]

converged after 51 iterations


  5%|▌         | 5/100 [00:05<01:46,  1.12s/it]

converged after 50 iterations
converged after 57 iterations


  7%|▋         | 7/100 [00:07<01:37,  1.04s/it]

converged after 50 iterations


  8%|▊         | 8/100 [00:08<01:37,  1.06s/it]

converged after 53 iterations


  9%|▉         | 9/100 [00:09<01:36,  1.06s/it]

converged after 59 iterations


 10%|█         | 10/100 [00:10<01:39,  1.10s/it]

converged after 61 iterations


 11%|█         | 11/100 [00:11<01:30,  1.02s/it]

converged after 49 iterations


 12%|█▏        | 12/100 [00:12<01:28,  1.00s/it]

converged after 49 iterations


 13%|█▎        | 13/100 [00:13<01:30,  1.04s/it]

converged after 53 iterations


 14%|█▍        | 14/100 [00:14<01:25,  1.01it/s]

converged after 57 iterations


 15%|█▌        | 15/100 [00:15<01:27,  1.03s/it]

converged after 57 iterations


 16%|█▌        | 16/100 [00:16<01:28,  1.05s/it]

converged after 49 iterations


 17%|█▋        | 17/100 [00:17<01:26,  1.05s/it]

converged after 54 iterations


 18%|█▊        | 18/100 [00:18<01:22,  1.00s/it]

converged after 46 iterations


 19%|█▉        | 19/100 [00:19<01:22,  1.02s/it]

converged after 56 iterations


 20%|██        | 20/100 [00:20<01:18,  1.03it/s]

converged after 43 iterations


 21%|██        | 21/100 [00:21<01:14,  1.05it/s]

converged after 47 iterations


 22%|██▏       | 22/100 [00:22<01:18,  1.01s/it]

converged after 50 iterations


 23%|██▎       | 23/100 [00:23<01:16,  1.00it/s]

converged after 57 iterations


 24%|██▍       | 24/100 [00:24<01:16,  1.00s/it]

converged after 54 iterations


 25%|██▌       | 25/100 [00:25<01:20,  1.07s/it]

converged after 72 iterations


 26%|██▌       | 26/100 [00:26<01:17,  1.05s/it]

converged after 50 iterations
converged after 52 iterations


 28%|██▊       | 28/100 [00:29<01:14,  1.04s/it]

converged after 48 iterations


 29%|██▉       | 29/100 [00:30<01:16,  1.08s/it]

converged after 60 iterations


 30%|███       | 30/100 [00:31<01:11,  1.02s/it]

converged after 44 iterations


 31%|███       | 31/100 [00:32<01:11,  1.03s/it]

converged after 57 iterations
converged after 46 iterations


 33%|███▎      | 33/100 [00:34<01:14,  1.11s/it]

converged after 49 iterations


 34%|███▍      | 34/100 [00:35<01:10,  1.06s/it]

converged after 49 iterations


 35%|███▌      | 35/100 [00:37<01:16,  1.17s/it]

converged after 60 iterations


 36%|███▌      | 36/100 [00:38<01:12,  1.14s/it]

converged after 51 iterations


 37%|███▋      | 37/100 [00:39<01:11,  1.14s/it]

converged after 65 iterations


 38%|███▊      | 38/100 [00:40<01:06,  1.07s/it]

converged after 55 iterations


 39%|███▉      | 39/100 [00:41<01:03,  1.04s/it]

converged after 45 iterations


 40%|████      | 40/100 [00:42<01:03,  1.06s/it]

converged after 44 iterations


 41%|████      | 41/100 [00:43<00:59,  1.01s/it]

converged after 44 iterations


 42%|████▏     | 42/100 [00:43<00:56,  1.03it/s]

converged after 52 iterations
converged after 52 iterations


 44%|████▍     | 44/100 [00:46<01:05,  1.16s/it]

converged after 41 iterations


 45%|████▌     | 45/100 [00:47<01:01,  1.12s/it]

converged after 50 iterations


 46%|████▌     | 46/100 [00:48<00:57,  1.06s/it]

converged after 55 iterations


 47%|████▋     | 47/100 [00:49<00:52,  1.02it/s]

converged after 49 iterations
converged after 52 iterations


 48%|████▊     | 48/100 [00:51<00:57,  1.10s/it]

converged after 50 iterations


 50%|█████     | 50/100 [00:53<00:55,  1.12s/it]

converged after 59 iterations


 51%|█████     | 51/100 [00:54<00:51,  1.05s/it]

converged after 46 iterations


 52%|█████▏    | 52/100 [00:55<00:50,  1.05s/it]

converged after 52 iterations
converged after 53 iterations


 54%|█████▍    | 54/100 [00:58<00:57,  1.25s/it]

converged after 51 iterations


 55%|█████▌    | 55/100 [00:59<00:50,  1.11s/it]

converged after 53 iterations


 56%|█████▌    | 56/100 [01:00<00:48,  1.10s/it]

converged after 54 iterations


 57%|█████▋    | 57/100 [01:01<00:45,  1.06s/it]

converged after 52 iterations


 58%|█████▊    | 58/100 [01:02<00:43,  1.03s/it]

converged after 63 iterations


 59%|█████▉    | 59/100 [01:03<00:42,  1.04s/it]

converged after 58 iterations


 60%|██████    | 60/100 [01:04<00:42,  1.05s/it]

converged after 59 iterations


 61%|██████    | 61/100 [01:05<00:42,  1.08s/it]

converged after 58 iterations


 62%|██████▏   | 62/100 [01:06<00:42,  1.11s/it]

converged after 66 iterations


 63%|██████▎   | 63/100 [01:07<00:37,  1.02s/it]

converged after 48 iterations


 64%|██████▍   | 64/100 [01:08<00:35,  1.02it/s]

converged after 51 iterations


 65%|██████▌   | 65/100 [01:09<00:36,  1.03s/it]

converged after 58 iterations


 66%|██████▌   | 66/100 [01:10<00:35,  1.03s/it]

converged after 51 iterations


 67%|██████▋   | 67/100 [01:11<00:33,  1.01s/it]

converged after 49 iterations


 68%|██████▊   | 68/100 [01:12<00:32,  1.02s/it]

converged after 42 iterations


 69%|██████▉   | 69/100 [01:13<00:30,  1.01it/s]

converged after 54 iterations


 70%|███████   | 70/100 [01:14<00:29,  1.01it/s]

converged after 58 iterations


 71%|███████   | 71/100 [01:15<00:29,  1.01s/it]

converged after 49 iterations


 72%|███████▏  | 72/100 [01:16<00:28,  1.03s/it]

converged after 53 iterations


 73%|███████▎  | 73/100 [01:17<00:27,  1.01s/it]

converged after 52 iterations


 74%|███████▍  | 74/100 [01:18<00:26,  1.04s/it]

converged after 69 iterations


 75%|███████▌  | 75/100 [01:19<00:26,  1.07s/it]

converged after 50 iterations


 76%|███████▌  | 76/100 [01:21<00:28,  1.19s/it]

converged after 46 iterations


 77%|███████▋  | 77/100 [01:22<00:28,  1.23s/it]

converged after 58 iterations


 78%|███████▊  | 78/100 [01:23<00:25,  1.16s/it]

converged after 49 iterations


 79%|███████▉  | 79/100 [01:24<00:22,  1.09s/it]

converged after 61 iterations


 80%|████████  | 80/100 [01:25<00:21,  1.09s/it]

converged after 51 iterations


 81%|████████  | 81/100 [01:26<00:20,  1.11s/it]

converged after 54 iterations


 82%|████████▏ | 82/100 [01:27<00:19,  1.07s/it]

converged after 50 iterations


 83%|████████▎ | 83/100 [01:28<00:17,  1.06s/it]

converged after 53 iterations


 84%|████████▍ | 84/100 [01:29<00:17,  1.08s/it]

converged after 45 iterations


 85%|████████▌ | 85/100 [01:31<00:16,  1.13s/it]

converged after 52 iterations


 86%|████████▌ | 86/100 [01:32<00:15,  1.10s/it]

converged after 44 iterations


 87%|████████▋ | 87/100 [01:32<00:13,  1.03s/it]

converged after 57 iterations


 88%|████████▊ | 88/100 [01:34<00:12,  1.04s/it]

converged after 54 iterations


 89%|████████▉ | 89/100 [01:35<00:11,  1.07s/it]

converged after 45 iterations


 90%|█████████ | 90/100 [01:36<00:10,  1.07s/it]

converged after 61 iterations


 91%|█████████ | 91/100 [01:37<00:09,  1.06s/it]

converged after 53 iterations


 92%|█████████▏| 92/100 [01:38<00:08,  1.10s/it]

converged after 60 iterations


 93%|█████████▎| 93/100 [01:39<00:07,  1.07s/it]

converged after 43 iterations


 94%|█████████▍| 94/100 [01:40<00:06,  1.04s/it]

converged after 48 iterations


 95%|█████████▌| 95/100 [01:41<00:05,  1.06s/it]

converged after 59 iterations


 96%|█████████▌| 96/100 [01:42<00:04,  1.08s/it]

converged after 48 iterations


 97%|█████████▋| 97/100 [01:43<00:03,  1.13s/it]

converged after 53 iterations


 98%|█████████▊| 98/100 [01:45<00:02,  1.17s/it]

converged after 56 iterations


 99%|█████████▉| 99/100 [01:46<00:01,  1.13s/it]

converged after 51 iterations


100%|██████████| 100/100 [01:47<00:00,  1.07s/it]

converged after 65 iterations





# NETTACK time Major League Gamers!

In [90]:
dataset_path = Path("./dataset/")

In [108]:
def get_nettack_params(surrogate_params):
    u = random.choice(
        surrogate_params["split_unlabeled"]
    )  # randomly pick an attacked node
    return {
        "u": u,
        "direct_attack": False,
        "n_influencers": int(surrogate_params["degrees"][u]),
        # limit the number of perturbations, since the computation takes too much time
        # (max node degree is around 160)
        "n_perturbations": min(50, int(surrogate_params["degrees"][u])),
        "perturb_features": True,
        "perturb_structure": True,
    }

In [109]:
def perform_nettack_on_a_surrogate_graph(surrogate_params, weights_path):
    nettack_params = get_nettack_params(surrogate_params)
    weigths = np.load(weights_path)

    nettack = ntk.Nettack(
        surrogate_params["_A_obs"],
        surrogate_params["_X_obs"],
        surrogate_params["_z_obs"],
        weigths["W1"],
        weigths["W2"],
        nettack_params["u"],
        verbose=False,
    )

    nettack.reset()

    nettack.attack_surrogate(
        nettack_params["n_perturbations"],
        perturb_structure=nettack_params["perturb_structure"],
        perturb_features=nettack_params["perturb_features"],
        direct=nettack_params["direct_attack"],
        n_influencers=nettack_params["n_influencers"],
    )

    out = {
        "params": nettack_params,
        "structure_perturbations": nettack.structure_perturbations,
        "feature_perturbations": nettack.feature_perturbations,
    }
    with open(dataset_path / f"{uuid4()}.pickle", "wb") as f:
        pickle.dump(out, f)

In [110]:
# perform_nettack_on_a_surrogate_graph(
#     surrogate_params, "./weights/00c1590c-cd03-4351-97cb-503c5a841224.npz"
# )

In [111]:
def dataset_generator():
    for weights_path in glob("./weights/*.npz"):
        for _ in range(100):
            yield weights_path

In [None]:
Parallel(n_jobs=32, verbose=10)(
    delayed(perform_nettack_on_a_surrogate_graph)(surrogate_params, weights_path)
    for weights_path in dataset_generator()
)

[Parallel(n_jobs=32)]: Using backend LokyBackend with 32 concurrent workers.
[Parallel(n_jobs=32)]: Done   8 tasks      | elapsed:   22.1s
[Parallel(n_jobs=32)]: Done  21 tasks      | elapsed:   26.0s
[Parallel(n_jobs=32)]: Done  34 tasks      | elapsed:   31.6s
[Parallel(n_jobs=32)]: Done  49 tasks      | elapsed:   37.6s
[Parallel(n_jobs=32)]: Done  64 tasks      | elapsed:   42.7s
[Parallel(n_jobs=32)]: Done  81 tasks      | elapsed:   52.7s
[Parallel(n_jobs=32)]: Done  98 tasks      | elapsed:   58.8s
[Parallel(n_jobs=32)]: Done 117 tasks      | elapsed:  1.2min
[Parallel(n_jobs=32)]: Done 136 tasks      | elapsed:  1.3min
[Parallel(n_jobs=32)]: Done 157 tasks      | elapsed:  1.6min
[Parallel(n_jobs=32)]: Done 178 tasks      | elapsed:  1.8min
[Parallel(n_jobs=32)]: Done 201 tasks      | elapsed:  2.1min
[Parallel(n_jobs=32)]: Done 224 tasks      | elapsed:  2.4min
[Parallel(n_jobs=32)]: Done 249 tasks      | elapsed:  2.7min
[Parallel(n_jobs=32)]: Done 274 tasks      | elapsed:  