# Maps


## Variables

* [ ] Temperature, $K$
* [ ] Salinity, $$
* [ ] Currents, [$m s^{-1}$]
* [ ] SSH, [$cm$]
* [ ] MLD, [$m$]
* [ ] Density, [$g cm^{-3}$]
* [ ] Geostrophic Currents, [$m s^{-1}$]
* [ ] Vorticity, [$s^{-1}$]
* [ ] Strain, 
* [ ] Lagrangian Trajectories, [$km$]

## Models

**"Truth"**
* [ ] GLORYS - Reanalysis (Simulation + Observations)
* [ ] GLO12 - Analysis (Forecast + Observations)

**Models**
* Model (x3) (Forecast)
  * [ ] GLO12 (Physical)
  * [ ] GLONET (ML)
  * [ ] XiHe (ML)
  * [ ] WenHai (ML)



In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
%load_ext autoreload
%autoreload 2

## Plotting Code

In [None]:
from typing import Callable
import cmocean
from matplotlib import ticker
import pandas as pd
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.feature as cfeature
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
from dataclasses import dataclass, field
from utils.types import (
    SeaSurfaceHeight,
    Temperature,
    Salinity,
    ZonalVelocity,
    MeridionalVelocity,
    GeostrophicZonalVelocity,
    GeostrophicMeridionalVelocity,
    MixedLayerDepth,
)


@dataclass
class RMSE:
    name: str = "rmse"
    standard_name: str = "root_mean_squared_error"
    long_name: str = "Root Mean Squared Error"
    cmap: str = "Reds"

@dataclass
class MAE:
    name: str = "mae"
    standard_name: str = "mean_absolute_error"
    long_name: str = "Mean Absolute Error"
    cmap: str = "Reds"

@dataclass
class PlotterContour:
    da: xr.DataArray
    config: object

    def __post_init__(self):
        self.correct_labels()

    def correct_labels(self):
        self.da["lon"].attrs["units"] = "degrees"
        self.da["lat"].attrs["units"] = "degrees"
        self.da.attrs["units"] = self.config.units
        self.da.attrs["standard_name"] = self.config.standard_name
        self.da.attrs["long_name"] = self.config.long_name
        self.da["lon"].attrs["standard_name"] = "longitude"
        self.da["lat"].attrs["standard_name"] = "latitude"
        self.da["lat"].attrs["long_name"] = "Latitude"
        self.da["lon"].attrs["long_name"] = "Longitude"

    def plot_figure(self, **kwargs):

        fig, ax = plt.subplots(figsize=(8,7), subplot_kw={'projection': ccrs.PlateCarree()})
        vmin=kwargs.pop("vmin", self.da.min().values)
        vmax=kwargs.pop("vmax", self.da.max().values)
        cmap=kwargs.pop("cmap", self.config.cmap)
        levels = kwargs.pop("levels", self.config.levels)
        levels = levels if levels else None
        
        self.da.plot.pcolormesh(
            ax=ax, vmin=vmin, vmax=vmax, cmap=cmap,
            transform=ccrs.PlateCarree(),
            cbar_kwargs=kwargs.pop("cbar_kwargs", None),
            **kwargs,
        )
        if levels:
            loc = ticker.MaxNLocator(levels)
            levels = loc.tick_values(self.da.min().values, self.da.max().values)
            self.da.plot.contour(
                ax=ax, 
                alpha=0.5, linewidths=1, cmap="black",
                levels=levels,
                # linestyles=self.config.linestyles
                # vmin=vmin, vmax=vmax,
                # **kwargs
            )    
    
    
        ax.coastlines(linewidth=1)
        # ax.set(title=)
        
        gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                          linewidth=0.1, color='k', alpha=1, 
                          linestyle='--')
        
        gl.top_labels = False
        gl.right_labels = False
        gl.xformatter = LONGITUDE_FORMATTER
        gl.yformatter = LATITUDE_FORMATTER
        gl.xlabel_style = {'size': 12}
        gl.ylabel_style = {'size': 12} 
    
        # Add map features with Cartopy 
        ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '10m', 
                                                    # edgecolor='face', 
                                                    facecolor='lightgray'))
        
        # ax.set_title(pd.to_datetime(self.da.time.values).strftime('%Y-%m-%d'))
        ax.set_title("")
        fig.tight_layout()
        fig.set(dpi=300)
        
        return fig, ax
        
@dataclass
class PlotterError:
    da: xr.DataArray
    config: object

    def __post_init__(self):
        self.correct_labels()

    def correct_labels(self):
        self.da["lon"].attrs["units"] = "degrees"
        self.da["lat"].attrs["units"] = "degrees"
        self.da.attrs["units"] = self.config.units
        self.da.attrs["standard_name"] = self.config.standard_name
        self.da.attrs["long_name"] = self.config.long_name
        self.da["lon"].attrs["standard_name"] = "longitude"
        self.da["lat"].attrs["standard_name"] = "latitude"
        self.da["lat"].attrs["long_name"] = "Latitude"
        self.da["lon"].attrs["long_name"] = "Longitude"

    def plot_figure(self, **kwargs):

        fig, ax = plt.subplots(figsize=(8,7), subplot_kw={'projection': ccrs.PlateCarree()})
        vmin=kwargs.pop("vmin", self.da.min().values)
        vmax=kwargs.pop("vmax", self.da.max().values)
        cmap=kwargs.pop("cmap", self.config.cmap)
        
        self.da.plot.pcolormesh(
            ax=ax, vmin=vmin, vmax=vmax, cmap=cmap,
            transform=ccrs.PlateCarree(),
            cbar_kwargs=kwargs.pop("cbar_kwargs", None),
            **kwargs,
        )  
    
        ax.coastlines(linewidth=1)
        # ax.set(title=)
        
        gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                          linewidth=0.1, color='k', alpha=1, 
                          linestyle='--')
        
        gl.top_labels = False
        gl.right_labels = False
        gl.xformatter = LONGITUDE_FORMATTER
        gl.yformatter = LATITUDE_FORMATTER
        gl.xlabel_style = {'size': 12}
        gl.ylabel_style = {'size': 12} 
    
        # Add map features with Cartopy 
        ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '10m', 
                                                    # edgecolor='face', 
                                                    facecolor='lightgray'))
        
        # ax.set_title(pd.to_datetime(self.da.time.values).strftime('%Y-%m-%d'))
        ax.set_title("")
        fig.tight_layout()
        fig.set(dpi=300)
        
        return fig, ax

In [None]:
from utils.types import generate_wednesdays_in_year
from utils.preprocessing import (
    latlon_deg2m,
    rectilinear_to_regular_grid,
    time_rescale,
    validate_latitude,
    validate_longitude,
    xr_cond_average,
    fuse_base_coords
)

dates = list()
for wednesday in generate_wednesdays_in_year(2024):
    start_datetime = wednesday 
    std_wedn = wednesday.strftime('%Y%m%d')  # Format as YYYYMMDD
    dates.append(std_wedn)

len(dates)

def geoprocess_fn(ds, fill_value: int | float | None = None, lead_time: int = 0, idepth: int = 0):

    try:
        ds = ds.isel(depth=idepth, time=lead_time)
    except ValueError:
        ds = ds.isel(time=lead_time)
        

    # assign time coord correctly
    ds = ds.assign_coords({"time": np.atleast_1d(ds.time)})
    
    # validate coordinates
    ds = validate_longitude(validate_latitude(ds))

    return ds

## Sea Surface Height

In [None]:
from tqdm.auto import tqdm
from functools import partial
from utils.types import ForecastDataset, DiagnosticDataset
forecast_config = ForecastDataset() # DiagnosticDataset() # 
diagonstic_config = DiagnosticDataset()

In [None]:
forecast_config = ForecastDataset() # DiagnosticDataset() # 
print(len(forecast_config.models))
base_model = forecast_config.models.pop("glo12")
print(len(forecast_config.models))

In [None]:
pbar_models = tqdm(forecast_config.models.values(), leave=True)
lead_times = [0, 4, 9]
pbar_lead_times = tqdm(lead_times, leave=True)
idepth = 0
itime = 0

ivariable = "zos"


max_values = []
with pbar_lead_times:
    for lead_time in pbar_lead_times:
        fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)
        with pbar_models:
            
            for imodel in pbar_models:
        
                # load paths
                base_paths = [str(base_model.forecast_path(idate)) for idate in dates]
                compare_paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                # load model results
                base_model_results = xr.open_mfdataset(base_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
                compare_model_results = xr.open_mfdataset(compare_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
    
                idate = pd.to_datetime(base_model_results.isel(time=itime).time.values).strftime("%Y-%m-%d")
                # calculate rmse
                diff = (base_model_results.sel(time=idate)[ivariable] - compare_model_results.sel(time=idate)[ivariable])
                diff = np.sqrt(diff**2).compute()
                
                
        
                # load the bounds
                max_values.append(np.abs(diff).quantile(0.999).values)
vmax = np.max(max_values)
with pbar_lead_times:
    for lead_time in pbar_lead_times:
        fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)
        with pbar_models:
            for imodel in pbar_models:
        
                # load paths
                base_paths = [str(base_model.forecast_path(idate)) for idate in dates]
                compare_paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                # load model results
                base_model_results = xr.open_mfdataset(base_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
                compare_model_results = xr.open_mfdataset(compare_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
        
                idate = pd.to_datetime(base_model_results.isel(time=itime).time.values).strftime("%Y-%m-%d")
                # calculate rmse
                diff = (base_model_results.sel(time=idate)[ivariable] - compare_model_results.sel(time=idate)[ivariable])
                diff = np.sqrt(diff**2).compute().squeeze()
                
                diff = diff.sortby("lon").sortby("lat")
        
        
                cbar_kwargs = {
                    "fraction": 0.02, 
                    "pad": 0.045, 
                    "orientation": "vertical",
                    # "label": variable_config.label
                }
        
                demo_ds = compare_model_results.isel(time=itime)[ivariable]
                date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
                config = SeaSurfaceHeight()
                config.cmap = "Reds"
                ssh_plot = PlotterContour(da=diff, config=config)
                fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=0.0, vmax=vmax, levels=None)
                fig.set(dpi=300)
                path = Path("OceanBenchFigures/maps/errors")
                save_name = Path(f"maps_global_rmse_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
                fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
                plt.close()

### Temperature

In [None]:
pbar_models = tqdm(forecast_config.models.values(), leave=True)
lead_times = [0, 4, 9]
pbar_lead_times = tqdm(lead_times, leave=True)
idepth = 0
itime = 0

ivariable = "thetao"

max_values = []
with pbar_lead_times:
    for lead_time in pbar_lead_times:
        fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)
        with pbar_models:
            
            for imodel in pbar_models:
        
                # load paths
                base_paths = [str(base_model.forecast_path(idate)) for idate in dates]
                compare_paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                # load model results
                base_model_results = xr.open_mfdataset(base_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
                compare_model_results = xr.open_mfdataset(compare_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
    
                idate = pd.to_datetime(base_model_results.isel(time=itime).time.values).strftime("%Y-%m-%d")
                # calculate rmse
                diff = (base_model_results.sel(time=idate)[ivariable] - compare_model_results.sel(time=idate)[ivariable])
                diff = np.sqrt(diff**2).compute()
                
                
        
                # load the bounds
                max_values.append(np.abs(diff).quantile(0.999).values)
vmax = np.max(max_values)
with pbar_lead_times:
    for lead_time in pbar_lead_times:
        fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)
        with pbar_models:
            for imodel in pbar_models:
        
                # load paths
                base_paths = [str(base_model.forecast_path(idate)) for idate in dates]
                compare_paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                # load model results
                base_model_results = xr.open_mfdataset(base_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
                compare_model_results = xr.open_mfdataset(compare_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
        
                idate = pd.to_datetime(base_model_results.isel(time=itime).time.values).strftime("%Y-%m-%d")
                # calculate rmse
                diff = (base_model_results.sel(time=idate)[ivariable] - compare_model_results.sel(time=idate)[ivariable])
                diff = np.sqrt(diff**2).compute().squeeze()
                
                diff = diff.sortby("lon").sortby("lat")
        
        
                cbar_kwargs = {
                    "fraction": 0.02, 
                    "pad": 0.045, 
                    "orientation": "vertical",
                    # "label": variable_config.label
                }
    
                demo_ds = compare_model_results.isel(time=itime)[ivariable]
                date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
                config = Temperature()
                config.cmap = "Reds"
                ssh_plot = PlotterContour(da=diff, config=config)
                fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=0.0, vmax=vmax, levels=None)
                fig.set(dpi=300)
                path = Path("OceanBenchFigures/maps/errors")
                save_name = Path(f"maps_global_rmse_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
                fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
                plt.close()

## Currents

In [None]:
pbar_models = tqdm(forecast_config.models.values(), leave=True)
lead_times = [0, 4, 9]
pbar_lead_times = tqdm(lead_times, leave=True)
idepth = 0
itime = 0

ivariable = "uo"

max_values = []
with pbar_lead_times:
    for lead_time in pbar_lead_times:
        fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)
        with pbar_models:
            
            for imodel in pbar_models:
        
                # load paths
                base_paths = [str(base_model.forecast_path(idate)) for idate in dates]
                compare_paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                # load model results
                base_model_results = xr.open_mfdataset(base_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
                compare_model_results = xr.open_mfdataset(compare_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
    
                idate = pd.to_datetime(base_model_results.isel(time=itime).time.values).strftime("%Y-%m-%d")
                # calculate rmse
                diff = (base_model_results.sel(time=idate)[ivariable] - compare_model_results.sel(time=idate)[ivariable])
                diff = np.sqrt(diff**2).compute()
                
                
        
                # load the bounds
                max_values.append(np.abs(diff).quantile(0.999).values)
vmax = np.max(max_values)
with pbar_lead_times:
    for lead_time in pbar_lead_times:
        fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)
        with pbar_models:
            for imodel in pbar_models:
        
                # load paths
                base_paths = [str(base_model.forecast_path(idate)) for idate in dates]
                compare_paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                # load model results
                base_model_results = xr.open_mfdataset(base_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
                compare_model_results = xr.open_mfdataset(compare_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
        
                idate = pd.to_datetime(base_model_results.isel(time=itime).time.values).strftime("%Y-%m-%d")
                # calculate rmse
                diff = (base_model_results.sel(time=idate)[ivariable] - compare_model_results.sel(time=idate)[ivariable])
                diff = np.sqrt(diff**2).compute().squeeze()
                
                diff = diff.sortby("lon").sortby("lat")
        
        
                cbar_kwargs = {
                    "fraction": 0.02, 
                    "pad": 0.045, 
                    "orientation": "vertical",
                    # "label": variable_config.label
                }
                demo_ds = compare_model_results.isel(time=itime)[ivariable]
                date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
                config = ZonalVelocity()
                config.cmap = "Reds"
                ssh_plot = PlotterContour(da=diff, config=config)
                fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=0.0, vmax=vmax, levels=None)
                fig.set(dpi=300)
                path = Path("OceanBenchFigures/maps/errors")
                save_name = Path(f"maps_global_rmse_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
                fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
                plt.close()

In [None]:
pbar_models = tqdm(forecast_config.models.values(), leave=True)
lead_times = [0, 4, 9]
pbar_lead_times = tqdm(lead_times, leave=True)
idepth = 0
itime = 0

ivariable = "vo"

max_values = []
with pbar_lead_times:
    for lead_time in pbar_lead_times:
        fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)
        with pbar_models:
            
            for imodel in pbar_models:
        
                # load paths
                base_paths = [str(base_model.forecast_path(idate)) for idate in dates]
                compare_paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                # load model results
                base_model_results = xr.open_mfdataset(base_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
                compare_model_results = xr.open_mfdataset(compare_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
    
                idate = pd.to_datetime(base_model_results.isel(time=itime).time.values).strftime("%Y-%m-%d")
                # calculate rmse
                diff = (base_model_results.sel(time=idate)[ivariable] - compare_model_results.sel(time=idate)[ivariable])
                diff = np.sqrt(diff**2).compute()
                
                
        
                # load the bounds
                max_values.append(np.abs(diff).quantile(0.999).values)
vmax = np.max(max_values)
with pbar_lead_times:
    for lead_time in pbar_lead_times:
        fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)
        with pbar_models:
            for imodel in pbar_models:
        
                # load paths
                base_paths = [str(base_model.forecast_path(idate)) for idate in dates]
                compare_paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                # load model results
                base_model_results = xr.open_mfdataset(base_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
                compare_model_results = xr.open_mfdataset(compare_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
        
                idate = pd.to_datetime(base_model_results.isel(time=itime).time.values).strftime("%Y-%m-%d")
                # calculate rmse
                diff = (base_model_results.sel(time=idate)[ivariable] - compare_model_results.sel(time=idate)[ivariable])
                diff = np.sqrt(diff**2).compute().squeeze()
                
                diff = diff.sortby("lon").sortby("lat")
        
        
                cbar_kwargs = {
                    "fraction": 0.02, 
                    "pad": 0.045, 
                    "orientation": "vertical",
                    # "label": variable_config.label
                    }
                demo_ds = compare_model_results.isel(time=itime)[ivariable]
                date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
                config = ZonalVelocity()
                config.cmap = "Reds"
                ssh_plot = PlotterContour(da=diff, config=config)
                fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=0.0, vmax=vmax, levels=None)
                fig.set(dpi=300)
                path = Path("OceanBenchFigures/maps/errors")
                save_name = Path(f"maps_global_rmse_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
                fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
                plt.close()

## Salinity

In [None]:
pbar_models = tqdm(forecast_config.models.values(), leave=True)
lead_times = [0, 4, 9]
pbar_lead_times = tqdm(lead_times, leave=True)
idepth = 0
itime = 0
ivariable = "so"

max_values = []
with pbar_lead_times:
    for lead_time in pbar_lead_times:
        fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)
        with pbar_models:
            
            for imodel in pbar_models:
        
                # load paths
                base_paths = [str(base_model.forecast_path(idate)) for idate in dates]
                compare_paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                # load model results
                base_model_results = xr.open_mfdataset(base_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
                compare_model_results = xr.open_mfdataset(compare_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
    
                idate = pd.to_datetime(base_model_results.isel(time=itime).time.values).strftime("%Y-%m-%d")
                # calculate rmse
                diff = (base_model_results.sel(time=idate)[ivariable] - compare_model_results.sel(time=idate)[ivariable])
                diff = np.sqrt(diff**2).compute()
                
                
        
                # load the bounds
                max_values.append(np.abs(diff).quantile(0.95).values)
vmax = np.max(max_values)
with pbar_lead_times:
    for lead_time in pbar_lead_times:
        fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)
        with pbar_models:
            for imodel in pbar_models:
        
                # load paths
                base_paths = [str(base_model.forecast_path(idate)) for idate in dates]
                compare_paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                # load model results
                base_model_results = xr.open_mfdataset(base_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
                compare_model_results = xr.open_mfdataset(compare_paths[:40], preprocess=fn, combine="by_coords", engine="zarr")
        
                idate = pd.to_datetime(base_model_results.isel(time=itime).time.values).strftime("%Y-%m-%d")
                # calculate rmse
                diff = (base_model_results.sel(time=idate)[ivariable] - compare_model_results.sel(time=idate)[ivariable])
                diff = np.sqrt(diff**2).compute().squeeze()
                
                diff = diff.sortby("lon").sortby("lat")
        
        
                cbar_kwargs = {
                    "fraction": 0.02, 
                    "pad": 0.045, 
                    "orientation": "vertical",
                    # "label": variable_config.label
                }
    
                demo_ds = compare_model_results.isel(time=itime)[ivariable]
                date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
                config = ZonalVelocity()
                config.cmap = "Reds"
                ssh_plot = PlotterContour(da=diff, config=config)
                fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=0.0, vmax=vmax, levels=None)
                fig.set(dpi=300)
                path = Path("OceanBenchFigures/maps/errors")
                save_name = Path(f"maps_global_rmse_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
                fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
                plt.close()

## Diagnostic Variables

In [None]:
diagnostic_config = DiagnosticDataset() # ForecastDataset() # 
print(len(diagnostic_config.models))
base_model = diagnostic_config.models.pop("glo12")
print(len(diagnostic_config.models))

## Geostrophic Zonal Velocity

In [None]:
pbar_models = tqdm(diagnostic_config.models.values(), leave=False)
idepth = 0
lead_time = 0
itime = 0
fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)

ivariable = "u_geo"


with pbar_models:
    max_values = []
    for imodel in pbar_models:

        # load paths
        base_paths = [str(base_model.diagnostic_path(idate)) for idate in dates]
        compare_paths = [str(imodel.diagnostic_path(idate)) for idate in dates]

        # load model results
        base_model_results = xr.open_mfdataset(base_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")
        compare_model_results = xr.open_mfdataset(compare_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")

        # calculate rmse
        diff = (base_model_results.isel(time=itime)[ivariable] - compare_model_results.isel(time=itime)[ivariable])
        diff = np.sqrt(diff**2).compute()
        
        

        # load the bounds
        max_values.append(np.abs(diff).quantile(0.98).values)

vmax = np.max(max_values)
vmax

In [None]:
with pbar_models:
    for imodel in pbar_models:

        # load paths
        base_paths = [str(base_model.diagnostic_path(idate)) for idate in dates]
        compare_paths = [str(imodel.diagnostic_path(idate)) for idate in dates]

        # load model results
        base_model_results = xr.open_mfdataset(base_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")
        compare_model_results = xr.open_mfdataset(compare_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")

        # calculate rmse
        diff = (base_model_results.isel(time=itime)[ivariable] - compare_model_results.isel(time=itime)[ivariable])
        diff = np.sqrt(diff**2).compute()
        
        diff = diff.sortby("lon").sortby("lat")


        cbar_kwargs = {
            "fraction": 0.02, 
            "pad": 0.045, 
            "orientation": "vertical",
            # "label": variable_config.label
        }

        demo_ds = compare_model_results.isel(time=itime)[ivariable]
        date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
        config = GeostrophicZonalVelocity()
        config.cmap = "Reds"
        ssh_plot = PlotterContour(da=diff, config=config)
        fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=0.0, vmax=vmax, levels=None)
        fig.set(dpi=300)
        path = Path("OceanBenchFigures/maps/errors")
        save_name = Path(f"maps_global_rmse_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
        fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
        plt.close()

## Geostrophic Meridional Velocity

In [None]:
pbar_models = tqdm(diagnostic_config.models.values(), leave=False)
idepth = 0
lead_time = 0
itime = 0
fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)

ivariable = "v_geo"


with pbar_models:
    max_values = []
    for imodel in pbar_models:

        # load paths
        base_paths = [str(base_model.diagnostic_path(idate)) for idate in dates]
        compare_paths = [str(imodel.diagnostic_path(idate)) for idate in dates]

        # load model results
        base_model_results = xr.open_mfdataset(base_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")
        compare_model_results = xr.open_mfdataset(compare_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")

        # calculate rmse
        diff = (base_model_results.isel(time=itime)[ivariable] - compare_model_results.isel(time=itime)[ivariable])
        diff = np.sqrt(diff**2).compute()
        
        

        # load the bounds
        max_values.append(np.abs(diff).quantile(0.98).values)

vmax = np.max(max_values)
vmax

In [None]:
with pbar_models:
    for imodel in pbar_models:

        # load paths
        base_paths = [str(base_model.diagnostic_path(idate)) for idate in dates]
        compare_paths = [str(imodel.diagnostic_path(idate)) for idate in dates]

        # load model results
        base_model_results = xr.open_mfdataset(base_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")
        compare_model_results = xr.open_mfdataset(compare_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")

        # calculate rmse
        diff = (base_model_results.isel(time=itime)[ivariable] - compare_model_results.isel(time=itime)[ivariable])
        diff = np.sqrt(diff**2).compute()
        
        diff = diff.sortby("lon").sortby("lat")


        cbar_kwargs = {
            "fraction": 0.02, 
            "pad": 0.045, 
            "orientation": "vertical",
            # "label": variable_config.label
        }

        demo_ds = compare_model_results.isel(time=itime)[ivariable]
        date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
        config = GeostrophicMeridionalVelocity()
        config.cmap = "Reds"
        ssh_plot = PlotterContour(da=diff, config=config)
        fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=0.0, vmax=vmax, levels=None)
        fig.set(dpi=300)
        path = Path("OceanBenchFigures/maps/errors")
        save_name = Path(f"maps_global_rmse_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
        fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
        plt.close()

## Mixed Layer Depth

In [None]:
pbar_models = tqdm(diagnostic_config.models.values(), leave=False)
idepth = 0
lead_time = 0
itime = 0
fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)

ivariable = "MLD"


with pbar_models:
    max_values = []
    for imodel in pbar_models:

        # load paths
        base_paths = [str(base_model.diagnostic_path(idate)) for idate in dates]
        compare_paths = [str(imodel.diagnostic_path(idate)) for idate in dates]

        # load model results
        base_model_results = xr.open_mfdataset(base_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")
        compare_model_results = xr.open_mfdataset(compare_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")

        base_model_results = fuse_base_coords(base_model_results, ivariable)
        compare_model_results = fuse_base_coords(compare_model_results, ivariable)
        
        # calculate rmse
        diff = (base_model_results.isel(time=itime)[ivariable] - compare_model_results.isel(time=itime)[ivariable])
        diff = np.sqrt(diff**2).compute()
        
        

        # load the bounds
        max_values.append(np.abs(diff).quantile(0.99).values)

vmax = np.max(max_values)
vmax

In [None]:
compare_model_results

In [None]:
with pbar_models:
    for imodel in pbar_models:

        # load paths
        base_paths = [str(base_model.diagnostic_path(idate)) for idate in dates]
        compare_paths = [str(imodel.diagnostic_path(idate)) for idate in dates]

        # load model results
        base_model_results = xr.open_mfdataset(base_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")
        compare_model_results = xr.open_mfdataset(compare_paths[:5], preprocess=fn, combine="by_coords", engine="zarr")

        base_model_results = fuse_base_coords(base_model_results, ivariable)
        compare_model_results = fuse_base_coords(compare_model_results, ivariable)
        
        # calculate rmse
        diff = (base_model_results.isel(time=itime)[ivariable] - compare_model_results.isel(time=itime)[ivariable])
        diff = np.sqrt(diff**2).compute()
        
        diff = diff.sortby("lon").sortby("lat")


        cbar_kwargs = {
            "fraction": 0.02, 
            "pad": 0.045, 
            "orientation": "vertical",
            # "label": variable_config.label
        }

        demo_ds = compare_model_results.isel(time=itime)[ivariable]
        date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
        config = MixedLayerDepth()
        config.cmap = "Reds"
        ssh_plot = PlotterContour(da=diff, config=config)
        fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=0.0, vmax=vmax, levels=None)
        fig.set(dpi=300)
        path = Path("OceanBenchFigures/maps/errors")
        save_name = Path(f"maps_global_rmse_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
        fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
        plt.close()