In [1]:
from collections import defaultdict
import glob
from datetime import datetime, timedelta
import os

import matplotlib.pyplot as plt # type: ignore
from matplotlib.colors import LogNorm # type: ignore
import numpy as np # type: ignore
import pandas as pd # type: ignore
import seaborn # type: ignore

In [2]:
def collect_sod_stats(idir, name, *args):
    frmt = '%Y%m%d.npz'
    var_names = [f'{name.upper()} accuracy', f'{name.upper()} precision', f'{name.upper()} recall', f'{name.upper()} f1-score']
    ice_var_names = ['precision', 'recall', 'fscore']

    dates = []
    sod_stats = defaultdict(list)
    conf_matrs = []
    sod_labels = {}

    ifiles = sorted(glob.glob(f'{idir}/stats_{name}*npz'))
    for ifile in ifiles:
        d = dict(np.load(ifile, allow_pickle=True))
        sod_labels = list(d['labels'])
        if 'none' in d:
            continue
        dates.append(datetime.strptime(ifile.split('_')[-1], frmt))
        for var_name in var_names:
            sod_stats[var_name].append(d[var_name].item())
        for ice_var_name in ice_var_names:
            ice_values = d[ice_var_name]
            for ice_name, ice_value in zip(sod_labels, ice_values):
                var_name = f'{ice_name} | {ice_var_name}'
                sod_stats[var_name].append(ice_value)
        conf_matrs.append(d['matrix'])
    if len(dates) == 0:
        return None, None, None, None
    conf_matrs = np.dstack([m for m in conf_matrs])
    conf_mat = np.nansum(conf_matrs, axis=2).T
    return dates, sod_stats, conf_mat, sod_labels

def collect_sic_stats(idir, *args):
    frmt = '%Y%m%d.npz'
    metric_names = ['Pearson', 'Bias', 'RMSE', 'DRMSE']
    var_names = []
    for metric_name in metric_names:
        for name1 in ['All', 'Avg']:
            var_names.append(f'SIC {name1} {metric_name}')
    dates = []
    sic_stats = defaultdict(list)
    ifiles = sorted(glob.glob(f'{idir}/stats_sic*npz'))
    for ifile in ifiles:
        d = dict(np.load(ifile, allow_pickle=True))
        if 'none' in d:
            continue
        dates.append(datetime.strptime(ifile.split('_')[-1], frmt))
        for var_name in var_names:
            sic_stats[var_name].append(d[var_name].item())
    return dates, sic_stats, None, None

def plot_confusion_matrix(idir, conf_mat, sod_labels, name):
    fig, axs = plt.subplots(1,1,figsize=(7,7))
    plt.colorbar(axs.imshow(conf_mat, norm=LogNorm()), ax=axs, shrink=0.7)
    axs.set_xticks(range(len(sod_labels)), sod_labels, rotation=90)
    axs.set_yticks(range(len(sod_labels)), sod_labels)
    axs.set_xlabel('Manual ice chart')
    axs.set_ylabel('Auto ice chart')
    plt.tight_layout()
    plt.savefig(f'../figures/{os.path.basename(idir)}_confusion_matrix_{name}.png', dpi=150, bbox_inches='tight', pad_inches=0.1)
    plt.close()

def collect_joined_stats(collect_func, ref_names, *args):
    dfs = []
    dates = []
    for i, ref_name in enumerate(ref_names):
        idir = f'../dmi_{ref_name}'
        sic_dates, sic_stats, _, _ = collect_func(idir, *args)
        df = pd.DataFrame(sic_stats)
        df['data_source'] = ref_name
        dfs.append(df)
        sic_dates = [sic_date + timedelta(hours=i) for sic_date in sic_dates]
        dates.append(sic_dates)
    dates =list(np.hstack(dates))
    df = pd.concat(dfs)
    df.index=dates
    return df

def joined_monthly_plots(df):
    df['month'] = df.index.strftime('%b')
    show_names = df.columns.drop(['month', 'data_source'])
    for var_name in show_names:
        filename = f'../figures/{var_name.replace(" ", "_")}.png'
        fig, axs = plt.subplots(1, 1, figsize=(10, 3))
        seaborn.boxplot(x='month', y=var_name, hue='data_source', data=df, showfliers=False, width=0.5, )
        plt.savefig(filename, dpi=100, bbox_inches='tight', pad_inches=0.1)
        plt.close()

In [3]:
for name in ['sod', 'flz']:
    idirs  = ['../dmi_nic', '../dmi_dmi']
    for idir in idirs:
        sod_dates, sod_stats, conf_mat, sod_labels = collect_sod_stats(idir, name)
        if sod_dates is not None:
            plot_confusion_matrix(idir, conf_mat, sod_labels, name)

In [4]:
df = collect_joined_stats(collect_sic_stats, ['nic', 'dmi', 'osisaf'])
joined_monthly_plots(df)

In [5]:
df = collect_joined_stats(collect_sod_stats, ['nic', 'dmi'], 'sod')
joined_monthly_plots(df)

In [6]:
df = collect_joined_stats(collect_sod_stats, ['dmi'], 'flz')
joined_monthly_plots(df)