Skip to content

Commit

Permalink
grid processing script that's computing SPI/Gamma using xarray GroupB…
Browse files Browse the repository at this point in the history
…y, various code cleanups

#191
  • Loading branch information
monocongo committed Oct 17, 2018
1 parent 30ed9c6 commit acb766d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 24 deletions.
9 changes: 9 additions & 0 deletions climate_indices/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ class Periodicity(Enum):
monthly = 12
daily = 366

def __str__(self):
return self.name

@staticmethod
def from_string(s):
try:
return Periodicity[s]
except KeyError:
raise ValueError()

# ----------------------------------------------------------------------------------------------------------------------
@numba.jit
Expand Down
3 changes: 2 additions & 1 deletion scripts/process_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import logging
import math
import multiprocessing

import netCDF4
import netcdf_utils
import numpy as np

from climate_indices import indices, utils
from scripts import netcdf_utils

#-----------------------------------------------------------------------------------------------------------------------
# static constants
Expand Down
86 changes: 63 additions & 23 deletions scripts/process_grid_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import xarray as xr

from climate_indices import indices
from climate_indices import compute, indices

# ----------------------------------------------------------------------------------------------------------------------
# set up a basic, global _logger which will write to the console as standard error
Expand Down Expand Up @@ -227,6 +227,22 @@ def _validate_args(args):
raise ValueError(message)


def spi_gamma(data_array,
scale,
data_start_year,
calibration_year_initial,
calibration_year_final,
periodicity):

return indices.spi(data_array,
scale,
indices.Distribution.gamma,
data_start_year,
calibration_year_initial,
calibration_year_final,
periodicity)


# ----------------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':
"""
Expand Down Expand Up @@ -271,7 +287,8 @@ def _validate_args(args):
required=True)
parser.add_argument("--periodicity",
help="Process input as either monthly or daily values",
choices=['monthly', 'daily'],
choices=list(compute.Periodicity),
type=compute.Periodicity.from_string,
required=True)
parser.add_argument("--netcdf_temp",
help="Temperature NetCDF file to be used as input for indices computations")
Expand All @@ -297,11 +314,29 @@ def _validate_args(args):
# compute SPI if specified
if arguments.index in ['spi', 'scaled']:

# open the precipitation NetCDF as an xarray DataSet object
dataset = xr.open_dataset(arguments.netcdf_precip)

# trim out all data variables from the dataset except the precipitation
for var in dataset.data_vars:
if var not in arguments.var_name_precip:
dataset = dataset.drop(var)

# get the precipitation variable as an xarray DataArray object
da_precip = dataset[arguments.var_name_precip]

# get the initial year of the data
data_start_year = int(str(da_precip['time'].values[0])[0:4])

# stack the lat and lon dimensions into a new dimension named point, so at each lat/lon
# we'll have a time series for the geospatial point
da_precip = da_precip.stack(point=('lat', 'lon'))

for timestep_scale in arguments.scales:

if arguments.periodicity == 'daily':
if arguments.periodicity is compute.Periodicity.daily:
scale_increment = 'day'
elif arguments.periodicity == 'monthly':
elif arguments.periodicity is compute.Periodicity.monthly:
scale_increment = 'month'
else:
raise ValueError("Invalid periodicity argument: {}".format(arguments.periodicity))
Expand All @@ -310,27 +345,32 @@ def _validate_args(args):
incr=scale_increment,
index='SPI'))

# open the precipitation NetCDF, getting the precipitation variable as an xarray DataArray object
da_precip = xr.open_dataset(arguments.netcdf_precip)[arguments.var_name_precip]

# get the initial year of the data
data_start_year = da_precip['time'][0].year

# stack the lat and lon dimensions into a new dimension named point, so at each lat/lon
# we'll have a time series for the geospatial point
da_precip = da_precip.stack(point=('lat', 'lon'))

# group the data by geospatial point and apply the SPI/Gamma function to each
da_precip = da_precip.groupby('point').apply(indices.spi,
timestep_scale,
indices.Distribution.gamma,
data_start_year,
arguments.calibration_start_year,
arguments.calibration_end_year,
arguments.periodicity)
# group the data by geospatial point and apply the SPI/Gamma function to each time series group
da_spi = da_precip.groupby('point').apply(spi_gamma,
scale=timestep_scale,
data_start_year=data_start_year,
calibration_year_initial=arguments.calibration_start_year,
calibration_year_final=arguments.calibration_end_year,
periodicity=arguments.periodicity)

# unstack the array back into original dimensions
da_precip = da_precip.unstack('point')
da_spi = da_spi.unstack('point')

# create a new variables to contain the SPI for the scale, assign into the dataset
long_name = "Standardized Precipitation Index (Gamma distribution), "\
"{scale}-{increment}".format(scale=timestep_scale, increment=scale_increment)
spi_var = xr.Variable(dims=da_spi.dims,
data=da_spi,
attrs={'long_name' : long_name,
'valid_min' : -3.09,
'valid_max' : 3.09})
dataset["spi_gamma_" + str(timestep_scale).zfill(2)] = spi_var

# trim out the precipitation variable since it won't be needed again
dataset = dataset.drop(arguments.var_name_precip)

# write the dataset as NetCDF
dataset.to_netcdf(arguments.output_file_base + "_spi_gamma.nc")

# report on the elapsed time
end_datetime = datetime.now()
Expand Down

0 comments on commit acb766d

Please sign in to comment.