In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [2]:
import os
import wandb
from tqdm.notebook import tqdm

def download_model_residuals():
    api = wandb.Api()

    src_datsets = ['src', 'bpm10_src', 'bike11_src']
    # fitler out the pre-training runs
    filters = {
        "$and": [
            {"config.train_dataset": {"$nin": src_datsets}}
        ]
    }

    runs = api.runs("transfer-learning-tcn/tcn_custom_loss_residuals",
                    filters=filters,
                    order="-created_at")

    config_list = []
    runs_df = pd.DataFrame()
    residuals = dict()
    if os.path.exists("tcn_custom_loss_residuals.npz"):
        residuals.update(**np.load("tcn_custom_loss_residuals.npz", allow_pickle=True))
        runs_df = pd.read_csv("tcn_custom_loss_residuals.csv")
    for run in tqdm(runs):
        global_id = run.name + '-' + run.id
        if global_id in residuals:
            continue
        # .summary contains the output keys/values for metrics like accuracy.
        #  We call ._json_dict to omit large files
        summary = {k: v for k, v in run.summary._json_dict.items()
                   if not k.startswith('_') and
                   not type(v) is list and
                   not type(v) is dict}
        # .config contains the hyperparameters.
        #  We remove special values that start with _.
        conf = {k: v for k,v in run.config.items()
                if not k.startswith('_')}
        conf.update(summary)
        conf['timestamp'] = run.summary['_timestamp']
        # .name is the human-readable name of the run.
        conf['wandb_name'] = run.name
        conf['wandb_id'] = run.id
        conf['global_id'] = global_id
        conf['tags'] = ",".join(run.tags)
        config_list.append(conf)
        # download the residuals
        res = wandb.restore('test_test_residuals.npy', run_path="/".join(run.path), replace=True)
        residuals[global_id] = np.load(res.name)

    runs_df = pd.concat([runs_df, pd.DataFrame(config_list)], axis=0)
    return runs_df, residuals

In [3]:
def download_sota_tl_results():
    api = wandb.Api()

    filters = {
        "$and": [
            {"state": 'finished'},
            {"tags": {"$in": ["normalized_labels"]}}
        ]
    }

    def parse_runs(wandb_proj, previous_ids):
        runs = api.runs(wandb_proj,
                    filters=filters,
                    order="-created_at")
        config_list = []
        for run in tqdm(runs):

            global_id = run.name + '-' + run.id
            if global_id in previous_ids:
                continue
            # .summary contains the output keys/values for metrics like accuracy.
            #  We call ._json_dict to omit large files
            summary = {k: v for k, v in run.summary._json_dict.items()
                       if not k.startswith('_') and
                       not type(v) is list and
                       not type(v) is dict}
            # .config contains the hyperparameters.
            #  We remove special values that start with _.
            conf = {k: v for k,v in run.config.items()
                    if not k.startswith('_')}
            conf.update(summary)
            conf['timestamp'] = run.summary['_timestamp']
            # .name is the human-readable name of the run.
            conf['wandb_name'] = run.name
            conf['wandb_id'] = run.id
            conf['global_id'] = global_id
            conf['tags'] = ",".join(run.tags)

            # store evaluation performance under default column test/mse(mae)
            if 'test/mse' not in conf:
                conf['test/mse'] = conf['test/loss']

            config_list.append(conf)
        return pd.DataFrame(config_list)

    runs_df = pd.DataFrame()
    prev_ids = []
    if os.path.exists("sota_results.csv"):
        runs_df = pd.read_csv("sota_results.csv")
        prev_ids = list(runs_df['global_id'])

    runs_wann = parse_runs("transfer-learning-tcn/wann", prev_ids)
    runs_trb = parse_runs("transfer-learning-tcn/tradaboostr2-nn", prev_ids)
    runs_wann['approach'] = 'WANN'
    runs_trb['approach'] = 'TRB'
    runs_df = pd.concat([runs_df, runs_wann, runs_trb], axis=0)
    return runs_df

In [4]:
def approach_rename(inputs):
    name_convertion = dict(
        MI='MI',
        MSE='MSL',
        MAE='MAL',
        MEE='MEE',
        HSIC='HSIC',
        NAN='baseline'
    )
    name = "-".join([name_convertion[str(n).upper()] for n in inputs])
    return name

## Download the results

In [16]:
# TODO: run the download again to combine all the results
download = False
if download:
    residual_df, residuals = download_model_residuals()
    residual_df['approach'] = residual_df[['loss_src', 'loss_function']].apply(approach_rename, axis=1)
    # append sota result to the existing df
    sota_df = download_sota_tl_results()

    # save wandb downloaded data
    np.savez("tcn_custom_loss_residuals", **residuals)
    residual_df.to_csv("tcn_custom_loss_residuals.csv")
    sota_df.to_csv("sota_results.csv")
else:
    residual_df = pd.read_csv("tcn_custom_loss_residuals.csv")
    residuals = np.load("tcn_custom_loss_residuals.npz", allow_pickle=True)
    sota_df = pd.read_csv("sota_results.csv")
residual_df = pd.concat([residual_df, sota_df], axis=0).reset_index(drop=True)

In [17]:
sum(~residual_df['tags'].isna())

505

In [18]:
# filter linear pruning results to separate dataframe
linear_prune_df = residual_df[~residual_df['tags'].isna() & residual_df['tags'].str.contains("linear_pruning")]
residual_df = residual_df.drop(linear_prune_df.index, axis="index")
# append baselines to the linear_prune df
linear_prune_df = pd.concat([linear_prune_df, residual_df[residual_df['approach'].str.contains('baseline')]])
linear_prune_df = linear_prune_df.reset_index(drop=True)

In [19]:
residual_df['approach'].value_counts()

baseline-MEE     105
baseline-HSIC    105
baseline-MAL     105
baseline-MSL     105
MSL-MAL          100
MAL-MSL          100
HSIC-MSL         100
MSL-MEE          100
HSIC-HSIC        100
MEE-MSL          100
MEE-MEE          100
MAL-MAL          100
MSL-HSIC         100
MSL-MSL          100
WANN              55
TRB               50
Name: approach, dtype: int64

In [20]:
linear_prune_df['approach'].value_counts()

baseline-MEE     105
baseline-HSIC    105
baseline-MAL     105
baseline-MSL     105
MSL-HSIC         100
MSL-MAL          100
MSL-MSL          100
MSL-MEE          100
Name: approach, dtype: int64

# Table overview of the results

In [35]:
from scipy.stats import wilcoxon
import re
def table_sota_comparison(result_df, methods, main, metric='test/mse'):
    all_methods = methods + [main]
    res = result_df[['approach', 'train_dataset', metric, 'seed']][result_df['approach'].isin(all_methods)]

    # renaming datasets
    new_ds_names = dict(
        bike11_tar="BKT",
        bpm10_tar="PMT",
        tar1="NT1",
        tar2="NT2",
        tar3="NT3",
    )
    res['train_dataset'] = res['train_dataset'].apply(lambda n: new_ds_names[n])
    res = res.sort_values('train_dataset')

    datasets = res['train_dataset'].unique()
    res = res[res['seed'] < 20] # remove extra seed of the baselines
    res = pd.pivot_table(res, values=metric, index=['train_dataset', 'seed'], columns=['approach'])

    # res = res.pivot(index='seed', columns='approach', values=['test/mse', 'train_dataset'])
    grouped = res.groupby('train_dataset').agg(['mean', 'std'])
    # latex_table = grouped[all_methods].to_latex(float_format="%.2g")
    latex_table = grouped[all_methods].style.format('{:.2g}').to_latex()

    latex_table = re.sub("&([\s0-9\-.+e]+)&([\s0-9\-.+e]+)", r"&\1$\\pm$\2", latex_table)

    print(latex_table)

    # methods = ['ols', 'dsft', 'dsft_nl', 'dp']
    # main = 'dp'
    # methods.remove(main)
    p_val_adjustment = len(methods)

    def wilcox(m1, m2):
        stat, p = wilcoxon(m1, m2)
        mean1 = np.mean(m1)
        mean2 = np.mean(m2)
        if p >= 0.05 / p_val_adjustment:
        # if p >= 0.05:
            test_result = '?'
        elif mean1 < mean2:
            test_result = 'v'
        else:
            test_result = 'x'
        return test_result

    def test_all_methods(results, test):
        return [test(results[main], results[m]) for m in methods]
    res = res.reset_index()
    tresults = [[d] + test_all_methods(res[res['train_dataset']==d], wilcox) for d in datasets]

    columns = ['dataset'] + [f"{main} x {m}" for m in methods]
    print(pd.DataFrame(tresults, columns=columns))

## Comparing methods for pre-training

In [28]:
table_sota_comparison(residual_df , ['MSL-MSL', 'MAL-MSL', 'HSIC-MSL'], main='MEE-MSL')

\begin{tabular}{lrrrrrrrr}
approach & \multicolumn{2}{r}{MSL-MSL} & \multicolumn{2}{r}{MAL-MSL} & \multicolumn{2}{r}{HSIC-MSL} & \multicolumn{2}{r}{MEE-MSL} \\
 & mean & std & mean & std & mean & std & mean & std \\
train_dataset &  $\pm$  &  $\pm$  &  $\pm$  &  $\pm$  \\
BKT & 0.16 $\pm$ 0.015 & 0.16 $\pm$ 0.014 & 0.25 $\pm$ 0.044 & 0.16 $\pm$ 0.012 \\
NT1 & 0.47 $\pm$ 0.0083 & 0.48 $\pm$ 0.0097 & 0.48 $\pm$ 0.0099 & 0.47 $\pm$ 0.012 \\
NT2 & 0.59 $\pm$ 0.013 & 0.59 $\pm$ 0.012 & 0.61 $\pm$ 0.017 & 0.59 $\pm$ 0.012 \\
NT3 & 0.45 $\pm$ 0.0072 & 0.45 $\pm$ 0.0062 & 0.46 $\pm$ 0.0054 & 0.45 $\pm$ 0.0073 \\
PMT & 0.44 $\pm$ 0.042 & 0.45 $\pm$ 0.046 & 0.46 $\pm$ 0.041 & 0.45 $\pm$ 0.033 \\
\end{tabular}

  dataset MEE-MSL x MSL-MSL MEE-MSL x MAL-MSL MEE-MSL x HSIC-MSL
0     BKT                 ?                 ?                  v
1     NT1                 ?                 ?                  ?
2     NT2                 ?                 ?                  v
3     NT3                 ?   

## Comparing methods for fine-tuning

In [29]:
table_sota_comparison(residual_df , ['MSL-MSL', 'MSL-MAL', 'MSL-HSIC'], main='MSL-MEE')

\begin{tabular}{lrrrrrrrr}
approach & \multicolumn{2}{r}{MSL-MSL} & \multicolumn{2}{r}{MSL-MAL} & \multicolumn{2}{r}{MSL-HSIC} & \multicolumn{2}{r}{MSL-MEE} \\
 & mean & std & mean & std & mean & std & mean & std \\
train_dataset &  $\pm$  &  $\pm$  &  $\pm$  &  $\pm$  \\
BKT & 0.16 $\pm$ 0.015 & 0.17 $\pm$ 0.018 & 0.16 $\pm$ 0.018 & 0.16 $\pm$ 0.012 \\
NT1 & 0.47 $\pm$ 0.0083 & 0.48 $\pm$ 0.0081 & 0.48 $\pm$ 0.008 & 0.46 $\pm$ 0.0074 \\
NT2 & 0.59 $\pm$ 0.013 & 0.58 $\pm$ 0.012 & 0.58 $\pm$ 0.012 & 0.59 $\pm$ 0.0089 \\
NT3 & 0.45 $\pm$ 0.0072 & 0.45 $\pm$ 0.0065 & 0.45 $\pm$ 0.0067 & 0.44 $\pm$ 0.0058 \\
PMT & 0.44 $\pm$ 0.042 & 0.44 $\pm$ 0.038 & 0.47 $\pm$ 0.034 & 0.46 $\pm$ 0.024 \\
\end{tabular}

  dataset MSL-MEE x MSL-MSL MSL-MEE x MSL-MAL MSL-MEE x MSL-HSIC
0     BKT                 ?                 ?                  ?
1     NT1                 v                 v                  v
2     NT2                 ?                 x                  x
3     NT3                 v  

## Comparing methods for pre-training + fine-tuning
### MSE

In [36]:
table_sota_comparison(residual_df , ['baseline-MEE', 'MSL-MSL', 'MAL-MAL', 'HSIC-HSIC'], main='MEE-MEE', metric='test/mse')

\begin{tabular}{lrrrrrrrrrr}
approach & \multicolumn{2}{r}{baseline-MEE} & \multicolumn{2}{r}{MSL-MSL} & \multicolumn{2}{r}{MAL-MAL} & \multicolumn{2}{r}{HSIC-HSIC} & \multicolumn{2}{r}{MEE-MEE} \\
 & mean & std & mean & std & mean & std & mean & std & mean & std \\
train_dataset &  $\pm$  &  $\pm$  &  $\pm$  &  $\pm$  &  $\pm$  \\
BKT & 0.41 $\pm$ 0.028 & 0.16 $\pm$ 0.015 & 0.16 $\pm$ 0.014 & 0.26 $\pm$ 0.065 & 0.16 $\pm$ 0.014 \\
NT1 & 0.46 $\pm$ 0.007 & 0.47 $\pm$ 0.0083 & 0.48 $\pm$ 0.01 & 0.48 $\pm$ 0.0085 & 0.45 $\pm$ 0.0075 \\
NT2 & 0.59 $\pm$ 0.019 & 0.59 $\pm$ 0.013 & 0.57 $\pm$ 0.012 & 0.59 $\pm$ 0.013 & 0.59 $\pm$ 0.013 \\
NT3 & 0.45 $\pm$ 0.0057 & 0.45 $\pm$ 0.0072 & 0.45 $\pm$ 0.0068 & 0.45 $\pm$ 0.0081 & 0.44 $\pm$ 0.0053 \\
PMT & 0.52 $\pm$ 0.036 & 0.44 $\pm$ 0.042 & 0.45 $\pm$ 0.049 & 0.52 $\pm$ 0.039 & 0.47 $\pm$ 0.021 \\
\end{tabular}

  dataset MEE-MEE x baseline-MEE MEE-MEE x MSL-MSL MEE-MEE x MAL-MAL  \
0     BKT                      v                 ?            

### MAE

In [None]:
table_sota_comparison(residual_df , ['baseline-MEE', 'MSL-MSL', 'MAL-MAL', 'HSIC-HSIC'], main='MEE-MEE', metric='test/mae')

## Comparison on target-only (no TL)
#### MSE

In [None]:
table_sota_comparison(residual_df , ['baseline-MSL', 'baseline-MAL', 'baseline-HSIC'], main='baseline-MEE', metric='test/mse')

### MAE

In [None]:
table_sota_comparison(residual_df , ['baseline-MSL', 'baseline-MAL', 'baseline-HSIC'], main='baseline-MEE', metric='test/mae')

## Linear probing comparison

In [38]:
table_sota_comparison(linear_prune_df, ['MSL-MSL', 'MSL-MAL', 'MSL-HSIC'], main='MSL-MEE', metric='test/mse')

\begin{tabular}{lrrrrrrrr}
approach & \multicolumn{2}{r}{MSL-MSL} & \multicolumn{2}{r}{MSL-MAL} & \multicolumn{2}{r}{MSL-HSIC} & \multicolumn{2}{r}{MSL-MEE} \\
 & mean & std & mean & std & mean & std & mean & std \\
train_dataset &  $\pm$  &  $\pm$  &  $\pm$  &  $\pm$  \\
BKT & 0.25 $\pm$ 0.029 & 0.26 $\pm$ 0.04 & 0.26 $\pm$ 0.031 & 0.24 $\pm$ 0.031 \\
NT1 & 0.87 $\pm$ 0.016 & 0.9 $\pm$ 0.0092 & 0.89 $\pm$ 0.025 & 0.87 $\pm$ 0.018 \\
NT2 & 0.48 $\pm$ 0.024 & 0.48 $\pm$ 0.029 & 0.54 $\pm$ 0.047 & 0.49 $\pm$ 0.025 \\
NT3 & 0.72 $\pm$ 0.022 & 0.73 $\pm$ 0.018 & 0.73 $\pm$ 0.015 & 0.7 $\pm$ 0.0092 \\
PMT & 0.57 $\pm$ 0.039 & 0.56 $\pm$ 0.036 & 0.63 $\pm$ 0.038 & 0.54 $\pm$ 0.034 \\
\end{tabular}

  dataset MSL-MEE x MSL-MSL MSL-MEE x MSL-MAL MSL-MEE x MSL-HSIC
0     BKT                 v                 v                  v
1     NT1                 v                 v                  v
2     NT2                 ?                 ?                  v
3     NT3                 v           

## Comparison with SOTA TL methods

In [None]:
table_sota_comparison(residual_df[residual_df['seed'] <= 10], ['MSL-MSL', 'TRB', 'WANN'], main='MEE-MEE')

In [None]:
residual_df[residual_df['approach']=='baseline-MAL']['train_dataset'].value_counts()

In [None]:
residual_df['approach'].value_counts().head(20)

In [None]:
len(residual_df)

In [None]:
# wandb_data = residual_df

In [None]:
def merge_residuals_and_metadata(df, res_dict, filter=None):
    if filter is not None:
        df = df[filter]
    res_list = [res_dict[name] for name in df['wandb_name']]
    lengths = [len(x) for x in res_list]
    meta = np.repeat(np.array(df[['approach', 'train_dataset']]), lengths, axis=0)
    res_np = np.concatenate(res_list)
    merged = pd.DataFrame(np.concatenate([meta, res_np], axis=1), columns=['approach', 'train_dataset', 'residuals'])
    merged.sort_values(by=['approach'])
    return merged

In [None]:
def filter_approaches(dataframe, dataset='tar1', approaches=[]):
    filter = dataframe['train_dataset'].str.contains(dataset)
    if len(approaches) == 0:
        return filter
    ap_filter = [False for _ in range(len(filter))]
    for ap in approaches:
        ap_filter = ap_filter | dataframe['approach'].str.contains(ap)
    return filter & ap_filter

In [None]:
sns.set_context('talk')
sns.set_style('whitegrid')
def plot_boxplots(wandb_df, dataset_name):

    filter_df = wandb_df[wandb_df['train_dataset'].str.contains(dataset_name)]
    filtered_approaches = ['HSIC-MSL', 'MEE-MSL']
    filter_df = filter_df[~filter_df['approach'].isin(filtered_approaches)]
    palette = sns.color_palette("Paired", len(filter_df['approach'].unique()))
    fig, axes = plt.subplots(1, 2, figsize=(18, 6))
    sns.boxplot(
        x='approach', y='test/mse',
        data=filter_df.sort_values('approach'),
        ax=axes[0],
        showmeans=True,
        meanprops={"marker": "x", "markeredgecolor": "black"},
        palette=palette)
    sns.boxplot(
        x='approach', y='test/mae',
        data=filter_df.sort_values('approach'),
        ax=axes[1],
        showmeans=True,
        meanprops={"marker": "x", "markeredgecolor": "black"},
        palette=palette)
    for i, ax in enumerate(axes):
        plt.sca(ax)
        plt.xticks(rotation=45, ha='right')
        plt.xlabel("Train/finetune loss function")
        if i == 0:
            plt.ylabel("Test MSE")
        else:
            plt.ylabel("Test MAE")
    plt.savefig(f"plots/{dataset_name}_boxplots.pdf", bbox_inches='tight', format='pdf', dpi=300)

def plot_histograms(wandb_df, res_dict, dataset_name):
    fig, axes = plt.subplots(1, 2, figsize=(18,6))

    approach_order = ['baseline-MSL', 'baseline-MEE']
    filter = filter_approaches(wandb_df, dataset=dataset_name, approaches=approach_order)
    merged = merge_residuals_and_metadata(wandb_df, res_dict, filter=filter)
    sns.histplot(data=merged, x='residuals', hue='approach', hue_order=approach_order, ax=axes[0])

    approach_order = ['MSL-MSL', 'MSL-MEE']
    filter = filter_approaches(wandb_df, dataset=dataset_name, approaches=approach_order)
    merged = merge_residuals_and_metadata(wandb_df, res_dict, filter=filter)
    sns.histplot(data=merged, x='residuals', hue='approach', hue_order=approach_order, ax=axes[1])

    plt.savefig(f"plots/{dataset_name}_histograms.png", bbox_inches='tight')

# CMAPSS target dataset 1

In [None]:
nasa_tar1_filter = residual_df['train_dataset'].str.contains('tar1')
nasa_tar1_results = residual_df[nasa_tar1_filter]
# wandb_data = wandb_data[wandb_data['tcn2'] == False]
# wandb_data = wandb_data[wandb_data['State'] == 'finished']
# nasa_tar1_results = nasa_tar1_results[~nasa_tar1_results['group'].str.contains('src')] # remove source network results

In [None]:
plot_boxplots(residual_df, "tar1")

In [None]:
# nasa_tar1_results.groupby('approach')['test/mse'].agg(['mean', 'std', 'median'])

In [None]:
plot_histograms(residual_df, residuals, 'tar1')

----------------------------------------------------------
# CMAPSS target dataset 2

In [None]:
# plot_boxplots(residual_df[residual_df['approach'] != 'baseline-MEE'], "tar2")
plot_boxplots(residual_df, "tar2")

In [None]:
# nasa_tar2_results.groupby('approach')['test/mse'].agg(['mean', 'std', 'median'])

In [None]:
plot_histograms(residual_df, residuals, 'tar2')

------------------
## CMAPSS target dataset 3

In [None]:
plot_boxplots(residual_df, "tar3")

In [None]:
# nasa_tar3_results = residual_df[residual_df['train_dataset'].str.contains('tar3')]
# nasa_tar3_results.groupby('approach')['test/mse'].agg(['mean', 'std', 'median'])

In [None]:
plot_histograms(residual_df, residuals, 'tar3')

------------------
# BPM10 dataset

In [None]:
# bpm10_tar_results = residual_df[residual_df['train_dataset'].str.contains('bpm10')]
# bpm10_tar_results.groupby('approach')['test/mse'].agg(['mean', 'std', 'median'])

In [None]:
plot_boxplots(residual_df, "bpm10_tar")

In [None]:
plot_histograms(residual_df, residuals, 'bpm10_tar')

-----------------------------------
# Bike rental 2011

In [None]:
bike11_results = residual_df[residual_df['train_dataset'].str.contains('bike11_tar')]
bike11_results.groupby('approach')['test/mse'].agg(['mean', 'std', 'median'])

In [None]:
plot_boxplots(residual_df, "bike11_tar")

In [None]:
plot_histograms(residual_df, residuals, 'bike11_tar')