# Satellite surface radiation budget

## Import packages

In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Variables to analyse
variables = ("srs", "sis", "sdl", "sol")

## Define requests

In [None]:
collection_id = "satellite-surface-radiation-budget"
chunks = {"year": 1}
request_dict = {
    "CLARA C3S": {
        "start": "2016-01",
        "stop": "2017-12",
        "climate_data_record_type": "thematic_climate_data_record",
        "format": "zip",
        "origin": "c3s",
        "product_family": "clara",
        "sensor_on_satellite": "avhrr_on_multiple_satellites",
        "time_aggregation": "monthly_mean",
        "variable": "surface_upwelling_shortwave_flux",
        "version": "v2_0_1",
    },
    "CLARA EUMETSAT": {
        "start": "2014-01",
        "stop": "2017-12",
        "climate_data_record_type": "thematic_climate_data_record",
        "format": "zip",
        "origin": "eumetsat",
        "product_family": "clara",
        "sensor_on_satellite": "avhrr_on_multiple_satellites",
        "time_aggregation": "monthly_mean",
        "variable": [
            "surface_downwelling_longwave_flux",
            "surface_downwelling_shortwave_flux",
            "surface_upwelling_longwave_flux",
        ],
        "version": "v2_0",
    },
    "ESA ENVISAT": {
        "start": "2007-01",
        "stop": "2010-12",
        "climate_data_record_type": "thematic_climate_data_record",
        "format": "zip",
        "origin": "esa",
        "product_family": "cci",
        "sensor_on_satellite": "aatsr_on_envisat",
        "time_aggregation": "monthly_mean",
        "variable": "all_variables",
    },
    "ESA ERS2": {
        "start": "2000-01",
        "stop": "2002-12",
        "climate_data_record_type": "thematic_climate_data_record",
        "format": "zip",
        "origin": "esa",
        "product_family": "cci",
        "sensor_on_satellite": "atsr2_on_ers2",
        "time_aggregation": "monthly_mean",
        "variable": "all_variables",
    },
    "Sentinel 3A": {
        "start": "2019-01",
        "stop": "2020-12",
        "climate_data_record_type": "interim_climate_data_record",
        "format": "zip",
        "origin": "c3s",
        "product_family": "cci",
        "sensor_on_satellite": "slstr_on_sentinel_3a_is_under_investigation",
        "time_aggregation": "monthly_mean",
        "variable": "all_variables",
    },
    "Sentinel 3B": {
        "start": "2019-01",
        "stop": "2020-12",
        "climate_data_record_type": "interim_climate_data_record",
        "format": "zip",
        "origin": "c3s",
        "product_family": "cci",
        "sensor_on_satellite": "slstr_on_sentinel_3b_is_under_investigation",
        "time_aggregation": "monthly_mean",
        "variable": "all_variables",
    },
}

## Functions to cache

In [None]:
def convert_source_to_time(ds):
    if "source" in ds.dims:
        ds["time"] = (
            "source",
            pd.to_datetime(
                [source.split("_")[-2] + "15" for source in ds["source"].values],
                format="%Y%m%d",
            ),
        )
        ds["time"].attrs["standard_name"] = "time"
        ds = ds.swap_dims(source="time").drop("source")
    return ds


def spatial_weighted_mean(ds, lon_slice, lat_slice):
    ds = convert_source_to_time(ds)
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    return diagnostics.spatial_weighted_mean(ds)


def time_weighted_mean(ds):
    ds = convert_source_to_time(ds)
    return diagnostics.time_weighted_mean(ds)

## Download and transform

In [None]:
ds_maps = {}
ds_timeseries = {}
for product, request in request_dict.items():
    print(f"{product}")
    start = request.pop("start")
    stop = request.pop("stop")
    requests = download.update_request_date(
        request, start=start, stop=stop, stringify_dates=True
    )

    # Maps
    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_func=time_weighted_mean,
        chunks=chunks,
        transform_chunks=False,
        drop_variables="time_bounds",
    )
    ds.attrs.update({"start": start, "stop": stop})
    ds_maps[product] = ds.rename({var: var.lower() for var in ds.data_vars})

    # Timeseries
    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_func=spatial_weighted_mean,
        transform_func_kwargs={
            "lon_slice": slice(-180, 180),
            "lat_slice": slice(-90, 90),
        },
        chunks=chunks,
        drop_variables="time_bounds",
    )
    ds_timeseries[product] = ds.rename({var: var.lower() for var in ds.data_vars})

## Plot spatial weighted timeseries

In [None]:
for var in variables:
    for i, (product, ds) in enumerate(ds_timeseries.items()):
        if var not in ds.data_vars:
            continue
        ds[var].plot(label=product, color=f"C{i}")
    plt.legend(bbox_to_anchor=(1, 1))
    plt.grid()
    plt.show()

## Plot time weighted means

In [None]:
for var in variables:
    vmin = min([ds[var].min().values for ds in ds_maps.values() if var in ds.data_vars])
    vmax = max([ds[var].max().values for ds in ds_maps.values() if var in ds.data_vars])
    for product, ds in ds_maps.items():
        if var not in ds.data_vars:
            continue
        plot.projected_map(
            ds[var],
            projection=ccrs.Robinson(),
            levels=11,
            vmin=np.floor(vmin),
            vmax=np.ceil(vmax),
            cmap="RdBu_r",
        )
        plt.title(f"{product} ({ds.attrs['start']}, {ds.attrs['stop']})")
        plt.show()

## Plot spatial weighted zonal means

In [None]:
for var in variables:
    for i, (product, ds) in enumerate(ds_maps.items()):
        if var not in ds.data_vars:
            continue
        da = diagnostics.spatial_weighted_mean(ds[var], dim="longitude")
        da.plot(y="latitude", color=f"C{i}", label=product)
    plt.legend(bbox_to_anchor=(1, 1))
    plt.grid()
    plt.show()