In [None]:
#!pip install hydrogapai --upgrade

In [1]:
import os
from datetime import datetime
import csv
import glob
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm  # Import tqdm for the progress bar
import gc
from hydrogapai.gap_prediction import (
    predict_station_gaps,
    plot_fully,
    plot_yearly,
)

#Change this folder path to your own
folder_path = r"<address_of_folder_with_csv_dataset(s)>"
all_metrics_folder = os.path.join(folder_path, 'all_metrics')
os.makedirs(all_metrics_folder, exist_ok=True)
# Define the model you want to use: lr  = linear regression = sdgRegressor;
#                                   knn = k nearest neighbours
#                                   svr = support vector regression
#                                   rf  = random forest
#                                   xgb = extreme gradient boost
#                                   lgb = light gradient boostt
model_type = "rf"
# Define the number of lag times for partial auto correlation ('obsdis' streamflow values)
num_pacf_lags=3
# Define the start of the lag time for cross correlation ('tp' precipitation values)
plag_start=1
# Define the number of lag times for cross correlation ('tp' precipitation values)
num_ccf_lags=30

# Define if you want to do hyper-parameter optimization
hyper_opt=False

# Get a list of all CSV files in the folder
station_files = glob.glob(os.path.join(folder_path, "*.csv"))

#Define start and end date of a plot if you need a specific range
plot_start_date = None
plot_end_date = None

headers = ["input_file_path", 
           "missing_rate",
           "min_gap_length", 
           "mean_gap_length", 
           "median_gap_length", 
           "max_gap_length",
           "std_gap_length", 
           "range_gap_length", 
           "gap_density", 
           "nr_gap_days", 
           "nr_gaps", 
           "min_value", 
           "mean_value", 
           "median_value", 
           "max_value", 
           "std_value", 
           "range_value", 
           "skew_value", 
           "kurtosis_value", 
           "Q_lags", 
           "Q_lags_Coefficients", 
           "P_lags", 
           "P_lags_Coefficients",
           "R2_score", 
           "RMSE", 
           "Mean Bias Error", 
           "MAE", 
           "Percentage Error (%)", 
           "Nash-Sutcliffe", 
           "Index of Agreement", 
           "Correlation Coefficient", 
           "KGE Overall", 
           "KGE Correlation", 
           "KGE Bias", 
           "KGE Variability"
]

# Create log CSV file at the start with headers for metrics of all datasets/csv files
current_datetime = datetime.now().strftime("%Y%m%d_%H%M")
output_file = os.path.join(all_metrics_folder, f"_all_metrics_{model_type}.csv")
log_file = os.path.join(all_metrics_folder, f"log_file_{model_type}_{current_datetime}.txt")

# Function to check if headers exist
def check_headers_exist(file_path, headers):
    if not os.path.exists(file_path):
        return False
    with open(file_path, 'r', newline='') as file:
        reader = csv.reader(file)
        existing_headers = next(reader, None)
        return existing_headers == headers

headers_exist = check_headers_exist(output_file, headers)# Check if headers exist
with open(output_file, mode='a', newline='') as file:
    writer = csv.DictWriter(file, fieldnames=headers)
    # Only write the headers if they don't exist
    if not headers_exist:
        writer.writeheader()

with tqdm(total=len(station_files), desc="Processing Files", unit="file") as pbar:
    for station_file in station_files:
        try:
            # Open log file and append the station file path as a new line
            with open(log_file, mode='a') as log:
                log.write(f"Processing file: {station_file}\n")
            
            # Extract the filename from the input file path
            filename = os.path.basename(station_file)
            filename_no_ext = filename.rsplit('.', 1)[0]
            results_folder = os.path.join(
                os.path.dirname(station_file), 
                f"results_{filename_no_ext}"
            )
            os.makedirs(results_folder, exist_ok=True)
            
            all_combined_dfs, val_full, metrics_gaps, real_predictions = predict_station_gaps(
                station_file, results_folder, model_type=model_type, hyper_opt=hyper_opt, num_pacf_lags=num_pacf_lags,
                plag_start=plag_start, num_ccf_lags=num_ccf_lags
            )

            # Save combined plot for the full dataset
            plot_fully(
                results_folder,
                val_full,
                real_predictions,
                filename,
                model_type,
                plot_start_date,
                plot_end_date,
            )

            # Save combined plot per year
            plot_yearly(results_folder, val_full, real_predictions, filename, model_type)

            # Save the dataset with the real gaps filled with predictions
            output_csv_filename = os.path.join(results_folder, f"pred_{filename}")
            val_full = val_full.drop(columns=["year"])
            val_full.to_csv(output_csv_filename, index=True)

            # Create CSV file at the start with headers
            output_file1 = os.path.join(results_folder, f'{filename}_metrics_{model_type}.csv')
            # Append the new metrics to the existing CSV filea
            with open(output_file1, mode='a', newline='') as file:
                writer = csv.DictWriter(file, fieldnames=headers)
                writer.writeheader()
                writer.writerow(metrics_gaps)
            #os.makedirs(results_folder, exist_ok=True)
            with open(output_file, mode='a', newline='') as file:
                writer = csv.DictWriter(file, fieldnames=headers)
                #writer.writeheader()
                writer.writerow(metrics_gaps)
        except Exception as e:
            # Log the error and continue with the next file
            with open(log_file, mode='a') as log:
                log.write(f"Error processing file: {station_file}\n")
                log.write(f"Error details: {str(e)}\n")
        gc.collect
        # Update the progress bar
        pbar.update(1)
        
print(f"Processing completed for {filename}")

Processing Files: 100%|██████████████████████████████████████████████████████████████| 49/49 [41:09<00:00, 50.40s/file]

Processing completed for station_851.0_cleaned.csv



