In [None]:
import torch
import torch.nn as nn

class FEINN(nn.Module):
    """
    Finite Element-Integrated Neural Network (FEINN).

    This model interprets financial time-series data through the lens of continuum mechanics.
    It predicts 'displacements' (price movements) and calculates 'strain' (trend intensity)
    and 'stress' (market volatility) to inform its predictions and loss calculation.

    The core idea is to add a physics-informed regularization term to the standard loss,
    encouraging the model to learn solutions that are consistent with these pseudo-physical laws.
    """
    def __init__(self, input_dim, hidden_dim, output_dim=1, n_layers=2):
        super(FEINN, self).__init__()
        self.input_dim = input_dim

        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(n_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
        layers.append(nn.Linear(hidden_dim, output_dim))

        self.net = nn.Sequential(*layers)

        # Constitutive matrix 'C' (learnable), analogous to material properties.
        # It maps strain to stress. For simplicity, we model it as a single learnable parameter
        # representing market resilience/reactivity.
        self.C = nn.Parameter(torch.randn(1))

    def forward(self, x):
        """
        The forward pass predicts the 'displacement' (price change).

        Args:
            x (torch.Tensor): Input features from MaxViT of shape (batch_size, feature_dim).

        Returns:
            torch.Tensor: Predicted displacement (price change).
        """
        predicted_displacement = self.net(x)
        return predicted_displacement

    def compute_physics_loss(self, features, predicted_displacement):
        """
        Computes the physics-informed loss component.
        This loss penalizes solutions that violate the assumed stress-strain relationship.

        Args:
            features (torch.Tensor): The original input features.
            predicted_displacement (torch.Tensor): The model's output.

        Returns:
            torch.Tensor: The physics-informed loss value.
        """
        # 1. Calculate Strain (ε) - Proportional to the rate of price change (gradient).
        # We can approximate strain as the magnitude of the predicted displacement.
        # A larger predicted move implies a higher strain on the market trend.
        strain = torch.abs(predicted_displacement)

        # 2. Calculate Market Stress (σ) - Proportional to market volatility.
        # We can approximate this from input features, like a volatility indicator.
        # Assuming the last feature in the input is a volatility measure.
        # This is a simplification; a more complex model could learn this from multiple features.
        volatility_feature = features[:, -1].unsqueeze(1) # Assuming last feature is volatility
        stress_observed = torch.abs(volatility_feature)

        # 3. Predict Stress from Strain using the constitutive law: σ_predicted = C * ε
        # The learnable parameter 'C' represents the market's stiffness or elasticity.
        stress_predicted = self.C * strain

        # 4. The physics loss is the discrepancy between observed stress and predicted stress.
        # This is the "weak-form loss" mentioned in the paper.
        physics_loss = F.mse_loss(stress_predicted, stress_observed)

        return physics_loss

if __name__ == '__main__':
    # Example usage
    BATCH_SIZE = 8
    FEATURE_DIM = 64 # Output from MaxViT
    HIDDEN_DIM = 128

    model = FEINN(input_dim=FEATURE_DIM, hidden_dim=HIDDEN_DIM)

    # Dummy input features (output from MaxViT)
    dummy_features = torch.randn(BATCH_SIZE, FEATURE_DIM)
    # Ensure the last feature is positive to represent volatility
    dummy_features[:, -1] = torch.rand(BATCH_SIZE)

    # Dummy target price change
    dummy_target = torch.randn(BATCH_SIZE, 1)

    # --- Training Step Example ---
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Forward pass
    predicted_displacement = model(dummy_features)

    # Calculate standard prediction loss (e.g., MSE)
    prediction_loss = F.mse_loss(predicted_displacement, dummy_target)

    # Calculate physics-informed loss
    physics_loss = model.compute_physics_loss(dummy_features, predicted_displacement)

    # Combine the losses
    # The lambda (e.g., 0.1) is a hyperparameter to balance the two loss terms.
    lambda_physics = 0.1
    total_loss = prediction_loss + lambda_physics * physics_loss

    # Backpropagation
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    print("--- FEINN Example ---")
    print(f"Input features shape: {dummy_features.shape}")
    print(f"Predicted displacement shape: {predicted_displacement.shape}")
    print(f"Prediction Loss: {prediction_loss.item():.4f}")
    print(f"Physics Loss: {physics_loss.item():.4f}")
    print(f"Total Loss: {total_loss.item():.4f}")
    print(f"Learnable 'C' parameter: {model.C.item():.4f}")