# CelebA: result table

## Setup

In [3]:
import sys
import numpy as np
from itertools import product


## Load results

In [4]:
n_runs = 3

In [5]:
algo_filenames = []
algo_additionals = []
algo_dims = []
algo_labels = []
algo_dividers = []

def add_algo(filename, add, label, dim=512):
    algo_filenames.append(filename)
    algo_additionals.append(add)
    algo_dims.append(dim)
    algo_labels.append(label)
    
    
def add_divider():
    algo_dividers.append(len(algo_filenames))

add_algo("flow", "april", r"\af{}", 512)
add_algo("pie", "april", r"\pie{}", 512)
add_algo("pie", "may", r"\pie{} ($n = 128$)", 128)
add_algo("mf", "april", r"\mf{}", 512)
add_algo("mf", "may", r"\mf{} ($n = 128$)", 128)
add_algo("emf", "april", r"\mfe{}", 512)
add_algo("emf", "may", r"\mfe{} ($n = 128$)", 128)

n_algos = len(algo_filenames)


In [6]:
def load(tag, shape, numpyfy=True, chains=1, result_dir="../data/results"):
    all_results = []
    
    for algo_filename, algo_add, algo_dim in zip(algo_filenames, algo_additionals, algo_dims):
        algo_results = []
            
        for run in range(n_runs):
            run_str = "" if run == 0 else "_run{}".format(run)

            try:
                this_result = np.load(
                    f"{result_dir}/{algo_filename}_{algo_dim}_celeba_{algo_add}{run_str}_{tag}.npy"
                )
                if (not numpyfy) or (shape is None) or np.product(this_result.shape) == np.product(shape):
                    algo_results.append(this_result.reshape(shape))
                else:
                    algo_results.append(np.nan*np.ones(shape))

            except FileNotFoundError as e:
                if "reco_test" in tag: print(e)
                if shape is None:
                    algo_results.append(None)
                else:
                    algo_results.append(np.nan*np.ones(shape))
            
        all_results.append(algo_results)
    
    if numpyfy:
        all_results = np.array(all_results, dtype=np.float)
        all_results = all_results.reshape([all_results.shape[0], n_runs] + list(shape))
        
    return all_results


model_test_reco_xs = load("model_x_reco_test", (100, 3, 64, 64))
model_test_reco_errors = load("model_reco_error_test", (100,))
model_gen_fids = load("samples_fid", (1,)).squeeze()
model_gen_fids.shape


[Errno 2] No such file or directory: '../data/results/pie_128_celeba_may_model_x_reco_test.npy'
[Errno 2] No such file or directory: '../data/results/pie_128_celeba_may_run1_model_x_reco_test.npy'
[Errno 2] No such file or directory: '../data/results/pie_128_celeba_may_run2_model_x_reco_test.npy'
[Errno 2] No such file or directory: '../data/results/mf_128_celeba_may_run1_model_x_reco_test.npy'
[Errno 2] No such file or directory: '../data/results/mf_128_celeba_may_run2_model_x_reco_test.npy'
[Errno 2] No such file or directory: '../data/results/emf_128_celeba_may_model_x_reco_test.npy'
[Errno 2] No such file or directory: '../data/results/emf_128_celeba_may_run1_model_x_reco_test.npy'
[Errno 2] No such file or directory: '../data/results/emf_128_celeba_may_run2_model_x_reco_test.npy'


(7, 3)

In [7]:
max_reco_error = 10000.
model_mean_reco_errors = np.mean(np.clip(model_test_reco_errors, 0., max_reco_error), axis=2)
model_mean_reco_errors.shape


(7, 3)

## Compute mean and error

In [8]:
def mean_err_without_outliers(data, remove=0):
    shape = list(data.shape)[:-1]
    data.reshape((-1, data.shape[-1]))
    
    means, errors = [], []
    
    for data_ in data:
        data_ = data_[np.isfinite(data_)]
        if not len(data_) > 0:
            means.append(np.nan)
            errors.append(np.nan)
            continue
            
        if len(data_) > 2*remove + 1:
            for _ in range(remove):
                data_ = np.delete(data_, np.argmin(data_))
                data_ = np.delete(data_, np.argmax(data_))

        means.append(np.mean(data_))
        errors.append(np.std(data_) / len(data_)**0.5)
        
    return np.array(means).reshape(shape), np.array(errors).reshape(shape)
    
    
model_fid_mean, model_fid_std = mean_err_without_outliers(model_gen_fids)
model_reco_error_mean, model_reco_error_std = mean_err_without_outliers(model_mean_reco_errors)


## Best metrics

In [9]:
best_fid = -1
best_reco = -1

best_fid = np.nanargmin(model_fid_mean)
print(algo_labels[best_fid])

best_reco = np.nanargmin(np.where(model_reco_error_mean > 1., model_reco_error_mean, np.nan))
print(algo_labels[best_reco])


\af{}
\mf{}


  import sys


## Print result table

In [10]:
def print_results(
    l_label=max([len(l) for l in algo_labels]), l_means=(5,4), l_errs=(3,3), latex=False, after_decs=(1,0), labels=["FID", "RE"]
):
    # Number of digits
    l_results = np.array(l_means) + 2 + np.array(l_errs)
    l_total = l_label + 1 + np.sum(3 + l_results)
        
    # Divider
    col_divider = "&" if latex else "|"
    line_end = r"\\" if latex else ""
    block_divider = r"\midrule" if latex else "-"*l_total
    
    # Number formatting
    def _f(val, err, after_dec, best, l_mean, l_err):
        l_result = l_mean + 2 + l_err
        empty_result = "" if latex else " "*(l_result + 1)
        
        if not np.any(np.isfinite(val)):
            return empty_result
        
        result = "{:>{}.{}f}".format(val, l_mean, after_dec)
        if latex and best:
            result = r"\textbf{" + result + "}"
            
        if latex:
            err_str = str.rjust("{:.{}f}".format(err, after_dec), l_err).replace(" ", r"\hphantom{0}")
            result += r"\;\textcolor{darkgray}{$\pm$\;" + err_str + "}"
        else:
            err_str = "({:>{}.{}f})".format(err, l_err, after_dec)
            result += err_str
            
        result += "*" if not latex and best else " "
        
        if latex:
            result = result.replace("-", "$-{}$")
            result = result.replace("darkgray", "dark-gray")
        return result
    
    
    # Header
    print(f"{'':<{l_label}.{l_label}s} {col_divider} {labels[0]:>{l_results[0]}.{l_results[0]}s} {col_divider} {labels[1]:>{l_results[1]}.{l_results[1]}s} {line_end}")
    print(block_divider)

    # Iterate over methods
    for i, (label, fid, fid_err, reco, reco_err) in enumerate(zip(
        algo_labels, model_fid_mean, model_fid_std, model_reco_error_mean, model_reco_error_std
    )):
        # Divider
        if i in algo_dividers:
            print(block_divider)
            
        # Print results
        print(
            f"{label:<{l_label}.{l_label}s} {col_divider} "
            + f"{_f(fid, fid_err, after_decs[0], i==best_fid, l_means[0], l_errs[0]):s}{col_divider} "
            + f"{_f(reco, reco_err, after_decs[1], i==best_reco, l_means[1], l_errs[1]):s} {line_end}"
        )


In [11]:
print_results()

                   |        FID |        RE 
--------------------------------------------
\af{}              |  33.6(0.2)*|    0(  0)  
\pie{}             |  75.7(5.1) | 6970( 97)  
\pie{} ($n = 128$) |            |            
\mf{}              |  37.4(0.2) |  830(  5)* 
\mf{} ($n = 128$)  |  37.2(0.0) | 1645(  0)  
\mfe{}             |  35.8(0.4) |  991(  4)  
\mfe{} ($n = 128$) |            |            


In [12]:
print_results(latex=True)

                   &        FID &        RE \\
\midrule
\af{}              & \textbf{ 33.6}\;\textcolor{dark-gray}{$\pm$\;0.2} &    0\;\textcolor{dark-gray}{$\pm$\;\hphantom{0}\hphantom{0}0}  \\
\pie{}             &  75.7\;\textcolor{dark-gray}{$\pm$\;5.1} & 6970\;\textcolor{dark-gray}{$\pm$\;\hphantom{0}97}  \\
\pie{} ($n = 128$) & &  \\
\mf{}              &  37.4\;\textcolor{dark-gray}{$\pm$\;0.2} & \textbf{ 830}\;\textcolor{dark-gray}{$\pm$\;\hphantom{0}\hphantom{0}5}  \\
\mf{} ($n = 128$)  &  37.2\;\textcolor{dark-gray}{$\pm$\;0.0} & 1645\;\textcolor{dark-gray}{$\pm$\;\hphantom{0}\hphantom{0}0}  \\
\mfe{}             &  35.8\;\textcolor{dark-gray}{$\pm$\;0.4} &  991\;\textcolor{dark-gray}{$\pm$\;\hphantom{0}\hphantom{0}4}  \\
\mfe{} ($n = 128$) & &  \\


In [13]:
model_gen_fids

array([[34.12582888, 33.56935076, 33.12149346],
       [72.60346539, 87.62884596, 66.76780555],
       [        nan,         nan,         nan],
       [37.89893445, 36.85327462, 37.34314453],
       [37.19593512,         nan,         nan],
       [35.78891746, 36.7025036 , 35.05215865],
       [        nan,         nan,         nan]])