### Extract ROMS boundary files from ROMS climatology files

In [None]:
import os
import glob
from dataclasses import dataclass

import xarray as xr

In [None]:
@dataclass
class RomsVar:
    name: str
    eta_name: str
    xi_name: str
    lat_name: str
    lon_name: str
    s_name: None | str = None

@dataclass
class RomsVars:
    temp: RomsVar = RomsVar("temp", "eta_rho", "xi_rho", "lat_rho", "lon_rho", "s_rho")
    salt: RomsVar = RomsVar("salt", "eta_rho", "xi_rho", "lat_rho", "lon_rho", "s_rho")
    zeta: RomsVar = RomsVar("zeta", "eta_rho", "xi_rho", "lat_rho", "lon_rho")
    u: RomsVar = RomsVar("u", "eta_u", "xi_u", "lat_u", "lon_u", "s_rho")
    v: RomsVar = RomsVar("v", "eta_v", "xi_v", "lat_v", "lon_v", "s_rho")
    ubar: RomsVar = RomsVar("ubar", "eta_u", "xi_u", "lat_u", "lon_u")
    vbar: RomsVar = RomsVar("vbar", "eta_v", "xi_v", "lat_v", "lon_v")

In [None]:
roms_vars = RomsVars()
filepaths = sorted(glob.glob(
    '/cluster/projects/nn9297k/OF160/Clm/*_OF160_clm_*.nc'
    ))

In [None]:
def parse_filepath(filepath: str):
    """
    For example: '/.../098_OF160_clm_v.nc' -> '098', 'v'
    Returns:
        file_number: str
        var: str
    """
    filename = os.path.basename(filepath)
    items = filename.split("_")
    return items[0], os.path.splitext(items[-1])[0]

In [None]:
def extract_boundaries(ds: xr.Dataset, var: RomsVar):
    """
    Returns: xr bry data arrays
    """
    da = ds[var.name]
    return (
        da.isel({var.xi_name: 0}),  # west
        da.isel({var.xi_name: -1}),  # east
        da.isel({var.eta_name: 0}),  # south
        da.isel({var.eta_name: -1}),  # north
    )

In [None]:
filepath = filepaths[0]
fnum_str, var_str = parse_filepath(filepath)
ds = xr.open_dataset(filepath)
var = getattr(roms_vars, var_str)

da_west, da_east, da_south, da_north = extract_boundaries(ds, var)

In [None]:
da_west

In [None]:
da_west.transpose("ocean_time", "s_rho", "eta_rho")

In [None]:
def get_dataset(filepath: str):
    """
    Extract boundaries from a ROMS climatology files and puts them and
    auxiliary variables to the xarray dataset
    Returns:
        fnum_str: str = A number of input clm file
        var_str: str = A name of the variable
        result: xr.Dataset
    """
    fnum_str, var_str = parse_filepath(filepath)
    ds = xr.open_dataset(filepath)
    var = getattr(roms_vars, var_str)

    da_west, da_east, da_south, da_north = extract_boundaries(ds, var)

    if var.s_name is not None:
        da_west = da_west.transpose("ocean_time", var.s_name, var.eta_name)
        da_east = da_east.transpose("ocean_time", var.s_name, var.eta_name)
        da_south = da_south.transpose("ocean_time", var.s_name, var.xi_name)
        da_north = da_north.transpose("ocean_time", var.s_name, var.xi_name)
        result = xr.Dataset({
            f"{var.name}_west": (["ocean_time", var.s_name, var.eta_name], da_west.values, {"time": "ocean_time"}),
            f"{var.name}_east": (["ocean_time", var.s_name, var.eta_name], da_east.values, {"time": "ocean_time"}),
            f"{var.name}_south": (["ocean_time", var.s_name, var.xi_name], da_south.values, {"time": "ocean_time"}),
            f"{var.name}_north": (["ocean_time", var.s_name, var.xi_name], da_north.values, {"time": "ocean_time"}),
            f"{var.lat_name}": ([var.eta_name, var.xi_name], ds[var.lat_name].values),
            f"{var.lon_name}": ([var.eta_name, var.xi_name], ds[var.lon_name].values),
            "ocean_time": (["ocean_time"], da_west.ocean_time.values),
            })
    else:
        da_west = da_west.transpose("ocean_time", var.eta_name)
        da_east = da_east.transpose("ocean_time", var.eta_name)
        da_south = da_south.transpose("ocean_time", var.xi_name)
        da_north = da_north.transpose("ocean_time", var.xi_name)
        result = xr.Dataset({
            f"{var.name}_west": (["ocean_time", var.eta_name], da_west.values, {"time": "ocean_time"}),
            f"{var.name}_east": (["ocean_time", var.eta_name], da_east.values, {"time": "ocean_time"}),
            f"{var.name}_south": (["ocean_time", var.xi_name], da_south.values, {"time": "ocean_time"}),
            f"{var.name}_north": (["ocean_time", var.xi_name], da_north.values, {"time": "ocean_time"}),
            f"{var.lat_name}": ([var.eta_name, var.xi_name], ds[var.lat_name].values),
            f"{var.lon_name}": ([var.eta_name, var.xi_name], ds[var.lon_name].values),
            "ocean_time": (["ocean_time"], da_west.ocean_time.values),
            })

    return fnum_str, var_str, result

In [None]:
prev_fnumber = None
ds_merged = None
for filepath in filepaths:
    fnumber, variable, ds = get_dataset(filepath)
    if prev_fnumber is None or prev_fnumber != fnumber:
        if ds_merged is not None:
            ds_merged.to_netcdf(
                f'/cluster/projects/nn9297k/OF160/Bry/{fnumber}_OF160_bry.nc'
            )
            print(f"File number {fnumber} saved")
        ds_merged = ds
    else:
        ds_merged = ds_merged.merge(ds)
    prev_fnumber = fnumber
# save the last step
ds_merged.to_netcdf(
    f'/cluster/projects/nn9297k/OF160/Bry/{fnumber}_OF160_bry.nc'
)