# Maps


## Variables

* [ ] Temperature, $K$
* [ ] Salinity, $PSU$
* [ ] Currents, [$m s^{-1}$]
* [ ] SSH, [$cm$]
* [ ] MLD, [$m$]
* [ ] Density, [$g cm^{-3}$]
* [ ] Geostrophic Currents, [$m s^{-1}$]
* [ ] Vorticity, [$s^{-1}$]
* [ ] Strain, 
* [ ] Lagrangian Trajectories, [$km$]

## Models

**"Truth"**
* [ ] GLORYS - Reanalysis (Simulation + Observations)
* [ ] GLO12 - Analysis (Forecast + Observations)

**Models**
* Model (x3) (Forecast)
  * [ ] GLO12 (Physical)
  * [ ] GLONET (ML)
  * [ ] XiHe (ML)
  * [ ] WenHai (ML)



In [None]:
import autoroot
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
from src.preprocessing import fuse_base_coords
import seaborn as sns
from functools import partial
from src.psd import PlotPSDIsotropic, PlotPSDSpaceTime
import xrft
import tqdm
from dask.diagnostics import ProgressBar

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

In [None]:
from src.types import GLONET, GLO12, WENHAI, XIHE, generate_wednesdays_in_year, SeaSurfaceHeight
from src.types import ForecastDataset, DiagnosticDataset

In [None]:
from src.preprocessing import (
    latlon_deg2m,
    rectilinear_to_regular_grid,
    time_rescale,
    validate_latitude,
    validate_longitude,
    xr_cond_average
)


def geoprocess_fn(da, fill_value: int | float | None = None):
    
    # validate coordinates
    da = validate_longitude(validate_latitude(da))
    
    # interpolate to regular grid
    da = rectilinear_to_regular_grid(da, method="linear")
    
    # change lat-lon units (degrees --> Meters)
    da = latlon_deg2m(da)
    
    # change time units (datetime --> days)
    da = time_rescale(da, t0=None, freq_dt=1, freq_unit="D")

    if fill_value is not None:
        da = da.fillna(fill_value)
    
    # sort coordinates
    da = da.sortby("time").sortby("lon").sortby("lat")

    return da

def preprocess_single_leadtime_fn(ds, lead_time: int = 0, idepth: int = 0):
    # select lead time
    try:
        ds = ds.isel(time=lead_time, depth=idepth)
    except ValueError:
        ds = ds.isel(time=lead_time)
    
    # assign time coord correctly
    ds = ds.assign_coords({"time": np.atleast_1d(ds.time)})
    return ds

def preprocess_all_leadtime_fn(ds, idepth: int = 0):

    # select lead time
    try:
        ds = ds.isel(depth=idepth)
    except ValueError:
        pass
        
    # select the first time step
    t0 =  ds.isel(time=0).time.expand_dims()
    
    # rename time to lead time
    ds = ds.rename({"time": "lead_time"})
    
    # expand dimensions
    ds = ds.expand_dims("time")

    # assign coordinates as time
    ds = ds.assign_coords({"time": np.atleast_1d(t0), "lead_time": np.arange(1,11)})
    return ds


def zonal_lon_psd(da: xr.DataArray) -> xr.DataArray:
    with ProgressBar():
    
        # frequency dims
        psd_iso_signal = xrft.power_spectrum(
            da.chunk({
                "time": 1,
                "lon": da.lon.shape[0], 
                "lat": 1,}),
            dim=["lon", ],
            detrend="linear",
            window="tukey",
            nfactor=2,
            window_correction=True,
            true_amplitude=True,
            truncate=True,
        )
    
        # average other dims
        psd_iso_signal = xr_cond_average(psd_iso_signal, dims=["time", "lat"], drop=True,).compute()
    return psd_iso_signal

def space_time_psd(da: xr.DataArray) -> xr.DataArray:
    with ProgressBar():
    
        # frequency dims
        psd_iso_signal = xrft.power_spectrum(
            da.chunk({
                "time": da.time.shape[0],
                "lon": da.lon.shape[0], 
                "lat": 1,}),
            dim=["lon", "time",],
            detrend="linear",
            window="tukey",
            nfactor=2,
            window_correction=True,
            true_amplitude=True,
            truncate=True,
        )
    
        # average other dims
        psd_iso_signal = xr_cond_average(psd_iso_signal, dims=["lat"], drop=True,).compute()
    return psd_iso_signal

In [None]:
dates = list()
for wednesday in generate_wednesdays_in_year(2024):
    start_datetime = wednesday 
    std_wedn = wednesday.strftime('%Y%m%d')  # Format as YYYYMMDD
    dates.append(std_wedn)

len(dates)

### Demo 

In [None]:
# initialize forecast dataset config
forecast_config = ForecastDataset() # DiagnosticDataset() # 

demo_model = forecast_config.models["glo12"]


paths = [str(demo_model.forecast_path(idate)) for idate in dates]
idepth = 0
lead_time = 0

fn = partial(preprocess_all_leadtime_fn, idepth=idepth)

model_results = xr.open_mfdataset(paths, preprocess=fn, combine="by_coords", engine="zarr")

In [None]:
model_results

In [None]:
demo_variable = forecast_config.variables['zos']

In [None]:

# select variable
da = model_results["zos"]
da = demo_variable.correct_real_attrs(da)

# geoprocess variables
fill_value = 0.0
da = geoprocess_fn(da, fill_value)
da

In [None]:
da = da.astype(np.float32)

psd_iso_signal = zonal_lon_psd(da)
psd_iso_signal

In [None]:


psd_iso_signal.name = demo_model.name
path = Path("OceanBenchFigures/psd/zonal_lon/")
save_name = Path(f"psd_zonallon_global_{ivariable.name}_{imodel.name}_z{idepth}.nc")
psd_iso_signal.to_netcdf(str(path.joinpath(save_name)))

## Result Datasets

### Zonal Spectrum

#### Forecast Dataset

In [None]:
# initialize forecast dataset config
forecast_config = ForecastDataset() # DiagnosticDataset() # 

# initialize configurations
# model_config = list(forecast_config.models.values())[0]
# variable_config = list(forecast_config.models.values())[0]

pbar_variable = tqdm.tqdm(forecast_config.variables)
pbar_model = tqdm.tqdm(forecast_config.models.values())


with pbar_variable:
    for ivariable in pbar_variable:
        pbar_variable.set_description(f"Variable: {ivariable.long_name}")
        # psd_iso_plot = PlotPSDIsotropic()
        # psd_iso_plot.init_fig(figsize=(8, 6))
        results = {}
        with pbar_model:
            for imodel in pbar_model:
                paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                pbar_model.set_description(f"Model: {imodel.name}")
                fn = partial(preprocess_all_leadtime_fn)
                
                model_results = xr.open_mfdataset(paths, preprocess=fn, combine="by_coords", engine="zarr")
                                
                # select variable
                da = model_results[ivariable.name]
                da = ivariable.correct_real_attrs(da)
                
                # geoprocess variables
                fill_value = 0.0
                da = geoprocess_fn(da, fill_value)
                
                da = da.astype(np.float32)

                psd_iso_signal = zonal_lon_psd(da)
                
                psd_iso_signal.name = imodel.name
                path = Path("OceanBenchFigures/psd/zonal_lon/")
                save_name = Path(f"psd_zonallon_global_{ivariable.name}_{imodel.name}_z{idepth}.nc")
                psd_iso_signal.to_netcdf(str(path.joinpath(save_name)))

#### Diagnostic Dataset

In [None]:
# initialize forecast dataset config
diagnostic_config = DiagnosticDataset() # ForecastDataset() # 

# initialize configurations
# model_config = list(forecast_config.models.values())[0]
# variable_config = list(forecast_config.models.values())[0]

pbar_variable = tqdm.tqdm(diagnostic_config.variables)
pbar_model = tqdm.tqdm(diagnostic_config.models.values())


with pbar_variable:
    for ivariable in pbar_variable:
        pbar_variable.set_description(f"Variable: {ivariable.long_name}")
        # psd_iso_plot = PlotPSDIsotropic()
        # psd_iso_plot.init_fig(figsize=(8, 6))
        results = {}
        with pbar_model:
            for imodel in pbar_model:
                paths = [str(imodel.diagnostic_path(idate)) for idate in dates]
        
                pbar_model.set_description(f"Model: {imodel.name}")
                idepth = 0
                lead_time = 0
                
                fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)
        
                model_results = xr.open_mfdataset(paths, preprocess=fn, combine="by_coords", engine="zarr")
                
                # select variable
                da = model_results[ivariable.name]
                da = ivariable.correct_real_attrs(da)
                
                # geoprocess variables
                fill_value = 0.0
                da = geoprocess_fn(da, fill_value)
                
                da = da.astype(np.float32)
        
                psd_iso_signal = zonal_lon_psd(da)
                
                psd_iso_signal.name = imodel.name
                path = Path("OceanBenchFigures/psd/zonal_lon/")
                save_name = Path(f"psd_zonallon_global_{ivariable.name}_{imodel.name}_t{lead_time+1}.nc")
                psd_iso_signal.to_netcdf(str(path.joinpath(save_name)))

### Space-Time Spectrum

#### Forecast Dataset

In [None]:
# initialize forecast dataset config
forecast_config = ForecastDataset() # DiagnosticDataset() # 

# initialize configurations
# model_config = list(forecast_config.models.values())[0]
# variable_config = list(forecast_config.models.values())[0]

pbar_variable = tqdm.tqdm(forecast_config.variables)
pbar_model = tqdm.tqdm(forecast_config.models.values())

with pbar_variable:
    for ivariable in pbar_variable:
        pbar_variable.set_description(f"Variable: {ivariable.long_name}")
        # psd_iso_plot = PlotPSDIsotropic()
        # psd_iso_plot.init_fig(figsize=(8, 6))
        results = {}
        with pbar_model:
            for imodel in pbar_model:
                paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                pbar_model.set_description(f"Model: {imodel.name}")
                idepth = 0
                lead_time = 0
                
                fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)
        
                model_results = xr.open_mfdataset(paths, preprocess=fn, combine="by_coords", engine="zarr")
                
                # select variable
                da = model_results[ivariable.name]
                da = ivariable.correct_real_attrs(da)
                
                # geoprocess variables
                fill_value = 0.0
                da = geoprocess_fn(da, fill_value)
                
                da = da.astype(np.float32)
        
                psd_iso_signal = space_time_psd(da)
                
                psd_iso_signal.name = imodel.name
                path = Path("OceanBenchFigures/psd/space_time/")
                save_name = Path(f"psd_spacetime_global_{ivariable.name}_{imodel.name}_t{lead_time+1}_z{idepth:.2f}.nc")
                psd_iso_signal.to_netcdf(str(path.joinpath(save_name)))

#### Diagnostic Dataset

In [None]:
# initialize forecast dataset config
diagnostic_config = DiagnosticDataset() # ForecastDataset() # 

# initialize configurations
# model_config = list(forecast_config.models.values())[0]
# variable_config = list(forecast_config.models.values())[0]

pbar_variable = tqdm.tqdm(diagnostic_config.variables)
pbar_model = tqdm.tqdm(diagnostic_config.models.values())



with pbar_variable:
    for ivariable in pbar_variable:
        pbar_variable.set_description(f"Variable: {ivariable.long_name}")
        # psd_iso_plot = PlotPSDIsotropic()
        # psd_iso_plot.init_fig(figsize=(8, 6))
        results = {}
        with pbar_model:
            for imodel in pbar_model:
                paths = [str(imodel.diagnostic_path(idate)) for idate in dates]
        
                pbar_model.set_description(f"Model: {imodel.name}")
                idepth = 0
                lead_time = 0
                
                fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)
        
                model_results = xr.open_mfdataset(paths, preprocess=fn, combine="by_coords", engine="zarr")
                
                # select variable
                da = model_results[ivariable.name]
                da = ivariable.correct_real_attrs(da)
                
                # geoprocess variables
                fill_value = 0.0
                da = geoprocess_fn(da, fill_value)
                
                da = da.astype(np.float32)
        
                psd_iso_signal = space_time_psd(da)
                
                psd_iso_signal.name = imodel.name
                path = Path("OceanBenchFigures/psd/space_time/")
                save_name = Path(f"psd_spacetime_global_{ivariable.name}_{imodel.name}_t{lead_time+1}.nc")
                psd_iso_signal.to_netcdf(str(path.joinpath(save_name)))

## Visualization

### Zonal Spectrum

#### Forecast Dataset

In [None]:
# initialize forecast dataset config
forecast_config = ForecastDataset() # DiagnosticDataset() # 

# initialize configurations
model_config = forecast_config.models[0]
variable_config = forecast_config.variables[0]

pbar_variable = tqdm.tqdm(forecast_config.variables)
pbar_model = tqdm.tqdm(forecast_config.models)


with pbar_variable:
    for ivariable in pbar_variable:
        pbar_variable.set_description(f"Variable: {ivariable.long_name}")
        psd_iso_plot = PlotPSDIsotropic()
        psd_iso_plot.init_fig(figsize=(8, 6))
        results = {}
        with pbar_model:
            for imodel in pbar_model:
                paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                pbar_model.set_description(f"Model: {imodel.name}")
                idepth = 0
                lead_time = 0
                
                fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)
                data_path = Path("/home/onyxia/work/OceanBenchFigures/psd/zonal_lon/")
                save_name = Path(f"psd_zonallon_global_{ivariable.name}_{imodel.name}_t{lead_time+1}_z{idepth:.2f}.nc")
                model_results = xr.open_dataset(str(data_path.joinpath(save_name)), engine="netcdf4")

                psd_iso_plot.plot_wavelength(
                    model_results[imodel.name],
                    freq_scale=1e3,
                    units="km",
                    label=imodel.name.upper(),
                    color=imodel.color,
                )
            try:
                depth = model_results.depth.values
            except:
                depth = 0
            psd_iso_plot.ax.invert_xaxis()
            psd_iso_plot.fig.set(dpi=300)
            figure_path = Path("/home/onyxia/work/OceanBenchFigures/psd/zonal_lon/figures/")
            save_name = Path(f"psd_zonallon_global_{ivariable.name}_t{lead_time+1}_z{depth:.2f}.png")
            psd_iso_plot.fig.savefig(figure_path.joinpath(save_name), bbox_inches='tight', transparent=True)
            plt.show()

#### Diagnostic Dataset

In [None]:
# initialize forecast dataset config
forecast_config = DiagnosticDataset() # ForecastDataset() # 

# initialize configurations
model_config = forecast_config.models[0]
variable_config = forecast_config.variables[0]

pbar_variable = tqdm.tqdm(forecast_config.variables)
pbar_model = tqdm.tqdm(forecast_config.models)


with pbar_variable:
    for ivariable in pbar_variable:
        pbar_variable.set_description(f"Variable: {ivariable.long_name}")
        psd_iso_plot = PlotPSDIsotropic()
        psd_iso_plot.init_fig(figsize=(8, 6))
        results = {}
        with pbar_model:
            for imodel in pbar_model:
                paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                pbar_model.set_description(f"Model: {imodel.name}")
                lead_time = 0
                
                fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)
                data_path = Path("/home/onyxia/work/OceanBenchFigures/psd/zonal_lon/")
                save_name = Path(f"psd_zonallon_global_{ivariable.name}_{imodel.name}_t{lead_time+1}.nc")
                model_results = xr.open_dataset(str(data_path.joinpath(save_name)), engine="netcdf4")

                psd_iso_plot.plot_wavelength(
                    model_results[imodel.name],
                    freq_scale=1e3,
                    units="km",
                    label=imodel.name.upper(),
                    color=imodel.color,
                )
            try:
                depth = model_results.depth.values
            except:
                depth = 0
            psd_iso_plot.ax.invert_xaxis()
            psd_iso_plot.fig.set(dpi=300)
            figure_path = Path("/home/onyxia/work/OceanBenchFigures/psd/zonal_lon/figures/")
            save_name = Path(f"psd_zonallon_global_{ivariable.name}_t{lead_time+1}.png")
            psd_iso_plot.fig.savefig(figure_path.joinpath(save_name), bbox_inches='tight', transparent=True)
            plt.show()

### Space-Time Spectrum

#### Forecast Dataset

In [None]:
1/vmin, 1/vmax

In [None]:
# initialize forecast dataset config
forecast_config = ForecastDataset() # DiagnosticDataset() # 

# initialize configurations
model_config = forecast_config.models[0]
variable_config = forecast_config.variables[0]

pbar_variable = tqdm.tqdm(forecast_config.variables)
pbar_model = tqdm.tqdm(forecast_config.models)


with pbar_variable:
    for ivariable in pbar_variable:
        pbar_variable.set_description(f"Variable: {ivariable.long_name}")

        results = {}
        vmin, vmax = [], []
        with pbar_model:
            for imodel in pbar_model:
                paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                pbar_model.set_description(f"Model: {imodel.name}")
                idepth = 0
                lead_time = 0
                
                fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)
                data_path = Path("/home/onyxia/work/OceanBenchFigures/psd/space_time/")
                save_name = Path(f"psd_spacetime_global_{ivariable.name}_{imodel.name}_t{lead_time+1}_z{idepth:.2f}.nc")
                model_results = xr.open_dataset(str(data_path.joinpath(save_name)), engine="netcdf4")
                vmin.append(model_results[imodel.name].quantile(0.001).values)
                vmax.append(model_results[imodel.name].quantile(0.999).values)
        vmin = np.min(vmin)
        vmax = np.max(vmax)
        with pbar_model:
            for imodel in pbar_model:
                paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                pbar_model.set_description(f"Model: {imodel.name}")
                idepth = 0
                lead_time = 0
                
                fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)
                data_path = Path("/home/onyxia/work/OceanBenchFigures/psd/space_time/")
                save_name = Path(f"psd_spacetime_global_{ivariable.name}_{imodel.name}_t{lead_time+1}_z{idepth:.2f}.nc")
                model_results = xr.open_dataset(str(data_path.joinpath(save_name)), engine="netcdf4")
                psd_st_plot = PlotPSDSpaceTime()
                psd_st_plot.init_fig(figsize=(8, 6))
                psd_st_plot.plot_wavelength(
                    model_results[imodel.name],
                    space_scale=1e3, 
                    space_units="km", 
                    time_units="days",
                    psd_units=f"{ivariable.name.upper()}",
                    # vmin=vmin, vmax=vmax
                )
                psd_st_plot.fig.set(dpi=300)
                
                figure_path = Path("/home/onyxia/work/OceanBenchFigures/psd/space_time/figures/")
                save_name = Path(f"psd_spacetime_global_{imodel.name}_{ivariable.name}_t{lead_time+1}_z{depth:.2f}.png")
                psd_st_plot.fig.savefig(figure_path.joinpath(save_name), bbox_inches='tight', transparent=True)
                plt.close()

#### Diagnostic Dataset

In [None]:
# initialize forecast dataset config
forecast_config =  DiagnosticDataset() # ForecastDataset() #

# initialize configurations
model_config = forecast_config.models[0]
variable_config = forecast_config.variables[0]

pbar_variable = tqdm.tqdm(forecast_config.variables)
pbar_model = tqdm.tqdm(forecast_config.models)


with pbar_variable:
    for ivariable in pbar_variable:
        pbar_variable.set_description(f"Variable: {ivariable.long_name}")

        results = {}
        vmin, vmax = [], []
        with pbar_model:
            for imodel in pbar_model:
                paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                pbar_model.set_description(f"Model: {imodel.name}")
                idepth = 0
                lead_time = 0
                
                fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)
                data_path = Path("/home/onyxia/work/OceanBenchFigures/psd/space_time/")
                save_name = Path(f"psd_spacetime_global_{ivariable.name}_{imodel.name}_t{lead_time+1}.nc")
                model_results = xr.open_dataset(str(data_path.joinpath(save_name)), engine="netcdf4")
                vmin.append(model_results[imodel.name].quantile(0.001).values)
                vmax.append(model_results[imodel.name].quantile(0.999).values)
        vmin = np.min(vmin)
        vmax = np.max(vmax)
        with pbar_model:
            for imodel in pbar_model:
                paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                pbar_model.set_description(f"Model: {imodel.name}")
                idepth = 0
                lead_time = 0
                
                fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)
                data_path = Path("/home/onyxia/work/OceanBenchFigures/psd/space_time/")
                save_name = Path(f"psd_spacetime_global_{ivariable.name}_{imodel.name}_t{lead_time+1}.nc")
                model_results = xr.open_dataset(str(data_path.joinpath(save_name)), engine="netcdf4")
                psd_st_plot = PlotPSDSpaceTime()
                psd_st_plot.init_fig(figsize=(8, 6))
                psd_st_plot.plot_wavelength(
                    model_results[imodel.name],
                    space_scale=1e3, 
                    space_units="km", 
                    time_units="days",
                    psd_units=f"{ivariable.name.upper()}",
                    # vmin=vmin, vmax=vmax
                )
                psd_st_plot.fig.set(dpi=300)
                
                figure_path = Path("/home/onyxia/work/OceanBenchFigures/psd/space_time/figures/")
                save_name = Path(f"psd_spacetime_global_{imodel.name}_{ivariable.name}_t{lead_time+1}.png")
                psd_st_plot.fig.savefig(figure_path.joinpath(save_name), bbox_inches='tight', transparent=True)
                plt.close()

In [None]:
OceanBenchFigures/psd/zonal_lon/psd_zonallon_global_zos_glonet_t1_z0.00.nc

In [None]:
model_results

In [None]:
# initialize forecast dataset config
forecast_config = ForecastDataset() # DiagnosticDataset() # 

# initialize configurations
model_config = forecast_config.models[3]
variable_config = forecast_config.variables[0]

print(f"Model: {model_config.name}")
print(f"Variable: {variable_config.long_name}")

In [None]:

paths = [str(model_config.forecast_path(idate)) for idate in dates]

In [None]:
from functools import partial
idepth = 0
lead_time = 0

fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)


In [None]:
# open file
model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

# select variable
da = model_results[variable_config.name]
da = variable_config.correct_real_attrs(da)

# geoprocess variables
fill_value = 0.0
da = geoprocess_fn(da, fill_value)

da = da.astype(np.float64)
da

## Space-Time Power Spectrum

In [None]:
import xrft
from dask.diagnostics import ProgressBar

with ProgressBar():

    # frequency dims
    psd_st_signal = xrft.power_spectrum(
        da.chunk({
            "time": da.time.shape[0],
            "lon": da.lon.shape[0], 
            "lat": 1}),
        dim=["time", "lon"],
        detrend="linear",
        window="tukey",
        nfactor=2,
        window_correction=True,
        true_amplitude=True,
        truncate=True,
    )

    # average other dims
    psd_st_signal = xr_cond_average(psd_st_signal, dims=["lat"], drop=True,).compute()

In [None]:
psd_st_signal.name = da.name
psd_st_signal = variable_config.correct_spectral_attrs(psd_st_signal)
psd_st_signal

In [None]:
try:
    depth = psd_st_signal.depth.values
except:
    depth = 0

In [None]:
path = Path("OceanBenchFigures/psd/space_time/")
save_name = Path(f"psd_spacetime_global_{model_config.name.lower()}_{variable_config.name}_t{lead_time+1}_z{depth:.2f}.png")
save_name

In [None]:
from utils.psd import PlotPSDSpaceTime

psd_st_plot = PlotPSDSpaceTime()

psd_st_plot.init_fig(figsize=(8, 6))
psd_st_plot.plot_wavelength(
    psd_st_signal,
    space_scale=1e3, 
    space_units="km", 
    time_units="days",
    psd_units="SSH"
)
psd_st_plot.fig.set(dpi=300)
psd_st_plot.fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
plt.show()

In [None]:
# pglonet = fuse_base_coords(pglonet, "MLD")
# pglo12 = fuse_base_coords(pglo12, "MLD")
# pwenhai = fuse_base_coords(pwenhai, "MLD")
# pxihe = fuse_base_coords(pxihe, "MLD")
# pglonet

In [None]:
# initialize forecast dataset config
forecast_config = ForecastDataset() # DiagnosticDataset() # 

# initialize configurations
model_config = forecast_config.models[1]
variable_config = forecast_config.variables[0]

paths = [str(model_config.forecast_path(idate)) for idate in dates]

print(f"Model: {model_config.name}")
print(f"Variable: {variable_config.long_name}")
idepth = 0
lead_time = 0

fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)

model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

In [None]:
# select variable
da = model_results[variable_config.name]
da = variable_config.correct_real_attrs(da)

# geoprocess variables
fill_value = 0.0
da = geoprocess_fn(da, fill_value)

da = da.astype(np.float64)

In [None]:
with ProgressBar():

    # frequency dims
    psd_lon_signal = xrft.power_spectrum(
        da.chunk({
            "time": 1,
            "lon": da.lon.shape[0], 
            "lat": 1,}),
        dim=["lon",],
        detrend="linear",
        window="tukey",
        nfactor=2,
        window_correction=True,
        true_amplitude=True,
        truncate=True,
    )
        
    # average other dims
    psd_lon_signal = xr_cond_average(psd_lon_signal, dims=["time", "lat"], drop=True,).compute()
    

In [None]:



psd_iso_plot = PlotPSDIsotropic()
psd_iso_plot.init_fig(figsize=(8, 6))

psd_iso_plot.plot_wavelength(
    psd_lon_signal,
    freq_scale=1e3,
    units="km",
    label=model_config.name.upper(),
    color=model_config.color,
)
try:
    depth = psd_iso_signal.depth.values
except:
    depth = 0

psd_iso_plot.ax.set(
    xticks=[10_000, 1_000, 100, 10],
    xlim=[None, 10],
    ylabel="Power Spectrum"
)
psd_iso_plot.fig.set(dpi=300)
# path = Path("OceanBenchFigures/psd/isotropic/")
# save_name = Path(f"psd_spacetime_global_{ivariable.name}_t{lead_time+1}_z{depth:.2f}.png")
# psd_iso_plot.fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
plt.show()

## All Models + Variables

In [None]:
# initialize forecast dataset config
forecast_config = ForecastDataset() # DiagnosticDataset() # 

# initialize configurations
model_config = forecast_config.models[0]
variable_config = forecast_config.variables[0]




pbar_variables = tqdm.tqdm(forecast_config.variables)
pbar_models = tqdm.tqdm(forecast_config.models, leave=False)
with pbar_variables:
    for ivariable in pbar_variables:

        with pbar_models:
            for imodel in pbar_models:
                paths = [str(imodel.forecast_path(idate)) for idate in dates]
        
                print(f"Model: {imodel.name}")
                print(f"Variable: {ivariable.long_name}")
                idepth = 0
                lead_time = 0
                
                fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)
        
                model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")
                
                # select variable
                da = model_results[ivariable.name]
                da = ivariable.correct_real_attrs(da)
                
                # geoprocess variables
                fill_value = 0.0
                da = geoprocess_fn(da, fill_value)
                
                da = da.astype(np.float64)
        
                with ProgressBar():
                
                    # frequency dims
                    psd_signal = xrft.power_spectrum(
                        da.chunk({
                            "time": da.time.shape[0],
                            "lon": da.lon.shape[0], 
                            "lat": 1,}),
                        dim=["time", "lon"],
                        detrend="linear",
                        window="tukey",
                        nfactor=2,
                        window_correction=True,
                        true_amplitude=True,
                        truncate=True,
                    ).compute()
                
                    # average other dims
                    psd_signal = xr_cond_average(psd_signal, dims=["lat"], drop=True,).drop()
        
                psd_st_plot = PlotPSDSpaceTime()
                psd_st_plot.init_fig(figsize=(8, 6))
                psd_st_plot.plot_wavelength(
                    psd_signal,
                    space_scale=1e3, 
                    space_units="km", 
                    time_units="days",
                    psd_units="SSH"
                )
                try:
                    depth = psd_iso_signal.depth.values
                except:
                    depth = 0
                psd_st_plot.fig.set(dpi=300)
                path = Path("OceanBenchFigures/psd/space_time/")
                save_name = Path(f"psd_spacetime_global_{imodel.name.lower()}_{ivariable.name}_t{lead_time+1}_z{depth:.2f}.png")
                psd_st_plot.fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
                plt.show()
        
            break

## Isotropic Power Spectrum - Model Comparison

In [None]:
# initialize forecast dataset config
forecast_config = ForecastDataset() # DiagnosticDataset() # 

# initialize configurations
model_config = forecast_config.models[0]
variable_config = forecast_config.variables[0]

paths = [str(model_config.forecast_path(idate)) for idate in dates]

print(f"Model: {model_config.name}")
print(f"Variable: {variable_config.long_name}")
idepth = 0
lead_time = 0

fn = partial(preprocess_single_leadtime_fn, idepth=idepth, lead_time=lead_time)

model_results = xr.open_mfdataset(paths[:-1], preprocess=fn, combine="by_coords", engine="zarr")

In [None]:
# select variable
da = model_results[variable_config.name]
da = variable_config.correct_real_attrs(da)

# geoprocess variables
fill_value = 0.0
da = geoprocess_fn(da, fill_value)

da = da.astype(np.float64)

In [None]:
with ProgressBar():

    # frequency dims
    psd_iso_signal = xrft.isotropic_power_spectrum(
        da.chunk({
            "time": 1,
            "lon": da.lon.shape[0], 
            "lat": 1,}),
        dim=["lon", "lat"],
        detrend="linear",
        window="tukey",
        nfactor=2,
        window_correction=True,
        true_amplitude=True,
        truncate=True,
    ).compute()
        
    # average other dims
    psd_iso_signal = xr_cond_average(psd_iso_signal, dims=["time", "lat"], drop=True,)
    

In [None]:



psd_iso_plot = PlotPSDIsotropic()
psd_iso_plot.init_fig(figsize=(8, 6))

psd_iso_plot.plot_wavelength(
    psd_iso_signal,
    freq_scale=1e3,
    units="km",
    label=i[0].name.upper(),
    color=i[0].color,
)
try:
    depth = psd_iso_signal.depth.values
except:
    depth = 0
psd_iso_plot.ax.invert_xaxis()
psd_iso_plot.fig.set(dpi=300)
# path = Path("OceanBenchFigures/psd/isotropic/")
# save_name = Path(f"psd_spacetime_global_{ivariable.name}_t{lead_time+1}_z{depth:.2f}.png")
# psd_iso_plot.fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
plt.show()

### Model Comparison

In [None]:
results

In [None]:
psd_iso_plot = PlotPSDIsotropic()
psd_iso_plot.init_fig(figsize=(8, 6))

for i in results:

    psd_iso_plot.plot_wavelength(
        i[1],
        freq_scale=1e3,
        units="km",
        label=i[0].name.upper(),
        color=i[0].color,
    )
try:
    depth = psd_iso_signal.depth.values
except:
    depth = 0
psd_iso_plot.ax.invert_xaxis()
psd_iso_plot.fig.set(dpi=300)
path = Path("OceanBenchFigures/psd/isotropic/")
save_name = Path(f"psd_spacetime_global_{ivariable.name}_t{lead_time+1}_z{depth:.2f}.png")
psd_iso_plot.fig.savefig(path.joinpath(save_name), bbox_inches='tight', transparent=True)
plt.show()