In [None]:
import cf_xarray  # noqa: F401
import lonboard
import xarray as xr

from grid_indexing import infer_cell_geometries, infer_grid_type

xr.set_options(keep_attrs=True)

In [None]:
def center_longitude(ds):
    lon_name = ds.cf.coordinates["longitude"][0]
    longitude = (ds[lon_name] + 180) % 360 - 180
    return ds.assign_coords({lon_name: longitude})

In [None]:
def visualize_grid(geoms, data, cmap="viridis", alpha=0.8):
    from arro3.core import Array, ChunkedArray, Schema, Table
    from lonboard.colormap import apply_continuous_cmap
    from matplotlib import colormaps
    from matplotlib.colors import Normalize

    array = Array.from_arrow(geoms)
    data_arrow = ChunkedArray([Array.from_numpy(data)])
    arrays = {"geometry": array, "data": data_arrow}
    fields = [array.field.with_name(name) for name, array in arrays.items()]
    schema = Schema(fields)

    table = Table.from_arrays(list(arrays.values()), schema=schema)

    normalizer = Normalize(vmin=data.min(skipna=True), vmax=data.max(skipna=True))
    normalized = normalizer(data.data)
    colormap = colormaps[cmap]
    colors = apply_continuous_cmap(normalized, colormap, alpha=alpha)

    return lonboard.SolidPolygonLayer(table=table, filled=True, get_fill_color=colors)

In [None]:
preprocessors = {
    "air_temperature": lambda ds: ds["air"].isel(time=0).stack(cells=["lon", "lat"]),
    "rasm": lambda ds: ds["Tair"].isel(time=0).stack(cells=["y", "x"]),
    "ROMS_example": lambda ds: ds["salt"]
    .isel(ocean_time=0, s_rho=0)
    .stack(cells=["eta_rho", "xi_rho"]),
}

In [None]:
datasets = preprocessors.keys()
cmaps = {"ROMS_example": "viridis", "air_temperature": "plasma", "rasm": "cividis"}

dss = {
    name: xr.tutorial.open_dataset(name).pipe(center_longitude)
    for name in preprocessors
}

print(
    "grid types:",
    *[f"{name}: {infer_grid_type(ds)}" for name, ds in dss.items()],
    sep="\n",
)

layers = [
    visualize_grid(
        infer_cell_geometries(ds), ds.pipe(preprocessors[name]), cmap=cmaps[name]
    )
    for name, ds in dss.items()
]

lonboard.Map(layers)