In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline
import numpy as np
import xarray as xr
import util

In [3]:
center_time = False

In [4]:
def dset(center_time=False):
    """Generate a simple test dataset"""
    
    bom = np.array([0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334], dtype=np.float64)
    start_date = np.append(bom, bom + 365)
    start_date = np.append(start_date, bom + 365*2)

    eom = np.array([31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365], dtype=np.float64)
    end_date = np.append(eom, eom + 365)
    end_date = np.append(end_date, eom+365*2)

    ds = xr.Dataset(coords={'time': 24, 'lat': 2, 'lon': 2, 'd2': 2})
    if center_time:
        ds['time'] = xr.DataArray(np.vstack((start_date, end_date)).mean(axis=0), dims='time')
    else:
        ds['time'] = xr.DataArray(end_date, dims='time')

    ds['lat'] = xr.DataArray([0, 1], dims='lat')
    ds['lon'] = xr.DataArray([0, 1], dims='lon')
    ds['d2'] = xr.DataArray([0, 1], dims='d2')
    ds['time_bound'] = xr.DataArray(
        np.array([start_date, end_date]).transpose(), dims=['time', 'd2']
    )
    
    ds['variable_1'] = xr.DataArray(
        np.concatenate((
            np.zeros([12, 2, 2], dtype='float32'), 
            np.ones([12, 2, 2], dtype='float32'), 
            np.zeros([12, 2, 2], dtype='float32'),
        ),
        axis=0
        ),
        dims=['time', 'lat', 'lon'],
    )
    ds.variable_1.attrs['description'] = 'Y1: zeros, Y2: ones, Y3: zeroes'
    
    ds['variable_2'] = xr.DataArray(
        np.concatenate((
            np.ones([12, 2, 2], dtype='float32'), 
            np.zeros([12, 2, 2], dtype='float32'), 
            np.ones([12, 2, 2], dtype='float32'),
        ),
        axis=0
        ),
        dims=['time', 'lat', 'lon'],
    )
    ds.variable_2.attrs['description'] = 'Y1: ones, Y2: zeros, Y3: ones'
    
    ds['non_time_variable_1'] = xr.DataArray(np.ones((2, 2)), dims=['lat', 'lon'])
    
    ds.time.attrs['units'] = 'days since 0001-01-01 00:00:00'
    ds.time.attrs['calendar'] = 'noleap'
    ds.time.attrs['bounds'] = 'time_bound'

    return xr.decode_cf(ds.copy(True))

ds_test = dset(center_time=center_time)
ds_test

In [5]:
ds_test_ann = util.ann_mean(ds_test, time_bnds_varname='time_bound', time_centered=center_time)

np.testing.assert_almost_equal(ds_test_ann.variable_1.isel(time=0).values, 0.)
np.testing.assert_almost_equal(ds_test_ann.variable_1.isel(time=1).values, 1.)
np.testing.assert_almost_equal(ds_test_ann.variable_2.isel(time=0).values, 1.)
np.testing.assert_almost_equal(ds_test_ann.variable_2.isel(time=1).values, 0.)

ds_test_ann

In [6]:
ds_test_djf = util.ann_mean(ds_test, season='DJF', time_bnds_varname='time_bound', time_centered=center_time)
np.testing.assert_almost_equal(ds_test_djf.variable_1.isel(time=0).values, 
                               (np.array([31, 28, 31]) * np.array([0., 1., 1.])).sum() / np.array([31, 28, 31]).sum())

np.testing.assert_almost_equal(ds_test_djf.variable_1.isel(time=1).values, 
                               (np.array([31, 28, 31]) * np.array([1., 0., 0.])).sum() / np.array([31, 28, 31]).sum())

np.testing.assert_almost_equal(ds_test_djf.variable_2.isel(time=0).values, 
                               (np.array([31, 28, 31]) * np.array([1., 0., 0.])).sum() / np.array([31, 28, 31]).sum())

np.testing.assert_almost_equal(ds_test_djf.variable_2.isel(time=1).values, 
                               (np.array([31, 28, 31]) * np.array([0., 1., 1.])).sum() / np.array([31, 28, 31]).sum())

ds_test_djf

In [7]:
for season in ['MAM', 'JJA', 'SON']:
    ds_test_seas = util.ann_mean(ds_test, season='MAM', time_bnds_varname='time_bound', time_centered=center_time)

    np.testing.assert_almost_equal(ds_test_seas.variable_1.isel(time=0).values, 0.)
    np.testing.assert_almost_equal(ds_test_seas.variable_1.isel(time=1).values, 1.)
    np.testing.assert_almost_equal(ds_test_seas.variable_2.isel(time=0).values, 1.)
    np.testing.assert_almost_equal(ds_test_seas.variable_2.isel(time=1).values, 0.)

ds_test_seas