In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, RobustScaler

from src.DatabaseGetInfo import DatabaseAnalyzer


def wind_power_forecast(
    X_,
    y_,
    wf_id,
    start_date,
    end_date,
):
    """
    Preprocesses data and identifies important features for wind power forecasting using an LSTM model.

    Args:
        X_ (pd.DataFrame): DataFrame containing the feature data.
        y_ (pd.Series): Series containing the target variable (wind power production).
        wf_id (int): Wind farm ID.
        start_date (str): Start date for data retrieval in 'YYYY-MM-DD' format.
        end_date (str): End date for data retrieval in 'YYYY-MM-DD' format.

    Returns:
        tuple: A tuple containing the preprocessed training and testing data (X_train, X_test, y_train, y_test).
    """

    # Convert to NumPy arrays for easier calculations
    X = np.asarray(X_)
    y = np.asarray(y_)

    # Split data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, shuffle=False
    )

    # Scale the data
    # Initialize the scaler
    x_scaler = RobustScaler()
    y_scaler = RobustScaler()

    # Fit the scaler on the training data and transform the training and testing data
    X_train = x_scaler.fit_transform(X_train)
    X_test = x_scaler.transform(X_test)

    # Apply the same for y
    y_train = y_scaler.fit_transform(y_train.reshape(-1, 1))
    y_test = y_scaler.transform(y_test.reshape(-1, 1))

    # Find the best features using the training set
    # Calculate the correlation matrix
    corr_matrix = np.corrcoef(X_train.T, y_train.T)

    # Sort the features by their correlation
    sorted_indices = np.argsort(np.abs(corr_matrix[-1, :-1]))

    # Plot the correlation matrix
    # Bar plot with the sorted indices
    plt.figure(figsize=(40, 10))
    plt.bar(
        np.arange(len(corr_matrix) - 1),
        np.abs(corr_matrix[-1, :-1])[sorted_indices],
    )
    plt.title("Correlation between features and wind power")
    plt.xticks(np.arange(len(corr_matrix) - 1), X_.columns[sorted_indices])
    plt.show()

    # Print the top 5 features (most important)
    print("\nTop 5 features:")
    for i in range(5):
        print(
            f"{X_.columns[sorted_indices[-(i+1)]]}:"
            f" {corr_matrix[-1, :-1][sorted_indices[-(i+1)]]}"
        )

    return X_train, X_test, y_train, y_test


def exploratory_data_analysis(
    df, target_variable, bins=10, hist=False, kde=True
):
    """
    Performs exploratory data analysis (EDA) on a DataFrame.

    Args:
        df (pd.DataFrame): DataFrame containing the data.
        target_variable (str): Name of the target variable column.
        bins (int, optional): Number of bins for histograms. Defaults to 10.
        hist (bool, optional): Whether to plot histograms. Defaults to False.
        kde (bool, optional): Whether to plot kernel density estimates. Defaults to True.
    """

    # Separate numerical and categorical columns
    num_cols = df.select_dtypes(include=np.number).columns.tolist()
    cat_cols = df.select_dtypes(exclude=np.number).columns.tolist()

    # Density plots and histograms for numerical features
    if len(num_cols) > 0:
        for col in num_cols:
            if col == target_variable:
                continue  # Skip target variable
            plt.figure(figsize=(10, 5))
            sns.histplot(df[col], bins=bins, kde=kde)
            plt.title(f"Density Plot of {col}")
            plt.xlabel(col)
            plt.ylabel("Density")
            plt.show()

    # Bar plots for categorical features
    if len(cat_cols) > 0:
        for col in cat_cols:
            if col == target_variable:
                continue  # Skip target variable
            plt.figure(figsize=(10, 5))
            sns.countplot(x=col, data=df)
            plt.title(f"Count Plot of {col}")
            plt.xlabel(col)
            plt.ylabel("Count")
            plt.show()


def plot_box(df, cols, col_y="power"):
    """
    Generates box plots for categorical features against a target variable.

    Args:
        df (pd.DataFrame): DataFrame containing the data.
        cols (list): List of categorical columns to plot.
        col_y (str, optional): Name of the target variable column. Defaults to "power".
    """
    for col in cols:
        plt.figure(figsize=(15, 10))
        sns.set_style("whitegrid")
        ax = sns.boxplot(x=col, y=col_y, data=df, orient="v")

        # Set the ticks explicitly
        ax.set_xticks(range(len(df[col].unique())))
        ax.set_xticklabels(ax.get_xticklabels(), rotation=90)

        plt.xlabel(col)  # Set text for the x axis
        plt.ylabel(col_y)  # Set text for y axis
        plt.show()


def plot_density_2d(df, cols, col_y="power", kind="kde"):
    """
    Generates 2D density plots for numerical features against a target variable.

    Args:
        df (pd.DataFrame): DataFrame containing the data.
        cols (list): List of numerical columns to plot.
        col_y (str, optional): Name of the target variable column. Defaults to "power".
        kind (str, optional): Type of plot to generate. Defaults to "kde".
    """
    for col in cols:
        sns.set_style("whitegrid")
        sns.jointplot(x=col, y=col_y, data=df, kind=kind)
        plt.xlabel(col)  # Set text for the x axis
        plt.ylabel(col_y)  # Set text for y axis
        plt.show()


def plot_feature_correlation(df):
    """
    Plots a heatmap of the feature correlation matrix.

    Args:
        df (pd.DataFrame): DataFrame containing the data.
    """
    plt.figure(figsize=(21, 13))
    sns.heatmap(df.corr(), cmap="Greens", annot=True)
    plt.title("Feature correlation")
    plt.show()


# Initialize the database analyzer
hostname = os.uname()[1]
if hostname == "penguin":
    path = os.path.abspath("/home/kutlay/WFD/wfd.db")
elif hostname == "GLaDOS":
    path = os.path.abspath("/home/wfd/WFD/wfd.db")  # remote server
analyzer = DatabaseAnalyzer.WindFarmAnalyzer(path)

# Get the list of wind farm IDs
# wf_id_list = analyzer.get_wf_id_list()
wf_id_list = [9]
start_date = "2022-01-01"
end_date = "2023-12-31"

# Loop through wind farm IDs
for wf_id in wf_id_list:
    # Get ERA5 data
    X_ = (
        analyzer.get_era5_data(
            wf_id,
            start_date,
            end_date,
            variables_to_plot=[
                "temperature",
                "pressure",
                "dew_point",
                "surface_sensible_heat_flux",
                "mean_sea_level_pressure",
                "u10n",
                "v10n",
                "fg10",
                "i10fg",
                "surface_latent_heat_flux",
                "boundary_layer_dissipation",
                "boundary_layer_height",
                "charnock",
                "forecast_surface_roughness",
                "friction_velocity",
                "ws100",
                "wd100",
                "ws10",
                "wd10",
            ],
        )
        .drop(columns=["timestamp"])
        .fillna(0)
    )

    # Get production data
    y_ = analyzer.get_wind_production_data(
        wf_id, start_date, end_date, CF=False, frequency="hourly"
    )["production"]

    # Skip if there is no production data for the given dates
    if y_ is None:
        print(f"No production data found for {wf_id}, skipping")
        continue

    # Call the functions
    wind_power_forecast(
        X_,
        y_,
        wf_id,
        start_date,
        end_date,
    )

    # Perform
    exploratory_data_analysis(
        pd.concat([X_, y_], axis=1),
        "production",
        bins=10,
        hist=False,
        kde=True,
    )

    # Call the other plotting functions
    plot_box(pd.concat([X_, y_], axis=1), ["wd100", "wd10"], "production")
    plot_density_2d(
        pd.concat([X_, y_], axis=1),
        [
            "temperature",
            "pressure",
            "dew_point",
            "surface_sensible_heat_flux",
            "mean_sea_level_pressure",
            "u10n",
            "v10n",
            "fg10",
            "i10fg",
            "surface_latent_heat_flux",
            "boundary_layer_dissipation",
            "boundary_layer_height",
            "charnock",
            "forecast_surface_roughness",
            "friction_velocity",
            "ws100",
            "ws10",
        ],
        "production",
    )
    plot_feature_correlation(pd.concat([X_, y_], axis=1))