In [None]:
from pathlib import Path

import dask.array as da
import numpy as np
import s3fs
from dask import delayed
from distributed import Client
from numcodecs import JSON, Zstd
from tqdm.auto import tqdm
from utils.txt import from_njson

from odc.emit import (
    SampleLoader,
    fetch_s3_creds,
    gen_sample,
    open_zict_json,
    review_gcp_sample,
    stac_store,
)

stacs_path = Path("/tmp/emit.zip")
stac_njson = Path("Data/emit-stac.njson.gz")

if not stacs_path.exists():
    print(f"Generating cache: {stacs_path}")
    stacs = open_zict_json("/tmp/emit.zip", "w")
    stacs.update((doc["id"], doc) for doc in tqdm(from_njson(stac_njson)))
    stacs.close()

In [None]:
creds = fetch_s3_creds()

s3_opts = dict(
    key=creds["accessKeyId"],
    secret=creds["secretAccessKey"],
    token=creds["sessionToken"],
    anon=False,
)

fs = s3fs.S3FileSystem(**s3_opts)
if isinstance(fs.protocol, list):
    # fix for `s3fs < 2023.10.0`
    fs.protocol = tuple(fs.protocol)

In [None]:
n = 5_000
pts = gen_sample(n, pad="auto")

In [None]:
granule = "EMIT_L2A_RFL_001_20230316T045133_2307503_005"  # AU
#granule = "EMIT_L2A_RFL_001_20230531T133036_2315109_002"  # Gibraltar
#granule = "EMIT_L2A_RFL_001_20230804T142809_2321610_001"  # South America

sampler = SampleLoader(pts, s3=fs)

sample = sampler.get(granule)
sample["shape"]

In [None]:
fig, axd = review_gcp_sample(sample)

In [None]:
client = Client(n_workers=16, threads_per_worker=1)
client

In [None]:
samples = da.from_array(np.asarray(list(stac_store()), dtype="O"), chunks=(10,))
samples

In [None]:
def _extractor(_ids, sampler):
    xx = [sampler.get(_id) for _id in _ids]
    return np.asarray(xx, dtype="O")

sampler = delayed(SampleLoader)(delayed(pts), s3=fs)
zzz = samples.map_blocks(_extractor, sampler).rechunk(50)
display(zzz)

In [None]:

rr = zzz.to_zarr(
    "/tmp/emit-xyz.samples.zarr",
    compute=False,
    overwrite=True,
    compressor=Zstd(),
    object_codec=JSON(),
)
rr

In [None]:
%%time
_ = client.gather(client.compute(rr))

```python

def _item_stream(zz, step=100):
    N = zz.shape[0]
    chunks = (slice(off, min(off+step, N)) for off in range(0, N, step))
    for roi in chunks:
        yield from zz[roi].copy()

store = ZipStore("/tmp/emit-xyz.zarr.zip")
zz = zarr.open_array(store=store)
samples = _item_stream(zz)

sink = open_zict_json("/tmp/emit-zxy-samples.zip", "w")

for sample in tqdm(samples, total=zz.shape[0]):
    sink[sample['id']] = sample  
```

```python
from tqdm.auto import tqdm

def patch_sample(sample):
    nx, ny = map(lambda a: max(a) + 1, [sample["col"], sample["row"]])
    return {"shape": (ny, nx), **sample}


dst = open_zict_json("/tmp/emit-xyz-samples.zip", "w")
for k, sample in tqdm(samples.items()):
    dst[k] = patch_sample(sample)

dst.close()
```
