In [1]:
import torch
import random
from torch_geometric.nn import Sequential, DenseGCNConv
from torch.nn import Linear, ReLU


def get_neighbor_subgraphs(adj: torch.Tensor, size: int, n: int) -> torch.Tensor:
    """
    Returns a tensor ~ `(n * size)` of node indices for `n` supgraphs of size `size`
    """
    res = torch.zeros([n, size], dtype=torch.long)

    for i in range(n):
        out = []
        stack = [random.randint(0, adj.shape[0] - 1)]
        while len(out) < size:
            if len(stack) == 0:
                stack.append(random.randint(0, adj.shape[0] - 1))
            curNode = stack.pop()
            if curNode not in out:
                out.append(curNode)
                children = adj[curNode].nonzero().t()[0].cpu().tolist()
                stack = children + stack
        res[i] = torch.tensor(out)

    return res


class PGDAttack:
    def __init__(self, 
                 adj: torch.Tensor, feat: torch.Tensor, label: torch.Tensor, 
                 sample_size=64, num_samples=1000, device='cpu') -> None:
        self.adj = adj
        self.feat = feat
        self.label = label
        self.device = device
        self.sample_size = sample_size
        self.num_samples = num_samples
        self.device = device
        self.subgraph_ids = get_neighbor_subgraphs(adj, 64, num_samples)
    
    def attack(self, num_epochs=1000, ptb_rate=0.25, hid=16) -> torch.Tensor:
        surrogate_gcn = Sequential('x, adj', [
            (DenseGCNConv(self.feat.shape[1], hid), 'x, adj -> x'),
            ReLU(inplace=True),
            (DenseGCNConv(hid, hid), 'x, adj -> x'),
            ReLU(inplace=True),
            Linear(hid, int(self.label.max()) + 1),
        ]).to(self.device)
        return torch.zeros([1,1])