In [None]:
%reload_ext eradiate
%reload_ext eradiate.notebook.tutorials

import os
from itertools import zip_longest

import numpy as np
import matplotlib.pyplot as plt
import xarray as xr

import eradiate
from eradiate import unit_registry as ureg

eradiate.set_mode("mono")
eradiate.config.progress = 1

In [None]:
def exp(sza, geometry="spherical_shell"):
    origin = (
        [0, 0,eradiate.constants.EARTH_RADIUS.m_as(ureg.m)] * ureg.m 
        if geometry == "spherical_shell" 
        else [0, 0, 0] * ureg.m
    )
    vertical = np.array([0, 0, 1])
    
    return eradiate.experiments.AtmosphereExperiment(
        geometry=geometry,
        atmosphere={
            "type": "molecular",
            "construct": "ussa_1976",
        },
        illumination={
            "type": "directional",
            "zenith": sza,
            "azimuth": 180,
        },
        measures={
            "type": "perspective", 
            "origin": origin + [100, 0, 1] * ureg.km,
            "target": origin + [0, 0, 10] * ureg.km,
            "up": vertical,
            "film_resolution": (320, 160),
            "spectral_cfg": {"wavelengths": [440, 550, 660] * ureg.nm},
            "sampler": "ldsampler",
            "spp": 64,
        }
    )

In [None]:
result = []

for geometry in ["spherical_shell", "plane_parallel"]:
    for sza in np.array([0, 30, 60, 85, 90, 95]) * ureg.deg:
        print(f"{geometry = }, {sza = :~}")
        ds = eradiate.run(exp(sza, geometry), spp=256)
        result.append(ds.expand_dims({"geometry": [geometry]}).copy())

In [None]:
ds = xr.combine_by_coords(result, combine_attrs="drop")
ds

In [None]:
ncols = 3
nrows = 2

for geometry in ds.geometry.values:
    print(f"{geometry = }")
    fig, axs = plt.subplots(ncols, nrows, figsize=(8, 6.5))

    for sza, ax in zip_longest(ds.sza.values, axs.flat, fillvalue=None):
        if ax is None:
            raise RuntimeError(f"Not enough axes to continue plotting, stopping at {sza = }")

        ax.axis("off")

        if sza is None:
            continue

        ax.imshow(
            eradiate.xarray.interp.dataarray_to_rgb(
                ds.radiance.sel(geometry=geometry, sza=sza),
                channels=[("w", 660), ("w", 550), ("w", 440)],
            ),
        )
        sza *= ureg.deg
        ax.set_title(f"{sza = :~}")

    fname = f"plots/below_horizon_{geometry}.png"
    os.makedirs(os.path.dirname(fname), exist_ok=True)
    plt.savefig(fname, bbox_inches="tight")
    plt.show()
    plt.close()