# Tutorial 1 - Forecasting with AtmoRep

## Preamble

Training of AtmoRep is a computationally intensive process, typically requiring a couple of weeks on multiple compute nodes of Juwels Booster. 
However, a selection of pretrained AtmoRep models (singleformers and multiformers) has been made available, enabling immediate use in so called zero-shot applications. <br>
Zero-shot applications leverage the intrinsic capabilities of the AtmoRep model and do not require task-specific finetuning or the addition of a specialized network to the core model.
In particular, AtmoRep has demonstrated skill for short-range forecasting/nowcasting, temporal interpolation and data decompression. <br>
In the following, we will explore the forecast application in a zero-shot setting.

#### Objectives
The objectives of this tutorial comprise
- Generation of a six-hour forecast with AtmoRep in inference mode
- Handling and plotting AtmoRep's model output
- Evaluation AtmoRep's forecast capacities

### Prerequisites
To run the following cells of the Jupyter Notebook, please ensure that you fulfill the following requirements:
- Run your JupyterLab on the GPU node of JURECA (use the `dc-gpu`-partition)
- Activate your hackathon kernel (see upper right corner of your Jupyter Notebook)

## Step 1 - Generate a forecast with AtmoRep

### Forecast Inference in AtmoRep

In this section, we will demonstrate how to perform a **global forecast** using AtmoRep. The global data is separated into tiles/neighborhoods using a spatiotemporal grids, the illustration bellow shows how a single tile is processed during a forecast. The data is represented across latitude, longitude, and time dimensions. Forecasting involves predicting future tokens (y) within a subset of this grid based on past observations (x). This procedure allows us to evaluate the model's ability to extrapolate patterns and generate meaningful predictions for unseen temporal data.

![forecast](../img/forecast_1ml.jpg "Data processing during forecast, with x input, y target output")

We start by importing some Python packages and classes from the AtmoRep code...

In [None]:
import os
import time
from pathlib import Path
from atmorep.core.evaluator import Evaluator
from atmorep.config.config import UserConfig

Before continuing, we need to set-up a user-specific configuration. <br>
This procedures provides both, a symbolic link to the directories where pretrained models and the ERA5-data (in zarr stores) are available and the configuration of user-specific output directories for their (custom) models, their inference results (e.g. global forecasts) and their output.

In [None]:
project_dir = Path(f"/p/project/training2445/{os.environ['USER']}/atmorep/")
user_config = UserConfig.from_path(project_dir)

In the next step, we are already ready to configure an instance of the `Evaluator`-class. 
This class allows users to perform inference on pretrained (either provided or custom) models. The class loads a pretrained model which is identified via its W&B-ID and then performs the inference step, i.e. only the forward step of the model is executed. The inference output is saved to the `results`-directory under the user-specific project directory. <br> 
Note that the `Evaluator`-class supports various modes for the inference (e.g. global forecasting) as outlined below. 

In [None]:
model_id='wc5e2i3t'                # pre-trained multiforner

mode, options = 'global_forecast', {
                                  'dates' : [[2021, 2, 10, 12]],    # this corresponds to thh last time step of the prediction sequence
                                  'token_overlap' : [0, 0],         # no overlapping between tiles
                                  'forecast_num_tokens' : 2,        # corresponds to a 6h-forecast (=2x3h)
                                  'with_pytest' : False }

For each inference experiment, three key parameters are required:
- The `model_id`, which identifies the pre-trained model used for inference
- The `mode` of inference you wish to perform
- The required attributes, labeled as `options`, necessary to carry out the selected inference mode

In the case of `global_forecast`, Atmorep requires the following options:
- `'dates'`: This is a list of dates. For a single date, it should still be provided as a list containing just that one item. Each date should follow the format `[year, month, day, hour]`, representing the last timestep to be forecasted. Atmorep will load data from 36 hours before this timestamp (inclusive), then start forecasting `forecast_num_tokens` $\times$ 3 hours backwards.
- `'token_overlap'`: Specifies the degree of overlap between tiles (or neighborhoods).
- `'forecast_num_tokens'`: The number of tokens to be forecasted.
- `'with_pytest'`: Enables or disables systematic testing processes. 

After choosing the appropriate mode for our application (global forecasting) and setting the configuration options, we can run the inference by calling the `evaluate`-method of the class object. 
Note that the following process takes a couple of minutes to complete the process.

In [None]:
os.environ["MASTER_ADDR"] = str(os.environ['SLURM_JOB_NODELIST'])       # AtmoRep needs to know the hostname of the master-node

# run inference
now = time.time()
Evaluator.evaluate( mode, model_id, options,  user_config=user_config)
print("time", time.time() - now)

Ensure to track the W&B-ID that is outputted when running the inference. This ID is required to run the subsequent validation (Step 2 of this tutorial).

### To-Dos @Asma and @Michael:
- [ ] Revise wording

### Other supported inference modes:
#### BERT Masked Token Mode
![BERT](../img/BERT_1ml.jpg "Data processing during BERT evaluation")

Neighborhoods are selected randomly based on dates and geographic locations, in which tokens are randomly masked according to a percentage specified in the configuration file.
```python
mode, options = 'BERT', {'years_val': [2021], 'num_samples_validate': 128, 'with_pytest': True}
```
Options to configure:
- `'years_val'`: Specifies the year from which samples will be drawn.
- `'num_samples_validate'`: Number of validation batches, which is also the number of selected neighborhoods. A neighborhood is a patch in space and time which gets processed together (??? improve the sentence)

#### BERT Forecast Mode
Random neighborhoods are sampled from the global dataset. For each neighborhood, forecasts are generated for a time period set by `forecast_num_tokens` $\times$ 3 hours.
```python
mode, options = 'forecast', {'forecast_num_tokens': 2, 'num_samples_validate': 128, 'with_pytest': True}
```
Previously defined options are used here.

#### Temporal Interpolation Mode
![temporal interpolation](../img/temporal_interpol_1ml.jpg "Data processing during temporal interpolation, with x input, y target output")

In this mode, AtmoRep fills in temporal gaps by predicting missing data for selected neighborhoods.

```python
mode, options = 'temporal_interpolation', {'idx_time_mask': [5, 6, 7], 'num_samples_validate': 128, 'with_pytest': True}
```
Key option explained:
- `'idx_time_mask'`: Specifies the indices of tokens in a batch of 12 that AtmoRep will attempt to predict.

## Step 2 - Process and validate the AtmoRep output

The AtmoRep output from the inference step is written to zarr-stores. While the zarr Python package provides a straightforward way to read in the data, the multi-dimensional nature of the data is challenging for subsequent processing. For convenience, we make therefore of a small data interface that reads in the data and turns it into xarry DataArray with labelled dimensions. <br>
For more information about xarray, please refer to the [docs](https://docs.xarray.dev/en/latest/getting-started-guide/index.html).

Again, we start by importing an auxiliary class and a method for data retrieval and plotting, respectively.

In [None]:
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
from jsc_scripts.utils_hackathon.read_atmorep_data import HandleAtmoRepData
from jsc_scripts.utils_hackathon.plotting import plot_global_data 

Next, we initialize the AtmoRep data handler with the W&B-ID of the inference step executed above and with the results directory of AtmoRep.

In [None]:
input_dir = user_config.__dict__["results"]
model_id = "idry906z4c"                     # adapt here
varname = "temperature"                     # can be changed

ar_data = HandleAtmoRepData(model_id, input_dir)

Then, we can read in the forecasted and the (ground truth) reference data as follows:

In [None]:
da_fcst = ar_data.read_data(varname, "pred")
da_ref = ar_data.read_data(varname, "target")

Let's have a look and check what we obtained:

In [None]:
da_fcst

In [None]:
da_ref

Since multi-dimensional arrays are hard to grasp, we will plot the data in the follwoing to get some more insight into AtmoRep's forecasts. 

In [None]:
%matplotlib inline

# parameters to select the data of interest
it_idx = 0                        # index for init_time-dimension
lead_time = 6
vlevel = 96
offset = -273.15                  # for conveting temperature unit from K to °C

# set some parameters to customize our plots
proj = ccrs.Robinson(central_longitude=0.)      # The projection to display the (global) data
transform = ccrs.PlateCarree()                  # transformation-object to be used when processing the data in the plot-routne (don't change!)
cmap_name = "RdBu_r"                            # colormap 

# slice the data of interes
fcst, ref = da_fcst.isel({"init_time": it_idx}).sel({"ml": vlevel, "lead_time": lead_time}) + offset, \
            da_ref.isel({"init_time": it_idx}).sel({"ml": vlevel, "lead_time": lead_time}) + offset

In [None]:
# get auxiliary strings
init_time = pd.to_datetime(da_fcst["init_time"][it_idx].values)
fcst_str = f"{init_time.strftime('%Y%m%d_%H00')}+{lead_time:03d}"

# create the plots - this takes a while
fname_pred = user_config.__dict__["results"].joinpath(f"atmorep_pred_{varname}_ml{vlevel: 03d}_{fcst_str}.png")
#plot_global_data(fcst, fname_pred,
#                 levels=np.arange(-45., 6.), cmap_name=cmap_name, projection=proj, transform=transform)

fname_tar = str(fname_fcst).replace("pred", "tar") 
#plot_global_data(ref, fname_tar,
#                 levels=np.arange(-45., 6.), cmap_name=cmap_name, projection=proj, transform=transform)

fname_diff = str(fname_fcst).replace("pred", "diff")
plot_global_data(fcst - ref, fname_diff,
                 levels=np.arange(-3., 3.1, 0.1), cmap_name=cmap_name, projection=proj, transform=transform)

### Task 1:
- Plot different variables on different levels
- Describe what you see. Can you explain your findings?

<hr> 

In the following, we will make our analysis more quantitative. For this, we will compute some basic evaluation metrics such as the RMSE and investigate how the results change with lead time. <br>
Again, we don't need to code everything from scratch, but make use of a `Score`-engine that allows computation of several metrics including averaging over user-defined data dimensions. <br>
We start by initialising the `Score`-engine.

In [None]:
from jsc_scripts.utils_hackathon.metrics import Scores

score_engine = Scores(da_fcst, da_ref, _, ["init_time", "lat", "lon",])

You may consult the doc-string to get further information on the engine. Alternatively, consult the source code (`/p/project/training2445/shared/atmorep/jsc_scripts/utils_hackathon/metrics.py`)

In [None]:
import inspect
print(Scores.__init__.__doc__)
print(f"Available scores: {list(score_engine.metrics_dict.keys())}")

Let's calculate the RMSE and plot it against leadtime:

In [None]:
rmse = score_engine("rmse")

In [None]:
from jsc_scripts.utils_hackathon.plotting import plot_metric_line

fname_rmse = user_config.__dict__["output"].joinpath(f"atmorep_rmse_{varname}_ml{vlevel:03d}.png")

plot_metric_line(rmse.sel({"ml": vlevel}), metric={"RMSE": "K"}, value_range=(0., 1.), plt_fname=str(fname_rmse), 
                 x_coord="lead_time", xlabel="leadtime [h]")

### Task 2:
- Evaluate different variables on different levels
- What do you observe?
- Extra task: Go back to the top of the Jupyter Notebook and re-run the forecasting by modifying the maximum leadtime. How does this affect the results?

<hr>

## Step 3 - Process and validate the AtmoRep ensemble

## To-Dos @Asma:
- [ ] Revise text and improve
- [ ] Improve instructions in tasks and check what happens when modifying the maximum leadtime via the `forecast_num_tokens`-parameter when evaluating
- [ ] Further/better tasks?
    - [ ] Uncover how `target` and `pred` data is being red?
- [ ] Create data for more samples -> to be provied in shared-directory
- [ ] Append Jupyter Notebook with evaluation on larger amount of samples (Step 3)

## To-Dos @Michael
- [ ] Revise data handler with separated lead time dimension and init time of forecast
- [ ] Ensure that lead time dimension is consecutive
- [ ] Fix and harmonize doc-strings and used methods