<br>
<a href="https://www.nvidia.com/en-us/training/">
    <div style="width: 55%; background-color: white; margin-top: 50px;">
    <img src="https://dli-lms.s3.amazonaws.com/assets/general/nvidia-logo.png"
         width="400"
         height="186"
         style="margin: 0px -25px -5px; width: 300px"/>
</a>
<h1 style="line-height: 1.4;"><font color="#76b900"><b>Applying AI Weather Models With NVIDIA Earth-2</h1>
<h2><b>Part 3:</b> Downscaling</h2>
<br>

The weather data generated by global AI models like FourCastNet (SFNO) is usually confined to a 0.25° grid, which roughly corresponds to a 25-km resolution at the Tropic Circles. The primary reason for using this grid for global AI weather models is that the vast data archive of ERA5 is readily available at this resolution and can be leveraged for massive data-driven model training. Numerical assimilation and forecast systems often operate at lower resolutions. For example, the global ECMWF HRES forecast is available on a 0.1° grid (approximately 10 km) and a Cubic Octahedral (O1280) grid (approximately 9 km). Regional models, like the Weather Research and Forecasting (WRF) model or the Icosahedral Nonhydrostatic (ICON) model, are often run on resolutions between 1 km and 3 km. Many applications require these kilometer-scale or even sub-kilometer resolutions. 

Training and running high-resolution models at a global scale is resource intensive. Alternatively, we can use statistical downscaling models. These models are trained to go from lower resolutions (like 25 km) to higher resolutions (like 2 km). Statistical downscaling is a common alternative to dynamical downscaling (i.e., running regional numerical models conditioned on global inputs) also for numerical weather prediction. With the newest developments in the AI space, we can now build much more powerful statistical downscaling models.

[CorrDiff](https://arxiv.org/abs/2309.15214) employs a two-step approach to simultaneously map from low- to high-resolution data and synthesize new variables not present in the input. The first step uses UNet regression to predict the conditional mean of the output field. This helps in dealing with the significant distribution shift between inputs and outputs, such as wind speed peaks hiding between grid points. The second step then recovers a physically realistic representation using a diffusion model. Diffusion models are trained to iteratively remove noise from the input and are able to reveal fine details that the regression alone could not capture.

![CorrDiff Taiwan](./images/corrdiff.jpg "CorrDiff Taiwan")

In the third part of our workshop, we will develop a workflow applying CorrDiff, trained over Taiwan, to a forecast produced by FourCastNet (SFNO). 

In [None]:
from dotenv import load_dotenv

_ = load_dotenv()

from collections import OrderedDict

import numpy as np
import torch
import xarray as xr
from earth2studio.utils.coords import map_coords, split_coords
from earth2studio.data import fetch_data, GFS
from earth2studio.lexicon import GFSLexicon
from earth2studio.io import KVBackend
from earth2studio.models.dx import CorrDiffTaiwan
from earth2studio.models.px import SFNO
from earth2studio.utils.time import to_time_array
from tqdm import tqdm

from plot import plot_downscaled_forecast, plot_downscaled_samples, plot_downscaling, plot_pop, plot_pop_t2m
from utils import make_quarter_degree

%load_ext autoreload
%autoreload 2
%matplotlib inline

## Inference

The version of CorrDiff covering Taiwan was developed together with the Taiwan Central Weather Authority (CWA), which also provided the high-resolution regional weather data used for training. The model takes 12 variables on a 25 km grid and produces four variables on a 2 km grid. It is available through Earth2Studio and can be loaded exactly like the models we have dealt with before.

In [None]:
corrdiff = CorrDiffTaiwan.load_model(CorrDiffTaiwan.load_default_package())

This time, we want to develop an inference workflow from scratch instead of using one of the predefined workflows in Earth2Studio. We will couple FourCastNet (SFNO) to CorrDiff-Taiwan in a way that FourCastNet (SFNO) produces a forecast at 25 km and CorrDiff Taiwan downscales the forecast to 2 km. We can load our forecast model as before.

In [None]:
fcn = SFNO.load_model(SFNO.load_default_package())

For accelerated inference, we move our models to the GPU.

In [None]:
device = torch.device("cuda")

fcn = fcn.to(device)
corrdiff = corrdiff.to(device)

We will initialize our forecast from GFS data instead of ERA5 data this time. This will make it easy to switch to a live forecasting setup at a later point in time. Our forecast will cover the period when a heatwave hit Taiwan in July 2024.

In [None]:
gfs = GFS()
start_time = np.datetime64("2024-07-01 12:00:00")

Data retrieval works as before, with the help of `fetch_data`. We pass our data source, the start time, input variables, and lead time (which is 0 hours for the input). We also specify our GPU as the target device so that the data will already be available to our models once we start inference.

In [None]:
input_coords = fcn.input_coords()

x, coords = fetch_data(
    source=gfs,
    time=to_time_array([start_time]),
    variable=input_coords["variable"],
    lead_time=input_coords["lead_time"],
    device=device
)

Next, let’s define how long we want to forecast into the future. We will cover 12 time intervals of 6 hours each. Remember that the pretrained version of FourCastNet (SFNO) produces forecasts in steps of 6 hours.

In [None]:
nsteps = 12
lead_time = np.array([np.timedelta64(6 * i, "h") for i in range(nsteps + 1)])

CorrDiff is a diffusion model, so it can produce a distribution of output scenarios. This allows us to create a high-resolution ensemble from a single low-resolution input and get a probabilistic view on the small-scale weather. We set the number of samples to produce per time step to 4. Note that we will create a deterministic forecast and only produce an ensemble during downscaling. To take things further, we could use ensembles during both forecasting and downscaling.

In [None]:
corrdiff.number_of_samples = 4

Now, we can set up our data store.

In [None]:
io = KVBackend()

We tell the IO backend upfront what kind of data to expect so we can efficiently write to the output file during inference. For the coords, we supply the start times (a single start time in our case), the lead times (1 initial condition plus 12 forecast steps), the number of samples, and the geographic coordinates. Finally, we create one array per output variable.

In [None]:
output_coords = corrdiff.output_coords(corrdiff.input_coords())

In [None]:
io_coords = OrderedDict(
    {
        "time": to_time_array([start_time]),
        "lead_time": lead_time,
        "sample": output_coords["sample"],
        "lat": output_coords["lat"],
        "lon": output_coords["lon"],
    }
)
io.add_array(io_coords, output_coords["variable"])

We are now ready to run our inference workflow. In the code below, we make use of `map_coords` to select the input variables required for FourCastNet (SFNO) and CorrDiff. The forecast is handled by an iterator, and we immediately apply CorrDiff at each time step.

In [None]:
x, coords = map_coords(x, coords, fcn.input_coords())
fc_iterator = fcn.create_iterator(x, coords)

with tqdm(total=nsteps + 1, desc="Running inference") as pbar:
    for step, (x_i, coords_i) in enumerate(fc_iterator):
        x_i, coords_i = map_coords(x_i, coords_i, corrdiff.input_coords())
        x_i, coords_i = corrdiff(x_i, coords_i)
        io.write(*split_coords(x_i, coords_i))
        pbar.update(1)
        if step == nsteps:
            break

hi_res = io.to_xarray()  # load as xarray Dataset

## Analysis

For comparison, we now also retrieve the corresponding low-resolution GFS data. The code should look familiar from the previous part. We convert the output from `fetch_data` (PyTorch tensor and corresponding coords) to an xarray `Dataset` for convenience. The low-resolution data is limited to the window covered by CorrDiff Taiwan.

In [None]:
lo_res = xr.DataArray(
    *fetch_data(
        source=gfs,
        time=to_time_array([start_time]),
        variable=[v for v in output_coords["variable"] if v in GFSLexicon.VOCAB],
        lead_time=hi_res.lead_time.values,
    )
).to_dataset("variable")

# Limit to the window covered by our model
lat_from, lat_to = corrdiff.input_coords()["lat"][[0, -1]]
lon_from, lon_to = corrdiff.input_coords()["lon"][[0, -1]]
lo_res = lo_res.sel(lat=make_quarter_degree(lat_from, lat_to), lon=make_quarter_degree(lon_from, lon_to))

Taiwan is operating several offshore wind farms in the Taiwan Strait, and additionally onshore wind farms throughout the country. We may be interested in the expected energy production from these wind farms, and CorrDiff can help us get a more detailed picture on local wind speeds. The plots below compare GFS in the top row and the downscaled CorrDiff result in the bottom row.

In [None]:
def _get_wind_speed(ds):
    return (ds.u10m ** 2 + ds.v10m ** 2) ** 0.5

plot_downscaling(lo_res.assign(s10m=_get_wind_speed), hi_res.assign(s10m=_get_wind_speed), "s10m", start_time, cb_label="Wind speed (m/s)")

As explained above, CorrDiff can not only generate higher resolution data for input variables but also synthesize new variables. In addition to 2-meter temperature and 10-meter wind speeds, which are present in the input data, CorrDiff Taiwan calculates 1-hour maximum radar reflectivity (`mrr`). Radar reflectivity is an important proxy for rain intensity. The plots below show the results over several timesteps. High humidity before rainfall is what made the heatwave Taiwan experienced in July 2024 especially uncomfortable for people.

In [None]:
plot_downscaled_forecast(hi_res, "mrr", start_time, float(lo_res.lon[len(lo_res.lon) // 2]), cb_label="MRR (dBZ)")

We can compare different ensemble members to distinguish areas of higher uncertainty from areas of lower uncertainty. The overall pattern of small-scale weather looks similar across samples, but each is a separate physically realistic representation.

In [None]:
plot_downscaled_samples(hi_res, "mrr", start_time, float(lo_res.lon[len(lo_res.lon) // 2]), cb_label="MRR (dBZ)")

We will have a closer look at temperature now. To make the results more intuitive, choose whether you want to work with °C or °F.

In [None]:
unit = "°C"

t2m_converters = {
    "°C": lambda ds: ds.t2m - 273.15,
    "°F": lambda ds: (ds.t2m - 273.15) * 9/5 + 32,
    "K": lambda ds: ds
}

lo_res = lo_res.assign(t2m=t2m_converters[unit])
hi_res = hi_res.assign(t2m=t2m_converters[unit])

The plots below again compare GFS with CorrDiff results, this time for 2-meter temperature. Reported temperatures reached up to 38°C (100°F) in the lower regions of the country. On the other hand, the temperature drop high up in the mountainous regions of Taiwan becomes clearly visible after downscaling.

In [None]:
plot_downscaling(lo_res, hi_res, "t2m", start_time, normalize=True, cb_label=f"Temperature [{unit}]")

As a first indicator of energy consumption through, e.g., air conditioning, we will now have a look at the temperature weighted by the regional population. We have prepared weights based on recent census results that match the resolution of GFS and CorrDiff Taiwan, respectively.

In [None]:
pop_lo = np.load("./data/pop_tw_lo.npy")
msk_lo = (pop_lo > 0).astype(np.float32)
msk_lo /= msk_lo.sum()

pop_hi = np.load("./data/pop_tw_hi.npy")
msk_hi = (pop_hi > 0).astype(np.float32)
msk_hi /= msk_hi.sum()

plot_pop(pop_lo, msk_lo, pop_hi, msk_hi, lo_res.lon, lo_res.lat, hi_res.lon, hi_res.lat, float(lo_res.lon[len(lo_res.lon) // 2]))

Let's multiply the weights by the temperature values and calculate the expected temperature experienced by the population of Taiwan.

In [None]:
pop_t2m_gfs = (lo_res.t2m * pop_lo).sum(dim="lat").sum(dim="lon")
msk_t2m_gfs = (lo_res.t2m * msk_lo).sum(dim="lat").sum(dim="lon")

pop_t2m_corrdiff = (hi_res.t2m * pop_hi).sum(dim="ilat").sum(dim="ilon")
msk_t2m_corrdiff = (hi_res.t2m * msk_hi).sum(dim="ilat").sum(dim="ilon")

The plot below shows the results over the course of the fourcast. The diurnal cycle with higher temperatures during the day is clearly visible. Still, temperatures stayed high at night, which must have caused considerable heat stress. The two lines corresponding to the masks not taking population into account run lower than the lines based on population. This is because, without weighting by population, the cooler mountain regions of Taiwan, which are not inhabited as much as the coastal regions, have a higher influence on the result. The result produced by CorrDiff temperatures is also slightly lower compared to GFS temperatures.

In [None]:
plot_pop_t2m(pop_t2m_gfs, msk_t2m_gfs, pop_t2m_corrdiff, msk_t2m_corrdiff, ylabel=f"Temperature [{unit}]")

We hope you have enjoyed our workshop on NVIDIA Earth-2. You are now prepared to build your own AI weather applications with [Earth2Studio](https://github.com/NVIDIA/earth2studio). You can find more hands-on examples in the [User Guide](https://nvidia.github.io/earth2studio/userguide/index.html). For training your own AI weather models, check out the [examples in the NVIDIA Modulus](https://github.com/NVIDIA/modulus/tree/main/examples) repository.