# StyleGAN image manifolds: result table

## Setup

In [None]:
import sys
import numpy as np
from itertools import product
from sklearn.neighbors import KernelDensity


## Load results

In [None]:
n_runs = 10
n_chains = 1
n_trueparams = 1

In [None]:
algo_filenames = []
algo_additionals = []
algo_labels = []
algo_dividers = []
algo_dims = []

def add_algo(filename, add, label, dim=None):
    algo_filenames.append(filename)
    algo_additionals.append(add)
    algo_labels.append(label)
    algo_dims.append(dim)
    
    
def add_divider():
    algo_dividers.append(len(algo_filenames))

add_algo("flow", "april", r"\af{}", 2)
add_algo("pie", "april", r"\pie{}")
add_algo("mf", "april", r"\mf{}")
add_algo("emf", "april", r"\mfe{}")

n_algos = len(algo_filenames)


In [None]:
def load(tag, dim, shape, numpyfy=True, chains=1, result_dir="../data/results"):
    all_results = []
    
    for algo_filename, algo_add, algo_dim in zip(algo_filenames, algo_additionals, algo_dims):
        algo_results = []
            
        for run in range(n_runs):
            run_str = "" if run == 0 else "_run{}".format(run)

            try:
                this_result = np.load(
                    f"{result_dir}/{algo_filename}_{dim if algo_dim is None else algo_dim}_gan{dim}d_{algo_add}{run_str}_{tag}.npy"
                )
                if (not numpyfy) or (shape is None) or np.product(this_result.shape) == np.product(shape):
                    algo_results.append(this_result.reshape(shape))
                else:
                    algo_results.append(np.nan*np.ones(shape))

            except FileNotFoundError as e:
                # print(e)
                if shape is None:
                    algo_results.append(None)
                else:
                    algo_results.append(np.nan*np.ones(shape))
            
        all_results.append(algo_results)
    
    if numpyfy:
        all_results = np.array(all_results, dtype=np.float)
        all_results = all_results.reshape([all_results.shape[0], n_runs] + list(shape))
        
    return all_results


model_test_reco_xs_2 = load("model_x_reco_test", 2, (100, 3, 64, 64))
model_test_reco_errors_2 = load("model_reco_error_test", 2, (100,))
model_gen_fids_2 = load("samples_fid", 2, (1,)).squeeze()

model_test_reco_xs_64 = load("model_x_reco_test", 64, (100, 3, 64, 64))
model_test_reco_errors_64 = load("model_reco_error_test", 64, (100,))
model_gen_fids_64 = load("samples_fid", 64, (1,)).squeeze()


In [None]:
def load_mcmc(tag, dim, shape, numpyfy=True, result_dir="../data/results"):
    all_results = []
    
    for algo_filename, algo_add, algo_dim in zip(algo_filenames, algo_additionals, algo_dims):
        algo_results = []
            
        for run in range(n_runs):
            run_str = "" if run == 0 else "_run{}".format(run)
            
            for trueparam in range(n_trueparams):
                trueparam_str = "" if trueparam == 0 else "_trueparam{}".format(trueparam)
            
                for chain in range(n_chains):
                    chain_str = "" if chain == 0 else "_chain{}".format(chain)

                    try:
                        this_result = np.load(
                            f"{result_dir}/{algo_filename}_{dim if algo_dim is None else algo_dim}"
                            + f"_gan{dim}d_{algo_add}{run_str}_{tag}{trueparam_str}{chain_str}.npy"
                        )
                        if (not numpyfy) or (shape is None) or np.product(this_result.shape) == np.product(shape):
                            algo_results.append(this_result.reshape(shape))
                        else:
                            algo_results.append(np.nan*np.ones(shape))

                    except FileNotFoundError as e:
                        # print(e)
                        if shape is None:
                            algo_results.append(None)
                        else:
                            algo_results.append(np.nan*np.ones(shape))
            
        all_results.append(algo_results)
    
    all_results = np.array(all_results, dtype=np.float)
    all_results = all_results.reshape([all_results.shape[0], n_runs, n_trueparams, n_chains] + list(shape))
        
    return all_results


model_posterior_samples_64 = load_mcmc("posterior_samples", 64, (400, 1,))
model_posterior_samples_64.shape  # (algo, run, true param id, chain, sample, theta component)


## Compute metrics

In [None]:
max_reco_error = 10000.
model_mean_reco_errors_2 = np.mean(np.clip(model_test_reco_errors_2, 0., max_reco_error), axis=2)
model_mean_reco_errors_64 = np.mean(np.clip(model_test_reco_errors_64, 0., max_reco_error), axis=2)


In [None]:
bandwidth = 0.05
true_param_points = np.array([[0.,0.]])

model_true_log_posteriors_64 = []

for algo, run, trueparam in product(range(n_algos), range(n_runs), range(n_trueparams)):
    mcmcs = model_posterior_samples_64[algo, run, trueparam].reshape((-1, 2))
    mcmcs = mcmcs[np.all(np.isfinite(mcmcs), axis=-1)]
    
    if len(mcmcs) == 0:
        model_true_log_posteriors_64.append(np.nan)
        continue
        
    kde = KernelDensity(kernel="gaussian", bandwidth=bandwidth)
    kde.fit(mcmcs)
    model_true_log_posteriors_64.append(kde.score(true_param_points[trueparam].reshape((1, 2))))

model_true_log_posteriors_64 = np.mean(np.array(model_true_log_posteriors_64).reshape((n_algos, n_runs, n_trueparams)), axis=-1)
model_true_log_posteriors_64.shape


In [None]:
model_true_log_posteriors_64

## Compute mean and error

In [None]:
def mean_err_without_outliers(data, remove=0):
    shape = list(data.shape)[:-1]
    data.reshape((-1, data.shape[-1]))
    
    means, errors = [], []
    
    for data_ in data:
        data_ = data_[np.isfinite(data_)]
        if not len(data_) > 0:
            means.append(np.nan)
            errors.append(np.nan)
            continue
            
        if len(data_) > 2*remove + 1:
            for _ in range(remove):
                data_ = np.delete(data_, np.argmin(data_))
                data_ = np.delete(data_, np.argmax(data_))

        means.append(np.mean(data_))
        errors.append(np.std(data_) / len(data_)**0.5)
        
    return np.array(means).reshape(shape), np.array(errors).reshape(shape)
    
    
model_fid_mean_2, model_fid_std_2 = mean_err_without_outliers(model_gen_fids_2, 0)
model_reco_error_mean_2, model_reco_error_std_2 = mean_err_without_outliers(model_mean_reco_errors_2, 0)

model_fid_mean_64, model_fid_std_64 = mean_err_without_outliers(model_gen_fids_64, 0)
model_reco_error_mean_64, model_reco_error_std_64 = mean_err_without_outliers(model_mean_reco_errors_64, 0)
model_true_log_posteriors_mean_64, model_true_log_posteriors_std_64 = mean_err_without_outliers(model_true_log_posteriors_64, 0)


## Best metrics

In [None]:
best_fid_2 = -1
best_reco_2 = -1
best_fid_64 = -1
best_reco_64 = -1
best_posterior_64 = -1

best_fid_2 = np.nanargmin(model_fid_mean_2)
print(algo_labels[best_fid_2])
best_reco_2 = np.nanargmin(np.where(model_reco_error_mean_2 > 1., model_reco_error_mean_2, np.nan))
print(algo_labels[best_reco_2])

best_fid_64 = np.nanargmin(model_fid_mean_64)
print(algo_labels[best_fid_64])
best_reco_64 = np.nanargmin(np.where(model_reco_error_mean_64 > 1., model_reco_error_mean_64, np.nan))
print(algo_labels[best_reco_64])
best_posterior_64 = np.nanargmax(model_true_log_posteriors_mean_64)
print(algo_labels[best_posterior_64])


## Print result table

In [None]:
def print_results(
    l_label=max([len(l) for l in algo_labels]),
    l_means=(5,4,5,4,5),
    l_errs=(3,2,3,2,4),
    latex=False,
    after_decs=(1,0,1,0,2),
    labels=["FID (2)", "RE (2)", "FID (64)", "RE (64)", "log p (64)"]
):
    # Number of digits
    l_results = np.array(l_means) + 2 + np.array(l_errs)
    l_total = l_label + 1 + np.sum(3 + l_results)
        
    # Divider
    col_divider = "&" if latex else "|"
    line_end = r"\\" if latex else ""
    block_divider = r"\midrule" if latex else "-"*l_total
    
    # Number formatting
    def _f(val, err, after_dec, best, l_mean, l_err):
        l_result = l_mean + 2 + l_err
        empty_result = "" if latex else " "*(l_result + 1)
        
        if not np.any(np.isfinite(val)):
            return empty_result
        
        result = "{:>{}.{}f}".format(val, l_mean, after_dec)
        if latex and best:
            result = r"\textbf{" + result + "}"
            
        if latex:
            err_str = str.rjust("{:.{}f}".format(err, after_dec), l_err).replace(" ", r"\hphantom{0}")
            result += r"\;\textcolor{darkgray}{$\pm$\;" + err_str + "}"
        else:
            err_str = "({:>{}.{}f})".format(err, l_err, after_dec)
            result += err_str
            
        result += "*" if not latex and best else " "
        
        if latex:
            result = result.replace("-", "$-{}$")
            result = result.replace("darkgray", "dark-gray")
        return result
    
    
    # Header
    print(
        f"{'':<{l_label}.{l_label}s} {col_divider} "
        + f"{labels[0]:>{l_results[0]}.{l_results[0]}s} {col_divider} "
        + f"{labels[1]:>{l_results[1]}.{l_results[1]}s} {col_divider} "
        + f"{labels[2]:>{l_results[2]}.{l_results[2]}s} {col_divider} "
        + f"{labels[3]:>{l_results[3]}.{l_results[3]}s} {col_divider} "
        + f"{labels[4]:>{l_results[4]}.{l_results[4]}s} {line_end}"
    )
    print(block_divider)

    # Iterate over methods
    for i, (
        label, fid2, fid_err2, reco2, reco_err2, fid64, fid_err64, reco64, reco_err64, posterior64, posterior_err64
    ) in enumerate(zip(
        algo_labels,
        model_fid_mean_2, model_fid_std_2, model_reco_error_mean_2, model_reco_error_std_2,
        model_fid_mean_64, model_fid_std_64, model_reco_error_mean_64, model_reco_error_std_64,
        model_true_log_posteriors_mean_64, model_true_log_posteriors_std_64
    )):
        # Divider
        if i in algo_dividers:
            print(block_divider)
            
        # Print results
        print(
            f"{label:<{l_label}.{l_label}s} {col_divider} "
            + f"{_f(fid2, fid_err2, after_decs[0], i==best_fid_2, l_means[0], l_errs[0]):s}{col_divider} "
            + f"{_f(reco2, reco_err2, after_decs[1], i==best_reco_2, l_means[1], l_errs[1]):s}{col_divider} "
            + f"{_f(fid64, fid_err64, after_decs[2], i==best_fid_64, l_means[2], l_errs[2]):s}{col_divider} "
            + f"{_f(reco64, reco_err64, after_decs[3], i==best_reco_64, l_means[3], l_errs[3]):s}{col_divider} "
            + f"{_f(posterior64, posterior_err64, after_decs[4], i==best_posterior_64, l_means[4], l_errs[4]):s} {line_end}"
        )


In [None]:
print_results()

In [None]:
print_results(latex=True)