# Implementation of normalization class

### Background
To improve data normalization in MAELSTROM, a dedicated class for this job is introduced. <br>
However, the implementation reveals some low-level bugs with xarray that are documented here.

### Problem statement
Let's start with the *errornous* approach in which we allow handling of datasets in the class. Note that this approach should work according to [xarray's documentation](https://docs.xarray.dev/en/stable/generated/xarray.Dataset.std.html). It also works technically (i.e. it does not throw an error), but it produces strange results due to unrealistic values from the `std`-method on the dataset.

The (verbose) source-code looks as follows:

In [1]:
import os

from abc import ABC, abstractmethod
from typing import Union, List
import xarray as xr
import numpy as np

da_or_ds = Union[xr.DataArray, xr.Dataset]

class Normalize(ABC):
    """
    Abstract class for normalizing data.
    """
    def __init__(self, method: str, norm_dims: List):
        self.method = method
        self.norm_dims = norm_dims
        self.norm_stats = None

    def normalize(self, data: xr.DataArray, **stats):
        
        
        norm_stats = self.get_required_stats(data, **stats)
        data_norm = self.normalize_data(data, *norm_stats)

        return data_norm

    def denormalize(self, data: da_or_ds, **stats):
        norm_stats = self.get_required_stats(data, *stats)
        data_denorm = self.denormalize_data(data, *norm_stats)

        return data_denorm

    @abstractmethod
    def get_required_stats(self, data, **stats):
        """
        Function to retrieve either normalization parameters from data or from keyword arguments
        """
        pass

    @abstractmethod
    def normalize_data(data, *norm_param):
        """
        Function to normalize data.
        """
        pass

    @abstractmethod
    def denormalize_data(data, *norm_param):
        """
        Function to denormalize data.
        """
        pass

The child class for z-score normalization looks then as follows:

In [3]:
class ZScore(Normalize):

    def __init__(self, norm_dims: List):
        super().__init__("z_score", norm_dims)
        self.norm_stats = {"mu": None, "sigma": None}

    def get_required_stats(self, data: da_or_ds, **stats):

        mu, std = stats.get("mu", self.norm_stats["mu"]), stats.get("sigma", self.norm_stats["sigma"])

        if mu is None or std is None:
            print("Retrieve mu and sigma from data...")
            mu, std = data.mean(self.norm_dims), data.std(self.norm_dims)
            self.norm_stats = {"mu": mu, "sigma": std}
            print(self.norm_stats)
        else:
            print("Mu and sigma are parsed for (de-)normalization.")
            print(mu)
            print(std)
            
        return mu, std
            
    @staticmethod     
    def normalize_data(data, mu, std):
        data_norm = (data - mu) / std
        
        return data_norm
    
    @staticmethod     
    def denormalize_data(data, mu, std):
        data_denorm = data * std + mu
        
        return data_denorm
    

We load the validation data of the downscaling Tier-2 dataset for testing purposes.

In [4]:
datadir = "/p/project/deepacf/maelstrom/data/ap5/tier2"
datafile = os.path.join(datadir, "maelstrom-downscaling-tier2_test.nc")

ds = xr.open_dataset(datafile)

Then we apply the z-score normalization on the dataset:

In [6]:
norm_dims = ["time", "rlat", "rlon"]
zscore_norm = ZScore(norm_dims)

ds_norm1 = zscore_norm.normalize(ds)

Retrieve mu and sigma from data...
{'mu': <xarray.Dataset>
Dimensions:       ()
Data variables:
    rotated_pole  float64 1.0
    2t_in         float32 282.8
    sshf_in       float32 -7.078e+04
    slhf_in       float32 -1.565e+05
    blh_in        float32 529.6
    10u_in        float32 0.1829
    10v_in        float32 0.1448
    z_in          float32 5.686e+03
    t850_in       float32 278.6
    t925_in       float32 282.7
    hsurf_tar     float32 571.6
    t_2m_tar      float32 283.2, 'sigma': <xarray.Dataset>
Dimensions:       ()
Data variables:
    rotated_pole  float64 0.0
    2t_in         float64 104.5
    sshf_in       float64 1.973e+05
    slhf_in       float64 2.177e+05
    blh_in        float64 486.9
    10u_in        float64 2.089
    10v_in        float64 1.616
    z_in          float64 4.809e+03
    t850_in       float64 104.5
    t925_in       float64 104.5
    hsurf_tar     float64 511.0
    t_2m_tar      float64 104.5}


As we can see, the calculated standard deviation of the temperature variables (and others?) yields unexpected results which in would result into incorrect normalization. To prove this statement, we calculate the standard deviation manually:

In [7]:
t2m = ds["2t_in"]
sigma_t2m = np.sqrt(((t2m - t2m.mean(dim=norm_dims))**2).mean(dim=norm_dims))
print(sigma_t2m)              # note that omitting the parameter-setting dim=norm_dims would give the same result

<xarray.DataArray '2t_in' ()>
array(9.052816, dtype=float32)


By contrast, conversion of the dataset to a xr.DataArray, where the variables are assigned to a new dimension, gives the correct result.

In [8]:
da = ds.to_array(dim="variables")

zscore_norm2 = ZScore(norm_dims)     # get fresh class instance to avoid precomputed normalization parameters
da_norm2 = zscore_norm2.normalize(da)

Retrieve mu and sigma from data...
{'mu': <xarray.DataArray (variables: 12)>
array([ 1.00000000e+00,  2.82832158e+02, -7.07827281e+04, -1.56503764e+05,
        5.29617935e+02,  1.82885170e-01,  1.44832186e-01,  5.68644335e+03,
        2.78623048e+02,  2.82670052e+02,  5.71572380e+02,  2.83245770e+02])
Coordinates:
  * variables  (variables) <U12 'rotated_pole' '2t_in' ... 't_2m_tar', 'sigma': <xarray.DataArray (variables: 12)>
array([0.00000000e+00, 9.05282455e+00, 2.24066973e+05, 2.36301865e+05,
       5.11373825e+02, 2.43462441e+00, 1.87755217e+00, 4.48457913e+03,
       7.45010479e+00, 8.21746109e+00, 4.97467674e+02, 9.03267566e+00])
Coordinates:
  * variables  (variables) <U12 'rotated_pole' '2t_in' ... 't_2m_tar'}


Let's perform some further tests.

In [9]:
vardict = {"variables": "2t_in"}
t2m_norm1 = da_norm2.sel(vardict)
mu, std = zscore_norm2.norm_stats["mu"].sel(vardict), zscore_norm2.norm_stats["sigma"].sel(vardict)

# normalize 2t_in with precomputed normalization parameters
t2m_norm2 = zscore_norm2.normalize(t2m, mu=mu.values, sigma=std.values)
# reset norm_stats manually...
zscore_norm2.norm_stats = {"mu": None, "sigma": None}
# ... to trigger computation
t2m_norm3 = zscore_norm2.normalize(t2m)

Mu and sigma are parsed for (de-)normalization.
282.8321579222078
9.052824545594893
Retrieve mu and sigma from data...
{'mu': <xarray.DataArray '2t_in' ()>
array(282.83246, dtype=float32), 'sigma': <xarray.DataArray '2t_in' ()>
array(9.052816, dtype=float32)}


In [10]:
assert np.all(np.isclose(t2m_norm1, t2m_norm2, atol=1.e-04)), "t2m_norm1 and t2m_norm2 differ!"
assert np.all(np.isclose(t2m_norm2, t2m_norm3, atol=1.e-04)), "t2m_norm2 and t2m_norm3 differ!"
assert np.all(np.isclose(ds_norm1["2t_in"], t2m_norm2, atol=1.e-04)), "ds_norm1 and t2m_norm2 differ!"

AssertionError: ds_norm1 and t2m_norm2 differ!

As expected, the normalized values differ substanially due to the errornous value for the standard deviataion.
We can further trace the error which even happens when we apply the `std`-method on a Data Array with default-behaviur (i.e. without setting `dim`).

In [31]:
print(norm_dims)
print(ds["2t_in"].coords)

['time', 'rlat', 'rlon']
Coordinates:
  * time     (time) datetime64[ns] 2018-01-01T01:00:00 ... 2018-12-31T23:00:00
  * rlon     (rlon) float64 -8.273 -8.218 -8.163 -8.108 ... -1.838 -1.783 -1.728
  * rlat     (rlat) float64 -3.933 -3.878 -3.823 -3.768 ... 1.182 1.237 1.292


In [32]:
print(ds.std(skipna=True))                     # as above
print(ds["2t_in"].std(skipna=True))            # same error when retrievin 2t_in from the dataset only and defaulting .std()
print(ds["2t_in"].std(norm_dims, skipna=True))   # setting dimensions for .std()

<xarray.Dataset>
Dimensions:       ()
Data variables:
    rotated_pole  float64 0.0
    2t_in         float64 104.5
    sshf_in       float64 1.973e+05
    slhf_in       float64 2.177e+05
    blh_in        float64 486.9
    10u_in        float64 2.089
    10v_in        float64 1.616
    z_in          float64 4.809e+03
    t850_in       float64 104.5
    t925_in       float64 104.5
    hsurf_tar     float64 511.0
    t_2m_tar      float64 104.5
<xarray.DataArray '2t_in' ()>
array(104.4526062)
<xarray.DataArray '2t_in' ()>
array(9.052816, dtype=float32)


### The workaround

In the following, we ensure that the normalization class does not work on `xr.Datasets`. Note furthermore, that `norm_dims=None` also fails since the default approach may run into the same issues:

In [13]:
import dask

class Normalize(ABC):
    """
    Abstract class for normalizing data.
    """
    def __init__(self, method: str, norm_dims: List):
        self.method = method
        self.norm_dims = norm_dims
        self.norm_stats = None

    def normalize(self, data: xr.DataArray, **stats):
        """
        Normalize data.
        :param data: The DataArray to be normalized.
        :param **stats: Parameters to perform normalization. Must fit to normalization type!
        :return: DataArray with normalized data.
        """
        # sanity checks
        if not isinstance(data, xr.DataArray):
            raise TypeError(f"Passed data must be a xarray.DataArray, but is of type {str(type(data))}.")
            
        _ = self._check_norm_dims(data)
        # do the computation            
        norm_stats = self.get_required_stats(data, **stats)
        data_norm = self.normalize_data(data, *norm_stats)

        return data_norm

    def denormalize(self, data: da_or_ds, **stats):
        """
        Denormalize data.
        :param data: The DataArray to be denormalized.
        :param **stats: Parameters to perform denormalization. Must fit to normalization type!
        :return: DataArray with denormalized data.
        """
        # sanity checks
        if not isinstance(data, xr.DataArray):
            raise TypeError(f"Passed data must be a xarray.DataArray, but is of type {str(type(data))}.")
            
        _ = self._check_norm_dims(data)
        # do the computation       
        norm_stats = self.get_required_stats(data, *stats)
        data_denorm = self.denormalize_data(data, *norm_stats)

        return data_denorm
    
    @property
    def norm_dims(self):
        return self._norm_dims
    
    @norm_dims.setter
    def norm_dims(self, norm_dims):
        if norm_dims is None:
            raise AttributeError("norm_dims must not be None. Please parse a list of dimensions" +
                                 "over which normalization should be applied.")
        
        self._norm_dims = list(norm_dims)
        
    def _check_norm_dims(self, data):
        """
        Check if dimension for normalization reside in dimensions of data.
        :param data: the data (xr.DataArray) to be normalized
        :return True: in case of passed check, a ValueError is risen else
        """
        data_dims = list(data.dims)
        norm_dims_check = [norm_dim in data_dims for norm_dim in self.norm_dims]
        if not all(norm_dims_check):
            imiss = np.where(~np.array(norm_dims_check))[0]
            miss_dims = list(np.array(self.norm_dims)[imiss])
            raise ValueError("The following dimensions do not reside in the data: " +
                             f"{', '.join(miss_dims)}")

        return True
    
    def save_norm_to_file(self, js_file):
        """
        Write normalization parameters to file.
        :param js_file: Path to JSON-file to be created.
        :return: -
        """
        if self.norm_stats is None:
            raise AttributeError("norm_stats is still None. Please run (de-)normalization to get parameters.")
        
        if any([stat is None for stat in self.norm_stats.values()]):
            raise AttributeError("Some parameters of norm_stats are None.")
            
        norm_serialized = {key: da.to_dict() for key, da in norm_dict.items()}
        
        with open(js_file, "w") as jsf:
            js.dump(norm_dict_serialized, jsf)
        
    def read_norm_from_file(self, js_file):
        """
        Read normalization parameters from file. Inverse function to write_norm_from_file.
        :param js_file: Path to JSON-file to be read.
        :return: Parameters set to self.norm_stats
        """
        with open(js_file, "r") as jsf:
            norm_data = js.load(jsf)
            
        norm_dict_restored = {key: xr.DataArray.from_dict(da_dict) for key, da_dict in norm_data.items()}
        
        self.norm_stats = norm_dict_restored    
        

    @abstractmethod
    def get_required_stats(self, data, **stats):
        """
        Function to retrieve either normalization parameters from data or from keyword arguments
        """
        pass

    @abstractmethod
    def normalize_data(data, *norm_param):
        """
        Function to normalize data.
        """
        pass

    @abstractmethod
    def denormalize_data(data, *norm_param):
        """
        Function to denormalize data.
        """
        pass
    
class ZScore(Normalize):

    def __init__(self, norm_dims: List):
        super().__init__("z_score", norm_dims)
        self.norm_stats = {"mu": None, "sigma": None}

    def get_required_stats(self, data: da_or_ds, **stats):
        """
        Get required parameters for z-score normalization. They are either computed from the data 
        or can be parsed as keyword arguments.
        :param data: the data to be (de-)normalized
        :param mu: keyword argument for mean used for normalization
        :param sigma: keyword argument for standard deviation for normalization
        :return (mu, sigma): Parameters for normalization
        """
        mu, std = stats.get("mu", self.norm_stats["mu"]), stats.get("sigma", self.norm_stats["sigma"])

        if mu is None or std is None:
            print("Retrieve mu and sigma from data...")
            mu, std = data.mean(self.norm_dims), data.std(self.norm_dims)
            self.norm_stats = {"mu": mu, "sigma": std}
            print(self.norm_stats)
        else:
            print("Mu and sigma are parsed for (de-)normalization.")
            print(mu)
            print(std)
            
        return mu, std
            
    @staticmethod     
    def normalize_data(data, mu, std):
        """
        Perform z-score normalization on data
        :param data: Data array of interest
        :param mu: mean of data for normalization
        :param std: standard deviation of data for normalization
        :return data_norm: normalized data
        """
        data_norm = (data - mu) / std
        
        return data_norm
    
    @staticmethod     
    def denormalize_data(data, mu, std):
        """
        Perform z-score denormalization on data.
        :param data: Data array of interest
        :param mu: mean of data for denormalization
        :param std: standard deviation of data for denormalization
        :return data_norm: denormalized data
        """
        data_denorm = data * std + mu
        
        return data_denorm

To enable data processing, we start by converting the dataset to a data array.

In [110]:
da = ds.to_array(dim="variables")

Let's perform various tests:

In [121]:
zscore_norm = ZScore(["rlat", "rlon", "time"])

# normalize the resulting Data Array...
da_norm = zscore_norm.normalize(da)

Retrieve mu and sigma from data...
{'mu': <xarray.DataArray (variables: 12)>
array([ 1.00000000e+00,  2.82832158e+02, -7.07827281e+04, -1.56503764e+05,
        5.29617935e+02,  1.82885170e-01,  1.44832186e-01,  5.68644335e+03,
        2.78623048e+02,  2.82670052e+02,  5.71572380e+02,  2.83245770e+02])
Coordinates:
  * variables  (variables) <U12 'rotated_pole' '2t_in' ... 't_2m_tar', 'sigma': <xarray.DataArray (variables: 12)>
array([0.00000000e+00, 9.05282455e+00, 2.24066973e+05, 2.36301865e+05,
       5.11373825e+02, 2.43462441e+00, 1.87755217e+00, 4.48457913e+03,
       7.45010479e+00, 8.21746109e+00, 4.97467674e+02, 9.03267566e+00])
Coordinates:
  * variables  (variables) <U12 'rotated_pole' '2t_in' ... 't_2m_tar'}


In [124]:
# ... and compare:
assert np.all(np.isclose(da_norm.sel({"variables": "2t_in"}), t2m_norm1, atol=1.e-06)), "da_norm and t2m_norm1 differ!"
assert np.all(np.isclose(t2m_norm1, t2m_norm2, atol=1.e-04)), "t2m_norm1 and t2m_norm2 differ!"

Except for spurious deviations due to changes in the normalization parameters (why ever this happens?!), the results now coincide. At least, errors smaller than 1.e-04 can also be neglected in the normalized space for our purposes which aims to project the values into a data range suitable for backpropgation in a neural network.

In the following, we perform some further tests:

In [126]:
js_file = "./test.json"

# save normalization parameters to file
zscore_norm.save_norm_to_file(js_file)

# instantiate fresh instance and get normalization parameters from file
zscore_norm_new = ZScore(["rlat", "rlon", "time"])
zscore_norm_new.read_norm_from_file(js_file)

# apply normalization without retrieval from data (see print-statement!)
da_norm = zscore_norm_new.normalize(da)

Mu and sigma are parsed for (de-)normalization.
<xarray.DataArray (variables: 12)>
array([ 1.00000000e+00,  2.82832158e+02, -7.07827281e+04, -1.56503764e+05,
        5.29617935e+02,  1.82885170e-01,  1.44832186e-01,  5.68644335e+03,
        2.78623048e+02,  2.82670052e+02,  5.71572380e+02,  2.83245770e+02])
Coordinates:
  * variables  (variables) <U12 'rotated_pole' '2t_in' ... 't_2m_tar'
<xarray.DataArray (variables: 12)>
array([0.00000000e+00, 9.05282455e+00, 2.24066973e+05, 2.36301865e+05,
       5.11373825e+02, 2.43462441e+00, 1.87755217e+00, 4.48457913e+03,
       7.45010479e+00, 8.21746109e+00, 4.97467674e+02, 9.03267566e+00])
Coordinates:
  * variables  (variables) <U12 'rotated_pole' '2t_in' ... 't_2m_tar'


Finally check, if we still obtain the expected result:

In [128]:
assert np.all(np.isclose(da_norm.sel({"variables": "2t_in"}), t2m_norm1, atol=1.e-06)), "da_norm and t2m_norm1 differ!"

## 2023-04-04: Further tests 

### Problem statement
When retrieving the normalization parameters from xarray `Datasets` which then gets applied to xarray `DataArrays`, unwanted conversion of the data object or even a failure can be observed. Thus, the code below will further investigate how to ensure proper 'cross-application' of the (de-)normalization procedure. <br><br>
The current base class as obtained from branch #067 (commit ???):

In [14]:
import dask

class Normalize(ABC):
    """
    Abstract class for normalizing data.
    """

    def __init__(self, method: str, norm_dims: List):
        self.method = method
        self.norm_dims = norm_dims
        self.norm_stats = None

    def normalize(self, data: xr.DataArray, **stats):
        """
        Normalize data.
        :param data: The DataArray to be normalized.
        :param **stats: Parameters to perform normalization. Must fit to normalization type!
        :return: DataArray with normalized data.
        """
        # sanity checks
        # if not isinstance(data, xr.DataArray):
        #    raise TypeError(f"Passed data must be a xarray.DataArray, but is of type {str(type(data))}.")

        _ = self._check_norm_dims(data)
        # do the computation
        norm_stats = self.get_required_stats(data, **stats)
        data_norm = self.normalize_data(data, *norm_stats)

        return data_norm

    def denormalize(self, data: da_or_ds, **stats):
        """
        Denormalize data.
        :param data: The DataArray to be denormalized.
        :param **stats: Parameters to perform denormalization. Must fit to normalization type!
        :return: DataArray with denormalized data.
        """
        # sanity checks
        # if not isinstance(data, xr.DataArray):
        #    raise TypeError(f"Passed data must be a xarray.DataArray, but is of type {str(type(data))}.")

        _ = self._check_norm_dims(data)
        # do the computation
        norm_stats = self.get_required_stats(data, **stats)
        data_denorm = self.denormalize_data(data, *norm_stats)

        return data_denorm

    @property
    def norm_dims(self):
        return self._norm_dims

    @norm_dims.setter
    def norm_dims(self, norm_dims):
        if norm_dims is None:
            raise AttributeError("norm_dims must not be None. Please parse a list of dimensions" +
                                 "over which normalization should be applied.")

        self._norm_dims = list(norm_dims)

    def _check_norm_dims(self, data):
        """
        Check if dimension for normalization reside in dimensions of data.
        :param data: the data (xr.DataArray) to be normalized
        :return True: in case of passed check, a ValueError is risen else
        """
        data_dims = list(data.dims)
        norm_dims_check = [norm_dim in data_dims for norm_dim in self.norm_dims]
        if not all(norm_dims_check):
            imiss = np.where(~np.array(norm_dims_check))[0]
            miss_dims = list(np.array(self.norm_dims)[imiss])
            raise ValueError("The following dimensions do not reside in the data: " +
                             f"{', '.join(miss_dims)}")

        return True

    def save_norm_to_file(self, js_file, missdir_ok: bool = True):
        """
        Write normalization parameters to file.
        :param js_file: Path to JSON-file to be created.
        :param missdir_ok: If True, base-directory of JSON-file can be missing and will be created then.
        :return: -
        """
        if self.norm_stats is None:
            raise AttributeError("norm_stats is still None. Please run (de-)normalization to get parameters.")

        if any([stat is None for stat in self.norm_stats.values()]):
            raise AttributeError("Some parameters of norm_stats are None.")

        norm_serialized = {key: da.to_dict() for key, da in self.norm_stats.items()}

        # serialization and (later) deserialization depends on data type.
        # Thus, we have to save it to the dictionary
        d0 = list(self.norm_stats.values())[0]
        if isinstance(d0, xr.DataArray):
            norm_serialized["data_type"] = "data_array"
        elif isinstance(d0, xr.Dataset):
            norm_serialized["data_type"] = "data_set"

        if missdir_ok: os.makedirs(os.path.dirname(js_file), exist_ok=True)

        with open(js_file, "w") as jsf:
            js.dump(norm_serialized, jsf)

    def read_norm_from_file(self, js_file):
        """
        Read normalization parameters from file. Inverse function to write_norm_from_file.
        :param js_file: Path to JSON-file to be read.
        :return: Parameters set to self.norm_stats
        """
        with open(js_file, "r") as jsf:
            norm_data = js.load(jsf)

        data_type = norm_data.pop('data_type', None)

        if data_type == "data_array":
            xr_obj = xr.DataArray
        elif data_type == "data_set":
            xr_obj = xr.Dataset
        else:
            raise ValueError(
                f"Unknown data_type {data_type} in {js_file}. Only 'data_array' or 'data_set' are allowed.")

        norm_data.pop('data_type', None)

        norm_dict_restored = {key: xr_obj.from_dict(da_dict) for key, da_dict in norm_data.items()}

        self.norm_stats = norm_dict_restored

    @abstractmethod
    def get_required_stats(self, data, *stats):
        """
        Function to retrieve either normalization parameters from data or from keyword arguments
        """
        pass

    @staticmethod
    @abstractmethod
    def normalize_data(data, *norm_param):
        """
        Function to normalize data.
        """
        pass

    @staticmethod
    @abstractmethod
    def denormalize_data(data, *norm_param):
        """
        Function to denormalize data.
        """
        pass


The child class for zscore-normalization:

In [15]:
class ZScore(Normalize):
    def __init__(self, norm_dims: List):
        super().__init__("z_score", norm_dims)
        self.norm_stats = {"mu": None, "sigma": None}

    def get_required_stats(self, data: da_or_ds, **stats):
        """
        Get required parameters for z-score normalization. They are either computed from the data
        or can be parsed as keyword arguments.
        :param data: the data to be (de-)normalized
        :param stats: keyword arguments for mean (mu) and standard deviation (std) used for normalization
        :return (mu, sigma): Parameters for normalization
        """
        mu, std = stats.get("mu", self.norm_stats["mu"]), stats.get("sigma", self.norm_stats["sigma"])

        if mu is None or std is None:
            print("Retrieve mu and sigma from data...")
            mu, std = data.mean(self.norm_dims), data.std(self.norm_dims)
            # the following ensure that both parameters are computed in one graph!
            # This significantly reduces memory footprint as we don't end up having data duplicates
            # in memory due to multiple graphs (and also seem to enfore usage of data chunks as well)
            mu, std = dask.compute(mu, std)
            self.norm_stats = {"mu": mu, "sigma": std}
        # else:
        #    print("Mu and sigma are parsed for (de-)normalization.")

        return mu, std

    @staticmethod
    def normalize_data(data, mu, std):
        """
        Perform z-score normalization on data
        :param data: Data array of interest
        :param mu: mean of data for normalization
        :param std: standard deviation of data for normalization
        :return data_norm: normalized data
        """
        data = (data - mu) / std

        return data

    @staticmethod
    def denormalize_data(data, mu, std):
        """
        Perform z-score denormalization on data.
        :param data: Data array of interest
        :param mu: mean of data for denormalization
        :param std: standard deviation of data for denormalization
        :return data_norm: denormalized data
        """
        data = data * std + mu

        return data

In [10]:
def reshape_ds(ds):
    """
    Convert a xarray dataset to a data-array where the variables will constitute the last dimension (channel last)
    :param ds: the xarray dataset with dimensions (dims)
    :return da: the data-array with dimensions (dims, variables)
    """
    da = ds.to_array(dim="variables")
    print(da)
    da = da.transpose(..., "variables")
    return da


Next, we set the path to some test data (namely the augmented Tier-1 dataset)...

In [114]:
datadir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier1/netcdf_data/workdir/"

fname_train, fname_val = os.path.join(datadir, "downscaling_tier1_train_aug.nc"), os.path.join(datadir, "downscaling_tier1_val_aug.nc")

... and load the training data while also calculating the normalization parameters with the `xr.DataArray` after applying `reshape_ds`.

In [115]:
data_norm = ZScore(["time", "lat", "lon"])

print(f"Start loading the trining data from file '{fname_train}'...")
ds_train = xr.open_dataset(fname_train)
da_train = reshape_ds(ds_train.astype("float32", copy=False))
da_train = data_norm.normalize(da_train)

Start loading the trining data from file '/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier1/netcdf_data/workdir/downscaling_tier1_train_aug.nc'...
<xarray.DataArray (variables: 4, time: 20496, lat: 96, lon: 128)>
array([[[[ 2.7890393e+02,  2.7890341e+02,  2.7890289e+02, ...,
           2.7840961e+02,  2.7842529e+02,  2.7844098e+02],
         [ 2.7893362e+02,  2.7893417e+02,  2.7893475e+02, ...,
           2.7864780e+02,  2.7866364e+02,  2.7867950e+02],
         [ 2.7896332e+02,  2.7896497e+02,  2.7896661e+02, ...,
           2.7888596e+02,  2.7890198e+02,  2.7891803e+02],
         ...,
         [ 2.7672256e+02,  2.7678378e+02,  2.7684503e+02, ...,
           2.9391086e+02,  2.9382962e+02,  2.9374838e+02],
         [ 2.7706863e+02,  2.7717160e+02,  2.7727454e+02, ...,
           2.9351395e+02,  2.9345609e+02,  2.9339825e+02],
         [ 2.7741473e+02,  2.7755939e+02,  2.7770404e+02, ...,
           2.9311707e+02,  2.9308258e+02,  2.9304810e+02]],

        [[ 2.7

Next, we load the validation data, but keep as a `xr.Dataset`. Noe that we have retrieved the normalization parameters from a `xr.DataArray`-object.

In [119]:
ds_val = xr.open_dataset(fname_val)
print(ds_val)

<xarray.Dataset>
Dimensions:  (time: 2604, lon: 128, lat: 96)
Coordinates:
  * time     (time) datetime64[ns] 2020-05-01T10:00:00 ... 2020-08-31T16:03:00
  * lon      (lon) float64 4.0 4.1 4.2 4.3 4.4 4.5 ... 16.3 16.4 16.5 16.6 16.7
  * lat      (lat) float64 54.5 54.4 54.3 54.2 54.1 ... 45.4 45.3 45.2 45.1 45.0
Data variables:
    t2m_in   (time, lat, lon) float64 ...
    z_in     (time, lat, lon) float64 ...
    z_tar    (time, lat, lon) float64 ...
    t2m_tar  (time, lat, lon) float64 ...
Attributes:
    CDI:                        Climate Data Interface version 2.0.2 (https:/...
    Conventions:                CF-1.6
    history:                    Sun Mar 06 21:57:45 2022: cdo mergetime 2016/...
    NCO:                        netCDF Operators version 4.9.5 (Homepage = ht...
    history_of_appended_files:  Sun Mar  6 21:09:55 2022: Appended file /p/sc...
    CDO:                        Climate Data Operators version 2.0.2 (https:/...


To allow flexible corecing between different `xarray` data types, we introduce the following auxiliary function which performs this job:

In [111]:
def match_datatype(data, *args, var_dim="variables"):
    
    # sanity check 
    ds_or_da = (xr.Dataset, xr.DataArray)
    all_args = [data] + list(args)
    if not all(isinstance(arg, ds_or_da) for arg in all_args):
        flags = [not isinstance(arg, ds_or_da) for arg in all_args]
        inds = np.nonzero(flags)[0].tolist()#[0]
        if len(inds) == 1:
            err_str = f"The parsed argument at position {inds} is"
        else:
            err_str = f"The parsed arguments at positions {inds} are"
        raise ValueError(f"{err_str} not an xarray.DataArray or xarray.Dataset.")
    
    # align type of arguments if required
    if isinstance(data, type(args[0])):
        args_new = args
    elif isinstance(data, xr.Dataset) and isinstance(args[0], xr.DataArray):
        args_new = tuple(arg.to_dataset(dim=var_dim) for arg in args)
    elif isinstance(data, xr.DataArray) and isinstance(args[0], xr.Dataset):
        args_new = tuple(arg.to_array(dim=var_dim) for arg in args)
    else:
        raise ValueError("Unknown error occured. Please check all input parameters.")
        
    return args_new
                         

Finally, we check if this works for the given test case by actually running the normalization manually. 

In [117]:
mu, std = match_datatype(data, data_norm.norm_stats["mu"], data_norm.norm_stats["sigma"])
#mu, std = match_datatype(data, np.arange(4), np.arange(4))
data = (ds_val - mu) / std
print(data)

<xarray.Dataset>
Dimensions:  (time: 2604, lon: 128, lat: 96)
Coordinates:
  * time     (time) datetime64[ns] 2020-05-01T10:00:00 ... 2020-08-31T16:03:00
  * lon      (lon) float64 4.0 4.1 4.2 4.3 4.4 4.5 ... 16.3 16.4 16.5 16.6 16.7
  * lat      (lat) float64 54.5 54.4 54.3 54.2 54.1 ... 45.4 45.3 45.2 45.1 45.0
Data variables:
    t2m_in   (time, lat, lon) float64 -1.783 -1.785 -1.787 ... -0.5563 -0.5673
    z_in     (time, lat, lon) float64 -0.9632 -0.9632 ... -0.9632 -0.9632
    z_tar    (time, lat, lon) float64 -0.8164 -0.8165 ... -0.8165 -0.8164
    t2m_tar  (time, lat, lon) float64 -1.708 -1.713 -1.715 ... -0.4749 -0.4918
