Skip to content

Commit

Permalink
do all bands + month band (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
v0lat1le committed Nov 6, 2015
1 parent eb67682 commit 1fe74e8
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 62 deletions.
2 changes: 2 additions & 0 deletions cubeaccess/storage/geotif.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, filepath, other=None):
raise IOError("failed to open " + self._filepath)

t = self._transform = dataset.GetGeoTransform()
self._projection = dataset.GetProjection()
self.coordinates = {
'x': Coordinate(numpy.float32, t[0], t[0]+(dataset.RasterXSize-1)*t[1], dataset.RasterXSize),
'y': Coordinate(numpy.float32, t[3], t[3]+(dataset.RasterYSize-1)*t[5], dataset.RasterYSize)
Expand All @@ -42,6 +43,7 @@ def band2var(band):
self.variables = {str(i+1): band2var(dataset.GetRasterBand(i+1)) for i in xrange(dataset.RasterCount)}
else:
self._transform = other._transform
self._projection = other._projection
self.coordinates = other.coordinates
self.variables = other.variables

Expand Down
13 changes: 0 additions & 13 deletions scripts/__init__.py

This file was deleted.

52 changes: 35 additions & 17 deletions scripts/band_stats_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from __future__ import absolute_import, division, print_function
from builtins import *

import dask.array
import dask.imperative
import dask.multiprocessing

from distributed import Executor

import numpy

import builtins
Expand All @@ -28,25 +29,42 @@
builtins.__dict__['profile'] = lambda x: x

from cubeaccess.indexing import Range
from .common import do_work, _get_dataset, write_file
from common import do_work, _get_dataset, write_files


def main(argv):
lon = int(argv[1])
lat = int(argv[2])
dt = numpy.datetime64(argv[3])

stack = _get_dataset(lon, lat)
pqa = _get_dataset(lon, lat, dataset='PQA')

# TODO: this needs to propagate somehow from the input to the output
geotr = stack._storage_units[0]._storage_unit._transform
proj = stack._storage_units[0]._storage_unit._projection

qs = [10, 50, 90]
num_workers = 16
N = 4000//num_workers

def main():
stack = _get_dataset(146, -034)
pqa = _get_dataset(146, -034, dataset='PQA')
N = 250
zzz = []
for tidx, dt in enumerate(numpy.arange('1989', '1991', dtype='datetime64[Y]')):
data = []
for yidx, yoff in enumerate(range(0, 4000, N)):
kwargs = dict(y=slice(yoff, yoff+N), t=Range(dt, dt+numpy.timedelta64(1, 'Y')))
r = dask.imperative.do(do_work)(stack, pqa, **kwargs)
data.append(r)
r = dask.imperative.do(write_file)(str(dt), data)
zzz.append(r)
dask.imperative.compute(zzz, num_workers=16, get=dask.multiprocessing.get)
tasks = []
#for tidx, dt in enumerate(numpy.arange('1990', '1991', dtype='datetime64[Y]')):
filename = '/g/data/u46/gxr547/%s_%s_%s'%(lon, lat, dt)
data = []
for yidx, yoff in enumerate(range(0, 4000, N)):
kwargs = dict(y=slice(yoff, yoff+N), t=Range(dt, dt+numpy.timedelta64(1, 'Y')))
r = dask.imperative.do(do_work)(stack, pqa, qs, **kwargs)
data.append(r)
r = dask.imperative.do(write_files)(filename, data, qs, N, geotr, proj)
tasks.append(r)

#executor = Executor('127.0.0.1:8787')
#dask.imperative.compute(tasks, get=executor.get)
#dask.imperative.compute(tasks[0], num_workers=16)
dask.imperative.compute(tasks, get=dask.multiprocessing.get, num_workers=num_workers)


if __name__ == "__main__":
main()
import sys
main(sys.argv)
82 changes: 50 additions & 32 deletions scripts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from cubeaccess.core import StorageUnitDimensionProxy, StorageUnitStack
from cubeaccess.storage import GeoTifStorageUnit
from cubeaccess.indexing import make_index


def argpercentile(a, q, axis=0):
Expand All @@ -31,7 +32,7 @@ def argpercentile(a, q, axis=0):
index = (q*(a.shape[axis]-1-nans) + 0.5).astype(numpy.int32)
indices = numpy.indices(a.shape[:axis] + a.shape[axis+1:])
index = tuple(indices[:axis]) + (index,) + tuple(indices[axis:])
return numpy.argsort(a, axis=axis)[index]
return numpy.argsort(a, axis=axis)[index], nans == a.shape[axis]


def _time_from_filename(f):
Expand All @@ -55,16 +56,22 @@ def _get_dataset(lat, lon, dataset='NBAR', sat='LS5_TM'):
return stack


def write_file(name, data):
def write_files(name, data, qs, N, geotr, proj):
driver = gdal.GetDriverByName("GTiff")
raster = driver.Create(name+'.tif', 4000, 4000, 3, gdal.GDT_Int16,
options=["INTERLEAVE=BAND", "COMPRESS=LZW", "TILED=YES"])
for idx, y in enumerate(range(0, 4000, 250)):
raster.GetRasterBand(1).WriteArray(data[idx][0], 0, y)
raster.GetRasterBand(2).WriteArray(data[idx][1], 0, y)
raster.GetRasterBand(3).WriteArray(data[idx][2], 0, y)
raster.FlushCache()
del raster
nbands = len(data[0])
for qidx, q in enumerate(qs):
print('writing', name+'_'+str(q)+'.tif')
raster = driver.Create(name+'_'+str(q)+'.tif', 4000, 4000, nbands, gdal.GDT_Int16,
options=["INTERLEAVE=BAND", "COMPRESS=LZW", "TILED=YES"])
raster.SetProjection(proj)
raster.SetGeoTransform(geotr)
for band_num in range(nbands):
band = raster.GetRasterBand(band_num+1)
for idx, y in enumerate(range(0, 4000, N)):
band.WriteArray(data[idx][band_num][qidx], 0, y)
band.FlushCache()
raster.FlushCache()
del raster


def ndv_to_nan(a, ndv=-999):
Expand All @@ -73,32 +80,43 @@ def ndv_to_nan(a, ndv=-999):
return a


def do_thing(nir, red, green, blue, pqa):
def do_work(stack, pq, qs, **kwargs):
print('starting', datetime.now(), kwargs)
pqa = pq.get('1', **kwargs).values
red = ndv_to_nan(stack.get('3', **kwargs).values)
nir = ndv_to_nan(stack.get('4', **kwargs).values)

masked = 255 | 256 | 15360
pqa_idx = ((pqa & masked) != masked)
del pqa

nir = ndv_to_nan(nir)
nir[pqa_idx] = numpy.nan
red = ndv_to_nan(red)
red[pqa_idx] = numpy.nan

ndvi = (nir-red)/(nir+red)
index = argpercentile(ndvi, 90, axis=0)
index = (index,) + tuple(numpy.indices(index.shape))

red = red[index]
green = ndv_to_nan(green[index])
blue = ndv_to_nan(blue[index])

return red, green, blue


def do_work(stack, pq, **kwargs):
print(datetime.now(), kwargs)

nir = stack.get('4', **kwargs).values
red = stack.get('3', **kwargs).values
pqa = pq.get('1', **kwargs).values
green = stack.get('2', **kwargs).values
blue = stack.get('1', **kwargs).values
return do_thing(nir, red, green, blue, pqa)
index, mask = argpercentile(ndvi, qs, axis=0)

# TODO: make slicing coordinates nicer
tcoord = stack._get_coord('t')
slice_ = make_index(tcoord, kwargs['t'])
tcoord = tcoord[slice_]
tcoord = tcoord[index]
months = tcoord.astype('datetime64[M]').astype(int) % 12 + 1
months[..., mask] = 0

index = (index,) + tuple(numpy.indices(ndvi.shape[1:]))

def index_data(data):
data = ndv_to_nan(data[index])
data[..., mask] = numpy.nan
return data

nir = index_data(nir)
red = index_data(red)
blue = index_data(stack.get('1', **kwargs).values)
green = index_data(stack.get('2', **kwargs).values)
ir1 = index_data(stack.get('5', **kwargs).values)
ir2 = index_data(stack.get('6', **kwargs).values)

print('done', datetime.now(), kwargs)
return blue, green, red, nir, ir1, ir2, months

0 comments on commit 1fe74e8

Please sign in to comment.