# Result analysis

## Imports and data loading

In [24]:
import xarray as xr

import geopandas as gpd
import matplotlib.pyplot as plt
from shapely.geometry import Point
import numpy as np

import xskillscore as xs
import pandas as pd


import re
import skill_metrics as sm

import matplotlib.lines as mlines

from matplotlib.colors import TwoSlopeNorm


In [2]:
results = xr.open_dataset('results_2022_ws_gcshifted_meso_gc.nc')
synop = xr.open_dataset('synop_filtered.nc')

### Statistic functions

In [4]:
def calculate_statistic(results, statistic_method, model_suffix='hres', variable=None, date_ranges=None, by_station=False):
    """
    Calculate a statistical method (e.g., RMSE, MAE, Pearson R, Bias) between forecast and observation data.

    Parameters:
        results (xarray.Dataset): The dataset containing the forecast and observation variables.
        statistic_method (str): The statistical method to use for the calculation.
        model_suffix (str): The suffix used in the forecast variables to distinguish them (default is '_hres').
        variable (str): Specific variable to analyze. If None, analyze all variables.
        date_ranges (list of tuples): List of date ranges to filter the data. Each tuple should be (start_date, end_date).
        by_station (bool): If True, calculate the statistic separately for each station. Default is False.

    Returns:
        xarray.Dataset: A dataset containing the calculated statistical results.
    """
    # Define supported methods, including the new 'bias' method
    supported_methods = {
        'rmse': lambda forecast, obs, dim: xs.rmse(forecast, obs, dim=dim, skipna=True),
        'mae': lambda forecast, obs, dim: xs.mae(forecast, obs, dim=dim, skipna=True),
        'mse': lambda forecast, obs, dim: xs.mse(forecast, obs, dim=dim, skipna=True),
        'pearson_r': lambda forecast, obs, dim: xs.pearson_r(forecast, obs, dim=dim, skipna=True),
        'spearman_r': lambda forecast, obs, dim: xs.spearman_r(forecast, obs, dim=dim, skipna=True),
        'bias': lambda forecast, obs, dim: (forecast - obs).mean(dim=dim, skipna=True),  # Bias calculation
        'mean_absolute_percentage_error': lambda forecast, obs, dim: xs.mape(forecast, obs, dim=dim, skipna=True),
        'brier_score': lambda forecast, obs, dim: xs.brier_score(forecast, obs, dim=dim, skipna=True),
        'threshold_brier_score': lambda forecast, obs, dim: xs.threshold_brier_score(forecast, obs, dim=dim, skipna=True),
        'crps_gaussian': lambda forecast, obs, dim: xs.crps_gaussian(forecast, obs, dim=dim, skipna=True),
        'crps_quadrature': lambda forecast, obs, dim: xs.crps_quadrature(forecast, obs, dim=dim, skipna=True),
        'crps_ensemble': lambda forecast, obs, dim: xs.crps_ensemble(forecast, obs, dim=dim, skipna=True),
        'rank_histogram': lambda forecast, obs, dim: xs.rank_histogram(forecast, obs, dim=dim, skipna=True),
        'roc': lambda forecast, obs, dim: xs.roc(forecast, obs, dim=dim, skipna=True),
        'reliability': lambda forecast, obs, dim: xs.reliability(forecast, obs, dim=dim, skipna=True),
        'discrimination': lambda forecast, obs, dim: xs.discrimination(forecast, obs, dim=dim, skipna=True),
        'rps': lambda forecast, obs, dim: xs.rps(forecast, obs, dim=dim, skipna=True)
    }

    # Check if the method is supported
    if statistic_method not in supported_methods:
        raise ValueError(f"Statistic method '{statistic_method}' is not supported. Choose from {list(supported_methods.keys())}.")

    # If date ranges are provided, filter the data accordingly
    if date_ranges:
        datasets = [results.sel(time=slice(start_date, end_date)) for start_date, end_date in date_ranges]
        results = xr.concat(datasets, dim='time')

    # Define the variables to compare (forecast vs observation)
    variables = {
        f'2m_temperature_{model_suffix}': '2m_temperature_synop',
        f'mean_sea_level_pressure_{model_suffix}': 'mean_sea_level_pressure_synop',
        f'10m_v_component_of_wind_{model_suffix}': '10m_v_component_of_wind_synop',
        f'10m_u_component_of_wind_{model_suffix}': '10m_u_component_of_wind_synop',
        f'total_precipitation_6hr_{model_suffix}': 'total_precipitation_6hr_synop',
        # add a new variable for wind speed
        f'10m_wind_speed_{model_suffix}': '10m_wind_speed_synop'
    }
    

    if variable:
        variables = {f'{variable}_{model_suffix}': f'{variable}_synop'}

    # Determine the dimensions to calculate over
    dims = ['time']
    if not by_station:
        dims.append('station')

    # Apply the selected statistic method to each variable
    results = xr.Dataset({
        forecast_var: supported_methods[statistic_method](results[forecast_var], results[obs_var], dim=dims)
        for forecast_var, obs_var in variables.items()
    })

    return results



def calculate_all_statistics(results_chunked, metrics, models, by_station=False):
    """
    Calculates statistical metrics for different models and variables.

    Parameters:
    - results_chunked: The dataset or collection of results to be analyzed. This is typically a large
      dataset split into chunks.
    - metrics: A list of statistical metrics to be calculated (e.g., ['rmse', 'mae']).
    - models: A list of models or variables for which the metrics are calculated (e.g., ['hres', 'gc', 'analysis', 'meso']).

    Returns:
    - statistics: A dictionary where each key is a combination of a metric and a model
      (e.g., 'rmse_hres', 'mae_gc') and the value is the computed statistic for that combination.
    """
    statistics = {}
    for metric in metrics:
        for model in models:
            key = f"{metric}_{model}"
            statistics[key] = calculate_statistic(
                results_chunked, 
                metric, 
                model, 
                by_station=by_station
            )
    return statistics

def plot_statistic_results(statistics, error_metric, save_path=None):
    """
    Plots the statistical results for predefined variables against forecast lead time for one or multiple datasets.

    Parameters:
        error_metric (str): The error metric to plot (e.g., 'rmse', 'mse', 'mae').
        save_path (str, optional): The file path to save the plot. If None, the plot will be displayed.

    Returns:
        None
    """
    # Ensure that the error_metric is matched as a prefix in the keys
    datasets = [statistics[key] for key in statistics.keys() if key.startswith(error_metric + "_")]

    if len(datasets) == 0:
        raise ValueError(f"No dataset matches the error metric: {error_metric}")
    
    # Check if each dataset is an xarray Dataset
    for i, ds in enumerate(datasets):
        if not hasattr(ds, 'data_vars'):
            raise TypeError(f"Dataset {i+1} is not an xarray.Dataset. Ensure you pass the dataset object, not its name as a string.")
    
    # Define variable names and their corresponding units
    variables = {
        '2m_temperature': '°C',
        'mean_sea_level_pressure': 'hPa',
        '10m_v_component_of_wind': 'm/s',
        '10m_u_component_of_wind': 'm/s',
        'total_precipitation_6hr': 'mm',
        '10m_wind_speed': 'm/s'
    }
    
    # Convert lead times to numeric values (hours) for plotting
    lead_times_numeric_list = [ds['prediction_timedelta'].dt.total_seconds() / 3600 for ds in datasets]
    
    # Ensure lead times match between datasets if there are multiple
    if len(datasets) > 1:
        for i in range(1, len(lead_times_numeric_list)):
            if not lead_times_numeric_list[i].equals(lead_times_numeric_list[0]):
                raise ValueError("Lead times in the datasets do not match.")
    
    # Extract base variable names and model suffixes
    def extract_base_variable_name(var_name):
        match = re.match(r'^(.*)_(.+)$', var_name)
        if match:
            return match.groups()
        else:
            return var_name, None
    
    variables_dict_list = []
    for ds in datasets:
        variables_dict = {extract_base_variable_name(var)[0]: var for var in ds.data_vars}
        variables_dict_list.append(variables_dict)
    
    # Initialize the figure and subplots
    fig, axs = plt.subplots(len(variables), 1, figsize=(10, 3 * len(variables)))

    if len(variables) == 1:
        axs = [axs]

    # Define line styles
    line_styles = ['--', '-.', ':', '-']
    
    # Loop through variables to plot
    for i, (base_var, unit) in enumerate(variables.items()):
        for j, (ds, variables_dict) in enumerate(zip(datasets, variables_dict_list)):
            if base_var in variables_dict:
                var = variables_dict[base_var]
                _, model_suffix = extract_base_variable_name(var)
                label = model_suffix if len(datasets) > 1 else 'Observation'
                axs[i].plot(lead_times_numeric_list[j], ds[var], marker='o', linestyle=line_styles[j % len(line_styles)], 
                            alpha=0.75, linewidth=1.25, label=label)
        
        axs[i].set_title(f'{base_var.replace("_", " ").title()} vs Observation', fontsize=14)
        axs[i].set_xlabel('Lead Time (hours)', fontsize=12)
        axs[i].set_ylabel(f'{error_metric.upper()} ({unit})', fontsize=12)
        axs[i].grid(True, linestyle='--', linewidth=0.5)
        axs[i].set_xticks(range(0, int(max(lead_times_numeric_list[0])) + 1, 6))

        if len(datasets) > 1:
            axs[i].legend(loc='center left', bbox_to_anchor=(1,0.8), fontsize=10, frameon=False)

        if ds[var].min() >= 0:
            axs[i].set_ylim(bottom=0)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
        print(f"Plot saved to {save_path}")
        plt.close()
    else:
        plt.show()



def plot_statistics_for_variable(statistics, variable, metrics=['rmse', 'mae', 'mse', 'bias'], save_path=None):
    """
    Plots multiple statistical metrics (e.g., rmse, mae, mse, bias) for a single variable
    across different models and metrics, with units displayed on the vertical axis.

    Parameters:
        statistics (dict): A dictionary containing datasets with metric-model combinations as keys.
        variable (str): The variable to plot (e.g., '2m_temperature').
        save_path (str, optional): The file path to save the plot. If None, the plot will be displayed.

    Returns:
        None
    """
    
    variable_units = {
        '2m_temperature': '°C',
        'mean_sea_level_pressure': 'hPa',
        '10m_v_component_of_wind': 'm/s',
        '10m_u_component_of_wind': 'm/s',
        'total_precipitation_6hr': 'mm',
        '10m_wind_speed': 'm/s'
    }
    
    unit = variable_units.get(variable, '')  # Get unit for the variable
    
    metric_model_combinations = [key for key in statistics.keys() if any(metric in key for metric in metrics)]
    
    fig, axs = plt.subplots(len(metrics), 1, figsize=(10, 3 * len(metrics)))

    if len(metrics) == 1:
        axs = [axs]

    line_styles = ['--', '-.', ':', '-']
    
    for i, metric in enumerate(metrics):
        for j, key in enumerate(metric_model_combinations):
            if key.startswith(metric):
                ds = statistics[key]
                model_suffix = key.split("_")[1]
                variable_key = f'{variable}_{model_suffix}'

                if variable_key in ds.data_vars:
                    lead_times = ds['prediction_timedelta'].dt.total_seconds() / 3600
                    var_data = ds[variable_key]
                    axs[i].plot(lead_times, var_data, marker='o', linestyle=line_styles[j % len(line_styles)], 
                                alpha=0.75, linewidth=1.25, label=model_suffix)
        
        axs[i].set_title(f'{metric.upper()} for {variable.replace("_", " ").title()}', fontsize=14)
        axs[i].set_xlabel('Lead Time (hours)', fontsize=12)
        axs[i].set_ylabel(f'{metric.upper()} ({unit})', fontsize=12)
        axs[i].grid(True, linestyle='--', linewidth=0.5)
        axs[i].set_xticks(range(0, int(max(lead_times)) + 1, 6))
        
        if var_data.min() >= 0:
            axs[i].set_ylim(bottom=0)

        axs[i].legend(loc='center left', bbox_to_anchor=(1, 0.8), fontsize=10, frameon=False)
    
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
        print(f"Plot saved to {save_path}")
        plt.close()
    else:
        plt.show()


## Map creation and spatial analysis

### Map functions

In [21]:
def plot_stations_with_metric(synop_data, metric_dict=None, metric_label=None, variable_name=None, marker_size=50, save_path=None):
    """
    Function to plot weather stations and color them based on a given metric for 4 different models.
    
    Parameters:
    - synop_data (xarray.Dataset): Dataset containing the station data with latitude, longitude, and station name.
    - metric_dict (dict): Dictionary with keys as model names and values as metrics to be plotted for each model.
    - metric_label (str, optional): Label for the color bar, describing the metric. Default is None.
    - variable_name (str, optional): Name of the weather variable. Default is None.
    - marker_size (int, optional): Size of the station markers. Default is 50.
    - save_path (str, optional): Path to save the figure. Default is None.
    """

    # Define the models you want to plot
    model_suffix = ['gc', 'meso', 'hres', 'analysis']

    # Extract station information
    station_lats = synop_data['lat'].isel(time=100).values
    station_lons = synop_data['lon'].isel(time=100).values
    station_names = synop_data['station'].values

    # Remove invalid lat/lon stations
    valid_mask = ~np.isnan(station_lats) & ~np.isnan(station_lons)
    station_lats = station_lats[valid_mask]
    station_lons = station_lons[valid_mask]
    station_names = station_names[valid_mask]

    # Define the latitude and longitude boundaries
    lat_min, lat_max = 40, 70
    lon_min, lon_max = -5, 16

    # Filter based on lat/lon boundaries
    valid_bounds_mask = (
        (station_lats >= lat_min) & (station_lats <= lat_max) &
        (station_lons >= lon_min) & (station_lons <= lon_max)
    )
    station_lats = station_lats[valid_bounds_mask]
    station_lons = station_lons[valid_bounds_mask]
    station_names = station_names[valid_bounds_mask]

    # Create a GeoDataFrame for the stations
    gdf_stations = gpd.GeoDataFrame(
        {'Station': station_names},
        geometry=[Point(lon, lat) for lon, lat in zip(station_lons, station_lats)],
        crs="EPSG:4326"
    )

    # Load shapefiles for the Netherlands and provinces
    world = gpd.read_file("/home/koenr/thesis_code/ne_10m_admin_0_countries/ne_10m_admin_0_countries.shp")
    netherlands = world[(world.SOVEREIGNT == "Netherlands") & (world.CONTINENT == "Europe")].to_crs(epsg=28992)
    
    provinces = gpd.read_file("/home/koenr/thesis_code/ne_10m_admin_1_states_provinces/ne_10m_admin_1_states_provinces.shp")
    netherlands_provinces = provinces[provinces['adm1_code'].str.startswith('NLD')].to_crs(epsg=28992)

    # Reproject stations to EPSG:28992
    gdf_stations = gdf_stations.to_crs(epsg=28992)

    # Calculate global min and max for consistent colorbar scaling across all models
    all_metrics = np.concatenate([metric_dict[model] for model in model_suffix if metric_dict.get(model) is not None])
    vmin, vmax = np.nanmin(all_metrics), np.nanmax(all_metrics)

    # Define colormap preferences based on error type
    if metric_label in ['rmse', 'mae']:
        color_map = 'magma'  # Choose a sequential colormap for error magnitude
        norm = None
    elif metric_label == 'bias':
        color_map = 'seismic'  # Diverging colormap centered at zero
        norm = TwoSlopeNorm(vmin=-max(abs(vmin), abs(vmax)), vcenter=0, vmax=max(abs(vmin), abs(vmax)))
    else:
        color_map = 'viridis'  # Default colormap
        norm = None

    # Create 4 subplots for the 4 models
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 12))
    axes = axes.flatten()

    for i, model in enumerate(model_suffix):
        ax = axes[i]

        # Use the corresponding metric for each model
        metric = metric_dict.get(model, None)

        # Update GeoDataFrame with metric
        gdf_stations['Metric'] = metric if metric is not None else np.nan

        # Plot the Netherlands with province boundaries
        netherlands.plot(ax=ax, color='white', edgecolor='black', linewidth=0.5)
        netherlands_provinces.plot(ax=ax, color='none', edgecolor='blue', linewidth=0.25)

        # Plot the stations with color based on the metric
        if metric is not None:
            gdf_stations.plot(ax=ax, column='Metric', cmap=color_map, marker='^', markersize=marker_size, legend=True, vmin=vmin, vmax=vmax, norm=norm)
        else:
            gdf_stations.plot(ax=ax, color='green', marker='^', markersize=marker_size, label='Weather Stations')

        # Annotate the stations
        for idx, row in gdf_stations.iterrows():
            ax.annotate(row["Station"], xy=(row.geometry.x, row.geometry.y),
                        xytext=(3, 3), textcoords="offset points", fontsize=6)

        # Set plot limits and remove axes
        padding = 15000
        bounds = gdf_stations.total_bounds
        ax.set_xlim(bounds[0] - padding, bounds[2] + padding)
        ax.set_ylim(bounds[1] - padding - 5000, bounds[3] + padding)
        ax.set_axis_off()

        # Set the title for each subplot
        ax.set_title(f"{model.upper()}")

    # Add a main title with padding above the plots
    plt.suptitle(f"{variable_name} {metric_label} per Weather Station", x=0.56, fontsize=16)
    
    # Adjust layout to prevent overlap and leave space for suptitle
    plt.tight_layout()

    # Save the figure if a path is provided
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
        plt.close()
        
        # Print the path where the plot is saved
        print(f"Plot saved to {save_path}")
    else:
        plt.show()


### Map creation

In [23]:
model_suffix = ['gc', 'meso', 'hres', 'analysis']

variables = [
                '2m_temperature', 
                'mean_sea_level_pressure',
                '10m_wind_speed', 
                '10m_u_component_of_wind', 
                '10m_v_component_of_wind', 
                'total_precipitation_6hr',
             ]

metrics = ['rmse', 'mae', 'bias']

for variable in variables:
    for metric in metrics:
        # Initialize dictionary to store results for the current metric
        metric_values = {}
        
        for model in model_suffix:
            # Dynamically pass the metric into the calculation
            stat_result = calculate_statistic(results, metric, model, by_station=True).mean('prediction_timedelta').compute()
            
            # Store the result for the current model and variable
            metric_values[model] = stat_result[f'{variable}_{model}']

        # Define the save path for the plot
        save_path = f'map_plots_shifted/map_{variable}_{metric}.png'

        # Call the updated plotting function
        plot_stations_with_metric(synop, metric_dict=metric_values, variable_name=variable, metric_label=f'{metric}', save_path=save_path)


Plot saved to map_plots_shifted/map_2m_temperature_rmse.png
Plot saved to map_plots_shifted/map_2m_temperature_mae.png
Plot saved to map_plots_shifted/map_2m_temperature_bias.png
Plot saved to map_plots_shifted/map_mean_sea_level_pressure_rmse.png
Plot saved to map_plots_shifted/map_mean_sea_level_pressure_mae.png
Plot saved to map_plots_shifted/map_mean_sea_level_pressure_bias.png
Plot saved to map_plots_shifted/map_10m_wind_speed_rmse.png
Plot saved to map_plots_shifted/map_10m_wind_speed_mae.png
Plot saved to map_plots_shifted/map_10m_wind_speed_bias.png
Plot saved to map_plots_shifted/map_10m_u_component_of_wind_rmse.png
Plot saved to map_plots_shifted/map_10m_u_component_of_wind_mae.png
Plot saved to map_plots_shifted/map_10m_u_component_of_wind_bias.png
Plot saved to map_plots_shifted/map_10m_v_component_of_wind_rmse.png
Plot saved to map_plots_shifted/map_10m_v_component_of_wind_mae.png
Plot saved to map_plots_shifted/map_10m_v_component_of_wind_bias.png
Plot saved to map_plots_