In [1]:
import dask
import dask.distributed

def worker_setup_auto():
    from datacube.utils.rio import set_default_rio_config, activate_from_config
    set_default_rio_config(aws={'region_name': 'auto'},
                           GDAL_INGESTED_BYTES_AT_OPEN=32*1024,
                           cloud_defaults=True)
    return activate_from_config()


client = dask.distributed.Client(n_workers=4, 
                                 threads_per_worker=12, 
                                 processes=True, 
                                 ip='127.0.0.1')

# Runs once on every worker process, not per worker thread!
client.register_worker_callbacks(setup=worker_setup_auto)
client

0,1
Client  Scheduler: tcp://127.0.0.1:37211  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 4  Cores: 48  Memory: 128.58 GB


In [2]:
import numpy as np
import pickle
from tqdm import tqdm_notebook
from timeit import default_timer as t_now
from datacube import Datacube
from datacube.utils.rio import set_default_rio_config
from datacube.testutils.io import get_raster_info
from odc.algo import dask_compute_stream


def load_grids(ds, grids):
    b2g = {v:k for k,v in grids.items()}
    bands = list(grids.values())
    def safe_get(ds):
        try:
            return get_raster_info(ds, bands)
        except Exception:
            return {}

    grids = {b2g[k]: ii.geobox
            for k, ii in safe_get(ds).items()}
    
    return (ds.id, grids)


grids = {'10m': 'nbar_blue', 
         '20m': 'fmask',
         '60m': 'nbar_coastal_aerosol'}

product='s2a_nrt_granule'
set_default_rio_config(aws={'region_name': 'auto'}, 
                       cloud_defaults=True)

dc = Datacube(env='s2')

In [3]:
dss_total = dc.index.datasets.count(product=product)
ncores = sum(client.ncores().values())
dss_total, ncores

(22559, 48)

In [4]:
%%time

# Number of datasets to process in one go on one worker
#  - Larger number means less comms overhead
#  - Too large is problematic too however, aim for 5-20 seconds per task
lump = 40

t0 = t_now()
dss_all = dc.find_datasets_lazy(product=product)

results = dask_compute_stream(client, lambda ds: load_grids(ds, grids),
                              dss_all,
                              lump=lump,
                              max_in_flight=lump*ncores*2,
                              name='load_grids')

ngg_all = [v for v in tqdm_notebook(results, total=dss_total)]
t1 = t_now()

HBox(children=(IntProgress(value=0, max=22559), HTML(value='')))


CPU times: user 44.1 s, sys: 3.35 s, total: 47.4 s
Wall time: 2min 41s


In [5]:
t_total = t1 - t0
files_processed = len(grids)*dss_total

t_total, files_processed
print('''
Num. Workers    : {nworkers:d}
Files processed : {nf:,d} f
Time took       : {t_total:5.1f} sec
FPS             : {fps_total:5.1f} f/sec
FPS (worker)    : {fps_per_worker:5.1f} f/sec
'''.format(nf=files_processed, 
           t_total=t_total,
           nworkers=ncores,
           fps_total=files_processed/t_total,
           fps_per_worker=files_processed/t_total/ncores
          ))


Num. Workers    : 48
Files processed : 67,677 f
Time took       : 161.3 sec
FPS             : 419.6 f/sec
FPS (worker)    :   8.7 f/sec



In [6]:
pickle.dump(dict(ngg_all), open(product+'_grids.pickle', 'wb'))

In [7]:
failures = [_id for _id, v in ngg_all if len(v) == 0]

len(failures), failures[:10]

(2,
 [UUID('0ce886bc-9ec5-4465-bc65-570ac2f9c370'),
  UUID('81abd3f6-1e42-4b87-9e25-ea0de1c09e88')])