In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

class AttriMIL(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, attr_dim, num_layers=4, dropout=0.5):
        super(AttriMIL, self).__init__()

        self.feature_extraction_layers = nn.ModuleList()
        self.feature_extraction_layers.append(GCNConv(in_channels, hidden_channels))

        for _ in range(num_layers - 2):
            self.feature_extraction_layers.append(GCNConv(hidden_channels, hidden_channels))

        self.feature_extraction_layers.append(GCNConv(hidden_channels, out_channels))

        self.dropout = dropout
        self.attribute_attention_layer = nn.Linear(out_channels + attr_dim, 1)  # Combine features and attributes

    def forward(self, node_features, edge_index, attributes):
        x = node_features
        for i, conv in enumerate(self.feature_extraction_layers):
            x = conv(x, edge_index)
            if i < len(self.feature_extraction_layers) - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)

        # Concatenate node features with attributes
        combined_features = torch.cat([x, attributes], dim=1)

        # Compute attention weights
        attention_weights = torch.softmax(self.attribute_attention_layer(combined_features), dim=0)

        # Compute global feature as a weighted sum of node features
        global_feature = torch.sum(attention_weights * x, dim=0, keepdim=True)

        return global_feature, attention_weights

# Graph construction helper function
def construct_graph(node_features, adjacency_matrix):
    edge_index = torch.nonzero(adjacency_matrix, as_tuple=False).T
    return Data(x=node_features, edge_index=edge_index)

# Example graph givem
if __name__ == "__main__":
    # Sample node features and adjacency matrix
    num_nodes = 10
    feature_dim = 1024
    attr_dim = 128  # Dimension of attribute vectors

    node_features = torch.rand((num_nodes, feature_dim))
    attributes = torch.rand((num_nodes, attr_dim))  # Randomly generated attributes
    adjacency_matrix = torch.eye(num_nodes) + torch.rand((num_nodes, num_nodes)) > 0.5
    adjacency_matrix = adjacency_matrix.int()

    graph_data = construct_graph(node_features, adjacency_matrix)

    model = AttriMIL(in_channels=feature_dim, hidden_channels=512, out_channels=128, attr_dim=attr_dim, num_layers=4, dropout=0.5)

    global_feature, attention_weights = model(graph_data.x, graph_data.edge_index, attributes)
    print("Global Feature:", global_feature)
    print("Attention Weights:", attention_weights)


Global Feature: tensor([[ 0.2639,  0.6149, -0.1608, -0.0069, -0.3186, -0.1865,  0.3596, -0.0507,
          0.1148, -0.1079,  0.0648,  0.1059, -0.4524, -0.0956,  0.0757,  0.0203,
          0.2495, -0.0522,  0.2196,  0.3129, -0.0294, -0.3026,  0.0332,  0.0874,
          0.0477,  0.1471,  0.4346, -0.1119, -0.4331, -0.1022,  0.2118,  0.2733,
         -0.4519, -0.2439,  0.3235, -0.1835, -0.2859, -0.1376,  0.0564,  0.2608,
          0.1397,  0.3809,  0.0424, -0.0729,  0.2603, -0.0425, -0.0673,  0.0048,
          0.1821,  0.0872, -0.1333, -0.0859,  0.1721, -0.2886,  0.2003, -0.0436,
         -0.0066, -0.0401,  0.0611,  0.3757, -0.0481,  0.2436, -0.2375, -0.4859,
          0.1785, -0.3819, -0.2940, -0.0479, -0.2143, -0.2988,  0.1964, -0.1009,
          0.2479,  0.1049,  0.5147, -0.0489, -0.1047,  0.0868, -0.3590, -0.1017,
         -0.1807, -0.1208, -0.0523, -0.3798, -0.0252, -0.0761, -0.0215, -0.0642,
          0.3168,  0.0163, -0.0335, -0.0545,  0.2666,  0.3459, -0.3372, -0.5199,
         -0.