In [1]:
import xarray as xr
import numpy as np
import glob
import os.path

from dask.distributed import Client
from dask.diagnostics import ProgressBar

from aggfly import dataset, regions, grid_weights
from aggfly.aggregate import TemporalAggregator, SpatialAggregator, get_time_dim

ProgressBar().register()
# client = Client()

In [2]:
# Set file output name/path
output_path = "/home3/dth2133/data/aggregated/counties/"
output_name = "usa_counties_corn_monthly"
output_varn = "dd"
csv = True

In [4]:
# Open shapefile containing region features.
georegions = regions.from_name('counties')

# Open example climate dataset to calculate grid weights.
clim = dataset.from_path(
    f"/home3/dth2133/data/usa/usa-t2m_tempPrecLand2019.zarr", 
    't2m', 
    'zarr', 
    preprocess=dataset.timefix_era5l)

# Clip climate data to the US (raw data are global)
clim.clip_data_to_georegions_extent(georegions)

# Rechunk dataset to optimize multithreading
clim.rechunk((5, 578, -1, -1, -1, -1))

# Calculate area and crop layer weights.
weights = grid_weights.from_objects(clim, georegions, crop='corn')

# This object covers aggregating hourly and daily data to the yearly 
# level
daily = [TemporalAggregator(
    'dd',
    agg_from='hour',
    agg_to='day', 
    ddargs=[x,999,0]) for x in range(-3,36)]


annual = TemporalAggregator(
    'sum', 
    agg_from='day',
    agg_to='month')

# This object aggregates cells within a region to the average across 
# cells, weighted by `weights`, which in this case are the area of the
# cell and the share of the cell with corn crops.
spatial = SpatialAggregator('avg')


In [6]:
# list(range(-3,35))
clim.da

Unnamed: 0,Array,Chunk
Bytes,4.81 GiB,98.43 MiB
Shape,"(250, 578, 1, 12, 31, 24)","(5, 578, 1, 12, 31, 24)"
Count,3640 Tasks,50 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.81 GiB 98.43 MiB Shape (250, 578, 1, 12, 31, 24) (5, 578, 1, 12, 31, 24) Count 3640 Tasks 50 Chunks Type float32 numpy.ndarray",1  578  250  24  31  12,

Unnamed: 0,Array,Chunk
Bytes,4.81 GiB,98.43 MiB
Shape,"(250, 578, 1, 12, 31, 24)","(5, 578, 1, 12, 31, 24)"
Count,3640 Tasks,50 Chunks
Type,float32,numpy.ndarray


In [7]:
# Calculate the grid weights
w = weights.weights()

[########################################] | 100% Completed |  0.1s
[########################################] | 100% Completed | 23.1s
[########################################] | 100% Completed | 22.0s
[########################################] | 100% Completed | 42.7s
[########################################] | 100% Completed |  0.1s
[########################################] | 100% Completed |  0.1s


In [8]:
def aggregate_era5l_t2m_multi(path):
    
    # path = "/home3/dth2133/data/usa/usa-t2m_tempPrecLand1969.zarr"
    # Open climate dataset.
    clim = dataset.from_path(
        path, 
        't2m', 
        'zarr', 
        preprocess=dataset.timefix_era5l) # Kelvin to Celsius
    
    # Clip climate data to the US (raw data are global)
    clim.clip_data_to_georegions_extent(georegions)
    # Rechunk dataset to optimize multithreading
    clim.rechunk((5, 578, -1, -1, -1))
    
    daily_list = [x.map_execute(clim.da.data) for x in daily]
    annual_list = [annual.map_execute(x) for x in daily_list]
    rc_list = [x.rechunk(-1) for x in annual_list]
    spatial_list = [spatial.map_execute(x, w) for x in rc_list]
    
    return spatial_list
    

In [9]:
# Check years from the input path - this just creates a vector of years for
# which my raw climate data are available, e.g. [1970, 1971, ...]
import numpy as np
import glob
from os.path import basename
files = np.sort([x for x in glob.glob('/home3/dth2133/data/usa/*t2m_*')])
# Loop over years and aggregate.
output = list()
for f in files:
    print(f)
    output.append(aggregate_era5l_t2m_multi(f))

/home3/dth2133/data/usa/usa-t2m_tempPrecLand1951.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1952.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1953.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1954.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1955.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1956.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1957.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1958.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1960.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1961.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1962.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1963.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1964.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1965.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1966.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1967.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1968.zarr
/home3/dth2133/data/usa/usa-t2m_tempPrecLand1969.zarr
/home3/dth2133/data/usa/usa-

In [None]:
import dask
result = dask.compute(output)

[                                        ] | 1% Completed |  3min 12.3s

IOStream.flush timed out


[#                                       ] | 4% Completed | 12min  4.0s

In [10]:
temp = xr.open_zarr(os.path.join(output_path, output_name+'.zarr'))

In [105]:
lr = ['dd'+str(x).replace('-','M') for x in range(-3,36)]
lr

In [106]:
len(result[0][0])
dlist = list()
for d in range(0, len(result[0][0])):
    ylist = list()
    for y in range(0, len(result[0])):
        ylist.append(result[0][y][d])  
    temp_array = xr.DataArray(
        data = np.concatenate(ylist, axis=1),
        dims = ['region', 'year', 'month'],
        coords = dict(
            region=('region', georegions.regions),
            year = ('year', temp.year.values),
            month = ('month', temp.month.values))).to_dataset(name = lr[d])
    dlist.append(temp_array)

In [None]:
outds = xr.combine_by_coords(dlist)

In [None]:
ds = xr.open_zarr(os.path.join(output_path, output_name+'.zarr'))
ld = [x for x in ds.keys() if x in lr]
if len(ld) > 0:
    ds = ds.drop(lr)
ds = xr.combine_by_coords([ds,outds]).compute()

In [None]:
# https://github.com/pydata/xarray/issues/3476
for v in list(ds.coords.keys()):
    if ds.coords[v].dtype == object:
        ds.coords[v] = ds.coords[v].astype("unicode")

for v in list(ds.variables.keys()):
    if ds[v].dtype == object:
        ds[v] = ds[v].astype("unicode")
        
ds.to_zarr(os.path.join(output_path, output_name+'.zarr'), mode='w')

In [None]:
if csv:
    ds = xr.open_zarr(os.path.join(output_path, output_name+'.zarr'))
    ds.to_dataframe().to_csv(os.path.join(output_path, output_name+'.csv'))