# [Experimental] Demonstrate accessing UM DYAMOND3 simulations from zarr on JASMIN object store

### 11/4/24

* **Second look at simulations for WCRP Hackathon UK Node. These should still be considered experimental/development**
* **Can only retrieve zoom level 10 (n2560/regional) or 9 (n1280)**
* **Stores are not complete: some stores only have first 12 h, glm.n2560_RAL3p3 has most data but processing issues mean time data not contiguous**
* Shows the hierarchy of simulations that will be available at the UK node, from global to regional.
* You can see the URLs which are active in the `cat` catalog.
* Contact mark.muetzelfeldt@reading.ac.uk for more info.

## Simulations

* glm: global model. n1280 is approx. 10 km res (stored at zoom 9), n2560 is 5 km (zoom 10). Regional simulations are at 4.4 km (zoom 10).
* Regional: Africa, South East Asia, South America, Cyclic Tropical Channel
* Settings:
    * CoMA9: CoMorph global,
    * RAL3: Regional Atmosphere Land 3
    * GAL9: Global Atmosphere Land 9
    * RAL3p3: RAL3.3
    * CoMA9_TBv1: CoMA9 TrailBlazer v1

## Technical

* All data stored as healpix, including regional. Regional simulations only store active chunks.
* There are two stores for each zoom level, one for `PT1H` (2D) and `PT3H` (3D) variables. All simulations are in the `sims` variable.
* Calling `ds = ds.compute()` downloads the data from JASMIN. This can be slow and/or fail with a server error. Try again if this happens.
* Can be run on JASMIN or anywhere else: call `Catalog(on_jasmin=True)` for JASMIN
* Tested using this Python conda env: https://github.com/digital-earths-global-hackathon/tools/blob/main/python_envs/environment.yaml (with some extra packages).
    * You can install with:
    * `wget https://raw.githubusercontent.com/digital-earths-global-hackathon/tools/refs/heads/main/python_envs/environment.yaml`
    * <edit last line of environment.yaml to be the name of your new env, e.g. hackathon_env>
    * `conda env create -f environment.yaml`
* Not all variables in the standard protocol are present - I have included those that are.
* I believe there is a plotting issue at lon=0 - and that data is OK.

## Issues

* CTC simulations where I think there is a genuine issue at lon=0
* No data at zooms 9-0 (n2560/regional) or 8-0 (n1280), although empty zarr stores are present
* No data for most times
* No zarr store for glm.n1280_CoMA9


In [None]:
import math as maths

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

import easygems.healpix as egh

In [None]:
sims = [
    # 'glm.n1280_CoMA9', # currently has no zarr store.
    'glm.n2560_RAL3p3',
    'glm.n1280_GAL9_nest',
    'SAmer_km4p4_RAL3P3.n1280_GAL9_nest',
    'Africa_km4p4_RAL3P3.n1280_GAL9_nest',
    'SEA_km4p4_RAL3P3.n1280_GAL9_nest',
    'SAmer_km4p4_CoMA9_TBv1.n1280_GAL9_nest',
    'Africa_km4p4_CoMA9_TBv1.n1280_GAL9_nest',
    'SEA_km4p4_CoMA9_TBv1.n1280_GAL9_nest',
    'CTC_km4p4_RAL3P3.n1280_GAL9_nest',
    'CTC_km4p4_CoMA9_TBv1.n1280_GAL9_nest'
]

In [None]:
class Item:
    def __init__(self, url):
        self.url = url

    def to_dataset(self):
        return xr.open_zarr(self.url)

class Catalog:
    """Really simple Catalog. Checks that args are valid/allowed."""
    def __init__(self, on_jasmin=False):
        if on_jasmin:
            self.url_tpl = 'http://hackathon-o.s3.jc.rl.ac.uk/sim-data/dev/{sim}/v4/data.healpix.{freq}.z{zoom}.zarr'
        else:
            self.url_tpl = 'https://hackathon-o.s3-ext.jc.rl.ac.uk/sim-data/dev/{sim}/v4/data.healpix.{freq}.z{zoom}.zarr'
        
    allowed_params = {
        'sim': sims,
        'freq': ['PT1H', 'PT3H'],
        'zoom': list(range(11)),
    }
    def __call__(self, **kwargs):
        if 'glm.n1280' in kwargs['sim']:
            if kwargs['zoom'] == 10:
                raise Exception('n1280 has no zoom=10')
            elif kwargs['zoom'] <= 8:
                print(f'WARNING: no data for zoom={kwargs["zoom"]}')
        else:
            if kwargs['zoom'] <= 9:
                print(f'WARNING: no data for zoom={kwargs["zoom"]}')
            
        for k, v in kwargs.items():
            if k not in self.allowed_params:
                raise Exception(f'Unknown param: {k}, must be one of {self.allowed_params.keys()}')
            if v not in self.allowed_params[k]:
                raise Exception(f'Unallowed param value: {v}, must be one of {self.allowed_params[k]}')
            
        return Item(self.url_tpl.format(**kwargs))

In [None]:
cat = Catalog()
# Show example URL
cat(sim='glm.n1280_GAL9_nest', freq='PT1H', zoom=9).url

In [None]:
# Open a dataset.
ds = cat(sim='glm.n1280_GAL9_nest', freq='PT1H', zoom=9).to_dataset()

In [None]:
# Explore dataset. No data downloaded at this point, only metadata.
ds

In [None]:
# Quick plot of global T at 1.5m
egh.healpix_show(ds.isel(time=0).tas)

In [None]:
def plot_all_fields(ds_plot):
    """Plot all fields for a given dataset. Assumes that each field is 2D - i.e. sel(time=..., [pressure=...]) has been applied"""
    zoom = int(np.log2(ds_plot.crs.attrs['healpix_nside']))
    projection = ccrs.Robinson(central_longitude=0)
    rows = maths.ceil(len(ds_plot.data_vars) / 4)
    fig, axes = plt.subplots(rows, 4, figsize=(30, rows * 20 / 6), subplot_kw={'projection': projection}, layout='constrained')
    if 'pressure' in ds_plot.coords:
        plt.suptitle(f'{ds.simulation} z{zoom} @{float(ds_plot.pressure)}hPa')
    else:
        plt.suptitle(f'{ds.simulation} z{zoom}')
            
    for ax, (name, da) in zip(axes.flatten(), ds_plot.data_vars.items()):
        time = pd.Timestamp(ds.time.values[0])
    
        if abs(da.max() + da.min()) / (da.max() - da.min()) < 0.5:
            # data looks like it needs a diverging cmap.
            # figure out some nice bounds.
            pl, pu = np.percentile(da.values[~np.isnan(da.values)], [2, 98])
            vmax = np.abs([pl, pu]).max()
            kwargs = dict(
                cmap='bwr',
                vmin=-vmax,
                vmax=vmax,
            )
        else:
            kwargs = {}
        ax.set_title(f'time: {time} - {name}')
        ax.set_global()
        im = egh.healpix_show(da, ax=ax, **kwargs);
        long_name = da.long_name
            
        plt.colorbar(im, label=f'{long_name} ({da.attrs.get("units", "-")})')
        ax.coastlines()

In [None]:
# Download the requested data for plotting.
ds3d = cat(sim='glm.n1280_GAL9_nest', freq='PT3H', zoom=9).to_dataset().sel(time=pd.Timestamp('2020-01-20 03:00'), pressure=500).compute()

In [None]:
plot_all_fields(ds3d)

In [None]:
ds2d = cat(sim='glm.n1280_GAL9_nest', freq='PT1H', zoom=9).to_dataset().sel(time=pd.Timestamp('2020-01-20 03:00')).compute()

In [None]:
plot_all_fields(ds2d)

In [None]:
ds_africa = cat(sim='Africa_km4p4_CoMA9_TBv1.n1280_GAL9_nest', freq='PT1H', zoom=10).to_dataset().sel(time=pd.Timestamp('2020-01-20 03:00')).compute()

In [None]:
plot_all_fields(ds_africa)

In [None]:
# Get dataset for all available sims.
dss = {}
for sim in sims:
    zoom = 9 if 'glm.n1280' in sim else 10
    #zoom = 10
    try:
        dss[sim] = cat(sim=sim, freq='PT1H', zoom=zoom).to_dataset()
    except:
        print(f'Could not load {sim}')

In [None]:
# Sort to nicer order for plotting.
def sorter(sim):
    if 'glm.n2560' in sim:
        return 'A'
    elif 'glm.n1280' in sim:
        return 'AA'
    else:
        return sim

dss = {s: dss[s] for s in sorted(dss.keys(), key=sorter)}

In [None]:
def plot_var(plot_dss, var, time, **plot_kwargs):
    """Plot given var from each dataset."""
    rows = maths.ceil(len(plot_dss) / 3)
    projection = ccrs.Robinson(central_longitude=0)
    fig, axes = plt.subplots(rows, 3, figsize=(30, 5 * rows), subplot_kw={'projection': projection}, layout='constrained')
            
    for ax, (name, ds) in zip(axes.flatten(), plot_dss.items()):
        time = pd.Timestamp(ds.time.values[0])
        da = ds[var].sel(time=time).compute()
    
        if abs(da.max() + da.min()) / (da.max() - da.min()) < 0.5:
            # data looks like it needs a diverging cmap.
            # figure out some nice bounds.
            pl, pu = np.percentile(da.values[~np.isnan(da.values)], [2, 98])
            vmax = np.abs([pl, pu]).max()
            kwargs = dict(
                cmap='bwr',
                vmin=-vmax,
                vmax=vmax,
            )
        else:
            kwargs = {}
        kwargs.update(plot_kwargs)
        ax.set_title(f'time: {time} - {name}')
        ax.set_global()
        if ds.attrs['regional']:
            # Display the active chunks for any regional data.
            ds_ones = xr.Dataset({'ones': (['cell'], np.ones_like(ds.isel(time=0).tas))}, coords={'cell': ds.cell}).assign_coords(crs=ds.crs)
            egh.healpix_show(ds_ones.ones, ax=ax)
        im = egh.healpix_show(da, ax=ax, **kwargs);
        long_name = da.long_name
            
        plt.colorbar(im, label=f'{long_name} ({da.attrs.get("units", "-")})')
        ax.coastlines()

In [None]:
# Display tas/air_temperature for all available sims.
# For regional data, this also shows the active chunks (purply jagged outline). Only active chunks are saved to minimize memory reqs on host computer when loading data.
plot_var(dss, 'tas', pd.Timestamp('2020-01-20 10:00'), vmin=215, vmax=310)