In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

from src.data_loader import get_data
from src.model import DyanFHEGPIAN
from src.poa import PufferfishOptimizer
from src.utils import create_results_dir, plot_training_history, calculate_metrics

# --- Configuration ---
TICKER = "MSFT"
SEQ_LEN = 60 # Use 60 days of data to predict the next day
TARGET_COL = 'Close'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

def create_sequences(data, seq_length):
    """Creates sequences and corresponding labels."""
    xs, ys = [], []
    for i in range(len(data) - seq_length):
        x = data[i:(i + seq_length)]
        y = data[i + seq_length]
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)

def prepare_data():
    """Fetches, processes, and prepares data for the model."""
    df = get_data(ticker=TICKER)

    # Select features and target
    features = df.drop(columns=[TARGET_COL])
    target = df[[TARGET_COL]]

    # Scale the data
    scaler_features = MinMaxScaler()
    scaler_target = MinMaxScaler()

    features_scaled = scaler_features.fit_transform(features)
    target_scaled = scaler_target.fit_transform(target)

    # Combine back for sequencing
    data_scaled = np.concatenate([features_scaled, target_scaled], axis=1)

    # Create sequences
    X, y = create_sequences(data_scaled, SEQ_LEN)

    # The target in y is the last column of the sequence
    y = y[:, -1]

    # Split data
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, shuffle=False)

    # Convert to tensors
    X_train = torch.tensor(X_train, dtype=torch.float32).to(DEVICE)
    y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1).to(DEVICE)
    X_val = torch.tensor(X_val, dtype=torch.float32).to(DEVICE)
    y_val = torch.tensor(y_val, dtype=torch.float32).unsqueeze(1).to(DEVICE)

    print(f"Training data shape: {X_train.shape}")
    print(f"Validation data shape: {X_val.shape}")

    return X_train, y_train, X_val, y_val, scaler_target

def objective_function(params):
    """
    The objective function for the Pufferfish Optimizer.
    It trains the model with a given set of hyperparameters and returns the validation loss.
    """
    # Unpack hyperparameters
    lr = params['lr']
    embed_dim = params['embed_dim']
    hidden_dim = params['hidden_dim']
    num_heads = params['num_heads']
    lambda_physics = params['lambda_physics']
    epochs = 30 # Use a fixed number of epochs for each POA evaluation

    # --- Model Setup ---
    input_dim = X_train.shape[2]

    maxvit_params = {'embed_dim': embed_dim, 'num_heads': num_heads, 'num_blocks': 2}
    feinn_params = {'hidden_dim': hidden_dim, 'output_dim': 1, 'n_layers': 2}
    dhgann_params = {'hidden_dim': hidden_dim, 'output_dim': hidden_dim // 2, 'num_heads': num_heads}

    model = DyanFHEGPIAN(
        input_dim=input_dim,
        seq_len=SEQ_LEN,
        maxvit_params=maxvit_params,
        feinn_params=feinn_params,
        dhgann_params=dhgann_params
    ).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True)

    # --- Training Loop ---
    for epoch in range(epochs):
        model.train()
        for x_batch, y_batch in train_loader:
            optimizer.zero_grad()

            # Main prediction
            y_pred, features_for_loss = model(x_batch)
            prediction_loss = criterion(y_pred, y_batch)

            # Physics-informed loss from FEINN
            feinn_module = model.feinn
            feinn_displacement = feinn_module(features_for_loss)
            physics_loss = feinn_module.compute_physics_loss(features_for_loss, feinn_displacement)

            # Total loss
            total_loss = prediction_loss + lambda_physics * physics_loss

            total_loss.backward()
            optimizer.step()

    # --- Evaluation ---
    model.eval()
    with torch.no_grad():
        y_val_pred, _ = model(X_val)
        val_loss = criterion(y_val_pred, y_val).item()

    print(f"Params: {params} -> Validation Loss: {val_loss:.6f}")
    return val_loss

def train_final_model(params):
    """
    Trains the final model using the best hyperparameters found by POA.
    """
    print("\n--- Training Final Model with Best Hyperparameters ---")

    # Unpack best hyperparameters
    lr = params['lr']
    embed_dim = params['embed_dim']
    hidden_dim = params['hidden_dim']
    num_heads = params['num_heads']
    lambda_physics = params['lambda_physics']
    epochs = 100 # Train for more epochs on the final model

    # Model Setup
    input_dim = X_train.shape[2]
    maxvit_params = {'embed_dim': embed_dim, 'num_heads': num_heads, 'num_blocks': 2}
    feinn_params = {'hidden_dim': hidden_dim, 'output_dim': 1, 'n_layers': 2}
    dhgann_params = {'hidden_dim': hidden_dim, 'output_dim': hidden_dim // 2, 'num_heads': num_heads}

    model = DyanFHEGPIAN(
        input_dim=input_dim,
        seq_len=SEQ_LEN,
        maxvit_params=maxvit_params,
        feinn_params=feinn_params,
        dhgann_params=dhgann_params
    ).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True)
    val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=32)

    history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': []}

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for x_batch, y_batch in train_loader:
            optimizer.zero_grad()
            y_pred, features = model(x_batch)
            pred_loss = criterion(y_pred, y_batch)
            physics_loss = model.feinn.compute_physics_loss(features, model.feinn(features))
            total_loss = pred_loss + lambda_physics * physics_loss
            total_loss.backward()
            optimizer.step()
            epoch_loss += total_loss.item()

        # Validation
        model.eval()
        val_epoch_loss = 0
        y_true_cls, y_pred_cls = [], []
        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                y_val_pred, _ = model(x_batch)
                val_epoch_loss += criterion(y_val_pred, y_batch).item()

                # For accuracy metric
                y_true_cls.extend((y_batch[1:] > y_batch[:-1]).cpu().numpy())
                y_pred_cls.extend((y_val_pred[1:] > y_val_pred[:-1]).cpu().numpy())

        avg_loss = epoch_loss / len(train_loader)
        avg_val_loss = val_epoch_loss / len(val_loader)

        # Calculate directional accuracy
        acc = calculate_metrics(np.array(y_true_cls), np.array(y_pred_cls))['Accuracy']

        history['loss'].append(avg_loss)
        history['val_loss'].append(avg_val_loss)
        history['accuracy'].append(acc) # Simplified accuracy for history
        history['val_accuracy'].append(acc)

        print(f"Epoch {epoch+1}/{epochs} -> Loss: {avg_loss:.6f}, Val Loss: {avg_val_loss:.6f}, Val Acc: {acc:.4f}")

    # Save the model
    torch.save(model.state_dict(), 'dyan_fheg_pian_model.pth')
    print("Final model saved to dyan_fheg_pian_model.pth")

    return model, history

if __name__ == '__main__':
    create_results_dir()
    X_train, y_train, X_val, y_val, scaler_target = prepare_data()

    # --- POA Optimization ---
    param_bounds = {
        'lr': (1e-4, 1e-2, 'float'),
        'embed_dim': (32, 128, 'int'),
        'hidden_dim': (64, 256, 'int'),
        'num_heads': (2, 8, 'int'),
        'lambda_physics': (0.01, 0.5, 'float')
    }

    # NOTE: POA is computationally expensive. For a quick run, reduce population and generations.
    poa = PufferfishOptimizer(
        objective_function=objective_function,
        param_bounds=param_bounds,
        population_size=5, # Reduced for speed
        max_generations=3  # Reduced for speed
    )

    best_params, _ = poa.optimize()

    # --- Final Training ---
    final_model, history = train_final_model(best_params)

    # --- Final Evaluation ---
    plot_training_history(history)

    final_model.eval()
    with torch.no_grad():
        final_preds_scaled = final_model(X_val)[0]

    # Inverse transform predictions
    final_preds = scaler_target.inverse_transform(final_preds_scaled.cpu().numpy())
    y_val_actual = scaler_target.inverse_transform(y_val.cpu().numpy())

    calculate_metrics(y_val_actual, final_preds)