In [None]:
import datetime
import os
from typing import Dict

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.display import clear_output

from base.meta_simulator import MetaSimulator
from base.simulator import SimulatorResult

In [None]:
def load_results_from_dir(msim: MetaSimulator, target_dir: str) -> Dict[int, Dict[str, Dict[str, SimulatorResult]]]:
    results: Dict[int, Dict[str, Dict[str, SimulatorResult]]] = dict()
    i = 1
    for res in msim.strides:
        results[res] = dict()
        for metric in msim.threshold_metrics:
            results[res][metric] = dict()
            for strategy in msim.strategies:
                path = os.path.join(target_dir, f'{res}s', metric, strategy)
                print(f'Loading {i} of {n_runs} from {path}')
                sim_result = SimulatorResult.load(path, lazy_loading=True)
                # manually set the start date of the simulation for lazy loading
                sim_result.continual_df = pd.DataFrame(index=[datetime.datetime.fromisoformat('2020-01-01T00:00')])
                results[res][metric][strategy] = sim_result
                clear_output(wait=True)
                i += 1
    print(f'Loaded {n_runs} results')
    return results


# collect all high-level stats in a single dataframe
def compute_stats(msim: MetaSimulator, out_dir: str, print=True):
    # define the estimated sizes for the network packages
    packet_overhead = 20 + 20 + 32  # TCP/IP overhead 20 + 20 bytes, 802.11 overhead 32 bytes
    measurement_size = 16  # 8 bytes timestamp, 8 bytes float64
    horizon_length = 24

    stats = []
    for res in msim.strides:
        for metric in msim.threshold_metrics:
            for strategy in msim.strategies:
                sim_result = results[res][metric][strategy]
                d_t = sim_result.estimate_data_transferred(packet_overhead,
                                                           measurement_size,
                                                           horizon_length,
                                                           )
                d = {
                    'id': f'{res}s_{metric}_{strategy}',
                    'stride': res,
                    'steps': sim_result.steps,
                    'metric': metric,
                    'strategy': strategy,
                    'n_v': sim_result.num_threshold_violations,
                    'n_u': sim_result.num_horizon_updates,
                    'n_d': sim_result.num_deployments,
                    'n_m': sim_result.message_exchanges,  # the sum of n_v, n_u, n_d
                    'data_d': sim_result.deployments['size'].sum(),
                    'data': int(d_t),
                }
                for feature, value in sim_result.mae.items():
                    d[f'MAE_{feature}'] = value
                for feature, value in sim_result.mse.items():
                    d[f'MSE_{feature}'] = value
                for feature, value in sim_result.rmse.items():
                    d[f'RMSE_{feature}'] = value
                stats.append(d)
    stats = pd.DataFrame(stats)
    stats.to_csv(os.path.join(out_dir, f'{sim_id}_results.csv'), index=False)

    if print:
        pd.set_option('display.max_columns', None)
        pd.set_option('display.max_rows', None)
        pd.set_option('display.precision', 4)
        print(stats)

    return stats


def save_latex_table(stats: pd.DataFrame, out_dir: str):
    # create a reduced version for latex export
    latex = stats.drop(columns=['id', 'steps', 'n_m'])
    latex = latex[latex.strategy == 'repeat']  # we do not need the baseline in the table
    latex.to_latex(
        buf=os.path.join(out_dir, f'{sim_id}_results.tex'),
        index=False,
        float_format='%.4f',
    )

In [None]:
import matplotlib
from matplotlib.backends.backend_pgf import FigureCanvasPgf

matplotlib.backend_bases.register_backend('pdf', FigureCanvasPgf)

sns.set_theme()
#plt.rcParams['font.family'] = 'Open Sans'
plt.rcParams.update({
    'pgf.texsystem': 'pdflatex',
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
    'pgf.preamble': '\\usepackage{lmodern}',
})

# \textwidth of latex document, cf. https://timodenk.com/blog/exporting-matplotlib-plots-to-latex/
textwidth = 5.78851


def create_plots(metrics: dict, suffixes: list, out_dir: str, groups=['strategy', 'stride'], show=True):
    for k, v in metrics.items():
        if 'strategy' in groups:
            fig, axes = plt.subplots(nrows=1, ncols=len(msim.strategies), figsize=(14, 3.5), sharey='all')
            for (strategy, ax) in zip(msim.strategies, axes.flatten()):
                ax = sns.barplot(ax=ax, data=stats[stats['strategy'] == strategy],
                                 x='stride', y=k, hue='metric'
                                 )
                ax.set_title(strategy.replace('_', '\_'))
                ax.set_xlabel('')
                ax.set_ylabel('')
                ax.get_legend().remove()

            handles, labels = ax.get_legend_handles_labels()
            fig.legend(handles,
                       ['TL\_high', 'TL\_medium', 'TL\_low'],
                       loc='upper left',
                       bbox_to_anchor=(0.001, 0.3),
                       title='Threshold Metric',
                       borderpad=0.5
                       )
            fig.supxlabel('Measurement Interval [s]', y=0.01)
            fig.supylabel(v, x=0.06, y=0.6)
            plt.subplots_adjust(left=0.14, bottom=0.17)
            for suffix in suffixes:
                fig.savefig(os.path.join(out_dir, f'{sim_id}_{k}_grouped_by_strategy.{suffix}'), bbox_inches='tight',
                            pad_inches=0.05)
            if show:
                plt.show()

        if len(msim.strides) <= 1 or 'stride' not in groups:
            continue
        fig, axes = plt.subplots(nrows=1, ncols=len(msim.strides), figsize=(15, 5), sharey='all')
        for (stride, ax) in zip(msim.strides, axes):
            ax = sns.barplot(ax=ax, data=stats[stats['stride'] == stride],
                             x='strategy', y=k, hue='metric'
                             )
            ax.set_title(stride)
            ax.set_xlabel('')
            ax.set_ylabel('')
            ax.get_legend().remove()
            plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
            #ax.tick_params(axis='x', labelrotation=45)
            # plt.xticks(ha='right')

        handles, labels = ax.get_legend_handles_labels()
        fig.legend(handles,
                   ['TL\_high', 'TL\_medium', 'TL\_low'],
                   loc='upper left',
                   bbox_to_anchor=(0.001, 0.3),
                   title='Threshold Metric',
                   borderpad=0.5
                   )
        fig.supxlabel('Continual Strategy')
        fig.supylabel(v, x=0.06)
        plt.subplots_adjust(left=0.14, bottom=0.25)
        for suffix in suffixes:
            fig.savefig(os.path.join(out_dir, f'{sim_id}_{k}_grouped_by_stride.{suffix}'), pad_inches=0.05)
        if show:
            plt.show()


In [None]:
datasets = [
    'vienna_2010_2019',
    'vienna_2019_2019',
    'vienna_201907_201912',
    'linz_2010_2019',
]

models = [
    'simple_dense',
    'simple_lstm',
    'conv_lstm',
]

suffixes = ['pdf']  # e.g., png, pgf, jpg
metrics = {
    'data': 'Transferred Data [B]',
    'n_m': 'Message Exchanges',
    # 'MAE_TL': 'MAE [°C]',
    # 'n_v': 'Threshold Violations',
    # 'n_u': 'Horizon Updates',
}

for model in models:
    for data in datasets:
        sim_id = f'zamg_{data}_{model}'
        sim_dir = f'zamg/simulations/{sim_id}'
        out_dir = 'zamg/analysis'
        os.makedirs(out_dir, exist_ok=True)

        msim = MetaSimulator.load(sim_dir)
        n_runs = len(msim.strides) * len(msim.threshold_metrics) * len(msim.strategies)
        print(f'Directory contains {n_runs} simulation runs')
        results = load_results_from_dir(msim, sim_dir)

        stats = compute_stats(msim, out_dir, print=False)
        save_latex_table(stats, out_dir)
        create_plots(metrics, suffixes, out_dir, groups=['strategy'], show=False)

The following cells are for a more detailed scenario of a single MetaSimulator run. Just load only a single MetaSimulator using the above cell so that the variables are instantiated properly.

In [None]:
# boxplots of the absolute prediction errors
col = 'TL'
sns.set_style('whitegrid')

for resolution in msim.strides:
    fig, axes = plt.subplots(nrows=1, ncols=len(msim.threshold_metrics), figsize=(21, 14), sharex='all', sharey='all')
    axes = [axes] if len(msim.threshold_metrics) == 1 else axes.flatten()
    for (metric, ax) in zip(msim.threshold_metrics, axes):
        prediction_errors = []
        for strategy in msim.strategies:
            result = results[resolution][metric][strategy]
            prediction_errors.append(
                pd.DataFrame(data=result.data.get_diff().loc[:, col].abs().values, columns=[strategy])
            )

        print(f'Prediction Errors for {metric}:')
        prediction_errors = pd.concat(prediction_errors, axis=1)
        print(prediction_errors.describe())

        ax = sns.violinplot(ax=ax, data=prediction_errors, orient='h', cut=0, scale='count', inner='quartile')
        # ax = sns.histplot(ax=ax, data=prediction_errors, bins=2000, element='step', fill=False, cumulative=True)
        ax.grid(visible=True)
        ax.set_title(metric)
        ax.set_xlabel('Absolute Error')
        ax.set_ylabel('Strategy')

    fig.suptitle(f'Prediction Error Distribution ({col}, {resolution}s)')
    for suffix in suffixes:
        fig.savefig(os.path.join(out_dir, f'{sim_id}_prediction_error_boxplot_{col}_{resolution}s.{suffix}'),
                    pad_inches=0.01
                    )
    plt.show()

In [None]:
# compare MAE for different strategies
from datetime import timedelta

resolutions = [3600]
metrics = ['TL_low']
strategies = ['repeat', 'static', 'retrain_short', 'retrain_long', 'transfer_short', 'transfer_long', 'fine_tune_short',
              'fine_tune_long']
col = 'TL'

fig: plt.Figure = plt.figure(figsize=(28, 7))
for resolution in resolutions:
    for metric in metrics:
        for strategy in strategies:
            result = results[resolution][metric][strategy]
            ax = result.data.get_diff().loc[:, col].abs().rolling(timedelta(days=30)).mean().plot(
                label=f'{resolution}s, {metric}, {strategy}'
            )

ax.set_ylim([0.25, 4])
ax.set_title('Rolling MAE')
file = f'rolling_MAE.pgf'
plt.legend()
fig.savefig(os.path.join(out_dir, file), pad_inches=0.01)
plt.show()

In [None]:
from datetime import timedelta

# resolutions = [3600]
# metrics = ['TL_low']
# strategies = ['repeat', 'static', 'retrain_short', 'retrain_long', 'transfer_short', 'transfer_long', 'fine_tune_short',
#               'fine_tune_long']

fig: plt.Figure = plt.figure(figsize=(21, 7))
for resolution in resolutions:
    for metric in metrics:
        for strategy in strategies:
            result = results[resolution][metric][strategy]
            tvs = result.threshold_violations.get_measurements()
            ax = tvs.iloc[:, 0].resample('D').count().rolling(timedelta(days=30)).mean().plot(
                label=f'{resolution}s, {metric}, {strategy}'
            )
ax.set_title(f'Rolling Threshold Violations per Day (30 days)')
plt.legend()
for suffix in suffixes:
    fig.savefig(os.path.join(out_dir, f'{sim_id}_rolling_threshold_violations.{suffix}'), pad_inches=0.01)
plt.show()

In [None]:
# resolutions = [3600]
# metrics = ['TL_low']
# strategies = ['repeat', 'static', 'retrain_short', 'retrain_long', 'transfer_short', 'transfer_long', 'fine_tune_short',
#               'fine_tune_long']

fig: plt.Figure = plt.figure(figsize=(21, 7))
for resolution in resolutions:
    for metric in metrics:
        for strategy in strategies:
            result = results[resolution][metric][strategy]
            tvs = result.threshold_violations.get_measurements()
            ax = tvs.loc[:, col].resample('Q').count().plot(label=f'{resolution}s, {metric}, {strategy}')
ax.set_title(f'Threshold Violations per Quarter')
plt.legend()
for suffix in suffixes:
    fig.savefig(os.path.join(out_dir, f'{sim_id}_threshold_violations_quarterly.{suffix}'), pad_inches=0.01)
plt.show()

In [None]:
# visualize predictions, measurements, threshold violations and horizon updates in a specific range
# %matplotlib qt
date_range = slice('2020-07-01', '2020-07-07')
column = 'TL'

# resolutions = [3600]
# metrics = ['TL_low']
# strategies = ['repeat', 'static', 'retrain_short', 'retrain_long', 'transfer_short', 'transfer_long', 'fine_tune_short',
#               'fine_tune_long']

for resolution in resolutions:
    for metric in metrics:
        for strategy in strategies:
            node = results[resolution][metric][strategy].node_manager.get_node('SIM')

            fig: plt.Figure = plt.figure(figsize=(12, 4))
            ax = plt.subplot(111)
            plt.plot(node.data.get_measurements().loc[date_range, column], label='Measurement')
            plt.plot(node.data.get_predictions().loc[date_range, column], label='Prediction')
            plt.plot(node.threshold_violations.get_predictions().loc[date_range, column], 'rx',
                     label='Threshold Violation'
                     )
            plt.plot(node.data.get_predictions().loc[node.horizon_updates.to_series()[date_range].index, column], 'go',
                     label='Horizon Update'
                     )
            plt.plot(node.data.get_predictions().loc[node.model_deployments.loc[date_range].index], 'yv',
                     label='Model Deployment'
                     )
            plt.legend()
            ax.set_yticks([])  # remove the tick labels
            ax.set_yticklabels([])
            # plt.title(f'{column} with {resolution}s, {metric}, {strategy}')

            file = f'vis_{resolution}_{metric}_{strategy}_{date_range.start}_{date_range.stop}.pdf'
            fig.savefig(os.path.join(out_dir, file), pad_inches=0.01)
            plt.show()

% matplotlib inline

In [None]:
def analyze_result(
        result: SimulatorResult,
        resolution: int,
        metric: str,
        strategy: str,
        window: int,
        col: str,
        dir: str):
    from datetime import timedelta

    # threshold violation distribution
    tvs = result.compute_time_until_threshold_violations()
    avg_duration = tvs.astype('timedelta64[m]').mean()
    fig: plt.Figure = plt.figure(figsize=(11, 7))
    ax = tvs.astype('timedelta64[h]').plot.hist(bins=range(1, 24))
    ax.set_title(
        f'Threshold Violations Distribution, avg. {int(avg_duration)} min ({resolution}s, {metric}, {strategy})'
    )
    ax.set_xlabel('Elapsed Time [h]')
    fig.savefig(os.path.join(dir, f'threshold_violations_histogram_{resolution}s_{metric}_{strategy}.png'))
    plt.show()

    # distribution grouped by month
    grouped = tvs.astype('timedelta64[h]').groupby(tvs.index.month)
    fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(16, 12), sharey='all', sharex='all')
    for (key, ax) in zip(grouped.groups.keys(), axes.flatten()):
        grouped.get_group(key).hist(ax=ax, bins=range(1, 24))
        ax.set_title(key)
    fig.suptitle(f'Threshold Violations Distribution ({resolution}s, {metric}, {strategy})')
    fig.savefig(os.path.join(dir, f'threshold_violations_grouped_{resolution}s_{metric}_{strategy}.png'))
    plt.show()

    # rolling MAE over column with window
    fig: plt.Figure = plt.figure(figsize=(21, 7))
    ax = result.data.get_diff().loc[:, col].abs().rolling(timedelta(days=window)).mean().plot()
    ax.set_title(f'Rolling MAE ({col}, {window} days)')
    fig.savefig(os.path.join(dir, f'rolling_MAE_{col}_{resolution}s_{metric}_{strategy}.png'))
    plt.show()

    tvs = result.threshold_violations.get_measurements()
    fig: plt.Figure = plt.figure(figsize=(21, 7))
    ax = tvs.iloc[:, 0].resample('D').count().rolling(timedelta(days=window)).mean().plot()
    ax.set_title(f'Rolling Threshold Violations per Day ({window} days)')
    fig.savefig(os.path.join(dir, f'rolling_threshold_violations_{resolution}s_{metric}_{strategy}.png'))
    plt.show()

In [None]:
resolution = 3600
metric = 'TL_low'
strategy = 'static'
analyze_result(results[resolution][metric][strategy], resolution, metric, strategy, 7, 'TL', out_dir)