# Satellite outgoing radiation

## Import packages

In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Variable to analyse
variable = "shortwave"
assert variable in ("longwave", "shortwave")

# Region for timeseries
region_slices = {
    "global": {"lat_slice": slice(-90, 90), "lon_slice": slice(0, 360)},
    "northern hemisphere": {"lat_slice": slice(0, 90), "lon_slice": slice(0, 360)},
    "southern hemisphere": {"lat_slice": slice(-90, 0), "lon_slice": slice(0, 360)},
}

## Define requests

In [None]:
collection_id = "satellite-earth-radiation-budget"
chunks = {"year": 1}
varnames = (
    {"olr", "toa_lw_all_mon", "LW_flux"}
    if variable == "longwave"
    else {"rsf", "toa_sw_all_mon", "SW_flux"}
)

request_dict = {
    "CERES": {
        "start": "2000-03",
        "stop": "2024-02",
        "product_family": "ceres_ebaf",
        "climate_data_record_type": "thematic_climate_data_record",
        "time_aggregation": "monthly_mean",
        "format": "zip",
        "origin": "nasa",
        "variable": f"outgoing_{variable}_radiation",
    },
    "Sentinel 3A": {
        "start": "2017-01",
        "stop": "2022-06",
        "format": "zip",
        "origin": "c3s",
        "sensor_on_satellite": "slstr_on_sentinel_3a",
        "variable": "all_variables",
        "product_family": "cci",
        "time_aggregation": "monthly_mean",
        "climate_data_record_type": "interim_climate_data_record",
    },
    "Sentinel 3B": {
        "start": "2018-10",
        "stop": "2022-06",
        "format": "zip",
        "origin": "c3s",
        "sensor_on_satellite": "slstr_on_sentinel_3b",
        "variable": "all_variables",
        "product_family": "cci",
        "time_aggregation": "monthly_mean",
        "climate_data_record_type": "interim_climate_data_record",
    },
    "Sentinel 3A_3B": {
        "start": "2018-10",
        "stop": "2022-06",
        "format": "zip",
        "origin": "c3s",
        "sensor_on_satellite": "slstr_on_sentinel_3a_3b",
        "variable": "all_variables",
        "product_family": "cci",
        "time_aggregation": "monthly_mean",
        "climate_data_record_type": "interim_climate_data_record",
    },
    "ESA ENVISAT": {
        "start": "2002-05",
        "stop": "2012-04",
        "format": "zip",
        "origin": "esa",
        "product_family": "cci",
        "climate_data_record_type": "thematic_climate_data_record",
        "time_aggregation": "monthly_mean",
        "sensor_on_satellite": "aatsr",
        "variable": "all_variables",
    },
    "ESA ERS2": {
        "start": "1995-06",
        "stop": "2002-12",
        "format": "zip",
        "origin": "esa",
        "product_family": "cci",
        "climate_data_record_type": "thematic_climate_data_record",
        "time_aggregation": "monthly_mean",
        "sensor_on_satellite": "atsr2",
        "variable": "all_variables",
    },
    "HIRS": {
        "start": "1979-01",
        "stop": "2024-04",
        "format": "zip",
        "origin": "noaa_ncei",
        "product_family": "hirs",
        "climate_data_record_type": "thematic_climate_data_record",
        "time_aggregation": "monthly_mean",
        "version": "2_7_reprocessed",
        "variable": f"outgoing_{variable}_radiation",
    },
    "CLARA_A3": {
        "start": "1979-01",
        "stop": "2020-12",
        "product_family": "clara_a3",
        "origin": "eumetsat",
        "variable": f"outgoing_{variable}_radiation",
        "climate_data_record_type": "thematic_climate_data_record",
        "time_aggregation": "monthly_mean",
    },
}

## Functions to cache

In [None]:
def preprocess_time(ds):
    if "time" in ds and "units" in ds["time"].attrs:
        # Could not decode
        ds = ds.squeeze("time", drop=True)
    if "time" not in ds:
        time = pd.to_datetime(ds.attrs["time_coverage_start"])
        ds = ds.assign_coords(time=time)
    return ds


def spatial_weighted_mean(ds, lon_slice, lat_slice):
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    return diagnostics.spatial_weighted_mean(ds)

## Download and transform

In [None]:
xarray_kwargs = {
    "drop_variables": ["time_bounds", "record_status"],
    "preprocess": preprocess_time,
}

da_maps = {}
da_timeseries = {}
for product, request in request_dict.items():
    if product == "HIRS" and variable == "shortwave":
        print(f"{product=} skip")
        continue
    print(f"{product=}")

    start = request.pop("start")
    stop = request.pop("stop")
    requests = download.update_request_date(
        request, start=start, stop=stop, stringify_dates=True
    )

    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_func=diagnostics.time_weighted_mean,
        chunks=chunks,
        transform_chunks=False,
        **xarray_kwargs,
    )
    (varname,) = set(ds.data_vars) & varnames
    da = ds[varname]
    da.attrs.update({"start": start, "stop": stop})
    da_maps[product] = da

    dataarrays = []
    for region, slices in region_slices.items():
        ds = download.download_and_transform(
            collection_id,
            requests,
            transform_func=spatial_weighted_mean,
            transform_func_kwargs=slices,
            chunks=chunks,
            **xarray_kwargs,
        )
        dataarrays.append(ds[varname].expand_dims(region=[region]))
    da = xr.concat(dataarrays, "region")
    da_timeseries[product] = da

## Plot spatial weighted timeseries

In [None]:
for region in region_slices:
    for product, da in da_timeseries.items():
        da.sel(region=region).plot(label=product)
    plt.legend(bbox_to_anchor=(1, 1))
    plt.title(region.title())
    plt.grid()
    plt.show()

## Plot time weighted means

In [None]:
for product, da in da_maps.items():
    plot.projected_map(
        da,
        projection=ccrs.Robinson(),
        levels=range(150, 315, 15) if variable == "longwave" else range(100, 210, 10),
        extend="both",
        cmap="RdBu_r",
    )
    plt.title(f"{product} ({da.attrs['start']}, {da.attrs['stop']})")
    plt.show()

## Plot spatial weighted zonal means

In [None]:
for product, da in da_maps.items():
    da = diagnostics.spatial_weighted_mean(da, dim="longitude")
    da.plot(y="latitude", label=product)
plt.legend(bbox_to_anchor=(1, 1))
plt.grid()