In [None]:
import torch
import torch.nn as nn

class GCN(nn.Module):
    def __init__(self, num_nodes, num_position_features, num_classes):
        super(GCN, self).__init__()

        # Define layers for GCN
        self.adjacency_conv = nn.Conv2d(1, 16, kernel_size=1)
        self.position_conv = nn.Conv2d(num_position_features, 32, kernel_size=1)
        self.structuring_element_adj_conv = nn.Conv2d(1, 16, kernel_size=1)
        self.structuring_element_pos_conv = nn.Conv2d(2, 16, kernel_size=1)
        self.fc = nn.Linear(80, num_classes)

    def forward(self, adjacency, position_features, structuring_element_adj, structuring_element_pos):
        # Apply the adjacency convolution
        adjacency_features = self.adjacency_conv(adjacency)
        
        # Apply the position convolution
        position_features = self.position_conv(position_features)
        
        # Apply the structuring element adjacency convolution
        structuring_element_adj_features = self.structuring_element_adj_conv(structuring_element_adj)
        
        # Apply the structuring element position convolution
        structuring_element_pos_features = self.structuring_element_pos_conv(structuring_element_pos)
        
        # Sum all the features
        combined_features = adjacency_features + position_features + structuring_element_adj_features + structuring_element_pos_features
        
        # Perform global pooling (e.g., mean pooling) to obtain graph-level features
        graph_features = torch.mean(combined_features, dim=(2, 3))
        
        # Feed graph-level features to fully connected layer
        output = self.fc(graph_features)

        return output

# Example usage:
num_nodes = 100
num_position_features = 2
num_classes = 2

# Generate dummy input data
adjacency = torch.randn(1, 1, num_nodes, num_nodes)
position_features = torch.randn(1, num_position_features, num_nodes, num_nodes)
structuring_element_adj = torch.randn(1, 1, 4, 4)
structuring_element_pos = torch.randn(1, 2, 4, 4)

# Create GCN model
model = GCN(num_nodes, num_position_features, num_classes)

# Forward pass
output = model(adjacency, position_features, structuring_element_adj, structuring_element_pos)
print(output.shape)


In [None]:
#demo code for using GCN

import torch
import torch.nn as nn

class GCN(nn.Module):
    def __init__(self, num_nodes, num_position_features, num_classes):
        super(GCN, self).__init__()

        # Define layers for GCN
        self.adjacency_conv = nn.Conv2d(1, 16, kernel_size=1)
        self.position_conv = nn.Conv2d(num_position_features, 32, kernel_size=1)
        self.structuring_element_adj_conv = nn.Conv2d(1, 16, kernel_size=1)
        self.structuring_element_pos_conv = nn.Conv2d(2, 16, kernel_size=1)
        self.fc = nn.Linear(80, num_classes)

    def forward(self, adjacency, position_features, structuring_element_adj, structuring_element_pos):
        # Apply the adjacency convolution
        adjacency_features = self.adjacency_conv(adjacency)
        
        # Apply the position convolution
        position_features = self.position_conv(position_features)
        
        # Apply the structuring element adjacency convolution
        structuring_element_adj_features = self.structuring_element_adj_conv(structuring_element_adj)
        
        # Apply the structuring element position convolution
        structuring_element_pos_features = self.structuring_element_pos_conv(structuring_element_pos)
        
        # Sum all the features
        combined_features = adjacency_features + position_features + structuring_element_adj_features + structuring_element_pos_features
        
        # Perform global pooling (e.g., mean pooling) to obtain graph-level features
        graph_features = torch.mean(combined_features, dim=(2, 3))
        
        # Feed graph-level features to fully connected layer
        output = self.fc(graph_features)

        return output

# Example usage:
num_nodes = 100
num_position_features = 2
num_classes = 2

# Generate dummy input data
adjacency = torch.randn(1, 1, num_nodes, num_nodes)
position_features = torch.randn(1, num_position_features, num_nodes, num_nodes)
structuring_element_adj = torch.randn(1, 1, 4, 4)
structuring_element_pos = torch.randn(1, 2, 4, 4)

# Create GCN model
model = GCN(num_nodes, num_position_features, num_classes)

# Forward pass
output = model(adjacency, position_features, structuring_element_adj, structuring_element_pos)
print(output.shape)

# Generate dummy target data
target_adjacency = torch.randn(1, 1, num_nodes, num_nodes)
target_position = torch.randn(1, num_position_features, num_nodes, num_nodes)

# Concatenate position features with target adjacency
concatenated_target = torch.cat((target_adjacency, target_position), dim=1)

# Compare output with concatenated target
criterion = nn.MSELoss()
loss = criterion(output, concatenated_target)
print(loss)
