# Introduction to Earth2Studio - Continued

In this notebook, we will look at an example on how to run a deterministic inference workflow that couples a prognostic model with a diagnostic model. After which, we will look at how we run an Ensemble Inference with a prognostic model with pertubation.

#### Contents of the Notebook

- [Running Diagnostic and Ensemble Inference with Earth2Studio](#Running-Diagnostic-and-Ensemble-Inference-with-Earth2Studio)
    - [Running Diagnostic Inference](#Running-Diagnostic-Inference)
    - [Execute the Workflow](#Execute-the-Workflow)
    - [Post Processing](#Post-Processing)
- [Important: Free up GPU Memory!](#Important:-Free-up-GPU-Memory!)
- [Running Ensemble Inference](#Running-Ensemble-Inference)
    - [Set Up](#Set-Up)
    - [Execute the Workflow](#Execute-the-Workflow)
    - [Post Processing](#Post-Processing)
    
#### Learning Outcomes

- Select a perturbation method
- Running the built in diagnostic workflow
- Running a simple built in workflow for ensembling
- Post-processing results


## Running Diagnostic Inference

This diagnostic model will predict a new
atmospheric quantity from the predicted fields of the prognostic.
<center><img src="images/diagnostic.png" alt="Drawing" style="center"/></center>

### **Set Up**
For this example, we will use the built in diagnostic workflow `earth2studio.run.diagnostic` method. 


In the last notebook we looked at the list of models, datasources and IO Backend. Let us look at the list of Diagnostic Models available in Earth2Studio. 

**Diagnostic models**: 
Diagnostic models are a class of models that do not perform time-integration. These may be used to map between weather/climate variables to other quantities of interest, used to enbable additional analysis, improve prediction accuracy, downscale, etc.

The list of Diagnostic Models available as of `0.7.0` are:

- **models.dx.CorrDiffTaiwan** : CorrDiff is a Corrector Diffusion model that learns mappings between low- and high-resolution weather data with high fidelity.
- **models.dx.ClimateNet** : Climate Net diagnostic model, built into Earth2Studio.
- **models.dx.DerivedRH** : Calculates the relative humidity (RH) from specific humidity and temperature for specified pressure levels.
- **models.dx.DerivedRHDewpoint** : Calculates the surface relative humidity (RH) from dewpoint temperature and air temperature.
- **models.dx.DerivedVPD** : Calculates the Vapor Pressure Deficit (VPD) in hPa from relative humidity and temperature fields.
- **models.dx.DerivedWS** : Calculates the Wind Speed (WS) magnitude from eastward and northward wind components for specified levels.
- **models.dx.PrecipitationAFNO** : Precipitation AFNO diagnsotic model.
- **models.dx.PrecipitationAFNOv2** : Improved Precipitation AFNO diagnostic model.
- **models.dx.TCTrackerWuDuan** : Finds a list of tropical cyclone (TC) centers using an adaption of the method described in the conditions in Wu and Duan 2023.
- **models.dx.TCTrackerVitart** : Finds a list of tropical cyclone centers using the conditions in Vitart 1997
- **models.dx.WindgustAFNO** : Wind gust AFNO diagnsotic model.
- **models.dx.Identity** :  Identity diagnostic that is coordinate insensitive.

For this example, we will be using the following:

- **Prognostic Model**: Use the built in FourCastNet Model :py:class:`earth2studio.models.px.FCN`.
- **Diagnostic Model**: Use the built in precipitation AFNO model :py:class:`earth2studio.models.dx.PrecipitationAFNO`.
- **Datasource**: Pull data from the GFS data api :py:class:`earth2studio.data.GFS`.
- **IO Backend**: Save the outputs into a Zarr store :py:class:`earth2studio.io.ZarrBackend`.

#### Precipitation AFNO Model: 

The Precipitation AFNO is FourCastNet diagnostic model which predicts total precipitation from 20 atmospheric variables. The total precipitation, sourced from the ERA5 re-analysis dataset, represents the accumulated liquid and frozen water that falls to the Earth’s surface through rainfall and snow. It is defined in units of length as the depth of water that would accumulate if spread evenly over a unit grid box of the model. Here is a visual representation of how we would implement the Diagnostic inference. 

<center><img src="images/precipafno.png" alt="Drawing" style="center" width="600px"/></center>

In [None]:
import os

os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv
load_dotenv()

from earth2studio.data import GFS
from earth2studio.io import ZarrBackend
from earth2studio.models.dx import PrecipitationAFNO
from earth2studio.models.px import FCN

# Prognostic Model - Load the default model package which downloads the check point from NGC
package = FCN.load_default_package()
prognostic_model = FCN.load_model(package)

# Diagnostic Model - Load the default model package which downloads the check point from NGC
package = PrecipitationAFNO.load_default_package()
diagnostic_model = PrecipitationAFNO.load_model(package)

# Data Source - Create the data source
data = GFS()

# IO Backend - Create the IO handler, store in memory
io = ZarrBackend()

## Execute the Workflow
With all components initialized, running the workflow is a single line of Python code.
Workflow will return the provided IO object back to the user, which can be used to
then post process. Let us look at the API for Diagnostic inference

```python
def diagnostic(
    time: list[str] | list[datetime] | list[np.datetime64],
    nsteps: int,
    prognostic: PrognosticModel,
    diagnostic: DiagnosticModel,
    data: DataSource,
    io: IOBackend,
    output_coords: CoordSystem = OrderedDict({}),
    device: torch.device | None = None,
) -> IOBackend:
    """Built in diagnostic workflow.
    This workflow creates a determinstic inference pipeline that couples a prognostic
    model with a diagnostic model.

    Parameters
    ----------
    time : list[str] | list[datetime] | list[np.datetime64]
        List of string, datetimes or np.datetime64
    nsteps : int
        Number of forecast steps
    prognostic : PrognosticModel
        Prognostic model
    diagnostic: DiagnosticModel
        Diagnostic model, must be on same coordinate axis as prognostic
    data : DataSource
        Data source
    io : IOBackend
        IO object
    output_coords: CoordSystem, optional
        IO output coordinate system override, by default OrderedDict({})
    device : torch.device, optional
        Device to run inference on, by default None

    Returns
    -------
    IOBackend
        Output IO object
    """
```



In [None]:
import earth2studio.run as run

nsteps = 8
io = run.diagnostic(
    ["2021-06-01"], nsteps, prognostic_model, diagnostic_model, data, io
)

print(io.root.tree())

## Post Processing
The last step is to plot the resulting predicted total precipitation. The power of
diagnostic models is that they allow the prediction of any variable from a pre-trained
prognostic model.


In [None]:
from datetime import datetime

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np

forecast = datetime(2021, 6, 1)
variable = "tp"
step = 8  # lead time = 48 hrs

plt.close("all")
# Create a Orthographic projection of USA
projection = ccrs.Orthographic(-100, 40)

# Create a figure and axes with the specified projection
fig, ax = plt.subplots(subplot_kw={"projection": projection}, figsize=(10, 6))

# Plot the field using pcolormesh
levels = np.arange(0.0, 0.01, 0.001)
im = ax.contourf(
    io["lon"][:],
    io["lat"][:],
    io[variable][0, step],
    levels,
    transform=ccrs.PlateCarree(),
    vmax=0.01,
    vmin=0.00,
    cmap="terrain",
)

# Set title
ax.set_title(f"{forecast.strftime('%Y-%m-%d')} - Lead time: {6*step}hrs")

# Add coastlines and gridlines6
ax.set_extent([220, 340, 20, 70])  # [lat min, lat max, lon min, lon max]
ax.coastlines()
ax.gridlines()
plt.colorbar(
    im, ax=ax, ticks=levels, shrink=0.75, pad=0.04, label="Total precipitation (m)"
)

plt.savefig("outputs/02_tp_prediction.jpg")

Let us now clean up the GPU memory and look at another inbuilt workflow.

# Important: Free up GPU Memory!

Run the below cell to free up GPU memory after training the model before moving to the next notebook.

In [None]:
import os
os._exit(00)

# Running Ensemble Inference

<center><img src="images/ensemble.png" alt="Drawing" style="center"/></center>


## Set Up
All workflows inside Earth2Studio require constructed components to be
handed to them. In this example, we will use the built in ensemble workflow
 `earth2studio.run.ensemble`.



**Ensemble Inference**: Ensemble inference with perturbation in weather forecasting involves generating multiple forecasts with slight variations in initial conditions or model parameters. This approach is crucial because the atmosphere is a chaotic system where small changes in initial conditions can lead to significant differences in outcomes. By using ensemble methods, we can quantify the uncertainty in predictions and provide a range of possible weather scenarios, enhancing the reliability and accuracy of forecasts, especially for extreme weather events. This method is particularly useful in deep learning models, which traditionally focus on deterministic outputs, by allowing them to incorporate probabilistic elements and better reflect the inherent uncertainties in weather prediction.

As we understand the use of Ensemble inference, we will use Pertubation to add changes to initial conditions to this flow. Some of the Pertubation methods available in version `0.7.0` are as follows: 

- **perturbation.Brown** : Lat/Lon 2D brown noise
- **perturbation.BredVector** : Bred Vector perturbation method, a classical technique for pertubations in ensemble forecasting.
- **perturbation.CorrelatedSphericalGaussian** : Produces Gaussian random field on the sphere with Matern covariance peturbation method output to a lat lon grid
- **perturbation.Gaussian** : Standard Gaussian peturbation
- **perturbation.HemisphericCentredBredVector** : Bred Vector perturbation method, following the approach introduced in 'Huge Ensembles Part I: Design of Ensemble Weather Forecasts using Spherical Fourier Neural Operators'.
- **perturbation.LaggedEnsemble** : Lagged Ensemble perturbation method.
- **perturbation.SphericalGaussian** : Gaussian random field on the sphere with Matern covariance peturbation method output to a lat lon grid
- **perturbation.Zero** : No perturbation scheme

We will use the following:

- **Prognostic Model**: Use the built in FourCastNet model `earth2studio.models.px.FCN`.
- **perturbation_method**: Use the Spherical Gaussian Method `earth2studio.perturbation.SphericalGaussian`.
- **Datasource**: Pull data from the GFS data api `earth2studio.data.GFS`.
- **IO Backend**: Save the outputs into a Zarr store `earth2studio.io.ZarrBackend`.


In [None]:
import os

os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv
load_dotenv()


import numpy as np

from earth2studio.data import GFS
from earth2studio.io import ZarrBackend
from earth2studio.models.px import FCN
from earth2studio.perturbation import SphericalGaussian
from earth2studio.run import ensemble

# # Prognostic Model - Load the default model package which downloads the check point from NGC
package = FCN.load_default_package()
model = FCN.load_model(package)

# Pertubation Method - Instantiate the pertubation method
sg = SphericalGaussian(noise_amplitude=0.15)

# Data Source - Create the data source
data = GFS()

# IO Backend - Create the IO handler, store in memory
chunks = {"ensemble": 1, "time": 1}
io = ZarrBackend(file_name="outputs/02_ensemble_sg.zarr", chunks=chunks)

## Execute the Workflow
With all components initialized, running the workflow is a single line of Python code.
Workflow will return the provided IO object back to the user, which can be used to
then post process. Let us look at the API for Ensemble: 

```python

def ensemble(
    time: list[str] | list[datetime] | list[np.datetime64],
    nsteps: int,
    nensemble: int,
    prognostic: PrognosticModel,
    data: DataSource,
    io: IOBackend,
    perturbation: Perturbation,
    batch_size: int | None = None,
    output_coords: CoordSystem = OrderedDict({}),
    device: torch.device | None = None,
) -> IOBackend:
    """Built in ensemble workflow.

    Parameters
    ----------
    time : list[str] | list[datetime] | list[np.datetime64]
        List of string, datetimes or np.datetime64
    nsteps : int
        Number of forecast steps
    nensemble : int
        Number of ensemble members to run inference for.
    prognostic : PrognosticModel
        Prognostic models
    data : DataSource
        Data source
    io : IOBackend
        IO object
    perturbation_method : Perturbation
        Method to perturb the initial condition to create an ensemble.
    batch_size: int, optional
        Number of ensemble members to run in a single batch,
        by default None.
    output_coords: CoordSystem, optional
        IO output coordinate system override, by default OrderedDict({})
    device : torch.device, optional
        Device to run inference on, by default None

    Returns
    -------
    IOBackend
        Output IO object
    """
```



For the forecast we will predict for 10 steps (for FCN, this is 60 hours) with 8 ensemble
members which will be ran in 2 batches with batch size 4.



In [None]:
nsteps = 10
nensemble = 8
batch_size = 2
io = ensemble(
    ["2024-01-01"],
    nsteps,
    nensemble,
    model,
    data,
    io,
    sg,
    batch_size=batch_size,
    output_coords={"variable": np.array(["t2m", "tcwv"])},
)

## Post Processing
The last step is to post process our results. Cartopy is a great library for plotting
fields on projections of a sphere.

Notice that the Zarr IO function has additional APIs to interact with the stored data.



In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt

forecast = "2024-01-01"


def plot_(axi, data, title, cmap):
    """Convenience function for plotting pcolormesh."""
    # Plot the field using pcolormesh
    im = axi.pcolormesh(
        io["lon"][:],
        io["lat"][:],
        data,
        transform=ccrs.PlateCarree(),
        cmap=cmap,
    )
    plt.colorbar(im, ax=axi, shrink=0.6, pad=0.04)
    # Set title
    axi.set_title(title)
    # Add coastlines and gridlines
    axi.coastlines()
    axi.gridlines()


for variable, cmap in zip(["tcwv"], ["Blues"]):
    step = 4  # lead time = 24 hrs

    plt.close("all")
    # Create a Robinson projection
    projection = ccrs.Robinson()

    # Create a figure and axes with the specified projection
    fig, (ax1, ax2, ax3) = plt.subplots(
        nrows=1, ncols=3, subplot_kw={"projection": projection}, figsize=(16, 3)
    )

    plot_(
        ax1,
        io[variable][0, 0, step],
        f"{forecast} - Lead time: {6*step}hrs - Member: {0}",
        cmap,
    )
    plot_(
        ax2,
        io[variable][1, 0, step],
        f"{forecast} - Lead time: {6*step}hrs - Member: {1}",
        cmap,
    )
    plot_(
        ax3,
        np.std(io[variable][:, 0, step], axis=0),
        f"{forecast} - Lead time: {6*step}hrs - Std",
        cmap,
    )

    plt.savefig(f"outputs/03_{forecast}_{variable}_{step}_ensemble.jpg")

### Additional Resources

We have looked at three workflows with Earth2Studio, but Earth2Studio allows us to work on our custom worklflows, whcih gives flexibility to researchers to expand on them. Here are some resrouces for Earth2Studio: 

- [Earth2Studio Github](https://github.com/NVIDIA/earth2studio)
- [Documentation](https://nvidia.github.io/earth2studio/index.html)
- [API Reference](https://nvidia.github.io/earth2studio/modules/index.html)
- [Additional Examples](https://nvidia.github.io/earth2studio/examples/index.html)

--- 

Don't forget to check out additional [Open Hackathons Resources](https://www.openhackathons.org/s/technical-resources) and join our [OpenACC and Hackathons Slack Channel](https://www.openacc.org/community#slack) to share your experience and get more help from the community.

---

# Licensing

Copyright © 2023 OpenACC-Standard.org.  This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials may include references to hardware and software developed by other entities; all applicable licensing and copyrights apply.
