In [None]:
import numpy as np
import seaborn as sns
from IPython.display import JSON
from matplotlib import pyplot as plt
from odc.emit import (
    emit_load,
    mk_error_plot,
    open_zict_json,
    prep_s3_fs,
    review_gcp_sample,
    gxy,
)

In [None]:
from dask.distributed import Client
from dask.distributed import progress as dask_progress

if "client" not in locals():
    client = Client(n_workers=1, threads_per_worker=None)
client

In [None]:
fs = prep_s3_fs()

In [None]:
samples = open_zict_json("Data/emit-xyz-samples.zip", "r")
stacs = open_zict_json("Data/emit-stac.zip", "r")
all_ids = list(samples)
len(all_ids), all_ids[:5]

### Load XYZ Samples

In [None]:
granule = "EMIT_L2A_RFL_001_20230316T045133_2307503_005"  # AU
# granule = "EMIT_L2A_RFL_001_20230531T133036_2315109_002"  # Gibraltar
# granule = "EMIT_L2A_RFL_001_20230804T142809_2321610_001"  # South America
# granule = all_ids[20_106]

sample = samples[granule]
stac_doc = stacs[granule]
print(f"Granule {granule}, {sample['shape'][1]}x{sample['shape'][0]}px")
xx = emit_load(stac_doc, fs, chunks={"y": 32})

display(JSON(sample))
_ = review_gcp_sample(sample, figsize=(7, 5), s=90)

### GCP Error Analysis: whole image

In [None]:
from odc.emit.gcps import gcp_geobox, gcp_sample_error, rio_gcp_transformer, sub_sample

nside = 7
n_total = 100
_sample = sub_sample(sample, max(0, n_total - 4 * nside), nside=nside)
display(JSON(_sample))

gbox = gcp_geobox(_sample)
rio_gcp_tr = rio_gcp_transformer(_sample)

err_max = 2.1

rr1 = mk_error_plot(
    gcp_sample_error(sample, rio_gcp_tr), max_err_axis=err_max, msg="RIO"
)
rr2 = mk_error_plot(gcp_sample_error(sample, gbox), max_err_axis=err_max, msg="ODC")
_ = review_gcp_sample(_sample)

display(rr1.ee.std(0), rr2.ee.std(0))

if False:
    rr3 = mk_error_plot(gcp_sample_error(sample, gbox.approx), msg="APPROX")
    display(rr3.ee.std(0))

In [None]:
display(
    xx.elev.odc.geobox.footprint(4326).exterior | gbox.footprint(4326).exterior,
    xx.elev.odc.geobox.footprint(4326) - gbox.footprint(4326),
)

In [None]:
JSON(stac_doc)

### Review GLT_X/Y

In [None]:
_xx = client.persist(xx)
dask_progress(_xx)

In [None]:
%%time
yy0 = _xx.compute()

In [None]:
gx = yy0.glt_x.data.astype("int32")
gy = yy0.glt_y.data.astype("int32")
pix_sz = yy0.attrs["ortho_geotransform"][1]
mm = (gx != 0) * (gy != 0)

In [None]:
_lon, _lat = (
    np.pad(
        a.data,
        ((1, 0), (1, 0)),
        mode="constant",
        constant_values=float("nan"),
    )[gy, gx]
    for a in [yy0.lon, yy0.lat]
)

assert _lon.shape == yy0.glt_x.shape

In [None]:
ex = _lon - yy0.ortho_x.data.reshape(1, -1)
ey = _lat - yy0.ortho_y.data.reshape(-1, 1)

ee = np.sqrt(ex * ex + ey * ey) / pix_sz
px = ex[mm] / pix_sz
py = ey[mm] / pix_sz

In [None]:
plt.scatter(px[::10], py[::10], s=0.1)
plt.axis([-2, 2, -2, 2])
plt.vlines([0], -2, 2, "y")
plt.hlines([0], -2, 2, "y")
plt.vlines([np.mean(px)], -2, 2, "k")
plt.hlines([np.mean(py)], -2, 2, "k")

pass

In [None]:
sns.kdeplot(
    {"x": px[::10], "y": py[::10]},
    x="x",
    y="y",
    gridsize=30,
    levels=100,
    fill=True,
    cmap="viridis",
)
plt.axis([-2, 2, -2, 2])
plt.vlines([0], -2, 2, "y")
plt.hlines([0], -2, 2, "y")
plt.vlines([np.mean(px)], -2, 2, "k")
plt.hlines([np.mean(py)], -2, 2, "k")

In [None]:
np.nanmedian(ee), np.nanmean(ee), np.nanmean(ex) / pix_sz, np.nanmean(ey) / pix_sz

In [None]:
xx.ortho_geotransform

In [None]:
_px = px[::10] + 1 / 5
_py = py[::10] + 2 / 5

sns.kdeplot(
    {"x": _px, "y": _py},
    x="x",
    y="y",
    gridsize=30,
    levels=100,
    fill=True,
    cmap="viridis",
)
plt.axis([-2, 2, -2, 2])
plt.vlines([0], -2, 2, "y")
plt.hlines([0], -2, 2, "y")
plt.vlines([np.mean(_px)], -2, 2, "k")
plt.hlines([np.mean(_py)], -2, 2, "k")

----------------------------------------------

### GCP Error Analysis: sub-image

In [None]:
from odc.emit.gcps import to_pandas

xx = to_pandas(sample)
ny, nx = 256, 256  # sample["shape"][1]

_xx = xx[(xx.row < ny) * (xx.col < nx)]
cropped_sample = {
    "id": xx.attrs["id"] + "_cropped",
    "shape": (ny, nx),
    **{k: v.tolist() for k, v in _xx.items()},
}

_csample = sub_sample(cropped_sample, 100)
gbox = gcp_geobox(_csample)
rio_gcp_tr = rio_gcp_transformer(_csample)

err_max = 2.1

if True:
    rr1 = gcp_sample_error(cropped_sample, rio_gcp_tr)
    rr2 = gcp_sample_error(cropped_sample, gbox)

fig, ax = review_gcp_sample(_csample)

In [None]:
rr1.ee.std(0), rr2.ee.std(0)

### GDAL GCP with Z

In [None]:
from odc.emit.gcps import extract_rio_gcps
from rasterio.transform import GCPTransformer

gcps = extract_rio_gcps(_sample, skip_z=False)
display(len(gcps), gcps[:3])
tr = GCPTransformer(gcps)
x0, y0 = tr.xy(0, 0)
[(tr.rowcol(x0, y0 + 0.01, zs=z, op=lambda x: x), z) for z in [None, 0, 100, 1000]]

In [None]:
gcps = extract_rio_gcps(_sample, skip_z=True)
display(len(gcps), gcps[:3])
tr = GCPTransformer(gcps)
x0, y0 = tr.xy(0, 0)
[(tr.rowcol(x0, y0 + 0.01, zs=z, op=lambda x: x), z) for z in [None, 0, 100, 1000]]