In [None]:
import sys
sys.path.append('../src')
import numpy as np
import os
from evaluation import load_dataset, load_results, compute_scores

In [None]:
def print_scores(folder_data, name_data, folder_result, filename_base='result_{}.h5', num_tests=5):
    images, labels = load_dataset(folder_data, name_data)
    scores_list = {key: [] for key in ['LL_M', 'LL_S', 'AMI', 'ARI', 'RMSE']}
    for model_id in range(num_tests):
        results = load_results(folder_result, filename_base.format(model_id))
        scores = compute_scores(results, labels, is_ordered=False)
        for key in scores_list:
            if key in ['LL_M', 'LL_S']:
                scores_list[key].append(scores[key])
            elif key in ['RMSE']:
                scores_list[key].append(scores[key] * 1e1)
            else:
                scores_list[key].append(scores[key] * 1e2)
    scores_mean = {key: np.mean(val) for key, val in scores_list.items()}
    scores_std = {key: np.std(val) for key, val in scores_list.items()}
    print('LL_M:{:7.5g} {:.0e}'.format(scores_mean['LL_M'], scores_std['LL_M']), end=' ' * 3)
    print('LL_S:{:7.5g} {:.0e}'.format(scores_mean['LL_S'], scores_std['LL_S']), end=' ' * 3)
    print('AMI(%):{:5.3g} {:.0e}'.format(scores_mean['AMI'], scores_std['AMI']), end=' ' * 3)
    print('ARI(%):{:5.3g} {:.0e}'.format(scores_mean['ARI'], scores_std['ARI']), end=' ' * 3)
    print('RMSE(e-1):{:4.2g} {:.0e}'.format(scores_mean['RMSE'], scores_std['RMSE']))
    return

folder_data = '../data'

In [None]:
# Multi-Shapes
for image_size in ['20', '28']:
    folder_result = 'shapes_{}x{}'.format(image_size, image_size)
    name_data = folder_result if image_size == '20' else folder_result + '_3'
    print(folder_result)
    print_scores(folder_data, name_data, folder_result)

In [None]:
# Multi-MNIST
for variants in ['20', '500', 'all']:
    folder_result = 'mnist_{}'.format(variants)
    name_data = folder_result
    print(folder_result)
    print_scores(folder_data, name_data, folder_result)

In [None]:
# Generalization
folder_result = 'shapes_28x28'
for num_objects in [2, 3, 4]:
    print('{} Objects'.format(num_objects))
    name_data = '{}_{}'.format(folder_result, num_objects)
    filename_base='general_{}_result_{{}}.h5'.format(num_objects) if num_objects != 3 else 'result_{}.h5'
    print_scores(folder_data, name_data, folder_result, filename_base)