### Handling Data Produced by the AtmoRep Model
This notebook provides a convenient way to work with the output data generated by the AtmoRep model, which is in the zarr format.

#### The AtmoRep model
The AtmoRep model operates on a regional patch-based approach, where the global atmospheric field is divided into smaller, overlapping regional patches (also, called tokens). This approach allows the model to capture local-scale features and dynamics more effectively, while still maintaining a global perspective. The model is implemented as a transformer neural network with 3.5 billion parameters and trained from the ERA5 reanalysis. It can be used for nowcasting (short term forecasting), temporal interpolation, model correction or counterfactuals. This notebook will demonstrate the case of nowcasting. 

#### ERA5 data
ERA5 is the fifth generation global atmospheric reanalysis produced by the European Centre for Medium Range Weather Forecasting (ECMWF). It provides the most consistent and coherent global estimate for the state of the atmosphere, land surface, and ocean waves that is currently available. For AtmoRep, the ERA5 reanalysis dataset is used with an hourly temporal resolution and ERA5’s default equi-angular grid with 721 × 1440 grid points in space. In the vertical dimension, the employed model levels are: 96, 105, 114, 123 and 137, corresponding approximately to pressure levels 546, 693, 850, 947, and 1012 hPa. The available variables are: zonal and meridional wind components, vorticity, divergence, vertical velocity, temperature, specific humidity, and total precipitation.

#### Handling AtmoRep data
This code contains a Python class, HandleAtmoRepData, which encapsulates the functionality to read, process, and analyze the data produced by the AtmoRep model. The most important methods are: 
- `get_hierarchical_sorted_files`: goes through the zarr files in the specified directory and outputs a list of the sorted filenames. The files are selected based on the model ID (models with various architecture details are organized using different IDs) and based on the number of epochs used for model training. One can choose only one zarr file that corresponds to a certain epoch or select files from several epochs.
- `get_config`: loads the configuration file for a given model ID.
- `read_one_file`: reads data from a single output zarr file. It iterates over all the patches and outputs a list of xarrays where each xarray is extracted from one patch.
- `read_file`: Merges the output of all the zarr files. It combines the individual data arrays which represent data from different spatial locations or time steps into a global array. The method `get_global_field` makes possible the latter.

The `read_file` method encapsulates all the most important steps for retrieving the global data field. It is important that the `get_global_field` method is included in the `read_file` method, since it enables the stitching of the regional data patches into a cohesive global data field.

The use of the zarr file format enables efficient I/O operations thanks to chunking.

#### Usage
This code scans the data in zarr format after you download them in your machine. One can experiment on extracting the ground truth and prediction for various variables such as: specific humidity, temperature, total precipitation, velocity and temperature. The model prediction on the demonstrated case consists on five future time stamps. Through the last routine, five global plots can be generated for each variable. Each plot consists on one time stamp. 

#### Source Code

##### Necessary python packages to be imported

In [None]:
import os
import json
from typing import List
from tqdm import tqdm
import code
import zarr
import numpy as np
import xarray as xr
from pathlib import Path

import cartopy
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['axes.linewidth'] = 0.1
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import warnings
warnings.filterwarnings('ignore')

##### HandleAtmoRepData class 

In [None]:
class HandleAtmoRepData(object):
    """
    Handle output data of AtmoRep model.
    """
    known_data_types = ["source", "pred", "target", "ens"]
    
    def __init__(self, model_id: str, results_dir: str):
        """
        :param model_id: ID of Atmorep-run to process
        :param results_dir: top-directory where results are stored
        """
        self.model_id = model_id if model_id.startswith("id") else f"id{model_id}"
        self.results_dir = results_dir
        self.config_file, self.config = self._get_config()

    def _get_config(self) -> (str, dict):
        """
        Get configuration dictionary of trained AtmoRep-model.
        """
        config_jsf = self.results_dir.joinpath(f"model_{self.model_id}.json")  
        with open(config_jsf) as json_file:
            config = json.load(json_file)         
        return config_jsf, config
        
    @staticmethod
    def get_number(file_name, split_arg):
        """ 
        Extract the number from the file name using the provided split argument.
        """
        return int(str(file_name).split(split_arg)[1].split('_')[0])  

    def get_hierarchical_sorted_files(self, data_type: str, epoch: int = -1):
        """
        Get sorted list of file names based on the specified data type and epoch.
        :param data_type: Type of data which should be retrieved (either 'source', 'target', 'ens' or 'pred')
        :param epoch: number of epoch for which data should be retrieved. Parse -1 for getting all epochs.
        """
        epoch_str = f"epoch*" if epoch == -1 else f"epoch{epoch:05d}"
        fpatt = f"results_{self.model_id}_{epoch_str}_{data_type}.zarr"
        file_list = list(Path(self.results_dir).glob(f"**/{fpatt}"))
        if not file_list:
            raise FileNotFoundError(f"No files found matching '{fpatt}' in '{self.results_dir}'.")            
        return sorted(file_list, key=lambda x: self.get_number(x, "_epoch"))
        
    @staticmethod
    def get_global_field(da_list):
        """
        Combines the individual data arrays which may represent data from different spatial locations or time steps,
        into a single global data array.
        """
        # get unique time stamps
        times_unique = list(set([time for da in da_list for time in da["datetime"].values]))
        dx, dy = np.abs(da_list[0]["lon"][1] - da_list[0]["lon"][0]), \
                 np.abs(da_list[0]["lat"][1] - da_list[0]["lat"][0])
        
        # initialize empty global data array
        dims = da_list[0].dims
        data_coords = {k: v for k, v in da_list[0].coords.items() if k not in ["lat", "lon"]}
        data_coords["lat"] = np.linspace(-90., 90., num=int(180/dy) + 1, endpoint=True)
        data_coords["lon"] = np.linspace(0, 360, num=int(360/dx), endpoint=False)  
        data_coords["datetime"] = times_unique
    
        da_global = xr.DataArray(np.empty(tuple(len(d) for d in data_coords.values())), 
                                 coords=data_coords, dims=dims)
        # fill global data array 
        for da in da_list:
            da_global.loc[{"datetime": da["datetime"], "lat": da["lat"], "lon": da["lon"]}] = da
    
        if np.any(da_global.isnull()): 
            raise ValueError(f"Could not get global data field.")             
        return da_global  

    def read_one_file(self, fname: str, varname: str, data_type: str):
        """
        Reads data from a single output file of AtmoRep. It creates a zarr.ZipStore object from the fname and a zarr.group 
        object from the store. Iterates over the patches in the Zarr group and constructs a list of 
        xarray.DataArray objects, one for each patch. Each xarray.DataArray object is created by extracting the data and 
        coordinates from the Zarr group.
        Finally, the function returns the list of xarray.DataArray objects.
        :param fname: Name of zarr-file that should be read
        :param varname: name of variable in zarr-file to be accessed
        :param data_type: Type of data which should be retrieved (either 'source', 'target', 'ens' or 'pred')
        :return: list of DataArrays where each element provides one sample
        """    
        store = zarr.ZipStore(fname, mode='r')
        grouped_store = zarr.group(store)
            
        dims = ["ml", "datetime", "lat", "lon"]
        if data_type == "ens":
            nens = self.config["net_tail_num_nets"]
            coords = {"ensemble": range(nens)}
        else:
            coords = {}
            
        da = []
        for ip, patch in tqdm(enumerate(grouped_store[os.path.join(varname)])):
            coords.update({dim: grouped_store[os.path.join(varname, patch, dim)] for dim in dims})
            da_p = xr.DataArray(grouped_store[os.path.join(varname, patch, "data")], coords=coords,                
                                dims = ["ensemble"] + dims if data_type == "ens" else dims, name=f"{varname}_{patch.replace('=', '')}")
            da.append(da_p)       
        return da
        
    def read_data(self, varname: str, data_type, epoch: int = -1, **kwargs):
        """
        Coordinates the reading of data from the output files and returns the combined data as a list of xarray.DataArray objects.
        Further analysis can be done by using these data. 
        :param varname: name of variable for which token info is requested
        :param data_type: Type of data which should be retrieved (either 'source', 'target', 'ens' or 'pred')
        :param epoch: training epoch of requested token information file
        """                  
        assert data_type in self.known_data_types, f"Data type '{data_type}' is unknown. Choose one of the following: '{', '.join(self.known_data_types)}'"
        
        file_list = self.get_hierarchical_sorted_files(data_type, epoch)
        
        if self.config["BERT_strategy"] == "forecast":
            args = {"varname": varname, "data_type": data_type}
        
        print(f"Start reading {len(file_list)} files...")
        da = []
        for i, f in enumerate(file_list):
            da += self.read_one_file(f, **args)
            
        # return global data if global forecasting evaluation mode was chosen 
        if self.config["BERT_strategy"] == "forecast" and self.config.get("token_overlap", False):
            da = self.get_global_field(da)   
        return da

##### A routine that plots a global map for the selected variable
This routine loads the data from the zarr file. One patch contains data from several time stamps and each time stamp creates one plot.

In [None]:
def plot_global_map(input_dir, field, model_id):
    # create empty canvas where local patches can be filled in
    store = zarr.ZipStore(f'{input_dir}/results_id{model_id}_epoch00000_pred.zarr', mode='r')
    ds = zarr.group(store=store)
    i = 0
    ds_o = xr.Dataset( coords={ 'ml' : ds[ f'{field}/sample={i:05d}/ml' ][:],
                                'datetime': ds[ f'{field}/sample={i:05d}/datetime' ][:], 
                                'lat' : np.linspace( -90., 90., num=180*4+1, endpoint=True), 
                                'lon' : np.linspace( 0., 360., num=360*4, endpoint=False) } )
    print("The plotted time stamps: ", ds[ f'{field}/sample={i:05d}/datetime' ][:])
    nlevels = ds[ f'{field}/sample={i:05d}/ml' ].shape[0]
    ds_o['vo'] = (['ml', 'datetime', 'lat', 'lon'], np.zeros( ( nlevels, 6, 721, 1440)))
    
    # fill in local patches
    all_lats=[]
    all_lons=[]
    for i_str in ds[ f'{field}']:
      if np.any(ds[ f'{field}/{i_str}/datetime' ][:]  != ds_o['vo'].datetime):
        break
      ds_o['vo'].loc[ dict( datetime=ds[ f'{field}/{i_str}/datetime' ][:],
            lat=ds[ f'{field}/{i_str}/lat' ][:],
            lon=ds[ f'{field}/{i_str}/lon' ][:]) ] = ds[ f'{field}/{i_str}/data'][:] #[0, :]
    
    # plot and save the time steps that form a token
    cmap = 'RdBu_r'
    vmin, vmax = ds_o['vo'].values[0].min(), ds_o['vo'].values[0].max()
    print(ds_o['datetime'].shape)
    for k in range( 6) :
      fig = plt.figure( figsize=(10,5), dpi=300)
      ax = plt.axes( projection=cartopy.crs.Robinson( central_longitude=0.))
      ax.add_feature( cartopy.feature.COASTLINE, linewidth=0.5, edgecolor='k', alpha=0.5)
      ax.set_global()
      date = ds_o['datetime'].values[k].astype('datetime64[m]')
      ax.set_title(f'{field} : {date}')
      ds_o['vo'].isel(ml=0, datetime = k).plot.imshow(cmap=cmap, vmin=vmin, vmax=vmax)
      im = ax.imshow( np.flip(ds_o['vo'].values[0,k], 0), cmap=cmap, vmin=vmin, vmax=vmax,
                      transform=cartopy.crs.PlateCarree( central_longitude=180.))
      axins = inset_axes( ax, width="80%", height="5%", loc='lower center', borderpad=-2 )
      fig.colorbar( im, cax=axins, orientation="horizontal")
      plt.show(f'example_{k:03d}.png')
      #plt.savefig(f'example_{k:03d}.png')
      plt.close()
    return

##### Instantiating an object of the HandleAtmoRepData class and calling methods on that object
Extracting target and prediction data by calling the `read_data` method. This method wraps the methods mentioned in the **Handling AtmoRep Data** section. We specify the model id, the variable of interest and the path of the folder where the zarr file is located. 


In [None]:
model_id = 'c96xrbip'
field = 'temperature'
input_dir = Path("/home/enxhi/Documents/data/atmorep_zarr2/")

ar_data    = HandleAtmoRepData(model_id, input_dir)
da_target  = ar_data.read_data(field, "target")
da_pred    = ar_data.read_data(field, "pred")

##### Calling the routine which plots the global map 

In [None]:
plot_global_map(input_dir, field, model_id)