# Metadata Decorator

*PROBLEM*
Sometimes we have to use functions that remove the xr metadata and return plain numpy arrays, but we still want that meta data for further computation.

*GOAL*
Construct a decorator that replaces numpy array with a DataArray by trying to copy matching metadata from the input.

*STATUS*
The example below creates a decorator that checks if the shape of the numpy array matches any of the input arrays, and if so it copies the coordinates. This works, but is limited. _TODO:_ Construct a dict that will match each dimension shape to an input shape, and therefore allow changes in shape within reason to be labeled on ouput.

In [1]:
import xarray as xr
import numpy as np

In [22]:
ds = xr.open_dataset("/Users/brianpm/Dropbox/Data/CERES/CERES_EBAF-TOA_Edition4.0_200003-201810.nc")
# print(ds.data_vars)

In [8]:
x = ds['toa_sw_all_mon']
print(x)
sol = ds['solar_mon']
print(sol)

<xarray.DataArray 'toa_sw_all_mon' (time: 224, lat: 180, lon: 360)>
[14515200 values with dtype=float32]
Coordinates:
  * lon      (lon) float32 0.5 1.5 2.5 3.5 4.5 ... 355.5 356.5 357.5 358.5 359.5
  * lat      (lat) float32 -89.5 -88.5 -87.5 -86.5 -85.5 ... 86.5 87.5 88.5 89.5
  * time     (time) datetime64[ns] 2000-03-15 2000-04-15 ... 2018-10-15
Attributes:
    long_name:      Top of The Atmosphere Shortwave Flux, Monthly Means, All-...
    standard_name:  TOA Shortwave Flux - All-Sky
    CF_name:        toa_outgoing_shortwave_flux
    units:          W m-2
    valid_min:            0.00000
    valid_max:            600.000
<xarray.DataArray 'solar_mon' (time: 224, lat: 180, lon: 360)>
[14515200 values with dtype=float32]
Coordinates:
  * lon      (lon) float32 0.5 1.5 2.5 3.5 4.5 ... 355.5 356.5 357.5 358.5 359.5
  * lat      (lat) float32 -89.5 -88.5 -87.5 -86.5 -85.5 ... 86.5 87.5 88.5 89.5
  * time     (time) datetime64[ns] 2000-03-15 2000-04-15 ... 2018-10-15
Attributes:
    l

In [10]:
# super simple might be to mask any regions that are in polar night (for the whole month)
lit = np.where(sol > 0.0, x, np.nan)
print(lit)  # plain numpy

In [19]:
# now design a decorator that goes around a function that gets metadata
def wrap(func):
    def wrapper(*args, **kwargs):
        result = func(*args, **kwargs)
        out_shape = result.shape
        for a in args:
            if hasattr(a, 'coords'):
                if a.shape == result.shape:
                    return xr.DataArray(result, coords=a.coords)
        print("Output shape does not match any input shape. No meta data copied.")
        return result
    return wrapper


In [20]:
@wrap
def masker(data, light):
    return np.where(light > 0.0, data, np.nan)

In [21]:
masker(x, sol)

<xarray.DataArray (time: 224, lat: 180, lon: 360)>
array([[[ 43.85 ,  43.85 , ...,  43.85 ,  43.85 ],
        [ 43.66 ,  43.66 , ...,  43.66 ,  43.66 ],
        ...,
        [ 16.94 ,  16.94 , ...,  16.94 ,  16.94 ],
        [ 17.78 ,  17.78 , ...,  17.78 ,  17.78 ]],

       [[    nan,     nan, ...,     nan,     nan],
        [    nan,     nan, ...,     nan,     nan],
        ...,
        [148.8  , 148.8  , ..., 148.8  , 148.8  ],
        [150.1  , 150.1  , ..., 150.1  , 150.1  ]],

       ...,

       [[ 11.1  ,  11.1  , ...,  11.1  ,  11.1  ],
        [  9.739,   9.739, ...,   9.739,   9.739],
        ...,
        [ 50.21 ,  50.21 , ...,  50.21 ,  50.21 ],
        [ 49.54 ,  49.54 , ...,  49.54 ,  49.54 ]],

       [[156.5  , 156.5  , ..., 156.5  , 156.5  ],
        [149.4  , 149.4  , ..., 149.4  , 149.4  ],
        ...,
        [    nan,     nan, ...,     nan,     nan],
        [    nan,     nan, ...,     nan,     nan]]], dtype=float32)
Coordinates:
  * lon      (lon) float32 0.5 1