# Using `uxarray` to regrid

In [None]:
import astropy.coordinates
import cartopy.crs as ccrs
import cdshealpix
import cf_xarray  # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import odc.geo
import odc.geo.xr  # noqa: F401
import pystac_client
import uxarray as ux
import xarray as xr
import xdggs  # noqa: F401

xr.set_options(keep_attrs=True)

In [None]:
from distributed import Client

client = Client()
client

## satellite data

In [None]:
client = pystac_client.Client.open("https://stac.core.eopf.eodc.eu")
client

In [None]:
bbox = [-40, -40, 40, 40]
items = client.search(
    collections=["sentinel-2-l2a"],
    max_items=4,
    bbox=bbox,
    ids=["S2B_MSIL2A_20250424T100029_N0511_R122_T32SPJ_20250424T124939"],
).item_collection()
items

In [None]:
item = items[0]
item

In [None]:
ds = xr.open_dataset(
    item.assets["product"],
    engine="stac",
    chunks={},
    group="measurements/reflectance/r60m",
).pipe(lambda ds: ds.odc.assign_crs(ds["b01"].attrs["proj:wkt2"]))
ds

In [None]:
ds["b01"].squeeze().compute().odc.explore()

In [None]:
reprojected = ds.odc.reproject("epsg:4326").assign_coords(
    longitude=lambda ds: ds["longitude"].assign_attrs({"standard_name": "longitude"}),
    latitude=lambda ds: ds["latitude"].assign_attrs({"standard_name": "latitude"}),
)
reprojected

In [None]:
input_grid = ux.UxDataset.from_structured(reprojected)
input_grid.to_xarray()

In [None]:
arr = input_grid["b01"]
arr.to_xarray()

In [None]:
level = 15
geom = reprojected.odc.geobox.extent
lon = astropy.coordinates.Longitude(geom.exterior.xy[0], unit="degree")
lat = astropy.coordinates.Latitude(geom.exterior.xy[1], unit="degree")
cell_ids, _, _ = cdshealpix.nested.polygon_search(lon, lat, depth=level, flat=True)

In [None]:
target_grid = xr.Dataset(coords={"cell_ids": ("cells", cell_ids)}).dggs.decode(
    {"grid_name": "healpix", "level": level, "indexing_scheme": "nested"}
)
target_grid

In [None]:
def grid_from_xdggs(ds):
    return (
        ds.dggs.assign_latlon_coords()
        .rename_dims({"cells": "n_face"})
        .rename_vars({"latitude": "face_lat", "longitude": "face_lon"})
        .assign_attrs(
            {
                "zoom": ds.dggs.grid_info.level,
                "n_side": ds.dggs.grid_info.nside,
                "n_pix": target_grid.sizes["cells"],
                "nest": ds.dggs.grid_info.nest,
            }
        )
        .assign_coords(
            grid_topology=(
                (),
                -1,
                {
                    "topology_dimension": 2,
                    "face_dimension": "n_face",
                    "node_dimension": "n_node",
                    "node_coordinates": "node_lon node_lat",
                    "face_node_connectivity": "face_node_connectivity",
                    "face_coordinates": "face_lon face_lat",
                },
            )
        )
    )

In [None]:
grid_ds = target_grid.pipe(grid_from_xdggs)
uxgrid = ux.Grid.from_dataset(grid_ds, source_grid_spec="HEALPix")
uxgrid.to_xarray()

In [None]:
uds = ux.UxDataset(uxgrid=uxgrid).assign_coords(
    target_grid["cell_ids"].rename({"cells": "n_face"}).coords,
)
uds.to_xarray()

In [None]:
%%time
remapped = (
    arr.remap.inverse_distance_weighted(uds.uxgrid)
    .assign_coords(cell_ids=uds["cell_ids"])
    .to_xarray()
    .dggs.decode()
)
remapped

In [None]:
remapped.dggs.explore()

## rectilinear: `air-temperature`

In [None]:
ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 20})
ds

In [None]:
upscaled = ds.interp(
    lon=np.linspace(200, 330, 1060), lat=np.linspace(15, 75, 500)
).assign_coords(lon=lambda ds: (ds["lon"] + 180) % 360 - 180)
upscaled

In [None]:
level = 7
lon = astropy.coordinates.Longitude(
    [200, 225, 250, 275, 300, 330, 330, 300, 275, 250, 225, 200], unit="degree"
)
lat = astropy.coordinates.Latitude(
    [15, 15, 15, 15, 15, 15, 75, 75, 75, 75, 75, 75], unit="degree"
)
cell_ids, _, _ = cdshealpix.nested.polygon_search(lon, lat, depth=level, flat=True)

target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "healpix", "level": level, "indexing_scheme": "nested"})
    .dggs.assign_latlon_coords()
)
target_grid

In [None]:
input_ds = ux.UxDataset.from_structured(upscaled)
input_ds.to_xarray()

In [None]:
uxgrid = ux.Grid.from_dataset(
    target_grid.pipe(grid_from_xdggs), source_grid_spec="HEALPix"
)
uxgrid.to_xarray()

In [None]:
arr = input_ds["air"]
arr.to_xarray()

In [None]:
regridded = (
    input_ds["air"]
    .isel(time=slice(None, 100))
    .remap.inverse_distance_weighted(uxgrid)
    .rename({"n_face": "cells"})
    .assign_coords(target_grid.coords)
    .to_xarray()
    # data is not contiguous
    .chunk()
    .compute()
)
regridded

In [None]:
regridded.dggs.explore(alpha=0.8)

## curvilinear: `rasm`

In [None]:
ds = xr.tutorial.open_dataset("rasm", chunks={"time": 8}).assign_coords(
    xc=lambda ds: ds["xc"].assign_attrs(standard_name="longitude"),
    yc=lambda ds: ds["yc"].assign_attrs(standard_name="latitude"),
)
ds

In [None]:
def curvilinear_to_grid(ds):
    points = (
        ds.stack(n_node=["y", "x"])
        .drop_indexes(["n_node", "x", "y"])
        .drop_vars(["x", "y"])
        .pipe(lambda ds: ds.merge(ds[["xc", "yc"]].compute()))
    )

    coords = [points["xc"].data, points["yc"].data]
    return ux.UxDataset(
        uxgrid=ux.Grid.from_points(coords),
        data_vars={k: v.variable for k, v in points.data_vars.items()},
        coords=points.drop_vars(["xc", "yc"]).coords,
    )

In [None]:
input_ds = curvilinear_to_grid(ds)
input_ds.to_xarray()

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.NorthPolarStereo()})
ds["Tair"].isel(time=1).plot.pcolormesh(
    x="xc", y="yc", ax=ax, transform=ccrs.PlateCarree()
)

In [None]:
level = 10
lon = astropy.coordinates.Longitude(0, unit="degree")
lat = astropy.coordinates.Latitude(90, unit="degree")
cell_ids, _, _ = cdshealpix.nested.cone_search(
    lon, lat, depth=level, flat=True, radius=(90 - 16.5) << astropy.units.degree
)

target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "healpix", "level": level, "indexing_scheme": "nested"})
    .dggs.assign_latlon_coords()
)
target_grid

In [None]:
uxgrid = ux.Grid.from_dataset(
    target_grid.pipe(grid_from_xdggs), source_grid_spec="HEALPix"
)
uxgrid.to_xarray()

In [None]:
regridded = (
    input_ds["Tair"]
    .isel(time=slice(None, 100))
    .rename({"n_node": "n_face"})
    .remap.nearest_neighbor(uxgrid)
    .rename({"n_face": "cells"})
    .assign_coords(target_grid.coords)
    .to_xarray()
    # data needs to be contiguous
    .chunk()
    .compute()
)
regridded

In [None]:
regridded.dggs.explore(alpha=0.8)