# Comparisons of Models

This file contains functions for loading the metric logs and printing them in nice format.

In [1]:
import os
import errno
import warnings
import numpy as np

In [2]:
def load_metric(filepath, start_step=0):
    """Load a metric file and return a dictionary containing metric arrays."""
    data = np.load(filepath)
    return (data[()]['score_matrix_mean'], data[()]['score_pair_matrix_mean'])

In [3]:
track_names = ('Drums', 'Piano', 'Guitar', 'Bass', 'Ensemble', 'Reed',
               'Synth Lead', 'Synth Pad')
metric_names = ('empty bar rate', '# of pitch used', 'qualified note rate',
                'polyphonicity', 'note in scale', 'drum in pattern rate')

In [4]:
traing_data_eval_path = './data/training_data/lastfm_alternative_8b_phrase.npy'
data_dir = './data/eval_test/'

In [5]:
filenames = [
    ('pretrained (BS)', "lastfm_alternative_pretrain_g_proposed_d_proposed_"
                        "test_time_bernoulli.npy"),
    ('pretrained (HT)', "lastfm_alternative_pretrain_g_proposed_d_proposed_"
                        "test_time_round.npy"),
    ('proposed (+SBNs)', "lastfm_alternative_train_g_proposed_d_proposed_r_"
                         "proposed_bernoulli.npy"),
    ('proposed (+DBNs)', "lastfm_alternative_train_g_proposed_d_proposed_r_"
                         "proposed_round.npy"),
    ('joint (+SBNs)', "lastfm_alternative_train_joint_g_proposed_d_proposed_r_"
                      "proposed_bernoulli.npy"),
    ('joint (+DBNs)', "lastfm_alternative_train_joint_g_proposed_d_proposed_r_"
                      "proposed_round.npy"),
    ('end2end (+SBNs)', "lastfm_alternative_end2end_g_proposed_small_d_"
                        "proposed_r_proposed_bernoulli.npy"),
    ('end2end (+DBNs)', "lastfm_alternative_end2end_g_proposed_small_d_"
                        "proposed_r_proposed_round.npy"),
    ('ablated (BS)', "lastfm_alternative_pretrain_g_proposed_d_ablated_test_"
                     "time_bernoulli.npy"),
    ('ablated (HT)', "lastfm_alternative_pretrain_g_proposed_d_ablated_test_"
                     "time_round.npy"),
    ('baseline (BS)', "lastfm_alternative_pretrain_g_proposed_d_baseline_test_"
                      "time_bernoulli.npy"),
    ('baseline (HT)', "lastfm_alternative_pretrain_g_proposed_d_baseline_test_"
                      "time_round.npy"),
]

In [6]:
score_list = [('training data', load_metric(traing_data_eval_path))]
for filename in filenames:
    score_list.append((filename[0],
                       load_metric(os.path.join(data_dir, filename[1]))))

In [7]:
def print_metric_table(m):
    print('='*30 + "\n{:=^30}\n".format(' ' + metric_names[m] + ' ') + '='*30)
    for entry in score_list:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            mean = np.nanmean(entry[1][0][m, :])
        print("{:24} {:5.2f}".format(entry[0], mean))

## Intratrack Evaluation metrics

In [8]:
for m in range(7):
    print_metric_table(m)

training data             0.57
pretrained (BS)           0.53
pretrained (HT)           0.53
proposed (+SBNs)          0.32
proposed (+DBNs)          0.62
joint (+SBNs)             0.02
joint (+DBNs)             0.26
end2end (+SBNs)           0.49
end2end (+DBNs)           0.51
ablated (BS)              0.51
ablated (HT)              0.51
baseline (BS)             0.33
baseline (HT)             0.33
training data             4.66
pretrained (BS)           3.52
pretrained (HT)           3.45
proposed (+SBNs)         14.75
proposed (+DBNs)          9.73
joint (+SBNs)            15.60
joint (+DBNs)             7.56
end2end (+SBNs)           5.21
end2end (+DBNs)           5.93
ablated (BS)              3.13
ablated (HT)              3.10
baseline (BS)             3.92
baseline (HT)             3.73
==== qualified note rate =====
training data             0.88
pretrained (BS)           0.67
pretrained (HT)           0.72
proposed (+SBNs)          0.42
proposed (+DBNs)          0.78
joint (+

## Intertrack Evaluation metrics

In [9]:
print('='*30 + "\n{:=^30}\n".format(' tonal distance ') + '='*30)
for entry in score_list:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        mean = np.nanmean(entry[1][1])
    print("{:24} {:5.2f}".format(entry[0], mean))

training data             0.96
pretrained (BS)           0.98
pretrained (HT)           1.00
proposed (+SBNs)          0.99
proposed (+DBNs)          0.87
joint (+SBNs)             0.95
joint (+DBNs)             1.03
end2end (+SBNs)           1.41
end2end (+DBNs)           1.10
ablated (BS)              1.00
ablated (HT)              1.01
baseline (BS)             1.33
baseline (HT)             1.35
