In [None]:
import warnings

import dask
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import ptitprince as pt
import seaborn as sns
import xarray as xr
from scipy import stats

import tams

## Load dataset

In [None]:
ds0 = xr.open_mfdataset(
    "/glade/campaign/mmm/c3we/prein/SouthAmerica/MCS-Tracking/WY2011/WRF/tb_rainrate_2010-06*.nc",
    concat_dim='time', combine='nested', parallel=True)

ds = (
    ds0.rename({"rainrate": "pr", "tb": "ctt"})
    .rename_dims({"rlat": "y", "rlon": "x"})
    .isel(time=slice(1, None))
)
ds = ds.assign_coords(lon=(((ds.lon + 180) % 360) - 180))

# with dask.config.set(**{'array.slicing.split_large_chunks': False}):

ds

## Identify CEs

In [None]:
ce_lists, _ = tams.identify(ds.ctt, parallel=True, ctt_threshold=241, ctt_core_threshold=225)

In [None]:
ce_lists[5].plot()

In [None]:
c = ce_lists[5].iloc[0:1]
d = ce_lists[6].iloc[0:1]
tams.overlap(c, c.translate(xoff=0.1))

## Track

In [None]:
ce = tams.track(ce_lists, ds.time.values, overlap_threshold=0.5)
ce

In [None]:
tams.plot_tracked(ce.query("time <= '2010-06-01 07'"), size=20)

## Classify (MCS???)

### Add precip

In [None]:
dfs = []
for t in ds.time.values:
    df = tams.data_in_contours(ds.pr.sel(time=t), ce.query("time == @t"),
                               merge=True,
                              agg=("mean", "max", "count"),)
    dfs.append(df)
ce = pd.concat(dfs)
ce

In [None]:
is_mcs_list = []
for mcs_id, g in ce.groupby("mcs_id"):
    # Compute time
    t = g.time.unique()
    tmin = t.min()
    tmax = t.max()
    duration = pd.Timedelta(tmax - tmin)
    
    # TODO: collect reasons

    # Assuming instantaneous times, need 5 h for the 4 continuous h criteria
    n = 5
    if duration < pd.Timedelta(f"{n}H"):
        is_mcs_list.append(False)
        continue

    # Sum area over cloud elements
    area = g.groupby("itime")["area_km2"].sum()
    
    # 1. Assess area criterion
    # NOTE: rolling usage assuming data is hourly
    yes = (area >= 40_000).rolling(n, min_periods=0).count().gt(n).any()
    if not yes:
        is_mcs_list.append(False)
        continue

    # Agg min precip over cloud elements
    maxpr = g.groupby("itime")["pr_max"].max()
    
    # 2. Assess minimum pixel-peak precip criterion
    yes = (maxpr >= 10).rolling(n, min_periods=0).count().gt(n).any()
    if not yes:
        is_mcs_list.append(False)
        continue
    
    # Compute rainfall volume
    g["prvol"] = g.area_km2 * g.mean_pr  # per CE
    prvol = g.groupby("itime")["prvol"].sum()
    
    # 3. Assess minimum rainfall volume criterion
    yes = (prvol >= 20_000).sum() >= 1
    if not yes:
        is_mcs_list.append(False)
        continue
    
    # 4. Overshoot threshold currently met for all

In [None]:
ce = ce.merge(
    pd.Series(is_mcs_list, index=range(len(is_mcs_list)), name="is_mcs"),
    how="left", left_on="mcs_id", right_index=True,
)
ce