In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.data import Data

class DHGANN(nn.Module):
    """
    Dynamic Heterogeneous Graph Attention Network (DHGANN).

    This model represents financial features as nodes in a graph and learns the
    dynamic relationships between them using a Graph Attention Network (GAT).

    It's "heterogeneous" in the sense that nodes (features) are of different types
    (e.g., price, volume, momentum), and "dynamic" because the graph structure
    and attention weights can change with each input.
    """
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads=4):
        super(DHGANN, self).__init__()
        self.input_dim = input_dim # This will be the sequence length
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        # We treat each feature type as a node. The input to the GAT will be a feature matrix
        # where each row is a node embedding. We use a simple linear layer to create initial embeddings.
        self.node_embedding = nn.Linear(self.input_dim, hidden_dim)

        # Graph Attention Layers
        self.gat_conv1 = GATConv(hidden_dim, hidden_dim, heads=num_heads, concat=True)
        self.gat_conv2 = GATConv(hidden_dim * num_heads, output_dim, heads=1, concat=False)

        self.elu = nn.ELU()

    def forward(self, x):
        """
        Forward pass for DHGANN.

        Args:
            x (torch.Tensor): The input time-series data of shape
                              (batch_size, sequence_length, num_features).

        Returns:
            torch.Tensor: A graph-based feature representation of shape (batch_size, output_dim).
        """
        batch_size, seq_len, num_features = x.shape

        # We want to model relationships between features. So we transpose the input
        # to treat features as nodes and the time sequence as node features.
        # Shape becomes (batch_size, num_features, seq_len)
        x = x.permute(0, 2, 1)

        # Process each item in the batch
        batch_outputs = []
        for i in range(batch_size):
            # Get the node features for the current graph in the batch
            # Shape: (num_features, seq_len)
            node_features = x[i]

            # 1. Create initial node embeddings
            # Shape: (num_features, hidden_dim)
            h = self.node_embedding(node_features)

            # 2. Dynamically create the graph structure (edge_index)
            # For simplicity, we create a fully connected graph where every feature
            # is connected to every other feature. A more complex model could learn
            # the edge structure.
            num_nodes = num_features
            edge_index = self._create_fully_connected_graph(num_nodes, h.device)

            # 3. Apply Graph Attention Layers
            h = self.elu(self.gat_conv1(h, edge_index))
            h = self.gat_conv2(h, edge_index)

            # 4. Aggregate node features to get a single graph representation
            # We use mean pooling over all nodes.
            graph_embedding = torch.mean(h, dim=0)
            batch_outputs.append(graph_embedding)

        # Stack the outputs for the batch
        output = torch.stack(batch_outputs, dim=0)
        return output

    def _create_fully_connected_graph(self, num_nodes, device):
        """Helper to create a fully connected edge index."""
        nodes = torch.arange(num_nodes, device=device)
        # Create all possible pairs of nodes (edges)
        edge_index = torch.cartesian_prod(nodes, nodes).t()
        # Remove self-loops
        edge_index = edge_index[:, edge_index[0] != edge_index[1]]
        return edge_index

if __name__ == '__main__':
    # Example Usage
    BATCH_SIZE = 4
    SEQ_LEN = 60
    NUM_FEATURES = 10
    HIDDEN_DIM = 128
    OUTPUT_DIM = 64 # Dimension of the final graph embedding

    model = DHGANN(
        input_dim=SEQ_LEN,
        hidden_dim=HIDDEN_DIM,
        output_dim=OUTPUT_DIM
    )

    # Dummy input tensor (batch, seq_len, num_features)
    dummy_input = torch.randn(BATCH_SIZE, SEQ_LEN, NUM_FEATURES)

    # Get the output
    output = model(dummy_input)

    print("--- DHGANN Example ---")
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output graph embedding shape: {output.shape}") # Should be (BATCH_SIZE, OUTPUT_DIM)