diff --git a/cubeaccess/storage/geotif.py b/cubeaccess/storage/geotif.py index 8771b5c6be..5384001074 100644 --- a/cubeaccess/storage/geotif.py +++ b/cubeaccess/storage/geotif.py @@ -32,6 +32,7 @@ def __init__(self, filepath, other=None): raise IOError("failed to open " + self._filepath) t = self._transform = dataset.GetGeoTransform() + self._projection = dataset.GetProjection() self.coordinates = { 'x': Coordinate(numpy.float32, t[0], t[0]+(dataset.RasterXSize-1)*t[1], dataset.RasterXSize), 'y': Coordinate(numpy.float32, t[3], t[3]+(dataset.RasterYSize-1)*t[5], dataset.RasterYSize) @@ -42,6 +43,7 @@ def band2var(band): self.variables = {str(i+1): band2var(dataset.GetRasterBand(i+1)) for i in xrange(dataset.RasterCount)} else: self._transform = other._transform + self._projection = other._projection self.coordinates = other.coordinates self.variables = other.variables diff --git a/scripts/__init__.py b/scripts/__init__.py deleted file mode 100644 index 1c90d7a065..0000000000 --- a/scripts/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2015 Geoscience Australia -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/scripts/band_stats_dask.py b/scripts/band_stats_dask.py index c73841ee44..38e584f511 100644 --- a/scripts/band_stats_dask.py +++ b/scripts/band_stats_dask.py @@ -16,10 +16,11 @@ from __future__ import absolute_import, division, print_function from builtins import * -import dask.array import dask.imperative import dask.multiprocessing +from distributed import Executor + import numpy import builtins @@ -28,25 +29,42 @@ builtins.__dict__['profile'] = lambda x: x from cubeaccess.indexing import Range -from .common import do_work, _get_dataset, write_file +from common import do_work, _get_dataset, write_files + + +def main(argv): + lon = int(argv[1]) + lat = int(argv[2]) + dt = numpy.datetime64(argv[3]) + + stack = _get_dataset(lon, lat) + pqa = _get_dataset(lon, lat, dataset='PQA') + + # TODO: this needs to propagate somehow from the input to the output + geotr = stack._storage_units[0]._storage_unit._transform + proj = stack._storage_units[0]._storage_unit._projection + qs = [10, 50, 90] + num_workers = 16 + N = 4000//num_workers -def main(): - stack = _get_dataset(146, -034) - pqa = _get_dataset(146, -034, dataset='PQA') - N = 250 - zzz = [] - for tidx, dt in enumerate(numpy.arange('1989', '1991', dtype='datetime64[Y]')): - data = [] - for yidx, yoff in enumerate(range(0, 4000, N)): - kwargs = dict(y=slice(yoff, yoff+N), t=Range(dt, dt+numpy.timedelta64(1, 'Y'))) - r = dask.imperative.do(do_work)(stack, pqa, **kwargs) - data.append(r) - r = dask.imperative.do(write_file)(str(dt), data) - zzz.append(r) - dask.imperative.compute(zzz, num_workers=16, get=dask.multiprocessing.get) + tasks = [] + #for tidx, dt in enumerate(numpy.arange('1990', '1991', dtype='datetime64[Y]')): + filename = '/g/data/u46/gxr547/%s_%s_%s'%(lon, lat, dt) + data = [] + for yidx, yoff in enumerate(range(0, 4000, N)): + kwargs = dict(y=slice(yoff, yoff+N), t=Range(dt, dt+numpy.timedelta64(1, 'Y'))) + r = dask.imperative.do(do_work)(stack, pqa, qs, **kwargs) + data.append(r) + r = dask.imperative.do(write_files)(filename, data, qs, N, geotr, proj) + tasks.append(r) + #executor = Executor('127.0.0.1:8787') + #dask.imperative.compute(tasks, get=executor.get) + #dask.imperative.compute(tasks[0], num_workers=16) + dask.imperative.compute(tasks, get=dask.multiprocessing.get, num_workers=num_workers) if __name__ == "__main__": - main() + import sys + main(sys.argv) diff --git a/scripts/common.py b/scripts/common.py index f15a6542b9..69794234ca 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -20,6 +20,7 @@ from cubeaccess.core import StorageUnitDimensionProxy, StorageUnitStack from cubeaccess.storage import GeoTifStorageUnit +from cubeaccess.indexing import make_index def argpercentile(a, q, axis=0): @@ -31,7 +32,7 @@ def argpercentile(a, q, axis=0): index = (q*(a.shape[axis]-1-nans) + 0.5).astype(numpy.int32) indices = numpy.indices(a.shape[:axis] + a.shape[axis+1:]) index = tuple(indices[:axis]) + (index,) + tuple(indices[axis:]) - return numpy.argsort(a, axis=axis)[index] + return numpy.argsort(a, axis=axis)[index], nans == a.shape[axis] def _time_from_filename(f): @@ -55,16 +56,22 @@ def _get_dataset(lat, lon, dataset='NBAR', sat='LS5_TM'): return stack -def write_file(name, data): +def write_files(name, data, qs, N, geotr, proj): driver = gdal.GetDriverByName("GTiff") - raster = driver.Create(name+'.tif', 4000, 4000, 3, gdal.GDT_Int16, - options=["INTERLEAVE=BAND", "COMPRESS=LZW", "TILED=YES"]) - for idx, y in enumerate(range(0, 4000, 250)): - raster.GetRasterBand(1).WriteArray(data[idx][0], 0, y) - raster.GetRasterBand(2).WriteArray(data[idx][1], 0, y) - raster.GetRasterBand(3).WriteArray(data[idx][2], 0, y) - raster.FlushCache() - del raster + nbands = len(data[0]) + for qidx, q in enumerate(qs): + print('writing', name+'_'+str(q)+'.tif') + raster = driver.Create(name+'_'+str(q)+'.tif', 4000, 4000, nbands, gdal.GDT_Int16, + options=["INTERLEAVE=BAND", "COMPRESS=LZW", "TILED=YES"]) + raster.SetProjection(proj) + raster.SetGeoTransform(geotr) + for band_num in range(nbands): + band = raster.GetRasterBand(band_num+1) + for idx, y in enumerate(range(0, 4000, N)): + band.WriteArray(data[idx][band_num][qidx], 0, y) + band.FlushCache() + raster.FlushCache() + del raster def ndv_to_nan(a, ndv=-999): @@ -73,32 +80,43 @@ def ndv_to_nan(a, ndv=-999): return a -def do_thing(nir, red, green, blue, pqa): +def do_work(stack, pq, qs, **kwargs): + print('starting', datetime.now(), kwargs) + pqa = pq.get('1', **kwargs).values + red = ndv_to_nan(stack.get('3', **kwargs).values) + nir = ndv_to_nan(stack.get('4', **kwargs).values) + masked = 255 | 256 | 15360 pqa_idx = ((pqa & masked) != masked) + del pqa - nir = ndv_to_nan(nir) nir[pqa_idx] = numpy.nan - red = ndv_to_nan(red) red[pqa_idx] = numpy.nan ndvi = (nir-red)/(nir+red) - index = argpercentile(ndvi, 90, axis=0) - index = (index,) + tuple(numpy.indices(index.shape)) - - red = red[index] - green = ndv_to_nan(green[index]) - blue = ndv_to_nan(blue[index]) - - return red, green, blue - - -def do_work(stack, pq, **kwargs): - print(datetime.now(), kwargs) - - nir = stack.get('4', **kwargs).values - red = stack.get('3', **kwargs).values - pqa = pq.get('1', **kwargs).values - green = stack.get('2', **kwargs).values - blue = stack.get('1', **kwargs).values - return do_thing(nir, red, green, blue, pqa) + index, mask = argpercentile(ndvi, qs, axis=0) + + # TODO: make slicing coordinates nicer + tcoord = stack._get_coord('t') + slice_ = make_index(tcoord, kwargs['t']) + tcoord = tcoord[slice_] + tcoord = tcoord[index] + months = tcoord.astype('datetime64[M]').astype(int) % 12 + 1 + months[..., mask] = 0 + + index = (index,) + tuple(numpy.indices(ndvi.shape[1:])) + + def index_data(data): + data = ndv_to_nan(data[index]) + data[..., mask] = numpy.nan + return data + + nir = index_data(nir) + red = index_data(red) + blue = index_data(stack.get('1', **kwargs).values) + green = index_data(stack.get('2', **kwargs).values) + ir1 = index_data(stack.get('5', **kwargs).values) + ir2 = index_data(stack.get('6', **kwargs).values) + + print('done', datetime.now(), kwargs) + return blue, green, red, nir, ir1, ir2, months