In [None]:
import os
from pathlib import Path

import eradiate
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from eradiate.spectral.response import make_gaussian
from eradiate.pipelines.logic import apply_spectral_response

eradiate.set_mode("ckd")
eradiate.plot.set_style()

In [None]:
# How to use the RPV data sample
rpv = xr.load_dataset("data/sample_rpv_desert.nc")

rpv_bsdf = eradiate.scenes.bsdfs.RPVBSDF(
    rho_0={
        "type": "interpolated",
        "construct": "from_dataarray",
        "dataarray": rpv.rho_0,
    },
    g={
        "type": "interpolated",
        "construct": "from_dataarray",
        "dataarray": rpv.theta,
    },
    k={
        "type": "interpolated",
        "construct": "from_dataarray",
        "dataarray": rpv.k,
    },
    rho_c={
        "type": "interpolated",
        "construct": "from_dataarray",
        "dataarray": rpv.rho_c,
    },
)
rpv_bsdf

In [None]:
# How to use the time series sample
from eradiate.units import to_quantity, unit_registry as ureg

timeseries = xr.load_dataset("data/sample_timeseries_prisma.nc")
# timeseries.toa_brf.isel(time=1).plot(x="band_wavelength")
display(timeseries)
timeseries.isel(band_name=0).solar_zenith_angle.plot()

In [None]:
# How to use the climatology sample
climatology = xr.load_dataset("data/sample_climatology.nc")
climatology.aod550.plot()

In [None]:
# Putting it together
def make_experiment(
    timestep: xr.Dataset, climatology: xr.Dataset, rpv: xr.Dataset, spectral_range
):
    # Configure surface reflectance
    rpv_bsdf = eradiate.scenes.bsdfs.RPVBSDF(
        rho_0={
            "type": "interpolated",
            "construct": "from_dataarray",
            "dataarray": rpv.rho_0,
        },
        g={
            "type": "interpolated",
            "construct": "from_dataarray",
            "dataarray": rpv.theta,
        },
        k={
            "type": "interpolated",
            "construct": "from_dataarray",
            "dataarray": rpv.k,
        },
        rho_c={
            "type": "interpolated",
            "construct": "from_dataarray",
            "dataarray": rpv.rho_c,
        },
    )

    # TODO: Add here atmosphere configuration (use climatology data array)
    atmosphere = None

    # Configure the experiment
    result = eradiate.experiments.AtmosphereExperiment(
        geometry="plane_parallel",
        atmosphere=atmosphere,
        surface=rpv_bsdf,
        illumination={
            "type": "directional",
            "zenith": to_quantity(timestep.solar_zenith_angle),
            "azimuth": to_quantity(timestep.solar_azimuth_angle),
        },
        measures={
            "type": "mdistant",
            "construct": "from_angles",
            "angles": [
                float(to_quantity(timestep.view_zenith_angle).m_as("deg")),
                float(to_quantity(timestep.view_azimuth_angle).m_as("deg")),
            ],
            "srf": {
                "type": "uniform",
                "wmin": spectral_range[0],
                "wmax": spectral_range[1],
            },
        },
    )
    return result


# Generate the experiment sequence using the time series, surface parameters and climatology
times = timeseries.time.values
spectral_range = [400.0, 2400.0] * ureg.nm
exps = {
    time: make_experiment(
        timestep=timeseries.sel(time=time).isel(band_name=0),
        climatology=climatology.sel(time=time, method="nearest", tolerance="1h"),
        rpv=rpv,
        spectral_range=spectral_range,
    )
    for time in times
}
# exps

In [None]:
# Run all experiments (with basic caching system)
os.makedirs("results", exist_ok=True)

fnames = {
    time: np.datetime_as_string(time, unit="s").replace(":", "-")
    for time in timeseries.time.values
}

results = {}
for time in times:
    fname = Path("results") / f"{fnames[time]}.nc"
    try:
        results[time] = xr.load_dataset(fname)
    except FileNotFoundError:
        ds = eradiate.run(exps[time])
        ds.to_netcdf(fname)
        results[time] = xr.load_dataset(fname)

results = xr.concat(
    [
        ds.squeeze(["x_index", "y_index"], drop=True)
        .squeeze(["sza", "saa"])
        .reset_coords(["sza", "saa"])
        .expand_dims({"time": 1})
        .assign_coords({"time": ("time", [time])})
        for time, ds in results.items()
    ],
    "time",
)
results

In [None]:
# Apply spectral response functions
srf_params = {
    band_name: (
        (
            float(timeseries["band_wavelength"].sel(band_name=band_name).values),
            float(timeseries["band_width"].sel(band_name=band_name).values),
        )
    )
    for band_name in timeseries.band_name.values
}
srfs = {
    band_name: (
        eradiate.spectral.BandSRF.from_dataarray(
            make_gaussian(band_wavelength, band_width).srf
        )
    )
    for band_name, (band_wavelength, band_width) in srf_params.items()
}

datasets = {}

for var in ["radiance", "irradiance"]:
    datasets[var] = []
    for band_name, srf in srfs.items():
        band_wavelength = srf_params[band_name][0]
        datasets[var].append(
            apply_spectral_response(results[var], srf)
            .expand_dims({"band_name": 1})
            .assign_coords(
                {
                    "band_name": ("band_name", [band_name]),
                    "band_wavelength": (
                        "band_name",
                        [band_wavelength],
                        {"long_name": "band wavelength", "units": "nm"},
                    ),
                }
            )
        )
    datasets[var] = xr.concat(datasets[var], "band_name").sortby(
        "time", "band_wavelength"
    )

datasets = xr.Dataset(datasets)
datasets["brf"] = datasets["radiance"] / datasets["irradiance"]
datasets

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
datasets["brf"].plot.step(
    hue="time", x="band_wavelength", where="mid", add_legend=False
)
ax.axis("off")
plt.savefig("../_images/hyperspectral_timeseries.png", bbox_inches="tight")
plt.show()
plt.close()

In [None]:
timeseries

In [None]:
# Artistic: SRF plot
import seaborn as sns

dfs = []

for band_name, srf in srfs.items():
    band_wavelength = float(timeseries.sel(band_name=band_name).band_wavelength)
    dfs.append(
        pd.DataFrame(
            data={
                "wavelength": srf.wavelengths.m,
                "values": srf.values.m,
                "band_name": band_name,
                "band_wavelength": band_wavelength,
            }
        )
    )

dfs = pd.concat(dfs).reset_index(drop=True)

n_bands = 50
band_wavelengths = timeseries.band_wavelength.values.tolist()[: n_bands + 1]
band_names = timeseries.band_name.values.tolist()[: n_bands + 1]
band_names = sorted(band_names, key=lambda x: band_wavelengths[band_names.index(x)])

xmin = srfs[band_names[0]].wavelengths.m.min()
xmax = srfs[band_names[-1]].wavelengths.m.max()
pal = sns.cubehelix_palette(len(band_wavelengths), rot=-0.25, light=0.7, as_cmap=True)
colors = [pal(i / len(band_wavelengths)) for i, _ in enumerate(band_wavelengths)]


fig, ax = plt.subplots(1, 1, figsize=(8, 4 * n_bands / 50), layout="constrained")

for i, (band_name, band_wavelength, color) in enumerate(
    zip(band_names, band_wavelengths, colors)
):
    srf = srfs[band_name]
    yoffset = (len(band_names) - i) / (n_bands / 5)

    x = srf.wavelengths.m
    y = yoffset + srf.values.m
    ax.plot(x, y, color="w", zorder=i)
    ax.axhline(yoffset, color=color, zorder=i)
    ax.fill_between(x, yoffset, y, color=color, zorder=i)

    ax.set_xlabel("Wavelength [nm]")
    ax.set_yticks([])
    ax.patch.set_alpha(0)

    for spine in ["top", "right", "left"]:
        ax.spines[spine].set_visible(False)

plt.savefig("../_images/prisma_bands_visible.png")
plt.show()
plt.close()