In [1]:
import sys
import os
sys.path.append(os.path.join('../..'))

################################################
# Arguments
################################################

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=123, help='Random seed for model')

parser.add_argument('--model_lr', type=float, default=0.01, help='Initial learning rate')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters)')
parser.add_argument('--hidden_layers', type=int, default=32, help='Number of hidden layers')
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate for GCN')

parser.add_argument('--protect_size', type=float, default=0.1, help='Number of randomly chosen protected nodes')
parser.add_argument('--ptb_rate', type=float, default=0.25, help='Perturbation rate (percentage of available edges)')

parser.add_argument('--do_sampling', type=str, default='Y', help='to do sampling or not')
parser.add_argument('--sample_size', type=int, default=500, help='')
parser.add_argument('--num_samples', type=int, default=20, help='')


parser.add_argument('--reg_epochs', type=int, default=100, help='Epochs to train models')
parser.add_argument('--ptb_epochs', type=int, default=30, help='Epochs to perturb adj matrix')
parser.add_argument('--surrogate_epochs', type=int, default=0, help='Epochs to train surrogate before perturb')

parser.add_argument('--save', type=str, default='N', help='save the outputs to csv')
parser.add_argument('--save_location', type=str, default="./SelectiveAttack.csv", help='where to save the outputs to csv')
parser.add_argument('--dataset', type=str, default='cora', help='dataset')


parser.add_argument('--check_universal', type=str, default='N', help='check universal protection')

args = parser.parse_args("")

################################################
# Environment
################################################

import torch
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

np.random.seed(args.seed)
torch.manual_seed(args.seed)

if device != 'cpu':
    torch.cuda.manual_seed(args.seed)

print('==== Environment ====')
print(f'  torch version: {torch.__version__}')
print(f'  device: {device}')
print(f'  torch seed: {args.seed}')

################################################
# Dataset
################################################

from Utils import GraphData

print(f'==== Dataset: {args.dataset} ====')

graph = GraphData.getGraph("../../Datasets", args.dataset, "gcn", args.seed, device)
graph.summarize()

################################################
# Designate protected
################################################

g0 = torch.rand(graph.features.shape[0]) <= args.protect_size
# g0 = graph.labels == 5 
g0 = g0.to(device)
gX = ~g0

print(f"Number of protected nodes: {g0.sum():.0f}")
print(f"Protected Size: {g0.sum() / graph.features.shape[0]:.2%}")

==== Environment ====
  torch version: 1.10.2
  device: cpu
  torch seed: 123
==== Dataset: cora ====
Loading cora dataset...

[i] Dataset Summary: 
	adj shape: [2708, 2708]
	feature shape: [2708, 1433]
	num labels: 7
	split seed: 123
	train|val|test: 140|500|1000
Number of protected nodes: 285
Protected Size: 10.52%


In [2]:
g0

tensor([False, False, False,  ..., False, False, False])

In [3]:
################################################
# Generate Perturbations
################################################
import Utils.Utils as Utils
import torch.nn.functional as F

diff = torch.zeros_like(graph.adj)
diff.index_fill_(0, Utils.bool_to_idx(gX).squeeze(), 1)
diff.index_fill_(1, Utils.bool_to_idx(gX).squeeze(), 1)
diff = torch.ones_like(diff) - diff
locked_adj = torch.clamp(graph.adj - diff, 0, 1)


locked_adj.index_fill_(0, Utils.bool_to_idx(g0).squeeze(), 0)
locked_adj.index_fill_(1, Utils.bool_to_idx(g0).squeeze(), 0)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

In [4]:


################################################
# Evaluation
################################################
from Models.GCN import GCN
import Utils.Metrics as Metrics

# baseline_model = GCN(
#     input_features=graph.features.shape[1],
#     output_classes=graph.labels.max().item()+1,
#     hidden_layers=args.hidden_layers,
#     device=device,
#     lr=args.model_lr,
#     dropout=args.dropout,
#     weight_decay=args.weight_decay,
#     name=f"baseline"
# ).to(device)

# baseline_model.fit(graph, args.reg_epochs)

# pred = baseline_model(graph.features, graph.adj)
# baseline_acc = Metrics.partial_acc(pred, graph.labels, g0, gX)

# locked_adj = Utils.get_modified_adj(graph.adj, best)

locked_model = GCN(
    input_features=graph.features.shape[1],
    output_classes=graph.labels.max().item()+1,
    hidden_layers=args.hidden_layers,
    device=device,
    lr=args.model_lr,
    dropout=args.dropout,
    weight_decay=args.weight_decay,
    name=f"locked"
)

locked_model.fitManual(graph.features, locked_adj, graph.labels, graph.idx_train, graph.idx_test, args.reg_epochs)

pred = locked_model(graph.features, locked_adj)
locked_acc = Metrics.partial_acc(pred, graph.labels, g0, gX)

################################################
# Summarize
################################################

# dg0 = locked_acc["g0"] - baseline_acc["g0"]
# dgX = locked_acc["gX"] - baseline_acc["gX"]

# print("==== Accuracies ====")
# print(f"         ΔG0\tΔGX")
# print(f"task1 | {dg0:.1%}\t{dgX:.1%}")

diff = locked_adj - graph.adj
diffSummary = Metrics.show_metrics(diff, graph.labels, g0, device)

print(diffSummary)

Training locked: 100%|██████████| 100/100 [00:07<00:00, 13.68it/s, loss=0.03]


G0: 53.33%
GX: 76.85%
     Within G0 ====
                A-A	A-B	TOTAL
          (+)   0  	0  	0
          (-)   69  	16  	85
     Within GX ====
                A-A	A-B	TOTAL
          (+)   0  	0  	0
          (-)   0  	0  	0
     Between G0-GX ====
                A-A	A-B	TOTAL
          (+)   0  	0  	0
          (-)   845  	235  	1080

        TOTAL   914  	251  	1165
{'g0': {'add': {'same': 0, 'diff': 0, 'total': 0}, 'remove': {'same': 69, 'diff': 16, 'total': 85}, 'total': 85}, 'gX': {'add': {'same': 0, 'diff': 0, 'total': 0}, 'remove': {'same': 0, 'diff': 0, 'total': 0}, 'total': 0}, 'g0gX': {'add': {'same': 0, 'diff': 0, 'total': 0}, 'remove': {'same': 845, 'diff': 235, 'total': 1080}, 'total': 1080}}
