In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

class PatchGCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=4, dropout=0.5):
        super(PatchGCN, self).__init__()

        self.layers = nn.ModuleList()
        self.layers.append(GCNConv(in_channels, hidden_channels))

        for _ in range(num_layers - 2):
            self.layers.append(GCNConv(hidden_channels, hidden_channels))

        self.layers.append(GCNConv(hidden_channels, out_channels))

        self.dropout = dropout
        self.attention_pooling = nn.Linear(out_channels, 1)  # Match out_channels

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.layers):
            x = conv(x, edge_index)
            if i < len(self.layers) - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)

        attention_weights = torch.softmax(self.attention_pooling(x), dim=0)
        global_feature = torch.sum(attention_weights * x, dim=0, keepdim=True)

        return global_feature, attention_weights

# Graph construction helper function
def construct_graph(features, adjacency_matrix):
    edge_index = torch.nonzero(adjacency_matrix, as_tuple=False).T
    return Data(x=features, edge_index=edge_index)

# Example usage
if __name__ == "__main__":
    # Sample node features and adjacency matrix
    num_nodes = 10
    feature_dim = 1024

    features = torch.rand((num_nodes, feature_dim))
    adjacency_matrix = torch.eye(num_nodes) + torch.rand((num_nodes, num_nodes)) > 0.5
    adjacency_matrix = adjacency_matrix.int()

    graph_data = construct_graph(features, adjacency_matrix)

    model = PatchGCN(in_channels=feature_dim, hidden_channels=512, out_channels=128, num_layers=4, dropout=0.5)

    global_feature, attention_weights = model(graph_data.x, graph_data.edge_index)
    print("Global Feature:", global_feature)
    print("Attention Weights:", attention_weights)


Global Feature: tensor([[-0.3328, -0.0308,  0.1009,  0.4461,  0.5876,  0.2351, -0.0224, -0.1435,
          0.7598, -0.4877, -0.0706, -0.0286, -0.0939, -0.1340,  0.0523, -0.1344,
         -0.2848, -0.1130,  0.1791,  0.3997, -0.0550, -0.0801,  0.4190, -0.1020,
          0.2094,  0.0101,  0.5934, -0.4145, -0.1456, -0.2173, -0.3908,  0.2839,
          0.1893,  0.3146,  0.1488, -0.0162,  0.8680,  0.6971,  0.5999,  0.1355,
          0.2126,  0.6969,  0.0809, -0.4190,  0.2180,  0.1755, -1.2294, -0.1417,
          0.1397,  0.3049, -0.5066, -0.2639,  0.0325, -0.0796,  0.4065,  0.3282,
          0.2626,  0.5231,  0.0906, -0.4036, -0.2759, -0.1835, -0.1665, -0.1994,
          0.0653,  0.4721, -0.0136,  0.8431, -0.0410, -0.1298,  0.2136,  0.0429,
          0.4508, -0.2869, -0.1636,  0.4278, -0.5509, -0.2560,  0.2993,  0.2012,
         -0.0703, -0.2811,  0.6018,  0.0763,  0.4343, -0.5704, -0.3667, -0.3250,
         -0.4640,  0.5785, -0.1030, -0.1797,  0.0746, -0.2202,  0.0262,  0.4571,
         -0.