In [None]:
import warnings
from pathlib import Path

# import dask
import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# import ptitprince as pt
# import seaborn as sns
import xarray as xr
# from scipy import stats

import tams
from tams.mosa import BASE_DIR, load_wrf

## Load dataset

In [None]:
all_files = list((BASE_DIR / "WY2011/WRF").glob("*.nc"))
assert len(all_files) == 365 * 24, "file for each hour"

In [None]:
# Let's just get a subset for faster testing
files = sorted((BASE_DIR / "WY2011/WRF").glob("tb_rainrate_2010-09-??_??:??.nc"))
assert len(files) == 30 * 24
files = files[:24 * 3]

print(files[0])
print("...")
print(files[-1])

In [None]:
%%time

ds = load_wrf(files)
ds

In [None]:
ds.ctt.isel(time=0).plot(x="lon", y="lat")

## Identify CEs

In [None]:
%%time

ce_lists, _ = tams.identify(ds.ctt, parallel=True, ctt_threshold=241, ctt_core_threshold=225)

In [None]:
ce_lists[5].plot()

In [None]:
c = ce_lists[5].iloc[0:1]
d = ce_lists[6].iloc[0:1]
tams.overlap(c, c.translate(xoff=0.1))

## Track

In [None]:
%%time

ce = tams.track(ce_lists, ds.time.values, overlap_threshold=0.5)
ce

In [None]:
tams.plot_tracked(ce.query("time <= '2010-09-01 07'"), size=12)

## Classify (MCS???)

### First add precip

In [None]:
%%time

# non-parallel method!

# dfs = []
# for t in ds.time.values:
#     df = tams.data_in_contours(ds.pr.sel(time=t), ce.query("time == @t"),
#                                merge=True,
#                               agg=("mean", "max", "count"),)
#     dfs.append(df)
# ce0 = pd.concat(dfs)
# ce0.head(3)

In [None]:
%%time

def _agg_one(ds_t, g):
    df = tams.data_in_contours(ds_t, g, merge=True, agg=("mean", "max", "count"))
    return df
    

dfs = joblib.Parallel(n_jobs=-2, verbose=10, batch_size="auto")(
    joblib.delayed(_agg_one)(ds.pr.sel(time=t), g)
    for t, g in ce.drop(columns=["mean_pr", "max_pr", "count_pr"], errors="ignore").groupby("time")
)

ce = pd.concat(dfs)
ce.head(3)

### Now apply criteria

In [None]:
%%time

n = ce.mcs_id.max() + 1
is_mcs_list = [None] * n
reason_list = [None] * n
for mcs_id, g in ce.groupby("mcs_id"):
    # Compute time
    t = g.time.unique()
    tmin = t.min()
    tmax = t.max()
    duration = pd.Timedelta(tmax - tmin)
    
    # TODO: collect reasons

    # Assuming instantaneous times, need 5 h for the 4 continuous h criteria
    # but for accumulated (during previous time step), 4 is fine(?)
    n = 4
    if duration < pd.Timedelta(f"{n}H"):
        is_mcs_list[mcs_id] = False
        reason_list[mcs_id] = "duration"
        continue

    # Sum area over cloud elements
    area = g.groupby("itime")["area_km2"].sum()
    
    # 1. Assess area criterion
    # NOTE: rolling usage assuming data is hourly
    yes = (area >= 40_000).rolling(n, min_periods=0).count().eq(n).any()
    if not yes:
        is_mcs_list[mcs_id] = False
        reason_list[mcs_id] = "area"
        continue

    # Agg min precip over cloud elements
    maxpr = g.groupby("itime")["max_pr"].max()
    
    # 2. Assess minimum pixel-peak precip criterion
    yes = (maxpr >= 10).rolling(n, min_periods=0).count().eq(n).any()
    if not yes:
        is_mcs_list[mcs_id] = False
        reason_list[mcs_id] = "peak precip"
        continue
    
    # Compute rainfall volume
    g["prvol"] = g.area_km2 * g.mean_pr  # per CE
    prvol = g.groupby("itime")["prvol"].sum()
    
    # 3. Assess minimum rainfall volume criterion
    yes = (prvol >= 20_000).sum() >= 1
    if not yes:
        is_mcs_list[mcs_id] = False
        reason_list[mcs_id] = "rainfall volume"
        continue
    
    # 4. Overshoot threshold currently met for all due to TAMS approach
    
    # If make it to here, is MCS
    is_mcs_list[mcs_id] = True
    reason_list[mcs_id] = ""
    
assert len(is_mcs_list) == len(reason_list) == ce.mcs_id.max() + 1
assert not any(x is None for x in is_mcs_list)
assert not any(x is None for x in reason_list)
assert (ce.query("is_mcs == True").not_is_mcs_reason == "").all()
assert (ce.query("is_mcs == False").not_is_mcs_reason != "").all()
    
ce = ce.drop(columns=["is_mcs"], errors="ignore").merge(
    pd.Series(is_mcs_list, index=range(len(is_mcs_list)), name="is_mcs"),
    how="left", left_on="mcs_id", right_index=True,
)
ce = ce.drop(columns=["not_is_mcs_reason"], errors="ignore").merge(
    pd.Series(reason_list, index=range(len(is_mcs_list)), name="not_is_mcs_reason"),
    how="left", left_on="mcs_id", right_index=True,
)
ce.head(3)

In [None]:
ce.is_mcs.value_counts()

In [None]:
ce.not_is_mcs_reason.value_counts()

In [None]:
g = ce.query("mcs_id == 5")
g

## Save

* We don't need 219 stuff
* We don't need all CE coordinates, just centroid (and maybe ellipse params?)
* 'itime' and 'dtime' can be re-derived

### Clean up table

In [None]:
cen = ce.geometry.to_crs("EPSG:32663").centroid.to_crs("EPSG:4326")

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=UserWarning, message="ellipse model failed for POLYGON")
    eccen = ce.geometry.apply(tams.calc_ellipse_eccen)

col_order = [
 'time',
 'lat',
 'lon',
 'area_km2',
 'eccen',
 'mcs_id',
 'mean_pr',
 'max_pr',
 'count_pr',
 'is_mcs',
 'not_is_mcs_reason',
]

ce_ = (
    ce
    .drop(
        columns=[
            "inds219", "area219_km2", "cs219",
            "itime", "dtime",
            "geometry",
        ]
    )
    .assign(eccen=eccen)
    .assign(lat=cen.y, lon=cen.x)
)

assert set(ce_.columns) == set(col_order)

df = pd.DataFrame(ce_)[col_order]
df

### Choose filepath

In [None]:
# Output directory
# out_dir = Path("./")
out_dir = Path("/glade/scratch/knocasio/SAAG")

# noclobber = False  # overwrite
noclobber = True  # don't

# Filename based on times
ta, tb = pd.Timestamp(ds.time.values[0]), pd.Timestamp(ds.time.values[-1])
tfmt = f"%Y%m%d%H"
ofn_stem_desired = f"ce_{ta:{tfmt}}-{tb:{tfmt}}"
ofp = out_dir / f"{ofn_stem_desired}.csv.gz"

# Adjust if already exists so don't overwrite
if noclobber:
    i = 0
    while ofp.is_file():
        i += 1
        ofp = out_dir / f"{ofn_stem_desired}_{i}.csv.gz"

print(ofp)

### Write out

In [None]:
df.to_csv(ofp)