In [1]:
%load_ext autoreload
%autoreload 2

import sys
import numpy as np
from scipy.stats import trim_mean
from sklearn.metrics import mean_squared_error

sys.path.append('..')
from higgs_inference import settings
from higgs_inference.various.utils import format_number

In [2]:
result_dir = '../results/'

# TablePrinter class

In [3]:
class TablePrinter:
    
    def __init__(self, metric_fns=[], header=None, precisions=[]):
        
        # Functions for metrics
        self.metric_fns = metric_fns
        self.n_metrics = len(self.metric_fns)
        self.precisions = precisions if len(precisions) == self.n_metrics else [2] * self.n_metrics
        
        # Total table and current block
        self.table = ''
        self.block_entries = []
        self.block_brackets = []
        self.content_in_last_block = False

        # Formatting options
        self.indent = '   '
        self.col_sep = ' & '
        self.end_row = r'\\'
        self.midrule = r'\midrule'
        self.end_line = '\n'
        self.emphasis_begin = r'\mathbf{'
        self.emphasis_end = r'}'
        
        # Header
        self.table = ''
        if header is not None:
            self.table += self.indent + header + self.end_row + self.end_line
    
    
    def finalise_block(self):
        
        self.content_in_last_block = False
        
        # Skip if block is empty
        if len(self.block_entries) == 0:
            return

        # Find best performance
        block_metrics = [line[2:] for line in self.block_entries]
        block_metrics = np.array(block_metrics)
        block_best = []
        for i in range(self.n_metrics):
            try:
                block_best.append(np.nanargmin(block_metrics[:,i]))
            except ValueError:
                block_best.append(-1)

        # Format entries
        text = ''
        for i, (line, brackets) in enumerate(zip(self.block_entries, self.block_brackets)):
            
            # Skip entirely empty lines
            try:
                if not np.any(np.isfinite(line[2:])):
                    continue
            except TypeError:
                print(line)
                continue
                
            self.content_in_last_block = True
            
            # Labels
            text += self.indent + line[0] + self.col_sep + line[1] + self.col_sep
            
            # Metrics
            for j in range(self.n_metrics):
                if np.isfinite(line[j + 2]):
                    if brackets[j+2]:
                        text += '(' + format_number(line[j + 2], self.precisions[j], latex_math_mode=True, emphasize=(i == block_best[j])) + ')'
                    else:
                        text += format_number(line[j + 2], self.precisions[j], latex_math_mode=True, emphasize=(i == block_best[j]))
                if j == len(line) - 3:
                    text += self.end_row + self.end_line
                else:
                    text += self.col_sep

        # Add to document and reset for next block
        self.table += text
        self.block_entries = []
        self.block_brackets = []
    
    
    def new_block(self):
        self.finalise_block()
        if self.content_in_last_block:
            self.table += self.indent + self.midrule + self.end_line
            self.content_in_last_block = False
    
    
    def add(self, col1, col2, filename, folder='parameterized'):
        
        # Label columns
        line = [col1, col2]
        if len(self.block_entries) > 0:
            for entry in self.block_entries:
                if entry[0] == col1:
                    line = ['', col2]
                    break
        brackets = [False, False]
        
        # Metrics
        for fn in self.metric_fns:
            bracket = False
            try:
                value = fn(filename, folder)
            except IOError:
                #print('File', filename, 'in folder', folder, 'not found')
                value = np.nan
            except ValueError:
                #print('File', filename, 'in folder', folder, 'leads to ValueError')
                value = np.nan
                
            if isinstance(value, (list, tuple)):
                value, bracket = value
                
            line.append(value)
            brackets.append(bracket)

        if np.any(np.isfinite(line[2:])):
            self.block_entries.append(line)
            self.block_brackets.append(brackets)
    
    
    def print(self):
        self.finalise_block()
        return self.table

# Metrics

In [4]:
def expected_mse_log_r(filename, folder='parameterized'):
    mse_log_r = np.load(result_dir + folder + '/mse_logr_' + filename + '.npy')
    return np.sum(settings.theta_prior * mse_log_r)
                       
def expected_trimmed_mse_log_r(filename, folder='parameterized'):
    mse_log_r = np.load(result_dir + folder + '/trimmed_mse_logr_' + filename + '.npy')
    return np.sum(settings.theta_prior * mse_log_r)

# Result table

In [8]:
table = TablePrinter([expected_mse_log_r, expected_trimmed_mse_log_r], precisions=[4,5])

table.add('Histogram', '', 'histo_2d_asymmetricbinning', 'histo')
table.add('CARL', '', 'carl_calibrated_shallow', 'parameterized')

table.new_block()
table.add('ROLR', '', 'regression_calibrated', 'parameterized')
table.add('CASCAL', '', 'combined_calibrated_deep', 'parameterized')
table.add('RASCAL', '', 'combinedregression_calibrated_deep', 'parameterized')
table.add('Modified XE', '', 'mxe_calibrated_deep', 'parameterized')
table.add('Modified XE + score', '', 'combinedmxe_calibrated_deep', 'parameterized')

table.new_block()
table.add('SALLY', '', 'scoreregression_rotatedscore_deep', 'score_regression')
table.add('SALLINO', '', 'scoreregression_scoretheta_deep', 'score_regression')

print(table.print())


   Histogram &  & $0.0561$ & $0.01057$\\
   CARL &  & $\mathbf{0.0124}$ & $\mathbf{0.00259}$\\
   \midrule
   ROLR &  & $0.0032$ & $0.00166$\\
   CASCAL &  & $0.0008$ & $0.00024$\\
   RASCAL &  & $0.0009$ & $0.00037$\\
   Modified XE &  & $\mathbf{0.0004}$ & $\mathbf{0.00008}$\\
   Modified XE + score &  & $0.0012$ & $0.00029$\\
   \midrule
   SALLY &  & $\mathbf{0.0132}$ & $\mathbf{0.00025}$\\
   SALLINO &  & $0.0213$ & $0.00063$\\

