In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        # Weight matrix for the linear transformation
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        # Initialization of the weight matrix
        self.reset_parameters()

    def reset_parameters(self):
        # Initialize the weight matrix with Glorot/Xavier initialization
        stdv = 1. / (self.weight.size(1) ** 0.5)
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        """
        Forward pass of the GCN layer.

        Parameters:
        - input: Node feature matrix (N x in_features), where N is the number of nodes.
        - adj: Adjacency matrix of the graph (N x N), possibly normalized.

        Returns:
        - Output feature matrix for the next layer (N x out_features)
        """
        # Aggregation: Matrix multiplication between the adjacency matrix and the input features
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)

        return output

# Example usage
if __name__ == "__main__":
    # Number of nodes N, input features F_in, and output features F_out
    N = 4  # Example: 4 nodes
    F_in = 5  # Example: 5 input features
    F_out = 2  # Example: 2 output features

    # Example input feature matrix (N x F_in)
    X = torch.rand((N, F_in))

    # Example adjacency matrix (N x N)
    # For simplicity, using an unnormalized adjacency matrix here
    A = torch.tensor([[0, 1, 0, 0],
                      [1, 0, 1, 1],
                      [0, 1, 0, 1],
                      [0, 1, 1, 0]], dtype=torch.float)

    # Initialize the GCN layer
    gcn_layer = GCNLayer(F_in, F_out)

    # Forward pass
    output_features = gcn_layer(X, A)
    print("Output features:\n", output_features)


Output features:
 tensor([[-0.0547, -1.2844],
        [-2.1244, -1.8702],
        [-0.4897, -2.1602],
        [-0.6673, -1.8903]], grad_fn=<MmBackward0>)
