In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
%matplotlib inline
import numpy as np
import xarray as xr
import util

In [8]:
center_time = True
calendar = 'noleap'

In [9]:
def dset(center_time=False):
    """Generate a simple test dataset"""
    
    bom = np.array([0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334], dtype=np.float64)
    start_date = np.append(bom, bom + 365)
    start_date = np.append(start_date, bom + 365*2)

    eom = np.array([31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365], dtype=np.float64)
    end_date = np.append(eom, eom + 365)
    end_date = np.append(end_date, eom+365*2)

    ds = xr.Dataset(coords={'time': 24, 'lat': 2, 'lon': 2, 'd2': 2})
    if center_time:
        ds['time'] = xr.DataArray(np.vstack((start_date, end_date)).mean(axis=0), dims='time')
    else:
        ds['time'] = xr.DataArray(end_date, dims='time')

    ds['lat'] = xr.DataArray([0, 1], dims='lat')
    ds['lon'] = xr.DataArray([0, 1], dims='lon')
    ds['d2'] = xr.DataArray([0, 1], dims='d2')
    ds['time_bound'] = xr.DataArray(
        np.array([start_date, end_date]).transpose(), dims=['time', 'd2']
    )
    
    ds['variable_1'] = xr.DataArray(
        np.concatenate((
            np.zeros([12, 2, 2], dtype='float32'), 
            np.ones([12, 2, 2], dtype='float32'), 
            np.zeros([12, 2, 2], dtype='float32'),
        ),
        axis=0
        ),
        dims=['time', 'lat', 'lon'],
    )
    ds.variable_1.attrs['description'] = 'Y1: zeros, Y2: ones, Y3: zeroes'
    
    ds['variable_2'] = xr.DataArray(
        np.concatenate((
            np.ones([12, 2, 2], dtype='float32'), 
            np.zeros([12, 2, 2], dtype='float32'), 
            np.ones([12, 2, 2], dtype='float32'),
        ),
        axis=0
        ),
        dims=['time', 'lat', 'lon'],
    )
    ds.variable_2.attrs['description'] = 'Y1: ones, Y2: zeros, Y3: ones'
    
    ds['variable_3'] = xr.DataArray(
        np.concatenate((
            np.ones([12, 2, 2], dtype='float32') * -1e36, 
            np.ones([12, 2, 2], dtype='float32') * 0.5, 
            np.ones([12, 2, 2], dtype='float32') * 20.,
        ),
        axis=0
        ),
        dims=['time', 'lat', 'lon'],
    )
    ds.variable_3.attrs['description'] = 'Y1: missing, Y2: 0.5, Y3: 20'
    ds.variable_3.attrs['_FillValue'] = -1e36

    ds['variable_4'] = xr.DataArray(
        np.concatenate((
            np.ones([12, 2, 2], dtype='float32') * -1e36, 
            np.ones([12, 2, 2], dtype='float32') * -1e36, 
            np.ones([12, 2, 2], dtype='float32') * 20.,
        ),
        axis=0
        ),
        dims=['time', 'lat', 'lon'],
    )
    ds.variable_4.attrs['description'] = 'Y1: missing, Y2: missing, Y3: 20'
    ds.variable_4.attrs['_FillValue'] = -1e36    
    
    ds['non_time_variable_1'] = xr.DataArray(np.ones((2, 2)), dims=['lat', 'lon'])  
    
    
    ds.time.attrs['units'] = 'days since 0001-01-01 00:00:00'
    ds.time.attrs['calendar'] = calendar
    ds.time.attrs['bounds'] = 'time_bound'

    return xr.decode_cf(ds.copy(True))

ds_test = dset(center_time=center_time)
ds_test

In [10]:
ds_test_ann = util.ann_mean(
    ds_test, time_bnds_varname='time_bound', time_centered=center_time, n_req=12,
)

np.testing.assert_almost_equal(ds_test_ann.variable_1.isel(time=0).values, 0.)
np.testing.assert_almost_equal(ds_test_ann.variable_1.isel(time=1).values, 1.)
np.testing.assert_almost_equal(ds_test_ann.variable_1.isel(time=2).values, 0.)

np.testing.assert_almost_equal(ds_test_ann.variable_2.isel(time=0).values, 1.)
np.testing.assert_almost_equal(ds_test_ann.variable_2.isel(time=1).values, 0.)
np.testing.assert_almost_equal(ds_test_ann.variable_2.isel(time=2).values, 1.)

assert np.isnan(ds_test_ann.variable_3.isel(time=0).values).all()
np.testing.assert_almost_equal(ds_test_ann.variable_3.isel(time=1).values, 0.5)
np.testing.assert_almost_equal(ds_test_ann.variable_3.isel(time=2).values, 20.)

assert np.isnan(ds_test_ann.variable_4.isel(time=0).values).all()
assert np.isnan(ds_test_ann.variable_4.isel(time=1).values).all()
np.testing.assert_almost_equal(ds_test_ann.variable_4.isel(time=2).values, 20.)


ds_test_ann

In [11]:
ds_test_djf = util.ann_mean(
    ds_test, season='DJF', time_bnds_varname='time_bound', time_centered=center_time, n_req=3
)

assert len(ds_test_djf.time) == 2

np.testing.assert_almost_equal(ds_test_djf.variable_1.isel(time=0).values, 
                               (np.array([31, 28, 31]) * np.array([0., 1., 1.])).sum() / np.array([31, 28, 31]).sum())
np.testing.assert_almost_equal(ds_test_djf.variable_1.isel(time=1).values, 
                               (np.array([31, 28, 31]) * np.array([1., 0., 0.])).sum() / np.array([31, 28, 31]).sum())

np.testing.assert_almost_equal(ds_test_djf.variable_2.isel(time=0).values, 
                               (np.array([31, 28, 31]) * np.array([1., 0., 0.])).sum() / np.array([31, 28, 31]).sum())
np.testing.assert_almost_equal(ds_test_djf.variable_2.isel(time=1).values, 
                               (np.array([31, 28, 31]) * np.array([0., 1., 1.])).sum() / np.array([31, 28, 31]).sum())

assert np.isnan(ds_test_ann.variable_3.isel(time=0).values).all()
np.testing.assert_almost_equal(ds_test_djf.variable_3.isel(time=1).values, 
                               (np.array([31, 28, 31]) * np.array([0.5, 20., 20.])).sum() / np.array([31, 28, 31]).sum())


assert np.isnan(ds_test_ann.variable_4.isel(time=0).values).all()
assert np.isnan(ds_test_ann.variable_4.isel(time=1).values).all()
#np.testing.assert_almost_equal(ds_test_djf.variable_4.isel(time=1).values, 
#                               (np.array([31, 28, 31]) * np.array([0.5, 20., 20.])).sum() / np.array([31, 28, 31]).sum())

ds_test_djf

In [12]:
ds_test_djf = util.ann_mean(
    ds_test, season='DJF', time_bnds_varname='time_bound', time_centered=center_time, n_req=2
)

assert len(ds_test_djf.time) == 3

np.testing.assert_almost_equal(ds_test_djf.variable_1.isel(time=0).values, 
                               (np.array([28, 31]) * np.array([0., 0.])).sum() / np.array([28, 31]).sum())
np.testing.assert_almost_equal(ds_test_djf.variable_1.isel(time=1).values, 
                               (np.array([31, 28, 31]) * np.array([0., 1., 1.])).sum() / np.array([31, 28, 31]).sum())
np.testing.assert_almost_equal(ds_test_djf.variable_1.isel(time=2).values, 
                               (np.array([31, 28, 31]) * np.array([1., 0., 0.])).sum() / np.array([31, 28, 31]).sum())


np.testing.assert_almost_equal(ds_test_djf.variable_2.isel(time=0).values, 
                               (np.array([28, 31]) * np.array([1., 1.])).sum() / np.array([28, 31]).sum())
np.testing.assert_almost_equal(ds_test_djf.variable_2.isel(time=1).values, 
                               (np.array([31, 28, 31]) * np.array([1., 0., 0.])).sum() / np.array([31, 28, 31]).sum())
np.testing.assert_almost_equal(ds_test_djf.variable_2.isel(time=2).values, 
                               (np.array([31, 28, 31]) * np.array([0., 1., 1.])).sum() / np.array([31, 28, 31]).sum())


assert np.isnan(ds_test_ann.variable_3.isel(time=0).values).all()
np.testing.assert_almost_equal(ds_test_djf.variable_3.isel(time=1).values, 
                               (np.array([28, 31]) * np.array([0.5, 0.5])).sum() / np.array([28, 31]).sum())

np.testing.assert_almost_equal(ds_test_djf.variable_3.isel(time=2).values, 
                               (np.array([31, 28, 31]) * np.array([0.5, 20., 20.])).sum() / np.array([31, 28, 31]).sum())


assert np.isnan(ds_test_ann.variable_4.isel(time=0).values).all()
assert np.isnan(ds_test_ann.variable_4.isel(time=1).values).all()
np.testing.assert_almost_equal(ds_test_djf.variable_4.isel(time=2).values, 
                               (np.array([28, 31]) * np.array([20., 20.])).sum() / np.array([28, 31]).sum())


ds_test_djf

In [13]:
for season in ['MAM', 'JJA', 'SON']:
    ds_test_seas = util.ann_mean(ds_test, season='MAM', time_bnds_varname='time_bound', time_centered=center_time)

    np.testing.assert_almost_equal(ds_test_seas.variable_1.isel(time=0).values, 0.)
    np.testing.assert_almost_equal(ds_test_seas.variable_1.isel(time=1).values, 1.)
    np.testing.assert_almost_equal(ds_test_seas.variable_1.isel(time=2).values, 0.)    
    np.testing.assert_almost_equal(ds_test_seas.variable_2.isel(time=0).values, 1.)
    np.testing.assert_almost_equal(ds_test_seas.variable_2.isel(time=1).values, 0.)
    np.testing.assert_almost_equal(ds_test_seas.variable_2.isel(time=2).values, 1.)
    assert np.isnan(ds_test_ann.variable_3.isel(time=0).values).all()
    np.testing.assert_almost_equal(ds_test_seas.variable_3.isel(time=1).values, 0.5)
    np.testing.assert_almost_equal(ds_test_seas.variable_3.isel(time=2).values, 20.)

ds_test_seas