# Cloud Type Determination by Altitude for Analyzing Cloud Radiative Effects

## Import libraries

In [None]:
import collections

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import regionmask
import xarray as xr
import xhistogram
from c3s_eqc_automatic_quality_control import download, utils

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Time
start = "2011-01"
stop = "2011-12"

# Latitude
latitude_slices = {
    "deep-tropics": [slice(-10, 10)],
    "sub-tropics": [slice(-30, -10), slice(10, 30)],
    "mid-latitudes": [slice(-60, -30), slice(30, 60)],
    "polar": [slice(-80, -60), slice(60, 80)],
}
assert all([isinstance(v, list) for k, v in latitude_slices.items()])

# Plot kwargs
plot_kwargs = {
    "CCI day": {"color": "dodgerblue", "ls": "--"},
    "CCI night": {"color": "darkblue", "ls": ":"},
    "CLARA-A2": {"color": "orange", "ls": "-"},
    "CLARA-A3": {"color": "red", "ls": "-"},
}

## Define request

In [None]:
collection_id = "satellite-cloud-properties"
request_clara = {
    "origin": "eumetsat",
    "variable": ["cloud_top_level"],
    "climate_data_record_type": "thematic_climate_data_record",
    "time_aggregation": "daily_mean",
}
requests = {
    "CCI": {
        "product_family": "cci",
        "time_aggregation": "daily_mean",
        "climate_data_record_type": "thematic_climate_data_record",
        "sensor_on_satellite": "aatsr_on_envisat",
        "variable": "all_variables",
        "origin": "esa",
    },
    "CLARA-A2": request_clara | {"product_family": "clara_a2"},
    "CLARA-A3": request_clara | {"product_family": "clara_a3"},
}
requests = {
    product: download.update_request_date(request, start, stop, stringify_dates=True)
    for product, request in requests.items()
}

chunks = {"year": 1, "month": 1}

## Functions to cache

In [None]:
def preprocess(ds):
    time = pd.to_datetime(ds.attrs["time_coverage_start"].replace("Z", ""))
    return ds.assign_coords(time=[time])


def get_variables(ds, variables):
    return ds[variables]


def mask_and_regionalise(ds, lat_slices):
    if (record_status := ds.get("record_status")) is not None:
        ds = ds.where((record_status == 0).compute(), drop=True)

    ds = xr.combine_by_coords(
        [
            utils.regionalise(ds, lon_slice=slice(-180, 180), lat_slice=lat_slice)
            for lat_slice in lat_slices
        ]
    )

    lsm = regionmask.defined_regions.natural_earth_v5_0_0.land_110
    lsm = lsm.mask_3D(ds).squeeze("region", drop=True)
    return xr.concat(
        [
            ds.where(lsm).expand_dims(mask=["land"]),
            ds.where(~lsm).expand_dims(mask=["ocean"]),
        ],
        "mask",
    )


def compute_cloud_top_distribution(ds, lat_slices):
    ds = mask_and_regionalise(ds, lat_slices)
    dataarrays = []
    for var, da in ds.data_vars.items():
        if not var.startswith("ctp"):
            continue
        n_bins = 50
        da = xhistogram.xarray.histogram(
            da,
            bins=np.linspace(0, 1_000, n_bins + 1),
            dim=["time", "latitude", "longitude"],
        )
        da = da.rename(var, **{f"{var}_bin": "bin"})
        da = da / da.sum("bin") * n_bins / 10
        da.attrs = {"long_name": "Normalized Density", "units": "hPa$^{-1}$"}
        dataarrays.append(da)
    da = xr.merge(dataarrays)
    da["bin"].attrs = {"long_name": "Cloud Top Pressure", "units": "hPa"}
    return da


def compute_cloud_top_density(ds):
    dataarrays = []
    for var, da in ds.data_vars.items():
        if not var.startswith("ctp"):
            continue
        da = xhistogram.xarray.histogram(da, density=True, bins=np.arange(0, 1_011))
        da = da.rename(var, **{f"{var}_bin": "bin"})
        da.attrs = {"long_name": "Density", "units": "hPa$^{-1}$"}
        dataarrays.append(da)
    da = xr.merge(dataarrays)
    da["bin"].attrs = {"long_name": "Cloud Top Pressure", "units": "hPa"}
    return da


def unc_distribution(ctp, unc):
    da = xr.concat(
        [
            unc.where(ctp > 700).expand_dims(clouds=["low"]),
            unc.where(ctp < 200).expand_dims(clouds=["high"]),
        ],
        "clouds",
    )
    n_bins = 50
    da = xhistogram.xarray.histogram(
        da, bins=np.linspace(0, 150, n_bins + 1), dim=["time", "latitude", "longitude"]
    )
    da = da.rename(unc.name, **{f"{unc.name}_bin": "bin"})
    da = da / da.sum("bin") * 100
    da["bin"].attrs = {"long_name": "Cloud top pressure uncertainty"}
    da.attrs = {"long_name": "Density", "units": "%"}
    return da


def compute_uncertainty_distribution(ds, lat_slices):
    ds = mask_and_regionalise(ds, lat_slices)
    dataarrays = []
    if "ctp" in ds:
        dataarrays.append(unc_distribution(ds["ctp"], ds["ctp_unc01"]))
    else:
        for prefix in ("day", "night"):
            for suffix in ("unc", "cor"):
                dataarrays.append(
                    unc_distribution(ds[f"ctp_{prefix}"], ds[f"ctp_{prefix}_{suffix}"])
                )
    assert dataarrays
    return xr.merge(dataarrays)

## Show CLARA-A3 quality

In [None]:
ds = download.download_and_transform(
    collection_id,
    requests["CLARA-A3"],
    preprocess=preprocess,
    chunks=chunks,
    transform_chunks=False,
    transform_func=get_variables,
    transform_func_kwargs={"variables": ["record_status"]},
)
da = ds["record_status"]
for status in [0, 1, 2]:
    print(f"Number of records with status {status}: ", (da == status).sum().values)

## Download and transform

In [None]:
distributions = collections.defaultdict(list)
for product, request in requests.items():
    for region, lat_slices in latitude_slices.items():
        print(f"{product=} {region=}")
        ds = download.download_and_transform(
            collection_id,
            request,
            preprocess=preprocess,
            chunks=chunks,
            transform_chunks=False,
            transform_func=compute_cloud_top_distribution,
            transform_func_kwargs={"lat_slices": lat_slices},
        )
        distributions[product].append(ds.expand_dims(region=[region]))
distributions = {
    key: xr.concat(values, "region") for key, values in distributions.items()
}

In [None]:
uncertanties = collections.defaultdict(list)
for product, request in requests.items():
    if product == "CLARA-A2":
        continue
    for region, lat_slices in latitude_slices.items():
        print(f"{product=} {region=}")
        ds = download.download_and_transform(
            collection_id,
            request,
            preprocess=preprocess,
            chunks=chunks,
            transform_chunks=False,
            transform_func=compute_uncertainty_distribution,
            transform_func_kwargs={"lat_slices": lat_slices},
        )
        uncertanties[product].append(ds.expand_dims(region=[region]))
uncertanties = {
    key: xr.concat(values, "region") for key, values in uncertanties.items()
}

## Cloud top pressure

In [None]:
tropopause_level = {
    "deep-tropics": {"level": 100, "width": 10},
    "sub-tropics": {"level": 100, "width": 10},
    "mid-latitudes": {"level": 200, "width": 50},
    "polar": {"level": 250, "width": 40},
}

for mask in ["ocean", "land"]:
    fig, axs = plt.subplots(1, len(latitude_slices), figsize=(14, 5), sharey=True)
    for ax, region in zip(axs, list(latitude_slices)):
        for product, ds in distributions.items():
            for suffix in ("", "day", "night"):
                name = f"ctp_{suffix}" if suffix else "ctp"
                if name not in ds:
                    continue
                da = ds[name].sel(mask=mask, region=region)
                label = f"{product} {suffix}" if suffix else product
                da.plot(
                    ax=ax, y="bin", yincrease=False, label=label, **plot_kwargs[label]
                )
        # Add level
        level = tropopause_level[region]["level"]
        width = tropopause_level[region]["width"]
        ax.fill_between(
            ax.get_xlim(),
            [level - width] * 2,
            [level + width] * 2,
            color="gray",
            alpha=0.5,
        )
        # Labels
        ax.set_title(region)
        ax.text(
            0.8,
            0.96,
            mask,
            transform=ax.transAxes,
            fontsize=12,
            weight="bold",
            ha="center",
            va="center",
        )
    axs[0].legend()

## Cloud top pressure uncertainty

In [None]:
ds = download.download_and_transform(
    collection_id,
    requests["CCI"],
    chunks=chunks,
    preprocess=preprocess,
    transform_chunks=False,
    transform_func=compute_cloud_top_density,
)
da = ds["ctp_day_cor"]
plt.bar(da["bin"], da, width=1)
plt.ylim(0.001, 1)
plt.yscale("log")
plt.xlabel("CCI cloud top pressure uncertainty during daytime [hPa]")
_ = plt.ylabel("Density [hPa$^{-1}$]")

In [None]:
for mask in ["ocean", "land"]:
    for clouds in ["high", "low"]:
        fig, axs = plt.subplots(1, len(latitude_slices), figsize=(14, 3.5), sharey=True)
        for ax, region in zip(axs, list(latitude_slices)):
            for product, ds in uncertanties.items():
                label = product
                match product:
                    case "CCI":
                        name = "ctp_night_unc"  # Alternative: "ctp_night_cor"
                        label += " night"
                    case "CLARA-A3":
                        name = "ctp_unc01"
                    case _:
                        raise NotImplementedError(f"{product=}")

                da = ds[name].sel(mask=mask, region=region, clouds=clouds)
                da.plot(ax=ax, x="bin", label=label, **plot_kwargs[label])
            # Labels
            ax.set_title(region)
            ax.text(
                0.82,
                0.88,
                mask,
                transform=ax.transAxes,
                fontsize=12,
                weight="bold",
                ha="center",
                va="center",
            )
            ax.text(
                0.78,
                0.96,
                clouds + " clouds",
                transform=ax.transAxes,
                fontsize=12,
                color="b",
                ha="center",
                va="center",
            )
            axs[0].legend()