# Maps


## Variables

* [ ] Temperature, $K$
* [ ] Salinity, $$
* [ ] Currents, [$m s^{-1}$]
* [ ] SSH, [$cm$]
* [ ] MLD, [$m$]
* [ ] Density, [$g cm^{-3}$]
* [ ] Geostrophic Currents, [$m s^{-1}$]
* [ ] Vorticity, [$s^{-1}$]
* [ ] Strain, 
* [ ] Lagrangian Trajectories, [$km$]

## Models

**"Truth"**
* [ ] GLORYS - Reanalysis (Simulation + Observations)
* [ ] GLO12 - Analysis (Forecast + Observations)

**Models**
* Model (x3) (Forecast)
  * [ ] GLO12 (Physical)
  * [ ] GLONET (ML)
  * [ ] XiHe (ML)
  * [ ] WenHai (ML)



In [None]:
import autoroot
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

## Plotting Code

In [None]:
from typing import Callable
import cmocean
from matplotlib import ticker
import pandas as pd
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.feature as cfeature
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
from dataclasses import dataclass, field


@dataclass
class SSHConfig:
    name: str = "zos"
    standard_name: str = "sea_surface_height"
    long_name: str = "Sea Surface Height"
    units: str = "m"
    cmap: str = "viridis"
    levels: int = 5
    linestyles: str | Callable = lambda levels: np.where(levels >= 0, "-", "--")


@dataclass
class TemperatureConfig:
    name: str = "thetao"
    standard_name: str = "temperature"
    long_name: str = "Sea Water Potential Temperature"
    units: str = "Â°C"
    cmap: str = "RdBu_r"
    levels: int = 5
    linestyles: str = "-"


@dataclass
class ZonalVelConfig:
    name: str = "uo"
    standard_name: str = "zonal_current"
    long_name: str = "Zonal Current"
    units: str = "m/s"
    cmap: str = field(default_factory=lambda: cmocean.cm.speed) #"YlGnBu_r"
    levels: int = 5
    linestyles: str = "-"

@dataclass
class MeridionalVelConfig(ZonalVelConfig):
    name: str = "vo"
    standard_name: str = "meridional_current"
    long_name: str = "Meridional Current"

@dataclass
class GeoZonalVelConfig(ZonalVelConfig):
    name: str = "u_geo"
    standard_name: str = "geostrophic_eastward_sea_water_velocity"
    long_name: str = "Geostrophic Zonal Velocity"

@dataclass
class GeoMeridionalVelConfig(ZonalVelConfig):
    name: str = "v_geo"
    standard_name: str = "geostrophic_northward_sea_water_velocity"
    long_name: str = "Geostrophic Meridional Velocity"


@dataclass
class SalinityConfig:
    name: str = "so"
    standard_name: str = "sea_water_salinity"
    long_name: str = "Sea Water Salinity"
    units: str = "PSU"
    cmap: str = field(default_factory=lambda: cmocean.cm.haline) #"YlGnBu_r"
    levels: int = 5
    linestyles: str = "-"


@dataclass
class MLDConfig:
    name: str = "MLD"
    standard_name: str = "mixed_layer_depth"
    long_name: str = "Mixed Layer Depth"
    units: str = "m"
    cmap: str = field(default_factory=lambda: cmocean.cm.deep) #"YlGnBu_r"
    levels: int = 5
    linestyles: str = "-"
    

@dataclass
class PlotterContour:
    da: xr.DataArray
    config: object

    def __post_init__(self):
        self.correct_labels()

    def correct_labels(self):
        self.da["lon"].attrs["units"] = "degrees"
        self.da["lat"].attrs["units"] = "degrees"
        self.da.attrs["units"] = self.config.units
        self.da.attrs["standard_name"] = self.config.standard_name
        self.da.attrs["long_name"] = self.config.long_name
        self.da["lon"].attrs["standard_name"] = "longitude"
        self.da["lat"].attrs["standard_name"] = "latitude"
        self.da["lat"].attrs["long_name"] = "Latitude"
        self.da["lon"].attrs["long_name"] = "Longitude"

    def plot_figure(self, **kwargs):

        fig, ax = plt.subplots(figsize=(8,7), subplot_kw={'projection': ccrs.PlateCarree()})
        vmin=kwargs.pop("vmin", self.da.min().values)
        vmax=kwargs.pop("vmax", self.da.max().values)
        cmap=kwargs.pop("cmap", self.config.cmap)
        levels = kwargs.pop("levels", self.config.levels)
        levels = levels if levels else None
        
        self.da.plot.pcolormesh(
            ax=ax, vmin=vmin, vmax=vmax, cmap=cmap,
            transform=ccrs.PlateCarree(),
            cbar_kwargs=kwargs.pop("cbar_kwargs", None),
            **kwargs,
        )
        if levels:
            loc = ticker.MaxNLocator(levels)
            levels = loc.tick_values(self.da.min().values, self.da.max().values)
            self.da.plot.contour(
                ax=ax, 
                alpha=0.5, linewidths=1, cmap="black",
                levels=levels,
                # linestyles=self.config.linestyles
                # vmin=vmin, vmax=vmax,
                # **kwargs
            )    
    
    
        ax.coastlines(linewidth=1)
        # ax.set(title=)
        
        gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                          linewidth=0.1, color='k', alpha=1, 
                          linestyle='--')
        
        gl.top_labels = False
        gl.right_labels = False
        gl.xformatter = LONGITUDE_FORMATTER
        gl.yformatter = LATITUDE_FORMATTER
        gl.xlabel_style = {'size': 12}
        gl.ylabel_style = {'size': 12} 
    
        # Add map features with Cartopy 
        ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '10m', 
                                                    # edgecolor='face', 
                                                    facecolor='lightgray'))
        
        # ax.set_title(pd.to_datetime(self.da.time.values).strftime('%Y-%m-%d'))
        ax.set_title("")
        fig.tight_layout()
        fig.set(dpi=300)
        
        return fig, ax


In [None]:
MeridionalVelConfig().standard_name

In [None]:
from src.types import generate_wednesdays_in_year
from src.preprocessing import (
    latlon_deg2m,
    rectilinear_to_regular_grid,
    time_rescale,
    validate_latitude,
    validate_longitude,
    xr_cond_average,
    fuse_base_coords
)

dates = list()
for wednesday in generate_wednesdays_in_year(2024):
    start_datetime = wednesday 
    std_wedn = wednesday.strftime('%Y%m%d')  # Format as YYYYMMDD
    dates.append(std_wedn)

len(dates)

def geoprocess_fn(ds, fill_value: int | float | None = None, lead_time: int = 0, idepth: int = 0):

    try:
        ds = ds.isel(depth=idepth, time=lead_time)
    except ValueError:
        ds = ds.isel(time=lead_time)
        

    # assign time coord correctly
    ds = ds.assign_coords({"time": np.atleast_1d(ds.time)})
    
    # validate coordinates
    ds = validate_longitude(validate_latitude(ds))

    return ds

## Sea Surface Height

In [None]:
from tqdm.auto import tqdm
from functools import partial
from src.types import ForecastDataset, DiagnosticDataset
forecast_config = ForecastDataset() # DiagnosticDataset() # 
diagonstic_config = DiagnosticDataset()

In [None]:
pbar_models = tqdm(forecast_config.models.values(), leave=False)
idepth = 0
lead_time = 0
itime = 0
fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)

ivariable = "zos"

with pbar_models:
    max_values = []
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.forecast_path(idate)) for idate in dates]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        # load the bounds
        max_values.append(np.abs(model_results.isel(time=itime)[ivariable]).quantile(0.99).values)



In [None]:
# load the bounds
vmax = np.max(max_values)
print(vmax)

In [None]:
date = "2025"

dates

In [None]:
with pbar_models:
    for imodel in pbar_models:

        # load paths
        paths = imodel.forecast_path("20240103")

        # load model results
        model_results = xr.open_mfdataset([paths], preprocess=fn, combine="by_coords", engine="zarr")

        # model_results = fuse_base_coords(model_results)

        model_results = model_results.sortby("time").sortby("lon").sortby("lat")

        cbar_kwargs = {
            "fraction": 0.02, 
            "pad": 0.045, 
            "orientation": "vertical",
            # "label": variable_config.label
        }
        
        demo_ds = model_results.isel(time=itime).zos
        date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
        ssh_plot = PlotterContour(da=demo_ds, config=SSHConfig())
        ssh_plot.correct_labels()
        fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=-vmax, vmax=vmax, levels=None)
        fig.set(dpi=300)
        path = Path(f"{autoroot.root.joinpath('figures/maps')}")
        path.mkdir(parents=True, exist_ok=True)
        save_name = Path(f"maps_global_zos_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
        fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
        plt.close()

## Temperature

In [None]:
pbar_models = tqdm(forecast_config.models, leave=False)
idepth = 0
lead_time = 0
itime = 0
fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)

ivariable = "thetao"

with pbar_models:
    max_values = []
    min_values = []
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.forecast_path(idate)) for idate in dates]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        # load the bounds
        max_values.append(model_results.isel(time=itime)[ivariable].quantile(0.99).values)
        min_values.append(model_results.isel(time=itime)[ivariable].quantile(0.01).values)

In [None]:
# load the bounds
vmax = np.max(max_values)
vmin = np.min(min_values)
print(vmin, vmax)

In [None]:
with pbar_models:
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.forecast_path(idate)) for idate in dates]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        model_results = model_results.sortby("time").sortby("lon").sortby("lat")

        cbar_kwargs = {
            "fraction": 0.02, 
            "pad": 0.045, 
            "orientation": "vertical",
            # "label": variable_config.label
        }
        
        demo_ds = model_results.isel(time=itime)[ivariable]
        date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
        ssh_plot = PlotterContour(da=demo_ds, config=TemperatureConfig())
        ssh_plot.correct_labels()
        fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=vmin, vmax=vmax, levels=None)
        fig.set(dpi=300)
        path = Path("OceanBenchFigures/maps/")
        save_name = Path(f"maps_global_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
        fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
        plt.close()

## Currents

In [None]:
pbar_models = tqdm(forecast_config.models, leave=False)
idepth = 0
lead_time = 0
itime = 0
fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)

ivariable = "uo"

with pbar_models:
    max_values = []
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.forecast_path(idate)) for idate in dates]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        # load the bounds
        max_values.append(np.abs(model_results.isel(time=itime)[ivariable]).quantile(0.995).values)

In [None]:
# load the bounds
vmax = np.max(max_values)
print(vmax)

In [None]:
with pbar_models:
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.forecast_path(idate)) for idate in dates]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        model_results = model_results.sortby("time").sortby("lon").sortby("lat")

        cbar_kwargs = {
            "fraction": 0.02, 
            "pad": 0.045, 
            "orientation": "vertical",
            # "label": variable_config.label
        }
        
        demo_ds = model_results.isel(time=itime)[ivariable]
        date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
        ssh_plot = PlotterContour(da=demo_ds, config=ZonalVelConfig())
        ssh_plot.correct_labels()
        fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=-vmax, vmax=vmax, levels=None)
        fig.set(dpi=300)
        path = Path("OceanBenchFigures/maps/")
        save_name = Path(f"maps_global_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
        fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
        plt.close()

In [None]:
pbar_models = tqdm(forecast_config.models, leave=False)
idepth = 0
lead_time = 0
itime = 0
fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)

ivariable = "vo"

with pbar_models:
    max_values = []
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.forecast_path(idate)) for idate in dates]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        # load the bounds
        max_values.append(np.abs(model_results.isel(time=itime)[ivariable]).quantile(0.995).values)

In [None]:
# load the bounds
vmax = np.max(max_values)
print(vmax)

In [None]:
with pbar_models:
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.forecast_path(idate)) for idate in dates]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        model_results = model_results.sortby("time").sortby("lon").sortby("lat")

        cbar_kwargs = {
            "fraction": 0.02, 
            "pad": 0.045, 
            "orientation": "vertical",
            # "label": variable_config.label
        }
        
        demo_ds = model_results.isel(time=itime)[ivariable]
        date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
        ssh_plot = PlotterContour(da=demo_ds, config=MeridionalVelConfig())
        ssh_plot.correct_labels()
        fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=-vmax, vmax=vmax, levels=None)
        fig.set(dpi=300)
        path = Path("OceanBenchFigures/maps/")
        save_name = Path(f"maps_global_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
        fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
        plt.close()

## Salinity

In [None]:
pbar_models = tqdm(forecast_config.models, leave=False)
idepth = 0
lead_time = 0
itime = 0
fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)

ivariable = "so"

with pbar_models:
    max_values = []
    min_values = []
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.forecast_path(idate)) for idate in dates[:2]]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        # load the bounds
        max_values.append(model_results.isel(time=itime)[ivariable].quantile(0.99).values)
        min_values.append(model_results.isel(time=itime)[ivariable].quantile(0.01).values)

In [None]:
# load the bounds
vmax = np.max(max_values)
vmin = np.min(min_values)
print(vmin, vmax)

In [None]:
with pbar_models:
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.forecast_path(idate)) for idate in dates[:2]]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        model_results = model_results.sortby("time").sortby("lon").sortby("lat")

        cbar_kwargs = {
            "fraction": 0.02, 
            "pad": 0.045, 
            "orientation": "vertical",
            # "label": variable_config.label
        }
        
        demo_ds = model_results.isel(time=itime)[ivariable]
        date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
        ssh_plot = PlotterContour(da=demo_ds, config=SalinityConfig())
        ssh_plot.correct_labels()
        fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=vmin, vmax=vmax, levels=None)
        fig.set(dpi=300)
        path = Path("OceanBenchFigures/maps/")
        save_name = Path(f"maps_global_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
        fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
        plt.close()

## Geostrophic Zonal Velocity

In [None]:
pbar_models = tqdm(diagonstic_config.models, leave=False)
idepth = 0
lead_time = 0
itime = 0
fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)

ivariable = "u_geo"

with pbar_models:
    max_values = []
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.diagnostic_path(idate)) for idate in dates[:2]]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        # load the bounds
        max_values.append(np.abs(model_results.isel(time=itime)[ivariable]).quantile(0.98).values)



In [None]:
# load the bounds
vmax = np.max(max_values)
print(vmax)

In [None]:
with pbar_models:
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.diagnostic_path(idate)) for idate in dates[:2]]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        model_results = model_results.sortby("time").sortby("lon").sortby("lat")

        cbar_kwargs = {
            "fraction": 0.02, 
            "pad": 0.045, 
            "orientation": "vertical",
            # "label": variable_config.label
        }
        
        demo_ds = model_results.isel(time=itime)[ivariable]
        date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
        ssh_plot = PlotterContour(da=demo_ds, config=GeoZonalVelConfig())
        ssh_plot.correct_labels()
        fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=-vmax, vmax=vmax, levels=None)
        fig.set(dpi=300)
        path = Path("OceanBenchFigures/maps/")
        save_name = Path(f"maps_global_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
        fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
        plt.close()

## Geostrophic Meridional Velocity

In [None]:
pbar_models = tqdm(forecast_config.models, leave=False)
idepth = 0
lead_time = 0
itime = 0
fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)

ivariable = "v_geo"

with pbar_models:
    max_values = []
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.diagnostic_path(idate)) for idate in dates[:2]]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        # load the bounds
        max_values.append(np.abs(model_results.isel(time=itime)[ivariable]).quantile(0.98).values)




In [None]:
# load the bounds
vmax = np.max(max_values)
print(vmax)

In [None]:
with pbar_models:
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.diagnostic_path(idate)) for idate in dates[:2]]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        model_results = model_results.sortby("time").sortby("lon").sortby("lat")

        cbar_kwargs = {
            "fraction": 0.02, 
            "pad": 0.045, 
            "orientation": "vertical",
            # "label": variable_config.label
        }
        
        demo_ds = model_results.isel(time=itime)[ivariable]
        date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
        ssh_plot = PlotterContour(da=demo_ds, config=GeoMeridionalVelConfig())
        ssh_plot.correct_labels()
        fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=-vmax, vmax=vmax, levels=None)
        fig.set(dpi=300)
        path = Path("OceanBenchFigures/maps/")
        save_name = Path(f"maps_global_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
        fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
        plt.close()

## Mixed Layer Depth

In [None]:
pbar_models = tqdm(diagonstic_config.models, leave=False)
idepth = 0
lead_time = 0
itime = 0
fn = partial(geoprocess_fn, idepth=idepth, lead_time=lead_time)

ivariable = "MLD"

with pbar_models:
    max_values = []
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.diagnostic_path(idate)) for idate in dates[:2]]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        # load the bounds
        max_values.append(np.abs(model_results.isel(time=itime)[ivariable]).quantile(0.99).values)



In [None]:
# load the bounds
vmax = np.max(max_values)
print(vmax)

In [None]:
with pbar_models:
    for imodel in pbar_models:

        # load paths
        paths = [str(imodel.diagnostic_path(idate)) for idate in dates[:2]]

        # load model results
        model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

        model_results = fuse_base_coords(model_results, ivariable)

        model_results = model_results.sortby("time").sortby("lon").sortby("lat")

        cbar_kwargs = {
            "fraction": 0.02, 
            "pad": 0.045, 
            "orientation": "vertical",
            # "label": variable_config.label
        }
        
        demo_ds = model_results.isel(time=itime)[ivariable]
        date = pd.to_datetime(demo_ds.time.values).strftime("%Y-%m-%d")
        ssh_plot = PlotterContour(da=demo_ds, config=MLDConfig())
        ssh_plot.correct_labels()
        fig, ax = ssh_plot.plot_figure(cbar_kwargs=cbar_kwargs, vmin=0.0, vmax=vmax, levels=None)
        fig.set(dpi=300)
        path = Path("OceanBenchFigures/maps/")
        save_name = Path(f"maps_global_{ivariable}_m{imodel.name}_t{date}_l{itime+1}_z{idepth:.2f}.png")
        fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
        plt.close()