In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gtsa

from pathlib import Path
import shutil
import psutil
import pandas as pd

# Raster stacking

Stacks single band rasters and chunks along the time dimension (on disk) for memory-efficient data retrieval.

#### Prerequesites
- Download DEM data with `00_download_dem_data.py` or `00_download_dem_data.ipynb`

## Start dask cluster
- For parallel read/write

In [3]:
workers = psutil.cpu_count(logical=True)-1
client = gtsa.io.dask_start_cluster(workers,
                                    ip_addres=None, # replace with address if working on remote machine
                                    port=':8787', # if occupies a different port will automatically be assigned
                                   )


Dask dashboard at: http://127.0.0.1:8787/status
Workers: 9
Threads per worker: 1 



## Get DEM file paths and time stamps

In [4]:
data_dir = '../../data/dems/south-cascade/' # small test dataset
# data_dir = '../../data/dems/mount-baker' # large dataset

In [5]:
dems = [x.as_posix() for x in sorted(Path(data_dir).glob('*.tif'))]
date_strings = [x[1:-1] for x in gtsa.io.parse_timestamps(dems,date_string_pattern='_........_')]
date_strings, dems = list(zip(*sorted(zip(date_strings, dems)))) # ensure chronological sorting 
date_times = [pd.to_datetime(x, format="%Y%m%d") for x in date_strings]

In [6]:
ref_dem = dems[-1] # always last after chronological sorting
ref_dem

'../../data/dems/south-cascade/WV_south-cascade_20151014_1m_dem.tif'

## Reproject to reference DEM grid
- Create a reprojected NetCDF file for each DEM
- Loads all NetCDF files lazily

In [None]:
ds = gtsa.io.xr_stack_geotifs(dems,
                              date_times,
                              ref_dem,
                              resampling="bilinear",
                              save_to_nc = True,
                              nc_out_dir = Path(data_dir,'nc_files').as_posix(),
                              overwrite = False)

## Examine current chunk shape
- Each time stamped DEM is a single chunk

In [None]:
ds['band1']

In [None]:
ds['band1'].sel(time = ds.time.values[0])

## Rechunk along time dimension
- Creates temporary zarr file for efficient rechunking
- Saves a zarr file chunked along full time dimension to disk
- Significantly improves dask worker occupation and processing time for computations along the time dimension

In [None]:
ds_zarr = gtsa.io.create_zarr_stack(ds,
                                    output_directory = Path(data_dir,'stack').as_posix(),
                                    variable_name='band1',
                                    zarr_stack_file_name='stack.zarr',
                                    overwrite = False,
                                    cleanup=True)

In [None]:
ds_zarr['band1']

In [None]:
ds_zarr['band1'].sel(time = ds_zarr.time.values[0])

## Why did we do this?
To compare performance we will compute the per-pixel count in the raster stack along the time dimension using the NetCDF files aligned to the same grid vs the zarr stack that has been spatially chunked, but includes all values in the time series for a given spatial chunk.
- watch your dask dashboard as you run the computations below
- note that it takes more steps, memory, and time to lazily compute

In [26]:
import xarray as xr

In [27]:
data_dir = '../../data/dems/south-cascade/'
data_dir = '../../data/dems/mount-baker/'

#### zarr per-pixel count computation

In [28]:
zarr_stack_fn = Path(data_dir,'stack','stack.zarr')
ds_zarr = xr.open_dataset(zarr_stack_fn,chunks='auto',engine='zarr')

# tc,yc,xc = gtsa.io.determine_optimal_chuck_size(ds_zarr,print_info = True)
# ds_zarr = xr.open_dataset(zarr_stack_fn,chunks={'time': tc, 'y': yc, 'x':xc},engine='zarr')

In [29]:
ds_zarr['band1']

Unnamed: 0,Array,Chunk
Bytes,22.43 GiB,89.83 MiB
Shape,"(11, 28392, 19282)","(11, 1775, 1206)"
Dask graph,256 chunks in 2 graph layers,256 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 22.43 GiB 89.83 MiB Shape (11, 28392, 19282) (11, 1775, 1206) Dask graph 256 chunks in 2 graph layers Data type float32 numpy.ndarray",19282  28392  11,

Unnamed: 0,Array,Chunk
Bytes,22.43 GiB,89.83 MiB
Shape,"(11, 28392, 19282)","(11, 1775, 1206)"
Dask graph,256 chunks in 2 graph layers,256 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [30]:
a = ds_zarr['band1'].count(axis=0).compute()

type: [Errno 22] Invalid argument

In [22]:
# a.plot()

#### NetCDF per-pixel count computation

In [23]:
nc_files = sorted(Path(data_dir,'nc_files').glob('*.nc'))
ds_nc = xr.open_mfdataset(nc_files)

In [24]:
%%time
ds_nc['band1'].count(axis=0).compute()

CPU times: user 136 ms, sys: 56.9 ms, total: 193 ms
Wall time: 509 ms
