## yt_xarray

linking yt & xarray

* https://github.com/data-exp-lab/yt_xarray/
* https://yt-xarray.readthedocs.io/en/latest/

this presentation: https://github.com/chrishavlin/yt_xarray_walkthrough_dxl 
built with: https://github.com/deathbeds/jupyterlab-deck

## xarray

Multidimensional array IO:

* self-describing data formats (netcdf, ...)

* arbitrary dimension names

* distributed support (chunks to files): 
    * dask arrays 
    * zarr arrays

Load in a [GEOS](https://gmao.gsfc.nasa.gov/GEOS_systems/) dataset (~2 GB, NASA Global Modeling and Assimilation Office):

In [None]:
import xarray as xr 
import os 

fname_geos = os.path.expanduser("~/hdd/data/yt_data/yt_sample_sets/geos/GEOS.fp.asm.inst3_3d_aer_Nv.20180822_0900.V01.nc4")
ds = xr.open_dataset(fname_geos)
ds

data variable access:

In [None]:
ds.data_vars["AIRDENS"]

extract ordered dimension names:

In [None]:
ds.AIRDENS.dims

## Data selection with xarray 

### np-style array access and slicing

In [None]:
ds.AIRDENS[0, 0, :, :]

extracting raw np arrays:

In [None]:
ds.AIRDENS[0,0,:,:].values

need to remember axis ordering!

### selection by coordinate **name**

by index (`isel`):

In [None]:
ds.AIRDENS.isel(time=0, lev=1, lat=4, lon=3)

by **exact** value (`sel`):

In [None]:
ds.AIRDENS.sel(lev=2.0, lat=-89.0)

with some fuzziness: 

In [None]:
ds.AIRDENS.sel(lev=2.0, lat=-89.013, method="nearest")

finally, with dictionary:

In [None]:
ds.AIRDENS.sel({"lev":2.0, "lat":-89.0})  # important for yt_xarray!

## xarray & dask 

In [None]:
ds.close()
del ds

Start dask client

In [None]:
from dask.distributed import Client
c = Client(n_workers=os.cpu_count()-2, threads_per_worker=1)

Test data set ([generated from here](https://github.com/chrishavlin/yt-xarray-dask-sandbox/blob/main/example.ipynb)):
* random field data 
* 1000 chunks
* 1 chunk = 1 .nc file

In [None]:
data_dir = os.path.expanduser("~/hdd/data/yt_data/yt_sample_sets/yt_xarray_test_data/dask_mf/data")
dask_test_ds = os.path.join(data_dir, "*.nc")
ds = xr.open_mfdataset(dask_test_ds)
ds

In [None]:
ds.temperature

* **Coordinates** are in memory and over all chunks!
* **Data variables** are dask arrays

Returning in-memory values:

In [None]:
ds.temperature.mean()

In [None]:
ds.temperature.mean().values  # equivalent to .compute()

In [None]:
ds.temperature.mean().load()  # to preserve xarray-ness

**selections are also delayed (important for yt_xarray!):**

In [None]:
vals = ds.temperature.isel(z=range(10)).sel(x=1, y=2, method="nearest")
vals

In [None]:
vals.load()

## what about yt?



previously:

1. load in arrays
2. use yt generic data loader (`yt.load_uniform_grid(...)`)


**yt_xarray** v0.1.1: yt datasets from xarray datasets

automate (as much as possible) 1 & 2 !

## **yt_xarray** usage overview

yt_xarray provides a `yt` "accessor object":

In [None]:
import yt_xarray

In [None]:
ds.yt

In [None]:
ds.yt.

### Loading all data (not always possible):

In [None]:
ds_yt = ds.yt.load_grid(length_unit="km")

In [None]:
ds_yt.field_list

In [None]:
import yt
yt.SlicePlot(ds_yt, "x", ("stream", "gauss"))

### not always so easy...

[**wrf**](https://www.mmm.ucar.edu/models/wrf): "weather research and forecasting model" 

cf (Climate and Forecast) compliance of netcdf files: https://cfconventions.org/

wrf is not...

In [None]:
ds = yt_xarray.open_dataset('wrf/wrfout_d03_2016-06-01.nc')  # checks yt paths

In [None]:
import xwrf  

In [None]:
ds_x = ds.xwrf.postprocess() # make it cf-compliant-ish
ds_x

1. different dimensionality of fields (including time)
2. yt has strict coordinate names (latitude, longitude, altitude), (x, y, z), (r, theta, phi), etc.

### choose a subset of fields

In [None]:
ds_x.yt.load_grid()
# ds_yt = ds_x.yt.load_grid(
#     fields=('geopotential', 'geopotential_height')
# )

### choose a time to load

In [None]:
ds_yt = ds_x.yt.load_grid(
    fields=('geopotential', 'geopotential_height'),                      
    sel_dict={'Time':0})

### COORDINATE ALIASING

In [None]:
yt_xarray.known_coord_aliases

In [None]:
yt_xarray.known_coord_aliases["z_stag"] = "z"

In [None]:
ds_yt = ds_x.yt.load_grid(fields=('geopotential', 'geopotential_height'),
                          sel_dict={'Time':0},
                          length_unit='m',
                          use_callable=False)

separate problem with the 3d data (bug: interpolation going wrong)... so:

In [None]:
ds_yt = ds_x.yt.load_grid(fields=('geopotential', 'geopotential_height'),
                          sel_dict={'Time':0, 'z_stag':4},
                          length_unit='m')   

finally ... 

In [None]:
slc = yt.SlicePlot(ds_yt, "z", ("stream", "geopotential_height"))
slc.set_log("all", False)

**Note**: need to use yt coordinate names for yt functions

**What is [geopotential height](https://legacy.climate.ncsu.edu/images/climate/enso/geo_heights.php)?**: 

* cold air denser than warm air 
* pressure in the atmo from overlying air

geopotential height = the altitude to get to a particular pressure


### yt_xarray chunking

create a test dataset with a dask array:

In [None]:
ds = xr.open_mfdataset(dask_test_ds)
ds

In [None]:
import yt_xarray
yt_ds = ds.yt.load_grid(fields=("gauss",), length_unit='m', chunksizes=51)

In [None]:
# index

In [None]:
import yt
slc = yt.SlicePlot(yt_ds, "z", ("stream", "gauss"))
slc.annotate_grids()
slc.show()

**each yt grid = dask chunk = on disk .nc file**

but chunk alignemnt not gauranteed... 

In [None]:
import yt_xarray
yt_ds = ds.yt.load_grid(fields=("gauss",), length_unit='m', chunksizes=102)
slc = yt.SlicePlot(yt_ds, "z", ("stream", "gauss"))
slc.annotate_grids()
slc.show()

possible feature? 

but working to auto-align... 

([ChunkWalker prototype](https://github.com/chrishavlin/yt-xarray-dask-sandbox/blob/main/daxryt/chunk_inspector.py)), `._recursive_chonker` walks dask-xr chunks:


In [None]:
ds.gauss.chunksizes

only have the chunk sizes in each dimension, yt needs the physical left, right edges, potentially cell widths for each chunk.

# yt_xarray code tour 

loads data as yt Stream frontend via `load_amr_grids`:


```python
yt.load_amr_grids(
        grid_data,  # the data OR FUNCTION for the grid(s)
        data_shp,   # global grid shape, (Nx, Ny, Nz)
        geometry=geom,  # e.g., ('cartesian', ('x', 'z', 'y'))
        bbox=bbox,  # the bounding box
        length_unit=length_unit,  
        **kwargs,
    )
```    

Form of `grid_data` depends on:

* grid type (uniform, stretched)
* memory management: delayed reads (`use_callable`) vs in memory
* chunking

We'll look at:

* `yt_xarray.accessor.accessor.YtAccessor` : the top level accessor object
* `yt_xarray.accessor._xr_to_yt.Selection` : yt-xr translation, mapping of selections
* `yt_xarray.accessor._readers._get_xarray_reader`: building a function to load the data when needed
