In [1]:
import os

import fiona.transform
import fsspec
import xarray as xr
from affine import Affine

from utils.hls import catalog
from utils.hls import compute
from utils import get_logger

In [2]:
logger = get_logger('hls-point-sampling')
cluster_args = dict(
    workers=16,
    worker_threads=3,
    worker_memory=16,
    scheduler_threads=4,
    scheduler_memory=8
)
code_path = './utils'
checkpoint_path = 'checkpoints/sampling.txt'

In [3]:
# fill with your account key
os.environ['AZURE_ACCOUNT_KEY'] = ""

In [4]:
catalog_path = fsspec.get_mapper(
    f"az://fia/catalogs/fia_tiles.zarr",
    account_name="usfs",
    account_key=os.environ['AZURE_ACCOUNT_KEY']
)
pt_catalog = catalog.HLSCatalog.from_zarr(catalog_path)

In [5]:
hls_pts = pt_catalog.xr_ds.where(pt_catalog.xr_ds['year'] >= 2015, drop=True)
hls_pts = hls_pts.to_dataframe()
jobs = hls_pts.groupby(['tile', 'year'])
west_tiles = {'10SDH', '10SDJ', '10SEF', '10SEG', '10SEH', '10SEJ', '10SFE', '10SFF', '10SFG', '10SFH', '10SFJ', '10SGD', '10SGE', '10SGF', '10SGG', '10SGH', '10SGJ', '10TCK', '10TCL', '10TCM', '10TCN', '10TCP', '10TCQ', '10TCT', '10TDK', '10TDL', '10TDM', '10TDN', '10TDP', '10TDQ', '10TDR', '10TDS', '10TDT', '10TEK', '10TEL', '10TEM', '10TEN', '10TEP', '10TEQ', '10TER', '10TES', '10TET', '10TFK', '10TFL', '10TFM', '10TFN', '10TFP', '10TFQ', '10TFR', '10TFS', '10TFT', '10TGK', '10TGL', '10TGM', '10TGN', '10TGP', '10TGQ', '10TGR', '10TGS', '10TGT', '10UCU', '10UCV', '10UDU', '10UDV', '10UEU', '10UEV', '10UFU', '10UFV', '10UGU', '10UGV', '11SKA', '11SKB', '11SKC', '11SKD', '11SKT', '11SKU', '11SKV', '11SLA', '11SLB', '11SLC', '11SLD', '11SLT', '11SLU', '11SLV', '11SMA', '11SMB', '11SMC', '11SMD', '11SMR', '11SMS', '11SMT', '11SMU', '11SMV', '11SNA', '11SNB', '11SNC', '11SND', '11SNR', '11SNS', '11SNT', '11SNU', '11SNV', '11SPA', '11SPB', '11SPC', '11SPD', '11SPR', '11SPS', '11SPT', '11SPU', '11SPV', '11SQA', '11SQB', '11SQC', '11SQD', '11SQR', '11SQS', '11SQT', '11SQU', '11SQV', '11TKE', '11TKF', '11TKG', '11TLE', '11TLF', '11TLG', '11TLH', '11TLJ', '11TLK', '11TLL', '11TLM', '11TLN', '11TME', '11TMF', '11TMG', '11TMH', '11TMJ', '11TMK', '11TML', '11TMM', '11TMN', '11TNE', '11TNF', '11TNG', '11TNH', '11TNJ', '11TNK', '11TNL', '11TNM', '11TNN', '11TPE', '11TPF', '11TPG', '11TPH', '11TPJ', '11TPK', '11TPL', '11TPM', '11TPN', '11TQE', '11TQF', '11TQG', '11TQH', '11TQJ', '11TQK', '11TQL', '11TQM', '11TQN', '11ULP', '11ULQ', '11UMP', '11UMQ', '11UNP', '11UNQ', '11UPP', '11UPQ', '11UQP', '11UQQ', '12RTV', '12RUV', '12RVV', '12RWV', '12RXV', '12RYV', '12STA', '12STB', '12STC', '12STD', '12STE', '12STF', '12STG', '12STH', '12STJ', '12SUA', '12SUB', '12SUC', '12SUD', '12SUE', '12SUF', '12SUG', '12SUH', '12SUJ', '12SVA', '12SVB', '12SVC', '12SVD', '12SVE', '12SVF', '12SVG', '12SVH', '12SVJ', '12SWA', '12SWB', '12SWC', '12SWD', '12SWE', '12SWF', '12SWG', '12SWH', '12SWJ', '12SXA', '12SXB', '12SXC', '12SXD', '12SXE', '12SXF', '12SXG', '12SXH', '12SXJ', '12SYA', '12SYB', '12SYC', '12SYD', '12SYE', '12SYF', '12SYG', '12SYH', '12SYJ', '12TTK', '12TTL', '12TTM', '12TUK', '12TUL', '12TUM', '12TUN', '12TUP', '12TUQ', '12TUR', '12TUS', '12TUT', '12TVK', '12TVL', '12TVM', '12TVN', '12TVP', '12TVQ', '12TVR', '12TVS', '12TVT', '12TWK', '12TWL', '12TWM', '12TWN', '12TWP', '12TWQ', '12TWR', '12TWS', '12TWT', '12TXK', '12TXL', '12TXM', '12TXN', '12TXP', '12TXQ', '12TXR', '12TXS', '12TXT', '12TYK', '12TYL', '12TYM', '12TYN', '12TYP', '12TYQ', '12TYR', '12TYS', '12TYT', '12UUU', '12UUV', '12UVU', '12UVV', '12UWU', '12UWV', '12UXU', '12UXV', '12UYU', '12UYV', '13RBQ', '13RCQ', '13RDQ', '13REQ', '13RFQ', '13RGQ', '13SBA', '13SBB', '13SBC', '13SBD', '13SBR', '13SBS', '13SBT', '13SBU', '13SBV', '13SCA', '13SCB', '13SCC', '13SCD', '13SCR', '13SCS', '13SCT', '13SCU', '13SCV', '13SDA', '13SDB', '13SDC', '13SDD', '13SDR', '13SDS', '13SDT', '13SDU', '13SDV', '13SEA', '13SEB', '13SEC', '13SED', '13SER', '13SES', '13SET', '13SEU', '13SEV', '13SFA', '13SFB', '13SFC', '13SFD', '13SFR', '13SFS', '13SFT', '13SFU', '13SFV', '13SGA', '13SGB', '13SGC', '13SGD', '13SGR', '13SGS', '13SGT', '13SGU', '13SGV', '13TBE', '13TBF', '13TBG', '13TCE', '13TCF', '13TCG', '13TCH', '13TCJ', '13TCK', '13TCL', '13TCM', '13TCN', '13TDE', '13TDF', '13TDG', '13TDH', '13TDJ', '13TDK', '13TDL', '13TDM', '13TDN', '13TEE', '13TEF', '13TEG', '13TEH', '13TEJ', '13TEK', '13TEL', '13TEM', '13TEN', '13TFE', '13TFF', '13TFG', '13TFH', '13TFJ', '13TFK', '13TFL', '13TFM', '13TFN', '13TGE', '13TGF', '13TGG', '13TGH', '13TGJ', '13TGK', '13TGL', '13TGM', '13TGN', '13UCP', '13UCQ', '13UDP', '13UDQ', '13UEP', '13UEQ', '13UFP', '13UFQ', '13UGP', '13UGQ', '14RKV', '14SKA', '14SKB', '14SKC', '14SKD', '14SKE', '14SKF', '14SKG', '14SKH', '14SKJ', '14TKK', '14TKL', '14TKM'}
west_jobs = [job for job in jobs if job[0][0] in west_tiles]

In [6]:
from dask.distributed import get_worker

def chip(ds, lat, lon, chip_size, metadata):
    CRS = "EPSG:4326"
    tfm = Affine(*ds.attrs['transform'])
    ([x], [y]) = fiona.transform.transform(
        CRS, ds.attrs['crs'], [lon], [lat]
    )
    x_idx, y_idx = [round(coord) for coord in ~tfm * (x, y)]

    half_chip = int(chip_size/2)
    try:
        return ds[dict(x=range(x_idx-half_chip, x_idx+half_chip), y=range(y_idx-half_chip, y_idx+half_chip))]
    except IndexError:
        get_worker().log_event("message", {"type": "IndexError", **metadata})
        return None
        

In [7]:
import dask

def chip_tile_year(
    job_id, job_df, chip_size, bands, account_name, storage_container, account_key
):
    def sample_and_write(tl, row):
        sample = chip(
            tl,
            row['lat'],
            row['lon'],
            chip_size,
            metadata={'index': row['INDEX'], 'tile': row['tile'], 'year': row['year']}
        )
        if sample:
            output_zarr = fsspec.get_mapper(
                f"az://{storage_container}/hls-testing/chips/{int(row['INDEX'])}-{row['tile']}.zarr",
                account_name=account_name,
                account_key=account_key
            )
            sample.chunk({'month': 12, 'x': 32, 'y': 32}).to_zarr(output_zarr, mode='w')
    band_names = [band.name for band in bands]
    tile, year = job_id
    input_zarr = fsspec.get_mapper(
        f"az://{storage_container}/hls/{year}/{tile}.zarr",
        account_name=account_name,
        account_key=account_key
    )
    ds = xr.open_zarr(input_zarr)[band_names].persist()
    samples = []
    for _, row in job_df.iterrows():
        samples.append(sample_and_write(ds, row))
    return job_id
    

In [8]:
compute.process_jobs(
    jobs=west_jobs,
    job_fn=chip_tile_year,
    checkpoint_path=checkpoint_path,
    logger=logger,
    cluster_args=cluster_args,
    code_path=code_path,
    concurrency=6,  # run 2 jobs at once
    cluster_restart_freq=40,  # restart after 16 jobs
    # chip_tile_year kwargs
    bands=pt_catalog.xr_ds.attrs['bands'],
    chip_size=32,
    account_name="usfs",
    storage_container="fia",
    account_key=os.environ['AZURE_ACCOUNT_KEY'],
)

2021-01-29 02:53:18,431 [INFO] hls-point-sampling - Skipping checkpointed job ('10SDH', 2015.0)
2021-01-29 02:53:18,432 [INFO] hls-point-sampling - Skipping checkpointed job ('10SDH', 2016.0)
2021-01-29 02:53:18,432 [INFO] hls-point-sampling - Skipping checkpointed job ('10SDH', 2017.0)
2021-01-29 02:53:18,434 [INFO] hls-point-sampling - Skipping checkpointed job ('10SDH', 2018.0)
2021-01-29 02:53:18,434 [INFO] hls-point-sampling - Skipping checkpointed job ('10SDJ', 2015.0)
2021-01-29 02:53:18,435 [INFO] hls-point-sampling - Skipping checkpointed job ('10SDJ', 2016.0)
2021-01-29 02:53:18,435 [INFO] hls-point-sampling - Skipping checkpointed job ('10SDJ', 2017.0)
2021-01-29 02:53:18,436 [INFO] hls-point-sampling - Skipping checkpointed job ('10SDJ', 2018.0)
2021-01-29 02:53:18,436 [INFO] hls-point-sampling - Skipping checkpointed job ('10SEF', 2015.0)
2021-01-29 02:53:18,436 [INFO] hls-point-sampling - Skipping checkpointed job ('10SEF', 2016.0)
2021-01-29 02:53:18,437 [INFO] hls-point

KeyboardInterrupt: 