In [1]:
import sys
from pathlib import Path

import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.nn.conv.gcn_conv import gcn_norm

sys.path.append("..")

from mlg.utils import get_summary_writer

In [2]:
dataset = Planetoid(
    Path() / ".." / "datasets" / "cora", "Cora", transform=T.NormalizeFeatures()
)

# Surrogate GCN 

In [3]:
class SurrogateNet(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(SurrogateNet, self).__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # No non-linearity per NETTACK paper
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SurrogateNet(dataset.num_features, dataset.num_classes).to(device)

data = dataset[0].to(device)

optimizer = torch.optim.Adam(
    [
        dict(params=model.conv1.parameters(), weight_decay=5e-4),
        dict(params=model.conv2.parameters(), weight_decay=0),
    ],
    lr=0.01,
)  # Only perform weight-decay on first convolution.
writer = get_summary_writer("NettackPytorch")

In [5]:
for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    pred = model(data)
    label = data.y
    loss = F.nll_loss(pred[data.train_mask], label[data.train_mask])
    writer.add_scalar("Loss/train", loss, epoch)
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        model.eval()
        pred = model(data)

        pred_train = pred[data.train_mask].max(1)[1]
        acc_train = (
            pred_train.eq(data.y[data.train_mask]).sum().item()
            / data.train_mask.sum().item()
        )

        pred_val = pred[data.val_mask].max(1)[1]
        acc_val = (
            pred_val.eq(data.y[data.val_mask]).sum().item() / data.val_mask.sum().item()
        )

        writer.add_scalar("Accuracy/train", acc_train, epoch)
        writer.add_scalar("Accuracy/validation", acc_val, epoch)

`model.conv1` and `model.conv2` are `W1` and `W2` per NETTACK paper

# Nettack (TODO)

In [42]:
n_perturbations = 10
perturb_structure = True
perturb_features = True
direct = True
n_influencers = 0
delta_cutoff = 0.004

attacked_node = 0
influencer_nodes = torch.tensor([])  # instagram bad

In [66]:
# .detach() disables gradient calculation on the tensor
logits = model(data).detach()

temp_attacked_logit = logits[attacked_node].clone()
temp_attacked_logit[data.y[attacked_node]] = float("-inf")
best_wrong_class = temp_attacked_logit.argmax().item()

attacked_logit = logits[attacked_node].clone()
true_class = data.y[attacked_node]

surrogate_loss = (attacked_logit[true_class] - attacked_logit[best_wrong_class]).item()

In [39]:
data.edge_index

tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
        [ 633, 1862, 2582,  ...,  598, 1473, 2706]])

In [None]:
def get_attacker_nodes():
    row, col = data.edge_index
    neighbors = col[row == attacked_node]
    
    