In [None]:
import os
from datetime import datetime
import csv
import glob
import matplotlib.pyplot as plt
from PIL import Image
import glob
import numpy as np


from hydrogapai.gap_prediction import (
    predict_station_gaps,
    plot_fully,
    plot_yearly,
)

In [None]:
model_type = "xgb"
station_files = glob.glob("data/lib/station_*_cleaned.csv")

current_datetime = datetime.now().strftime("%Y%m%d_%H%M")
results_base = f"./output/results_{model_type}_{current_datetime}"
results = []

for station_file in station_files:

    station_name = os.path.basename(station_file).replace(".csv", "")

    try:
        results_folder = os.path.join(
            results_base,
            station_name,
        )
        all_combined_dfs, val_full, metrics_gaps, real_predictions = (
            predict_station_gaps(
                station_file,
                results_folder=results_folder,
                model_type=model_type,
                hyper_opt=False,
            )
        )
        results.append(
            {
                "station_file": station_file,
                "outputs": [
                    results_folder,
                    all_combined_dfs,
                    val_full,
                    metrics_gaps,
                    real_predictions,
                ],
            }
        )

    except ValueError as ex:
        ## station has no gaps
        print(f"{ex} for {station_file}")

In [None]:
def snow_station_outputs(
    station_file,
    results_folder,
    all_combined_dfs,
    val_full,
    metrics_gaps,
    real_predictions,
):
    plot_start_date = None
    plot_end_date = None
    # Create a results directory with the current date and time

    os.makedirs(results_folder, exist_ok=True)

    # Extract the filename from the input file path
    filename = os.path.basename(station_file)

    # Process the single file

    # Save combined plot for the full dataset
    plot_fully(
        results_folder,
        val_full,
        real_predictions,
        filename,
        model_type,
        plot_start_date,
        plot_end_date,
    )
    # Save combined plot per year
    plot_yearly(results_folder, val_full, real_predictions, filename, model_type)
    # Save the dataset with the real gaps filled with predictions
    output_csv_filename = os.path.join(results_folder, f"pred_{filename}")
    val_full = val_full.drop(columns=["year"])
    val_full.to_csv(
        output_csv_filename, index=True
    )  # Save the dataset with filled real gaps

    headers = [
        "input_file_path",
        "missing_rate",
        "min_gap_length",
        "mean_gap_length",
        "median_gap_length",
        "max_gap_length",
        "std_gap_length",
        "range_gap_length",
        "gap_density",
        "nr_gap_days",
        "nr_gaps",
        "min_value",
        "mean_value",
        "median_value",
        "max_value",
        "std_value",
        "range_value",
        "skew_value",
        "kurtosis_value",
        "Q_lags",
        "Q_lags_Coefficients",
        "P_lags",
        "P_lags_Coefficients",
        "R2_score",
        "RMSE",
        "Mean Bias Error",
        "MAE",
        "Percentage Error (%)",
        "Nash-Sutcliffe",
        "Index of Agreement",
        "Correlation Coefficient",
        "KGE Overall",
        "KGE Correlation",
        "KGE Bias",
        "KGE Variability",
    ]

    # Create CSV file at the start with headers
    output_file1 = os.path.join(
        results_folder, f"gaps_evaluation_metrics_{model_type}.csv"
    )
    output_file1
    # Append the new metrics to the existing CSV file
    with open(output_file1, mode="a", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=headers)
        writer.writerow(metrics_gaps)

    print(f"Processing completed for {filename}")

    station_dir = os.path.splitext(os.path.basename(station_file))[0]
    station_dir = os.path.join(results_folder, station_dir)

    files = glob.glob(os.path.join(results_folder, "*.png")) + glob.glob(
        os.path.join(results_folder, "*.jpg")
    )
    for file in files:
        img = Image.open(file)
        plt.imshow(img)
        plt.show()

    files = glob.glob(os.path.join(station_dir, "*.png")) + glob.glob(
        os.path.join(station_dir, "*.jpg")
    )
    for file in files:
        img = Image.open(file)
        plt.imshow(img)
        plt.show()

In [None]:
# query and display results for a station file (station with index 0 in this example)

station_name = results[0]["station_file"]
results_folder, all_combined_dfs, val_full, metrics_gaps, real_predictions = results[0][
    "outputs"
]

In [None]:
snow_station_outputs(
    station_name,
    results_folder,
    all_combined_dfs,
    val_full,
    metrics_gaps,
    real_predictions,
)