# Satellite outgoing radiation

## Import packages

In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Variable to analyse
variable = "shortwave"
assert variable in ("longwave", "shortwave")

# Time range
year_ranges = {
    "CERES": (2000, 2022),
    "HIRS": (1979, 2022),
    "Sentinel 3A": (2017, 2021),
    "Sentinel 3B": (2018, 2021),
    "ESA ENVISAT": (2002, 2012),
    "ESA ERS2": (1995, 2003),
}

# Region for timeseries
region_slices = {
    "global": {"lat_slice": slice(-90, 90), "lon_slice": slice(0, 360)},
    "northern hemisphere": {"lat_slice": slice(0, 90), "lon_slice": slice(0, 360)},
    "southern hemisphere": {"lat_slice": slice(-90, 0), "lon_slice": slice(0, 360)},
}

## Define request

In [None]:
collection_id = "satellite-earth-radiation-budget"
chunks = {"year": 1}
varnames = (
    {"olr", "toa_lw_all_mon"} if variable == "longwave" else {"rsf", "toa_sw_all_mon"}
)

requests = {
    "CERES": {
        "format": "zip",
        "origin": "nasa_ceres_ebaf",
        "variable": f"outgoing_{variable}_radiation",
    },
    "HIRS": {
        "format": "zip",
        "origin": "noaa_ncei_hirs",
        "variable": f"outgoing_{variable}_radiation",
    },
    "Sentinel 3A": {
        "format": "zip",
        "origin": "c3s_cci",
        "sensor": "slstr_on_sentinel_3a",
        "variable": "all_available_variables",
    },
    "Sentinel 3B": {
        "format": "zip",
        "origin": "c3s_cci",
        "sensor": "slstr_on_sentinel_3b",
        "variable": "all_available_variables",
    },
    "ESA ENVISAT": {
        "format": "zip",
        "origin": "esa_cloud_cci",
        "sensor": "aatsr",
        "variable": "all_available_variables",
    },
    "ESA ERS2": {
        "format": "zip",
        "origin": "esa_cloud_cci",
        "sensor": "atsr2",
        "variable": "all_available_variables",
    },
}
assert not (invalid := set(year_ranges) - set(requests)), f"{invalid} are invalid"

## Define request

In [None]:
def convert_source_to_time(ds):
    if "source" in ds.dims:
        ds["time"] = (
            "source",
            pd.to_datetime(
                [source.split("_")[-2] + "15" for source in ds["source"].values],
                format="%Y%m%d",
            ),
        )
        ds["time"].attrs["standard_name"] = "time"
        ds = ds.swap_dims(source="time").drop("source")
    return ds


def spatial_weighted_mean(ds, lon_slice, lat_slice):
    ds = convert_source_to_time(ds)
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    return diagnostics.spatial_weighted_mean(ds)


def time_weighted_mean(ds):
    ds = convert_source_to_time(ds)
    return diagnostics.time_weighted_mean(ds)

## Download and transform

In [None]:
da_maps = {}
da_timeseries = {}
for product, (year_start, year_stop) in year_ranges.items():
    if product == "HIRS" and variable == "shortwave":
        print(f"{product=} skip")
        continue

    print(f"{product=}")
    request = requests[product]
    request["year"] = [f"{year}" for year in range(year_start, year_stop + 1)]
    request["month"] = [f"{month:02d}" for month in range(1, 13)]

    ds = download.download_and_transform(
        collection_id,
        request,
        transform_func=time_weighted_mean,
        chunks=chunks,
        transform_chunks=False,
    )
    (varname,) = set(ds.data_vars) & varnames
    da_maps[product] = ds[varname]

    dataarrays = []
    for region, slices in region_slices.items():
        ds = download.download_and_transform(
            collection_id,
            request,
            transform_func=spatial_weighted_mean,
            transform_func_kwargs=slices,
            chunks=chunks,
        )
        dataarrays.append(ds[varname].expand_dims(region=[region]))
    da_timeseries[product] = xr.concat(dataarrays, "region")

## Plot spatial weighted timeseries

In [None]:
for region in region_slices:
    for product, da in da_timeseries.items():
        da.sel(region=region).plot(label=product)
    plt.legend(bbox_to_anchor=(1, 1))
    plt.title(region.title())
    plt.grid()
    plt.show()

## Plot time weighted means

In [None]:
for product, da in da_maps.items():
    plot.projected_map(
        da,
        projection=ccrs.Robinson(),
        levels=range(150, 315, 15) if variable == "longwave" else range(100, 210, 10),
        extend="both",
        cmap="RdBu_r",
    )
    plt.title(product)
    plt.show()

## Plot spatial weighted zonal means

In [None]:
for product, da in da_maps.items():
    da = diagnostics.spatial_weighted_mean(da, dim="longitude")
    da.plot(y="latitude", label=product)
plt.legend(bbox_to_anchor=(1, 1))
plt.grid()