# Verification of PMCTRACK against ACCACIA and STARS cyclone datasets

In [None]:
# %matplotlib inline
import cartopy.crs as ccrs
import cf_units
from datetime import datetime
from IPython.display import clear_output
import iris
from ipywidgets import interact
import json
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText
import matplotlib.colors as mcol
from matplotlib.ticker import FuncFormatter
import matplotlib.patheffects as PathEffects
# import matplotlib.cm as mcm
import numpy as np
import pandas as pd
from pathlib import Path
import random
import xarray as xr
import string
from tqdm import tqdm_notebook as tqdm

from arke.cart import lcc_map, lcc_map_grid

from common_defs import winters, nyr, winter_dates, datasets, cat_kw, aliases, conf_key_typeset, runs_grid_formatter
from plot_utils import LCC_KW, trans, clev101, abs_plt_kw, iletters, cc
import mypaths
from stars_api import read_tracks_file

from octant.core import TrackRun, OctantTrack, HOUR
from octant.misc import SUBSETS
import octant
octant.__version__

In [None]:
plt.style.use('paperfig.mplstyle')

In [None]:
lsm = xr.open_dataarray(mypaths.era5_dir / 'lsm.nc').squeeze()
lon2d, lat2d = np.meshgrid(lsm.longitude, lsm.latitude)

In [None]:
subsets = SUBSETS[1:]  # only PMC and IC

### All PMCTRACK runs, split into two groups

In [None]:
datasets

In [None]:
RUNS = dict()
RUNS['vort_thresh'] = dict()
RUNS['diff_params'] = dict()
for dataset in datasets:
    _runs = []
    for run_id_start in [0, 100]:
        with (mypaths.trackresdir / f'{dataset}_{run_id_start:03d}_runs_grid.json').open('r') as f:
            for run_id, run_dict in enumerate(json.load(f), run_id_start):
                _runs.append( (run_id, run_dict) )

    RUNS['vort_thresh'][dataset] = []
    RUNS['diff_params'][dataset] = []
    for run_id, run_dict in _runs:
        if  len(run_dict) == 0 and run_id < 100:
            RUNS['diff_params'][dataset].append( (run_id, run_dict) )
        if 'zeta_max0' in run_dict or len(run_dict) == 0:
            if  run_id >= 100:
                if run_dict != {'zeta_max0': 0.0001, 'zeta_min0': 9e-05}:
                    RUNS['vort_thresh'][dataset].append( (run_id, run_dict) )
        else:
            RUNS['diff_params'][dataset].append( (run_id, run_dict) )

## Verification against ACCACIA

### Load ACCACIA PMC tracks

In [None]:
accacia_pmcs = pd.read_csv(mypaths.acctracks, delimiter='\t', names=['N', 'time', 'lon', 'lat'],
                           parse_dates=['time'], date_parser=lambda x: datetime.strptime(x, '%Y%m%d%H%M'))

In [None]:
acc_tracks = []
for i, df in accacia_pmcs.groupby('N'):
    ot = OctantTrack.from_df(df)
    if ot.lifetime_h >= 6:
        acc_tracks.append(ot)
n_ref = len(acc_tracks)
n_ref

In [None]:
acc_tracks[0]

## Load PMCTRACK data for the ACCACIA period

In [None]:
period = 'ACCACIA'
winter = 'accacia'

In [None]:
TRACKS = dict()
for run_group, dataset_dicts in tqdm(RUNS.items(), desc='run_group', leave=False):
    TRACKS[run_group] = dict()
    for dataset, run_dicts in tqdm(dataset_dicts.items(), desc='dataset', leave=False):
        TRACKS[run_group][dataset] = []
        for run_id, run_dict in tqdm(run_dicts, desc='run_id', leave=False):
            track_res_dir = mypaths.trackresdir / dataset / f'run{run_id:03d}' / winter
            TR = TrackRun(track_res_dir)
            TR.categorise(lsm=lsm, **cat_kw)
            TRACKS[run_group][dataset].append(TR)
clear_output()

In [None]:
method = 'bs2000'

In [None]:
MATCH_RATES = dict()
MATCH_PAIRS = dict()
for run_group, dataset_dicts in tqdm(RUNS.items(), desc='run_group', leave=False):
    results = dict()
    MATCH_PAIRS[run_group] = dict()
    for dataset, run_dicts in tqdm(dataset_dicts.items(), desc='dataset', leave=False):
        MATCH_PAIRS[run_group][dataset] = dict()
        perf_table = np.zeros((len(run_dicts), len(subsets)), dtype=np.int64)
        for irun, (run_id, run_dict) in tqdm(enumerate(run_dicts), desc='run_id', leave=False):
            for isub, subset in tqdm(enumerate(subsets), desc='subsets', leave=False):
                match_pairs = TRACKS[run_group][dataset][irun].match_tracks(acc_tracks,
                                                                            method=method, beta=50.,
                                                                            subset=subset)
                if len(run_dict) == 0:
                    # save id of matched vortices in the ctrl run ( {} )
                    MATCH_PAIRS[run_group][dataset][subset] = match_pairs
                perf_table[irun, isub] = len(match_pairs)
        if run_group == 'vort_thresh':
            index = [r[1].get('zeta_max0', 2e-4) for r in run_dicts]
        else:
            index = range(len(run_dicts))
        results[dataset] = pd.DataFrame(data=perf_table,
                                        columns=subsets,
                                        index=index)
    res_df = pd.merge(*results.values(), how='outer', left_index=True, right_index=True, suffixes=['_'+i for i in datasets])
    MATCH_RATES[run_group] = ((res_df / n_ref)
                              .reset_index(level=0)
                              .rename(columns=dict(index=run_group)))
clear_output()

## Vorticity thresholds

In [None]:
run_group = 'vort_thresh'

In [None]:
lcc_kw_zoom = LCC_KW.copy()
lcc_kw_zoom.update(extent=[-10, 35, 65, 81],
                   ticks=[5, 1])

In [None]:
fig = plt.figure(figsize=(18, 9))

width = 0.4

ax = fig.add_subplot(121)
for j, (dataset, color) in enumerate(zip(datasets, cc)):
    for i, (subset) in enumerate(subsets):
        res_df = MATCH_RATES[run_group]
        ax.bar(res_df.index.values + j*width, res_df[f'{subset}_{dataset}'],
               width=width,
               **color,
               alpha= 0.5 * (i+1),
               edgecolor='#000000',
               linewidth=0.75,
               label=f'{dataset}, {aliases[subset]}')
        
ax.legend(loc=1, ncol=2, fontsize='x-large')

for spine in ax.spines.values():
    if spine.spine_type in ['top', 'right']:
        spine.set_linewidth(0)
    else:
        spine.set_linewidth(1)

ax.set_ylim(0, 1)
ax.set_xticks(res_df.index.values + width/2)
ax.set_xticklabels((1e4 * res_df[run_group].values).round(decimals=1))

percent_formatter = FuncFormatter(lambda x, position: f'{x*100:3.0f}%')
ax.yaxis.set_major_formatter(percent_formatter)

ax.tick_params(labelsize='large')

# Annotate bars
fontcolor = '#222222'
for i, p in enumerate(sorted(ax.patches, key=lambda x: x.get_x())):
    if p.get_height() > 0:
        try:
            if np.allclose(p.get_x(), _p.get_x()):
                fontcolor = '#EEEEEE'
                # if abs(p.get_height() - _p.get_height()) < 0.03:
                an.set_y(_p.get_height() + 0.005)
            else:
                fontcolor = '#222222'
        except NameError:
            pass
        
        an = ax.annotate('{:d}'.format(int(p.get_height() * n_ref)),
                         (p.get_x()+0.2, p.get_height()-0.035),
                         ha='center', fontweight='bold', color=fontcolor,
                         size='x-large')
        _p = p
ax.set_xlabel(r'Vorticity threshold used for tracking ($\times10^{-4}$ $s^{-1}}$)', fontsize='x-large')
ax.set_ylabel('Percentage of cyclones detected', fontsize='x-large')
# ttl = ax.set_title(f'Number of matched vortices\nPMCTRACK vs {period}', loc='left', fontsize='xx-large');
ax.add_artist(AnchoredText('a', loc=2, prop=dict(size='large')));

#
# Show on the map what tracks are matched in CTRL runs
#
ax = lcc_map(fig, 122, **lcc_kw_zoom)

run_group = 'vort_thresh'
dataset = 'era5'

labels = ['Missed',
          f'Matched only to {aliases[subsets[0]]}',
          f'Matched only to {aliases[subsets[1]]}',
          f'Matched to {aliases[subsets[0]]} and {aliases[subsets[1]]}']
hs = [None] * 4
for idx, acc_df in enumerate(acc_tracks):
    acc_df.plot_track(ax=ax, color='C2', linestyle='--', alpha=0.75, **trans)
    hs[0], = ax.plot(acc_df.lon[0], acc_df.lat[0], color='C2', linestyle='--', alpha=0.75, **trans)
    if idx in [i[1] for i in MATCH_PAIRS[run_group][dataset][subsets[0]]]:
        acc_df.plot_track(ax=ax, color='#8DBAD7', linewidth=2, **trans)
        hs[1], = ax.plot(acc_df.lon[0], acc_df.lat[0], color='#8DBAD7', linewidth=2, **trans)
    if all([idx in [i[1] for i in MATCH_PAIRS[run_group][dataset][subset]] for subset in subsets]):
        acc_df.plot_track(ax=ax, color='C0', linewidth=2, **trans)
        hs[3], = ax.plot(acc_df.lon[0], acc_df.lat[0], color='C0', linewidth=2, **trans)
    elif idx in [i[1] for i in MATCH_PAIRS[run_group][dataset][subsets[1]]]:
        acc_df.plot_track(ax=ax, color='#00035b', linewidth=2, **trans)
        hs[2], = ax.plot(acc_df.lon[0], acc_df.lat[0], color='#00035b', linewidth=2, **trans)
        
hs, labels = [h for h, lab in zip(hs, labels) if h], [lab for h, lab in zip(hs, labels) if h]
    
ax.legend(hs, labels, loc=1, fontsize='x-large')
ax.add_artist(AnchoredText('b', loc=2, prop=dict(size='large')));
# ax.set_title(f'Number of matched vortices\nPMCTRACK vs {period}', loc='left', fontsize='xx-large');

fig.savefig(mypaths.plotdir / f'vs_{period.lower()}_vort_thresh_w_map_{method}')

## Other parameters

In [None]:
run_group = 'diff_params'

In [None]:
xlabels = [runs_grid_formatter(run_dict).strip()
           for i, (run_id, run_dict) in enumerate(RUNS[run_group]['era5'])]

In [None]:
fig = plt.figure(figsize=(9, 9))

width = 0.4

ax = fig.add_subplot(111)
for j, (dataset, color) in enumerate(zip(datasets, cc)):
    for i, (subset) in enumerate(subsets):
        res_df = MATCH_RATES[run_group]
        ax.bar(res_df.index.values + j*width, res_df[f'{subset}_{dataset}'],
               width=width,
               **color,
               alpha= 0.5 * (i+1),
               edgecolor='#000000',
               linewidth=0.75,
               label=f'{dataset}, {aliases[subset]}')
        
ax.legend(loc=1, ncol=2, fontsize='x-large')

for spine in ax.spines.values():
    if spine.spine_type in ['top', 'right']:
        spine.set_linewidth(0)
    else:
        spine.set_linewidth(1)

ax.set_ylim(0, 1)
ax.set_xticks(res_df.index.values + width/2)
ax.set_xticklabels(xlabels, rotation=90)

percent_formatter = FuncFormatter(lambda x, position: f'{x*100:3.0f}%')
ax.yaxis.set_major_formatter(percent_formatter)

ax.tick_params(labelsize='large')

# Annotate bars
fontcolor = '#222222'
for i, p in enumerate(sorted(ax.patches, key=lambda x: x.get_x())):
    if p.get_height() > 0:
        try:
            if np.allclose(p.get_x(), _p.get_x()):
                fontcolor = '#EEEEEE'
                # if abs(p.get_height() - _p.get_height()) < 0.03:
                an.set_y(_p.get_height() + 0.005)
            else:
                fontcolor = '#222222'
        except NameError:
            pass
        
        an = ax.annotate('{:d}'.format(int(p.get_height() * n_ref)),
                         (p.get_x()+0.2, p.get_height()-0.035),
                         ha='center', fontweight='bold', color=fontcolor,
                         size='x-large')
        _p = p
ax.set_xlabel(r'Tracking parameters', fontsize='x-large')
ax.set_ylabel('Percentage of cyclones detected', fontsize='x-large')
# ttl = ax.set_title(f'Number of matched vortices\nPMCTRACK vs {period}', loc='left', fontsize='xx-large');
# ax.add_artist(AnchoredText('a', loc=2, prop=dict(size='large')));
fig.savefig(mypaths.plotdir / f'vs_{period.lower()}_{run_group}_{method}')

## Verification against STARS

### Load STARS tracks

In [None]:
period = 'stars'

In [None]:
stars = read_tracks_file()

In [None]:
stars_winters = winters[:3]
stars_winters

In [None]:
stars_tracks = []
for winter in stars_winters:
    date_start, date_finish = winter_dates[winter]
    for i, df in stars[(stars['time'] >= date_start) & (stars['time'] <= date_finish)].groupby('N'):
        ot = OctantTrack.from_df(df)
        if ot.lifetime_h >= 6:
            stars_tracks.append(ot)
n_ref = len(stars_tracks)
n_ref

### All PMCTRACK runs, split into two groups

In [None]:
TRACKS = dict()
for run_group, dataset_dicts in tqdm(RUNS.items(), desc='run_group', leave=False):
    TRACKS[run_group] = dict()
    for dataset, run_dicts in tqdm(dataset_dicts.items(), desc='dataset', leave=False):
        TRACKS[run_group][dataset] = []
        for run_id, run_dict in tqdm(run_dicts, desc='run_id', leave=False):
            TR = TrackRun()
            for winter in tqdm(stars_winters, desc='winter', leave=False):
                track_res_dir = mypaths.trackresdir / dataset / f'run{run_id:03d}' / winter
                _TR = TrackRun(track_res_dir)
                _TR.categorise(lsm=lsm, **cat_kw)
                TR += _TR
            TRACKS[run_group][dataset].append(TR)
clear_output()

In [None]:
method = 'simple'

In [None]:
MATCH_RATES = dict()
MATCH_PAIRS = dict()
for run_group, dataset_dicts in tqdm(RUNS.items(), desc='run_group', leave=False):
    results = dict()
    MATCH_PAIRS[run_group] = dict()
    for dataset, run_dicts in tqdm(dataset_dicts.items(), desc='dataset', leave=False):
        MATCH_PAIRS[run_group][dataset] = dict()
        perf_table = np.zeros((len(run_dicts), len(subsets)), dtype=np.int64)
        for irun, (run_id, run_dict) in tqdm(enumerate(run_dicts), desc='run_id', leave=False):
            for isub, subset in tqdm(enumerate(subsets), desc='subsets', leave=False):
                match_pairs = TRACKS[run_group][dataset][irun].match_tracks(stars_tracks,
                                                                            method=method, beta=50.,
                                                                            subset=subset)
                if len(run_dict) == 0:
                    # save id of matched vortices in the ctrl run ( {} )
                    MATCH_PAIRS[run_group][dataset][subset] = match_pairs
                perf_table[irun, isub] = len(match_pairs)
        if run_group == 'vort_thresh':
            index = [r[1].get('zeta_max0', 2e-4) for r in run_dicts]
        else:
            index = range(len(run_dicts))
        results[dataset] = pd.DataFrame(data=perf_table,
                                        columns=subsets,
                                        index=index)
    res_df = pd.merge(*results.values(), how='outer', left_index=True, right_index=True, suffixes=['_'+i for i in datasets])
    MATCH_RATES[run_group] = ((res_df / n_ref)
                              .reset_index(level=0)
                              .rename(columns=dict(index=run_group)))
clear_output()

## Vorticity thresholds

In [None]:
run_group = 'vort_thresh'

In [None]:
lcc_kw_zoom = LCC_KW.copy()
lcc_kw_zoom.update(extent=[-15, 45, 63, 82],
                   ticks=[5, 2])

In [None]:
fig = plt.figure(figsize=(18, 7))

width = 0.4

ax = fig.add_subplot(121)
for j, (dataset, color) in enumerate(zip(datasets, cc)):
    for i, (subset) in enumerate(subsets):
        res_df = MATCH_RATES[run_group]
        ax.bar(res_df.index.values + j*width, res_df[f'{subset}_{dataset}'],
               width=width,
               **color,
               alpha= 0.5 * (i+1),
               edgecolor='#000000',
               linewidth=0.75,
               label=f'{dataset}, {aliases[subset]}')
        
ax.legend(loc=1, ncol=1, fontsize='x-large')

for spine in ax.spines.values():
    if spine.spine_type in ['top', 'right']:
        spine.set_linewidth(0)
    else:
        spine.set_linewidth(1)

ax.set_ylim(0, 1)
ax.set_xticks(res_df.index.values + width/2)
ax.set_xticklabels((1e4 * res_df[run_group].values).round(decimals=1))

percent_formatter = FuncFormatter(lambda x, position: f'{x*100:3.0f}%')
ax.yaxis.set_major_formatter(percent_formatter)

ax.tick_params(labelsize='large')

# Annotate bars
fontcolor = '#222222'
for i, p in enumerate(sorted(ax.patches, key=lambda x: x.get_x())):
    if p.get_height() > 0:
        try:
            if np.allclose(p.get_x(), _p.get_x()):
                fontcolor = '#EEEEEE'
                # if abs(p.get_height() - _p.get_height()) < 0.03:
                an.set_y(_p.get_height() + 0.005)
            else:
                fontcolor = '#222222'
        except NameError:
            pass
        
        an = ax.annotate('{:d}'.format(int(p.get_height() * n_ref)),
                         (p.get_x()+0.2, p.get_height()-0.035),
                         ha='center', fontweight='bold', color=fontcolor,
                         size='small')
        _p = p
ax.set_xlabel(r'Vorticity threshold used for tracking ($\times10^{-4}$ $s^{-1}}$)', fontsize='x-large')
ax.set_ylabel('Percentage of cyclones detected', fontsize='x-large')
# ttl = ax.set_title(f'Number of matched vortices\nPMCTRACK vs {period}', loc='left', fontsize='xx-large');
ax.add_artist(AnchoredText('a', loc=2, prop=dict(size='large')));

#
# Show on the map what tracks are matched in CTRL runs
#
ax = lcc_map(fig, 122, **lcc_kw_zoom)

run_group = 'vort_thresh'
dataset = 'era5'

labels = ['Missed',
          f'Matched only to {aliases[subsets[0]]}',
          f'Matched only to {aliases[subsets[1]]}',
          f'Matched to {aliases[subsets[0]]} and {aliases[subsets[1]]}']
hs = [None] * 4
for idx, df in enumerate(stars_tracks):
    df.plot_track(ax=ax, color='C2', linestyle='--', alpha=0.75, **trans)
    hs[0], = ax.plot(df.lon[0], df.lat[0], color='C2', linestyle='--', alpha=0.75, **trans)
    if idx in [i[1] for i in MATCH_PAIRS[run_group][dataset][subsets[0]]]:
        df.plot_track(ax=ax, color='#8DBAD7', linewidth=2, **trans)
        hs[1], = ax.plot(df.lon[0], df.lat[0], color='#8DBAD7', linewidth=2, **trans)
    if all([idx in [i[1] for i in MATCH_PAIRS[run_group][dataset][subset]] for subset in subsets]):
        df.plot_track(ax=ax, color='C0', linewidth=2, **trans)
        hs[3], = ax.plot(df.lon[0], df.lat[0], color='C0', linewidth=2, **trans)
    elif idx in [i[1] for i in MATCH_PAIRS[run_group][dataset][subsets[1]]]:
        df.plot_track(ax=ax, color='#00035b', linewidth=2, **trans)
        hs[2], = ax.plot(df.lon[0], df.lat[0], color='#00035b', linewidth=2, **trans)
        
hs, labels = [h for h, lab in zip(hs, labels) if h], [lab for h, lab in zip(hs, labels) if h]
    
ax.legend(hs, labels, loc=1, fontsize='x-large')
ax.add_artist(AnchoredText('b', loc=2, prop=dict(size='large')));
# ax.set_title(f'Number of matched vortices\nPMCTRACK vs {period}', loc='left', fontsize='xx-large');

fig.savefig(mypaths.plotdir / f'vs_{period.lower()}_vort_thresh_w_map_{method}')

## Other parameters

In [None]:
run_group = 'diff_params'

In [None]:
xlabels = [runs_grid_formatter(run_dict).strip()
           for i, (run_id, run_dict) in enumerate(RUNS[run_group]['era5'])]

In [None]:
fig = plt.figure(figsize=(9, 9))

width = 0.4

ax = fig.add_subplot(111)
for j, (dataset, color) in enumerate(zip(datasets, cc)):
    for i, (subset) in enumerate(subsets):
        res_df = MATCH_RATES[run_group]
        ax.bar(res_df.index.values + j*width, res_df[f'{subset}_{dataset}'],
               width=width,
               **color,
               alpha= 0.5 * (i+1),
               edgecolor='#000000',
               linewidth=0.75,
               label=f'{dataset}, {aliases[subset]}')
        
ax.legend(loc=1, ncol=2, fontsize='large') # , bbox_to_anchor=(1, 1.1))

for spine in ax.spines.values():
    if spine.spine_type in ['top', 'right']:
        spine.set_linewidth(0)
    else:
        spine.set_linewidth(1)

ax.set_ylim(0, 1)
ax.set_xticks(res_df.index.values + width/2)
ax.set_xticklabels(xlabels, rotation=90)

percent_formatter = FuncFormatter(lambda x, position: f'{x*100:3.0f}%')
ax.yaxis.set_major_formatter(percent_formatter)

ax.tick_params(labelsize='large')

# Annotate bars
fontcolor = '#222222'
for i, p in enumerate(sorted(ax.patches, key=lambda x: x.get_x())):
    if p.get_height() > 0:
        try:
            if np.allclose(p.get_x(), _p.get_x()):
                fontcolor = '#EEEEEE'
                # if abs(p.get_height() - _p.get_height()) < 0.03:
                an.set_y(_p.get_height() + 0.005)
            else:
                fontcolor = '#222222'
        except NameError:
            pass
        
        an = ax.annotate('{:d}'.format(int(p.get_height() * n_ref)),
                         (p.get_x()+0.2, p.get_height()-0.035),
                         ha='center', fontweight='bold', color=fontcolor,
                         size='small')
        _p = p
ax.set_xlabel(r'Tracking parameters', fontsize='x-large')
ax.set_ylabel('Percentage of cyclones detected', fontsize='x-large')
# ttl = ax.set_title(f'Number of matched vortices\nPMCTRACK vs {period}', loc='left', fontsize='xx-large');
# ax.add_artist(AnchoredText('a', loc=2, prop=dict(size='large')));
fig.savefig(mypaths.plotdir / f'vs_{period.lower()}_{run_group}_{method}')