In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.nn import Linear
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops, degree

In [36]:
from typing import Union


from torch import Tensor
from torch_geometric.typing import SparseTensor


class CustomGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(CustomGNNLayer, self).__init__(aggr='add')
        self.edge_encoder = nn.Linear(in_channels, out_channels)
        self.node_encoder = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index, edge_attr):
        # edge_attr = self.edge_encoder(edge_attr)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, forward=True)
        edge_out = self.edge_updater(edge_index, x=x, edge_attr=edge_attr)

        return out, edge_out

    def message(self, x_i, x_j, edge_attr, forward=True):
        # if forward:
            # 
        return x_i + x_j + edge_attr

    def update(self, aggr_out):
        return torch.ones_like(aggr_out)
        # return aggr_out

    def edge_update(self, x_i, x_j, edge_attr):
        # Custom function to update edge attributes
        # You can specify how you want to update edge attributes here
        updated_edge_attr = edge_attr + x_i + x_j
        return updated_edge_attr



In [37]:
# Create a toy graph with 3 nodes and 2 edges
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float)
edge_attr = torch.tensor([[0.1], [0.2], [0.3], [0.4]], dtype=torch.float)

# Create a data object
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

# Instantiate the custom GNN layer
gnn_layer = CustomGNNLayer(in_channels=1, out_channels=1)

# Forward pass through the GNN layer
output = gnn_layer(data.x, data.edge_index, data.edge_attr)

print(edge_attr)

print("Output after GNN layer:\n", output)

tensor([[0.1000],
        [0.2000],
        [0.3000],
        [0.4000]])
Output after GNN layer:
 (tensor([[1.],
        [1.],
        [1.]]), tensor([[3.1000],
        [3.2000],
        [5.3000],
        [5.4000]]))
