In [1]:
import torch
import torch_geometric
from custom_dataloder import MYDataLoader
from torch_geometric.datasets import ZINC, TUDataset
from torch_geometric.data import Data

In [2]:
from torch_sparse import SparseTensor

In [3]:
adj = (torch.rand(4, 4) > 0.4).to(torch.float)

In [4]:
sp_adj = SparseTensor.from_dense(adj)
edge_index = torch.stack((sp_adj.storage.row(), sp_adj.storage.col()), dim=0)

In [5]:
adj

tensor([[1., 0., 0., 1.],
        [1., 1., 0., 0.],
        [1., 0., 1., 1.],
        [1., 1., 1., 1.]])

In [6]:
edge_index

tensor([[0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3],
        [0, 3, 0, 1, 0, 2, 3, 0, 1, 2, 3]])

In [7]:
mask = torch.tensor([1, 1, 0, 0,], dtype=torch.float)

In [8]:
mask_adj = adj * mask[None] * mask[:, None]

In [9]:
mask_adj

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

# part1

In [10]:
mask_adj = torch.nn.Parameter(mask_adj)

In [11]:
torch.manual_seed(2022)
x = torch.rand(4, 3)

In [12]:
logits = mask_adj @ x

In [13]:
logits

tensor([[0.3958, 0.9219, 0.7588],
        [0.7769, 0.9482, 1.1182],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000]], grad_fn=<MmBackward0>)

In [14]:
logits.sum().backward()

In [15]:
mask_adj.grad

tensor([[2.0766, 0.7667, 2.0388, 1.5321],
        [2.0766, 0.7667, 2.0388, 1.5321],
        [2.0766, 0.7667, 2.0388, 1.5321],
        [2.0766, 0.7667, 2.0388, 1.5321]])

In [16]:
final_grad = (mask_adj.grad * adj.detach()).sum(1)
final_grad

tensor([3.6087, 2.8433, 5.6474, 6.4142])

# part2

In [17]:
from torch_geometric.nn import MessagePassing

class GINEConv(MessagePassing):
    def __init__(self):
        super(GINEConv, self).__init__(aggr="add")
    
    def forward(self, x, edge_index, edge_weight):
        if edge_weight.ndim < 2:
            edge_weight = edge_weight[:, None]

        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        return out

    def message(self, x_j, edge_weight):
        # x_j has shape [E, out_channels]
        # print(x_j)
        return x_j * edge_weight

    def update(self, aggr_out):
        return aggr_out

In [18]:
torch.manual_seed(2022)
model = GINEConv()

In [19]:
# edge_weight = torch.ones(edge_index.shape[1], dtype=torch.float)
mask_edge_weight = (mask[edge_index[0]] *  mask[edge_index[1]]).to(torch.float)

In [20]:
mask_edge_weight = torch.nn.Parameter(mask_edge_weight)

In [21]:
logits = model(x, edge_index[torch.tensor([1, 0])], mask_edge_weight)

In [22]:
logits

tensor([[0.3958, 0.9219, 0.7588],
        [0.7769, 0.9482, 1.1182],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000]], grad_fn=<ScatterAddBackward0>)

In [23]:
logits.sum().backward()

In [24]:
mask_edge_weight.grad

tensor([2.0766, 1.5321, 2.0766, 0.7667, 2.0766, 2.0388, 1.5321, 2.0766, 0.7667,
        2.0388, 1.5321])

In [25]:
from torch_scatter import scatter

final_grad = scatter(mask_edge_weight.grad, edge_index[0], dim=0, reduce='sum')

In [26]:
final_grad

tensor([3.6087, 2.8433, 5.6474, 6.4142])

# part 3

In [27]:
from torch_scatter import scatter

class NodemaskToEdgemask(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mask, *args):
        edge_index, n_nodes = args
        ctx.save_for_backward(mask, edge_index, n_nodes)
        return (mask[edge_index[0]] *  mask[edge_index[1]]).to(torch.float)

    @staticmethod
    def backward(ctx, grad_output):
        mask, edge_index, n_nodes = ctx.saved_tensors
        grad_input = grad_output.clone()
        final_grad = scatter(grad_output, edge_index[0], dim=0, reduce='sum', dim_size=n_nodes)
        return final_grad, None, None

In [28]:
masking = NodemaskToEdgemask.apply

In [29]:
mask = torch.nn.Parameter(mask)

In [30]:
mask_edge_weight = masking(mask, edge_index, torch.tensor(mask.shape[0]))

In [31]:
logits = model(x, edge_index[torch.tensor([1, 0])], mask_edge_weight)

In [32]:
logits.sum().backward()

In [33]:
mask.grad

tensor([3.6087, 2.8433, 5.6474, 6.4142])