# Tracking options

Exploring the impact of some of the tracking options for the sample datasets.

In [None]:
import warnings
from string import ascii_lowercase

import cartopy.crs as ccrs
import matplotlib as mpl
import matplotlib.patheffects as path_effects
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr

import tams

plt.rcParams.update(
    {
        "axes.formatter.use_mathtext": True,
    }
)

%matplotlib inline

xr.set_options(display_expand_data=False)

## In sample satellite data

In [None]:
tb = tams.load_example_tb()
tb

### Identify cloud elements (CEs)

Our tracking options cases here are all in the tracking stage, so we only have to run {func}`tams.identify` once.

In [None]:
%%time

ces, _ = tams.identify(tb, parallel=False)
# TODO: parallel not working on WSL with Win dir, getting `TypeError: cannot pickle '_thread.lock' object`

### Run cases

In [None]:
%%time

cases = {}
time = tb.time.values.tolist()

proj = -15

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", message="forward `look` considered experimental")
    cases["default"] = tams.track(ces, time)
    cases["link=f, norm=min"] = tams.track(ces, time, look="f", overlap_norm="min")
    cases["link=f, norm=max"] = tams.track(ces, time, look="f", overlap_norm="max")
    cases[f"u_proj={proj}"] = tams.track(ces, time, u_projection=proj)
    cases[f"u_proj={proj}, link=f, norm=min"] = tams.track(ces, time, u_projection=proj, look="f", overlap_norm="min")
    cases[f"u_proj={proj}, link=f, norm=max"] = tams.track(ces, time, u_projection=proj, look="f", overlap_norm="max")

for key, ce in cases.items():
    cases[key] = tams.classify(ce)

### Plot

In [None]:
m, n = 2, 3
assert len(cases) <= m * n

# Restrict range
# lon_min, lat_min, lon_max, lat_max = -10, 8, -3, 15
# lon_min, lat_min, lon_max, lat_max = 17, 1, 27, 8
# lon_min, lat_min, lon_max, lat_max = 20, 1, 27, 4
lon_min, lat_min, lon_max, lat_max = 30, 10, 38, 16

cases_full = []
for key in cases:
    d = {"u_proj": 0, "link": "b", "norm": "a"}
    updates = {}
    if key != "default":
        for kv in key.split(", "):
            k, v = kv.split("=")
            updates[k] = v
        if "norm" not in updates and updates.get("link", "b").startswith("f"):
            d["norm"] = "b"
    d.update(updates)
    cases_full.append(d)
tbl = pd.DataFrame(cases_full)
tbl = tbl.assign(continued=False)

for i, (_, ce) in enumerate(cases.items()):
    n = ce.cx[lon_min:lon_max, lat_min:lat_max].mcs_id.nunique()
    assert n in {1, 2}
    if n == 1:
        tbl.loc[i, "continued"] = True

tams.plot_tracked(ce, label="none", size=3.5, add_colorbar=True, cbar_kwargs=dict(fraction=0.05, shrink=0.3))

ax = plt.gca()
patch = mpl.patches.Polygon(
    [(lon_min, lat_min), (lon_min, lat_max), (lon_max, lat_max), (lon_max, lat_min)],
    ec="orangered",
    fill=False,
    transform=ccrs.PlateCarree(),
)
ax.add_patch(patch)

# proj = ccrs.PlateCarree()

# fig, axs = plt.subplots(m, n, figsize=(8, 5), sharex=True, sharey=True, subplot_kw=dict(projection=proj), constrained_layout=True)

# for i, ((key, ce), ax) in enumerate(zip(cases.items(), axs.flat)):
#     if i == 0:
#         assert key == "default"
#         tams.plot_tracked(ce)
    
#     tams.plot_tracked(ce.cx[lon_min:lon_max, lat_min:lat_max], ax=ax)
#     ax.text(0.03, 0.98, key, ha="left", va="top", transform=ax.transAxes, size=8)
#     gl = ax.gridlines(draw_labels=True, color="none")
#     if not i < n:
#         gl.top_labels = False
#     if not i % n == 0:
#         gl.left_labels = False

with open("sat-tracking-options-table.txt", "w") as f:
    f.write(tbl.assign(continued=tbl.continued.map({True: "y", False: "n"})).style.to_latex())

plt.savefig("sat-tracking-options-ces-box.pdf", bbox_inches="tight", pad_inches=0.05, transparent=False)

## In sample MPAS lat/lon dataset

In [None]:
ds = tams.load_example_mpas().isel(time=slice(1, None)).rename_vars(tb="ctt", precip="pr")
ds

### Identify CEs using different CTT thresholds

In [None]:
%%time

ctt = ds.ctt

thresh = 235
thresh_core = 219

cases_ces = {}
for delta in [-15, 0, 5]:
    if delta == 0:
        key = rf"default ($T = {thresh}\,\mathrm{{K}}$, $T_\mathrm{{core}} = {thresh_core}\,\mathrm{{K}}$)"
    else:
        s_delta = str(delta) if delta <= 0 else f"+{delta}"
        key = rf"$(T, T_\mathrm{{core}}) {s_delta}$"
    cases_ces[key] = tams.identify(ctt, parallel=True,
        ctt_threshold=thresh + delta, ctt_core_threshold=thresh_core + delta)[0]

### Track for different $u$ projection values

In [None]:
%%time

cases = {}
time = ctt.time.values.tolist()

projs = [-5, -10, -12]
for proj in projs:
    for thresh_key, ces in cases_ces.items():
        key = rf"{thresh_key}, $u_\mathrm{{proj}} = {proj}\,\mathrm{{m}}\,\mathrm{{s}}^{{-1}}$"
        cases[key] = tams.track(ces, time, u_projection=proj)

In [None]:
cases.keys()

### Plot

In [None]:
m = len(projs)
n = 3
assert m * n == len(cases)

fig, axs = plt.subplots(m, n, figsize=(8.2, 7), sharex=True, sharey=True, constrained_layout=True)

for a, (key, ce), ax in zip(ascii_lowercase, cases.items(), axs.flat):
    n = ce.mcs_id.nunique()
    # Dist of max area
    x = ce.groupby("mcs_id").area_km2.max()
    x.plot.kde(ax=ax, label="CE")
    
    mcs = ce.groupby(["mcs_id", "itime"]).area_km2.sum()
    x2 = mcs.groupby("mcs_id").max()
    x2.plot.kde(ax=ax, label="MCS")
    
    # ax.text(0.99, 0.98, f"$N={n}$", size=9, ha="right", va="top", transform=ax.transAxes)
    ax.text(0.99, 0.03, f"$N={n}$", size=9, ha="right", va="bottom", transform=ax.transAxes)
    
    # ax.set_title(key, size=9)
    i = key.index(", $u_")
    l1, l2 = key[:i], key[i+2:]
    if ax.get_subplotspec().is_first_col():
        xt, yt = 0.02, 0.22
    else:
        xt, yt = 0.1, 0.97
    ax.text(xt, yt, l1, size=10, ha="left", va="top", transform=ax.transAxes)
    ax.text(xt, yt - 0.09, l2, size=10, ha="left", va="top", transform=ax.transAxes)
    
    ax.set_ylabel("")
    ax.text(0.02, 0.97, f"{a}", size=12, weight="bold", ha="left", va="top", transform=ax.transAxes)
    if a == "c":
        ax.legend(loc="center right")

fig.supylabel("Density", x=-0.03)
fig.supxlabel("Area [km$^2$]")

# ax.set_xlim(xmin=0)
ax.set_xlim(xmin=1000); ax.set_xscale("log")
ax.set_ylim(ymin=0)

fig.savefig("mpas-tracking-options-area-kde.pdf", bbox_inches="tight", pad_inches=0.05, transparent=False)

In [None]:
# 2-D dist, duration and area

m = len(projs)
n = 3
assert m * n == len(cases)

fig, axs = plt.subplots(m, n, figsize=(8.2, 7), sharex=True, sharey=True, constrained_layout=True)

for a, (key, ce), ax in zip(ascii_lowercase, cases.items(), axs.flat):
    n = ce.mcs_id.nunique()
    gb = ce.groupby("mcs_id")
    area = gb.area_km2.max()
    dur = ((gb.time.max() - gb.time.min()).dt.total_seconds() / 3600).rename("duration_h")
    data = pd.concat([area, dur], axis="columns")
    
    # sns.jointplot(x="area_km2", y="duration_h", kind="kde", joint_kws=dict(fill=True), ax=ax, data=data)
    sns.kdeplot(x="area_km2", y="duration_h", fill=True, common_norm=True, clip=(0, None), hue_norm=(0, 0.5), ax=ax, data=data)
    # TODO: ensure same levels?
    # sns.rugplot(x="area_km2", y="duration_h", ax=ax, data=data)
    
    #ax.text(0.99, 0.98, f"$N={n}$", size=9, ha="right", va="top", transform=ax.transAxes)
    ax.text(0.99, 0.03, f"$N={n}$", size=9, ha="right", va="bottom",
            path_effects=[path_effects.withStroke(linewidth=1.5, foreground="w")],
            transform=ax.transAxes)
   
    # ax.set_title(key, size=10)
    i = key.index(", $u_")
    l1, l2 = key[:i], key[i+2:]
    xt, yt = 0.09, 0.975
    kws = dict(size=10, ha="left", va="top", transform=ax.transAxes, path_effects=[path_effects.withStroke(linewidth=1.5, foreground="w")])
    ax.text(xt, yt, l1, **kws)
    ax.text(xt, yt - 0.09, l2, **kws)
    
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.text(0.02, 0.97, f"{a}", size=12, weight="bold", ha="left", va="top", transform=ax.transAxes,
        path_effects=[path_effects.withStroke(linewidth=1.5, foreground="w")],
    )

fig.supylabel("Duration [h]", x=-0.03)
fig.supxlabel("Area [km$^2$]")

ax.set_xlim(xmin=0, xmax=0.5e7)
ax.set_ylim(ymin=0, ymax=30)

fig.savefig("mpas-tracking-options-duration-area-2d-kde.pdf", bbox_inches="tight", pad_inches=0.05, transparent=False)