# Example: Predictions for HRES at 0.1 deg

In this example, we will download HRES data for 11 May 2022 from the [Research Data Archive](https://rda.ucar.edu/datasets/d113001/#) at 0.1 degree resolution and run Aurora on this data. We will use the version of Aurora that was fine-tuned on IFS HRES 0.1 degree.

Running this notebook requires additional Python packages. You can install these as follows:

```
pip install cdsapi xarray zarr netcdf4 matplotlib       #TODO: Maybe remove cdsapi later.
```

Install cfgrib using conda-forge is easiest:
```
conda install -c conda-forge cfgrib
```

## Troubleshooting
If you get an error regarding `cfgrib`, then make sure you have the `eccodes` library installed correctly:
```
apt-get install libeccodes-tools
```


## Downloading the Data

To start, we download the data from [Research Data Archive](https://rda.ucar.edu/datasets/d113001/#).

In [None]:
from pathlib import Path

import xarray as xr

from aurora.util import download_hres_rda_atmos, download_hres_rda_surf

# Data will be downloaded here.
download_path = Path("~/downloads/hres_0.1")

download_path = download_path.expanduser()
download_path.mkdir(parents=True, exist_ok=True)

# Day to download. This will download all times for that day.
day = "11"
month = "05"
year = "2022"

# Each variable has a number associated with it. This is the number that
# will be used in the RDA request.
var_nums = {
    "msl": "151",  # Mean sea level pressure
    "10u": "165",  # 10m u-component of wind
    "10v": "166",  # 10m v-component of wind
    "2t": "167",  # 2m temperature
    "z": "129",  # Geopotential
    "t": "130",  # Temperature
    "u": "131",  # u-component of wind (atmos)
    "v": "132",  # v-component of wind (atmos)
    "q": "133",  # Specific humidity (atmos)
}

surface_vars = ["msl", "10u", "10v", "2t"]
atmos_vars = ["z", "t", "u", "v", "q"]
# Download surface variables. We write the downloaded data to cache.
for variable in surface_vars:
    if not (download_path / f"{variable}_{year}_{month}_{day}.grb").exists():
        download_hres_rda_surf(
            save_dir=download_path,
            year=year,
            month=month,
            day=day,
            variable=variable,
            var_dict=var_nums,
        )
    else:
        print(f"{variable} already downloaded")

# Download atmospheric variables. We write the downloaded data to cache.
# Each variable has 4 times per day, each of which is a separate file.
# This will take a few minutes.
for variable in atmos_vars:
    if not (download_path / f"{variable}_{year}_{month}_{day}_00.grb").exists():
        download_hres_rda_atmos(
            save_dir=download_path,
            year=year,
            month=month,
            day=day,
            variable=variable,
            var_dict=var_nums,
            timeofday="00",
        )
    else:
        print(f"{variable} at time 00 already downloaded")
    if not (download_path / f"{variable}_{year}_{month}_{day}_06.grb").exists():
        download_hres_rda_atmos(
            save_dir=download_path,
            year=year,
            month=month,
            day=day,
            variable=variable,
            var_dict=var_nums,
            timeofday="06",
        )
    else:
        print(f"{variable} at time 06 already downloaded")
    if not (download_path / f"{variable}_{year}_{month}_{day}_12.grb").exists():
        download_hres_rda_atmos(
            save_dir=download_path,
            year=year,
            month=month,
            day=day,
            variable=variable,
            var_dict=var_nums,
            timeofday="12",
        )
    else:
        print(f"{variable} at time 12 already downloaded")
    if not (download_path / f"{variable}_{year}_{month}_{day}_18.grb").exists():
        download_hres_rda_atmos(
            save_dir=download_path,
            year=year,
            month=month,
            day=day,
            variable=variable,
            var_dict=var_nums,
            timeofday="18",
        )
    else:
        print(f"{variable} at time 18 already downloaded")

In [None]:
import sys

# Combine the surface data into a single dataset.
if not (download_path / f"{year}_{month}_{day}-surface-level-0.1deg.nc").exists():
    msl = xr.open_dataset(download_path / f"msl_{year}_{month}_{day}.grb", engine="cfgrib")
    u10 = xr.open_dataset(download_path / f"10u_{year}_{month}_{day}.grb", engine="cfgrib")
    v10 = xr.open_dataset(download_path / f"10v_{year}_{month}_{day}.grb", engine="cfgrib")
    t2m = xr.open_dataset(download_path / f"2t_{year}_{month}_{day}.grb", engine="cfgrib")
    ds_surf = xr.merge([msl, u10, v10, t2m])
    ds_surf.to_netcdf(download_path / f"{year}_{month}_{day}-surface-level-0.1deg.nc")

# Combine the atmospheric data into a single dataset. This will take a few minutes.
# You need TODO: x GB of free space to store the data.
if not (download_path / f"{year}_{month}_{day}-atmospheric-0.1deg.nc").exists():
    q_00 = xr.open_dataset(download_path / f"q_{year}_{month}_{day}_00.grb", engine="cfgrib")
    q_06 = xr.open_dataset(download_path / f"q_{year}_{month}_{day}_06.grb", engine="cfgrib")
    q_12 = xr.open_dataset(download_path / f"q_{year}_{month}_{day}_12.grb", engine="cfgrib")
    q_18 = xr.open_dataset(download_path / f"q_{year}_{month}_{day}_18.grb", engine="cfgrib")

    t_00 = xr.open_dataset(download_path / f"t_{year}_{month}_{day}_00.grb", engine="cfgrib")
    t_06 = xr.open_dataset(download_path / f"t_{year}_{month}_{day}_06.grb", engine="cfgrib")
    t_12 = xr.open_dataset(download_path / f"t_{year}_{month}_{day}_12.grb", engine="cfgrib")
    t_18 = xr.open_dataset(download_path / f"t_{year}_{month}_{day}_18.grb", engine="cfgrib")

    u_00 = xr.open_dataset(download_path / f"u_{year}_{month}_{day}_00.grb", engine="cfgrib")
    u_06 = xr.open_dataset(download_path / f"u_{year}_{month}_{day}_06.grb", engine="cfgrib")
    u_12 = xr.open_dataset(download_path / f"u_{year}_{month}_{day}_12.grb", engine="cfgrib")
    u_18 = xr.open_dataset(download_path / f"u_{year}_{month}_{day}_18.grb", engine="cfgrib")

    v_00 = xr.open_dataset(download_path / f"v_{year}_{month}_{day}_00.grb", engine="cfgrib")
    v_06 = xr.open_dataset(download_path / f"v_{year}_{month}_{day}_06.grb", engine="cfgrib")
    v_12 = xr.open_dataset(download_path / f"v_{year}_{month}_{day}_12.grb", engine="cfgrib")
    v_18 = xr.open_dataset(download_path / f"v_{year}_{month}_{day}_18.grb", engine="cfgrib")

    z_00 = xr.open_dataset(download_path / f"z_{year}_{month}_{day}_00.grb", engine="cfgrib")
    z_06 = xr.open_dataset(download_path / f"z_{year}_{month}_{day}_06.grb", engine="cfgrib")
    z_12 = xr.open_dataset(download_path / f"z_{year}_{month}_{day}_12.grb", engine="cfgrib")
    z_18 = xr.open_dataset(download_path / f"z_{year}_{month}_{day}_18.grb", engine="cfgrib")

    ds_atmos = xr.concat(
        [
            q_00,
            q_06,
            q_12,
            q_18,
            t_00,
            t_06,
            t_12,
            t_18,
            u_00,
            u_06,
            u_12,
            u_18,
            v_00,
            v_06,
            v_12,
            v_18,
            z_00,
            z_06,
            z_12,
            z_18,
        ],
        dim="time",
    )
    ds_atmos.to_netcdf(download_path / f"{year}_{month}_{day}-atmospheric-0.1deg.nc")

ds_surf = xr.open_dataset(
    download_path / f"{year}_{month}_{day}-surface-level-0.1deg.nc", engine="netcdf4"
)
ds_atmos = xr.open_dataset(
    download_path / f"{year}_{month}_{day}-atmospheric-0.1deg.nc", engine="netcdf4"
)

print(
    f"Total size of data for 11 May 2022 is {(sys.getsizeof(ds_surf) +
                                              sys.getsizeof(ds_atmos))/1e9} GB!"
)

## Downloading Static Variables from ERA5 Data

The static variables are not available from the Research Data Archive, so we need to download them from ERA5, just like we did [in the example for ERA5](example_era5.ipynb#downloading-the-data) and [the example for HRES T0](example_hres_t0.ipynb#downloading-the-data)
To do so, register an account with the [Climate Data Store](https://cds.climate.copernicus.eu/) and create `$HOME/.cdsapirc` with the following content:

```
url: https://cds.climate.copernicus.eu/api/v2
key: <UID>:<API key>
```

You can find your UID and API key on your account page.

In [None]:
from pathlib import Path

import cdsapi

# Data will be downloaded here.
download_path = Path("~/downloads/hres_0.1")

c = cdsapi.Client()

download_path = download_path.expanduser()
download_path.mkdir(parents=True, exist_ok=True)

# Download the static variables.
if not (download_path / "static.nc").exists():
    c.retrieve(
        "reanalysis-era5-single-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "geopotential",
                "land_sea_mask",
                "soil_type",
            ],
            "year": "2023",
            "month": "01",
            "day": "01",
            "time": "00:00",
            "format": "netcdf",
        },
        str(download_path / "static.nc"),
    )
print("Static variables downloaded!")

## Preparing a Batch

We convert the downloaded data to an `aurora.Batch`, which is what the model requires.

In [None]:
import torch
import xarray as xr

from aurora import Batch, Metadata

static_vars_ds = xr.open_dataset(download_path / "static.nc", engine="netcdf4")
surf_vars_ds = xr.open_dataset(
    download_path / f"{year}_{month}_{day}-surface-level-0.1deg.nc", engine="netcdf4"
)
atmos_vars_ds = xr.open_dataset(
    download_path / f"{year}_{month}_{day}-atmospheric-0.1deg.nc", engine="netcdf4"
)

i = 1  # Select this time index in the downloaded data.


batch = Batch(
    surf_vars={
        "2t": surf_vars_ds["t2m"].values,
        "10u": surf_vars_ds["u10"].values,
        "10v": surf_vars_ds["v10"].values,
        "msl": surf_vars_ds["msl"].values,
    },
    static_vars={
        # The static variables are constant, so we just get them for the first time.
        "z": torch.from_numpy(static_vars_ds["z"].values[0]),
        "slt": torch.from_numpy(static_vars_ds["slt"].values[0]),
        "lsm": torch.from_numpy(static_vars_ds["lsm"].values[0]),
    },
    atmos_vars={
        "t": atmos_vars_ds["t"].values,
        "u": atmos_vars_ds["u"].values,
        "v": atmos_vars_ds["v"].values,
        "q": atmos_vars_ds["q"].values,
        "z": atmos_vars_ds["z"].values,
    },
    metadata=Metadata(
        lat=torch.from_numpy(surf_vars_ds.latitude.values.copy()),
        lon=torch.from_numpy(surf_vars_ds.longitude.values),
        # Converting to `datetime64[s]` ensures that the output of `tolist()` gives
        # `datetime.datetime`s. Note that this needs to be a tuple of length one:
        # one value for every batch element.
        time=(surf_vars_ds.time.values.astype("datetime64[s]").tolist()[i],),
        atmos_levels=tuple(int(level) for level in atmos_vars_ds.level.values),
    ),
)

## Loading and Running the Model

Finally, we are ready to load and run the model and visualise the predictions. We perform a roll-out for two steps, which produces predictions for hours 12:00 and 18:00.

In [None]:
from aurora import Aurora, rollout

model = Aurora()
model.load_checkpoint("wbruinsma/aurora", "aurora-0.1-finetuned.ckpt")

model.eval()
model = model.to("cuda")

with torch.inference_mode():
    preds = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]

model = model.to("cpu")

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(2, 2, figsize=(12, 6.5))

for i in range(ax.shape[0]):
    pred = preds[i]

    ax[i, 0].imshow(pred.surf_vars["2t"][0, 0].numpy() - 273.15, vmin=-50, vmax=50)
    ax[i, 0].set_ylabel(str(pred.metadata.time[0]))
    if i == 0:
        ax[i, 0].set_title("Aurora Prediction")
    ax[i, 0].set_xticks([])
    ax[i, 0].set_yticks([])

    ref = surf_vars_ds["2m_temperature"][2 + i].values[::-1, :]
    ax[i, 1].imshow(ref - 273.15, vmin=-50, vmax=50)
    if i == 0:
        ax[i, 1].set_title("HRES T0")
    ax[i, 1].set_xticks([])
    ax[i, 1].set_yticks([])

plt.tight_layout()