In [3]:
import torch
from torch_geometric.data import Dataset
import torch.nn.functional as F
from torch_geometric.nn import Linear
from torch_geometric.nn.dense import DenseGCNConv
from torch_geometric.utils import to_dense_batch, to_dense_adj
from torch_geometric.data import Data, Batch

torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def generate_random_graph(num_nodes: int = 75, num_edges: int = 1500) -> Data:
    edges: torch.Tensor = torch.randint(0, num_nodes, (num_edges, 2), dtype=torch.long)
    x: torch.Tensor = torch.rand((num_nodes, 1))
    return Data(x=x, edge_index=edges.t().contiguous()).coalesce()


class GCN_Dense(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.conv1 = DenseGCNConv(in_channels, hidden_channels)
        self.conv2 = DenseGCNConv(hidden_channels, hidden_channels)
        self.conv3 = DenseGCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x, adj, mask):
        x = self.conv1(x, adj, mask).relu()
        x = self.conv2(x, adj, mask).relu()
        x = self.conv3(x, adj, mask).relu()
        x = x.sum(dim=1)
        return self.lin(x)


model = GCN_Dense(
    in_channels=1,
    hidden_channels=32,
    out_channels=10,
).to(device)

model.load_state_dict(torch.load("checkpoints/dense/small_dense.pt", weights_only=False))

<All keys matched successfully>

In [4]:
batch_size = 32
data: Batch = Batch.from_data_list([generate_random_graph() for _ in range(batch_size)])
data.y = torch.randint(0, 10, (batch_size,)) # Fake labels
data.to(device)
x, mask = to_dense_batch(data.x, data.batch)
adj = to_dense_adj(data.edge_index, data.batch)

In [5]:
model.eval()
x.requires_grad_()
adj.requires_grad_()

noise_x = torch.randn(x.shape, device=device)
noise_adj = torch.randn(adj.shape, device=device)

for i in range(20):
    noise_x.normal_(0, 0.005)
    noise_adj.normal_(0, 0.005)
    x.retain_grad()
    adj.retain_grad()
    energy = model(x, adj, mask)[torch.arange(data.y.size(0)), data.y]
    energy.sum().backward()
    x = x - x.grad + noise_x
    adj = adj - adj.grad + noise_adj

print(adj[0])

tensor([[ 2.8194e-02, -6.3914e-02,  1.6147e-02,  ...,  5.4240e-03,
          9.7387e-01,  9.9486e-01],
        [-6.6771e-03,  9.2277e-03,  9.4047e-01,  ...,  6.8134e-02,
          3.7272e-02,  9.3470e-03],
        [ 1.8744e-02,  4.4778e-03,  4.1326e-02,  ...,  9.8044e-01,
          5.2424e-03, -4.1744e-02],
        ...,
        [ 1.2592e-02,  3.6151e-02, -3.2986e-02,  ..., -4.6441e-02,
          9.6355e-01,  2.2499e-02],
        [ 1.0041e+00,  9.8232e-01,  3.5269e-04,  ...,  2.3624e-02,
          1.0314e+00,  9.8360e-01],
        [-1.4371e-03,  2.2294e-02, -1.3815e-03,  ...,  9.6100e-01,
         -1.2142e-02, -1.4905e-02]], device='cuda:0',
       grad_fn=<SelectBackward0>)
