In [None]:
# -----------------------------
# 0. Carregar Dependências
# -----------------------------
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# -----------------------------
# 3. Definir modelos
# -----------------------------
class GCN(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int) -> None:
        """Init function for the GCN model.

        Args:
            input_dim: Dimension of the input features
            hidden_dim: Dimension of the hidden layer
        """
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        """Forward function for the GCN model.

        Args:
            x: Input node features
            edge_index: Graph edge indices (COO)

        Returns:
            Output (updated) node features with message passing.
        """
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

class EdgeMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int) -> None:
        """Init function for the EdgeMLP model.

        Args:
            input_dim: Dimension of the input features (concatenated embeddings)
            hidden_dim: Dimension of the hidden layer
            output_dim: Dimension of the output (number of classes)
        """
        super(EdgeMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward function for the EdgeMLP model.

        Args:
            x: Input edge features (concatenated node embeddings)

        Returns:
            Output edge class logits
        """
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# Function to create a single dummy graph
def create_dummy_graph(
    max_num_nodes: int, feature_dim: int, num_classes: int
) -> tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Create a dummy graph with random edges and labels.

    Args:
        max_num_nodes: Maximum number of nodes in the graph
        feature_dim: Dimension of the node features
        num_classes: Number of target classes

    Returns:
        num_nodes: Number of nodes in the graph
        node_features: Node features tensor
        edge_index: Graph edge indices tensor
        edge_labels: Edge labels tensor

    """
    # Randomly create nodes and edges
    num_nodes = torch.randint(2, max_num_nodes, (1,)).item()
    node_features = torch.randn(num_nodes, feature_dim)

    num_edges = torch.randint(1, num_nodes * 2, (1,)).item()
    edge_index = torch.randint(0, num_nodes, (2, num_edges))

    # Assign random labels to edges
    edge_labels = torch.randint(0, num_classes, (num_edges,))

    return num_nodes, node_features, edge_index, edge_labels

In [None]:
# Define the total number of entities and embedding dimensions
num_total_entities = 32  # Total number of entities
embedding_dim = 16  # Dimension of each embedding vector
num_classes = 2  # Number of target classes (for edge classification)

# Parameters
hidden_dim = 32
num_graphs = 256  # Number of graphs in the dataset
batch_size = 128
epochs = 1000
learning_rate = 0.001


# Initialize the models
gcn_model = GCN(input_dim=embedding_dim, hidden_dim=hidden_dim)
edge_mlp = EdgeMLP(
    input_dim=2 * hidden_dim, hidden_dim=hidden_dim, output_dim=num_classes
)

# Create dataset
dataset = [
    create_dummy_graph(num_total_entities, embedding_dim, num_classes)
    for _ in range(num_graphs)
]


optimizer = torch.optim.Adam(
    list(gcn_model.parameters()) + list(edge_mlp.parameters()),
    lr=learning_rate,
)
criterion = nn.CrossEntropyLoss()


# Training loop
for epoch in range(1, epochs + 1):
    total_loss = 0
    total_correct = 0
    total_samples = 0

    # iid shuffle of the dataset
    random.shuffle(dataset)
    batches = [dataset[i : i + batch_size] for i in range(0, len(dataset), batch_size)]

    for batch in batches:
        num_nodes_all = [graph[0] for graph in batch]
        node_features_all = [graph[1] for graph in batch]
        node_feature_batch = torch.cat(node_features_all, dim=0)
        edge_index_all = [graph[2] for graph in batch]
        edge_labels_all = [graph[3] for graph in batch]
        edge_labels_batch = torch.cat(edge_labels_all, dim=0)

        # This is a bit tricky: we need to update the edge indices to reflect the new
        # node ordering. We do this by adding the sum of the number of nodes in the
        # previous graphs. This is how batching is done in GNNs.
        edge_index_batch = []
        num_nodes_sum = 0
        for i, edge_index in enumerate(edge_index_all):
            edge_index_batch.append(edge_index + num_nodes_sum)
            num_nodes_sum += num_nodes_all[i]

        edge_index_batch = torch.cat(edge_index_batch, dim=1)

        # Forward pass through GCN to get node embeddings
        node_embeddings_out = gcn_model(node_feature_batch, edge_index_batch)

        # Prepare edge features by concatenating the embeddings of the head and tail nodes
        edge_embeddings = torch.cat(
            [
                node_embeddings_out[edge_index_batch[0]],
                node_embeddings_out[edge_index_batch[1]],
            ],
            dim=1,
        )

        # Forward pass through MLP for edge classification
        out = edge_mlp(edge_embeddings)

        # Compute loss
        loss = criterion(out, edge_labels_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate metrics
        total_loss += loss.item()
        preds = out.argmax(dim=1)
        total_correct += (preds == edge_labels_batch).sum().item()
        total_samples += edge_labels_batch.size(0)

    avg_loss = total_loss / len(batches)
    accuracy = total_correct / total_samples

    print(f"Epoch {epoch:02d} | Loss: {avg_loss:.4f} | Accuracy: {accuracy:.4f}")