In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.feature_extraction import MaxViTFeatureExtractor
from src.feinn import FEINN
from src.dhgann import DHGANN

class DyanFHEGPIAN(nn.Module):
    """
    The complete Dyan-FHEG-PIAN model.
    This class integrates the MaxViT feature extractor, the FEINN, and the DHGANN
    into a single end-to-end framework for stock forecasting.
    """
    def __init__(self, input_dim, seq_len, maxvit_params, feinn_params, dhgann_params):
        """
        Args:
            input_dim (int): The number of raw input features (e.g., Open, High, etc.).
            seq_len (int): The length of the input time-series sequence.
            maxvit_params (dict): Parameters for the MaxViTFeatureExtractor.
            feinn_params (dict): Parameters for the FEINN module.
            dhgann_params (dict): Parameters for the DHGANN module.
        """
        super(DyanFHEGPIAN, self).__init__()

        # 1. Feature Extraction Module
        self.maxvit = MaxViTFeatureExtractor(
            input_dim=input_dim,
            sequence_length=seq_len,
            **maxvit_params
        )
        # The output dimension from MaxViT will be its embedding dimension.
        feature_extractor_out_dim = maxvit_params['embed_dim']

        # 2. FEINN Module
        # It takes the features from MaxViT as input.
        self.feinn = FEINN(
            input_dim=feature_extractor_out_dim,
            **feinn_params
        )

        # 3. DHGANN Module
        # It takes the original time-series as input to build the graph.
        self.dhgann = DHGANN(
            input_dim=seq_len,
            **dhgann_params
        )

        # 4. Fusion Layer
        # This layer combines the outputs from FEINN and DHGANN.
        # FEINN output is (batch, 1) - the predicted displacement
        # DHGANN output is (batch, dhgann_output_dim) - the graph embedding
        # We concatenate the DHGANN output with the MaxViT features before the final prediction.

        # The final prediction will come from FEINN, but we can have an auxiliary output
        # from DHGANN or combine them. The paper implies a synergistic model.
        # Let's combine the feature vectors before the final prediction layers.

        combined_dim = feature_extractor_out_dim + dhgann_params['output_dim']

        self.fusion_layer = nn.Sequential(
            nn.Linear(combined_dim, combined_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(combined_dim // 2, 1) # Final output is a single value (price change)
        )

    def forward(self, x_seq):
        """
        The main forward pass for the integrated model.

        Args:
            x_seq (torch.Tensor): Input time-series data of shape (batch_size, seq_len, input_dim).

        Returns:
            torch.Tensor: The final predicted price change.
            torch.Tensor: The raw features passed to the physics loss function.
        """
        # 1. Extract high-level temporal features with MaxViT
        # Shape: (batch_size, maxvit_embed_dim)
        maxvit_features = self.maxvit(x_seq)

        # 2. Learn relational features with DHGANN from the raw sequence
        # Shape: (batch_size, dhgann_output_dim)
        dhgann_features = self.dhgann(x_seq)

        # 3. Combine features from both streams
        combined_features = torch.cat([maxvit_features, dhgann_features], dim=1)

        # 4. Make the final prediction using the fusion layer
        final_prediction = self.fusion_layer(combined_features)

        # The FEINN module is used primarily for its physics-informed loss,
        # which acts as a regularizer on the feature representation.
        # We can pass the MaxViT features to it to compute this loss during training.
        # We'll call this explicitly in the training loop.

        return final_prediction, maxvit_features

if __name__ == '__main__':
    # Example usage of the full model
    BATCH_SIZE = 4
    SEQ_LEN = 60
    INPUT_DIM = 10

    # Define hyperparameters for each module
    maxvit_params = {'embed_dim': 64, 'num_heads': 4, 'num_blocks': 2}
    feinn_params = {'hidden_dim': 128, 'output_dim': 1, 'n_layers': 2}
    dhgann_params = {'hidden_dim': 64, 'output_dim': 32, 'num_heads': 2}

    # Instantiate the full model
    model = DyanFHEGPIAN(
        input_dim=INPUT_DIM,
        seq_len=SEQ_LEN,
        maxvit_params=maxvit_params,
        feinn_params=feinn_params,
        dhgann_params=dhgann_params
    )

    # Create a dummy input tensor
    dummy_input = torch.randn(BATCH_SIZE, SEQ_LEN, INPUT_DIM)

    # Get the model output
    prediction, features_for_loss = model(dummy_input)

    # --- Example of Loss Calculation in a Training Loop ---

    # The FEINN module is part of the main model
    feinn_module = model.feinn

    # Calculate the physics loss using the FEINN module
    # The FEINN module itself predicts a displacement from the features
    feinn_displacement = feinn_module(features_for_loss)
    physics_loss = feinn_module.compute_physics_loss(features_for_loss, feinn_displacement)

    print("--- Dyan-FHEG-PIAN Integrated Model ---")
    print(f"Input shape: {dummy_input.shape}")
    print(f"Final prediction shape: {prediction.shape}")
    print(f"Features for physics loss shape: {features_for_loss.shape}")
    print(f"Calculated Physics Loss: {physics_loss.item():.4f}")