In [8]:
import torch
import torch_geometric
from data.custom_dataloader import MYDataLoader
from torch_geometric.datasets import ZINC, TUDataset
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected

In [9]:
from torch_sparse import SparseTensor

In [32]:
adj = (torch.rand(4, 4) > 0.7)
adj = ((adj + adj.t() / 2) > 0).to(torch.float)

In [33]:
sp_adj = SparseTensor.from_dense(adj)
edge_index = torch.stack((sp_adj.storage.row(), sp_adj.storage.col()), dim=0)

In [34]:
edge_index

tensor([[0, 0, 0, 1, 1, 2, 3, 3],
        [1, 2, 3, 0, 3, 0, 0, 1]])

In [35]:
mask = torch.tensor([1, 1, 0, 0,], dtype=torch.float)

In [36]:
mask_adj = adj * mask[None] * mask[:, None]

In [37]:
mask_adj

tensor([[0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

# part1

In [38]:
mask_adj = torch.nn.Parameter(mask_adj)

In [39]:
torch.manual_seed(2022)
x = torch.rand(4, 3)

In [40]:
logits = mask_adj @ x

In [41]:
logits

tensor([[0.3811, 0.0262, 0.3594],
        [0.3958, 0.9219, 0.7588],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000]], grad_fn=<MmBackward0>)

In [42]:
logits.sum().backward()

In [43]:
mask_adj.grad

tensor([[2.0766, 0.7667, 2.0388, 1.5321],
        [2.0766, 0.7667, 2.0388, 1.5321],
        [2.0766, 0.7667, 2.0388, 1.5321],
        [2.0766, 0.7667, 2.0388, 1.5321]])

In [44]:
final_grad = (mask_adj.grad * adj.detach()).sum(1)
final_grad

tensor([4.3376, 3.6087, 2.0766, 2.8433])

# part2

In [45]:
from torch_geometric.nn import MessagePassing

class GINEConv(MessagePassing):
    def __init__(self):
        super(GINEConv, self).__init__(aggr="add")
    
    def forward(self, x, edge_index, edge_weight):
        if edge_weight.ndim < 2:
            edge_weight = edge_weight[:, None]

        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        return out

    def message(self, x_j, edge_weight):
        # x_j has shape [E, out_channels]
        # print(x_j)
        return x_j * edge_weight

    def update(self, aggr_out):
        return aggr_out

In [46]:
torch.manual_seed(2022)
model = GINEConv()

In [47]:
# edge_weight = torch.ones(edge_index.shape[1], dtype=torch.float)
mask_edge_weight = (mask[edge_index[0]] *  mask[edge_index[1]]).to(torch.float)

In [48]:
mask_edge_weight = torch.nn.Parameter(mask_edge_weight)

In [50]:
logits = model(x, edge_index, mask_edge_weight)

In [51]:
logits

tensor([[0.3811, 0.0262, 0.3594],
        [0.3958, 0.9219, 0.7588],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000]], grad_fn=<ScatterAddBackward0>)

In [52]:
logits.sum().backward()

In [53]:
mask_edge_weight.grad

tensor([2.0766, 2.0766, 2.0766, 0.7667, 0.7667, 2.0388, 1.5321, 1.5321])

In [56]:
from torch_scatter import scatter

final_grad = scatter(mask_edge_weight.grad, edge_index[1], dim=0, reduce='sum')

In [57]:
final_grad

tensor([4.3376, 3.6087, 2.0766, 2.8433])

# part 3

In [66]:
from torch_scatter import scatter

class NodemaskToEdgemask(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mask, *args):
        edge_index, n_nodes = args
        ctx.save_for_backward(mask, edge_index, n_nodes)
        return (mask[edge_index[0]] *  mask[edge_index[1]]).to(torch.float)

    @staticmethod
    def backward(ctx, grad_output):
        mask, edge_index, n_nodes = ctx.saved_tensors
        grad_input = grad_output.clone()
        final_grad = scatter(grad_output, edge_index[1], dim=0, reduce='sum', dim_size=n_nodes)
        return final_grad, None, None

In [67]:
masking = NodemaskToEdgemask.apply

In [68]:
mask = torch.nn.Parameter(mask)

In [69]:
mask_edge_weight = masking(mask, edge_index, torch.tensor(mask.shape[0]))

In [70]:
logits = model(x, edge_index, mask_edge_weight)

In [71]:
logits.sum().backward()

In [72]:
mask.grad

tensor([4.3376, 3.6087, 2.0766, 2.8433])