# Floating-Point Unit (FPU) Experiments

In [1]:
import sys
sys.path.insert(1, '../')
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import iqr

import primitive_sets

root_dir = f'./results'

# Seed random number generators for reproducibility.
random_seed = 0
np.random.seed(random_seed)

## Wrapper function for FPU

Wrapper function to emulate the hardware-based floating-point unit.

In [2]:
@np.errstate(all='ignore')
def fpu(primitive_set, sel, args, primal=None, differential=None):
    """Floating-point unit.
    
    The `primitive_set` parameter must be an indexable container 
    for a list of tuples, where each tuple is to define a primitive 
    function and its arity, in that order. The `sel` parameter must 
    be an integer representing the index of the function within 
    `primitive_set` that is to be computed. The inputs corresponding
    to the relevant function are to be given by `*args[:arity]`, 
    with `arity` being the arity of the function.
    """
    # The relevant function and its arity.
    fn, arity = primitive_set[sel]
    # Compute the function using the relevant arguments.
    return fn(*args[:arity], primal=primal, differential=differential)

## Evaluating function error

We use various error measures to compare each function within each primitive set to each function within the standard primitive set.

We test various input domains, and for each primitive set, we use the same set of one million test cases, where the first 490 cases of this set are all choose-two combinations of the set of special inputs `{NaN, -Inf, +Inf, -0, +0, -1, +1}`, and where the remaining cases are randomly drawn (in a uniform manner) from the input domain.

### Helper functions for computing various error measures

In [3]:
@np.errstate(all='ignore')
def absolute_error(Y_true, Y_est):
    """Return absolute error.
    
    Special cases are handled.
    """
    res = np.abs(Y_true - Y_est)

    # Handle the case of both the desired/actual being 
    # equal to NaN.
    res[np.isnan(Y_true) & np.isnan(Y_est)] = 0

    # Handle the case of one of the desired/actual being
    # equal to NaN, but not both.
    res[~np.isnan(Y_true) & np.isnan(Y_est)] = np.inf
    res[np.isnan(Y_true) & ~np.isnan(Y_est)] = np.inf

    # Handle the case of both the desired/actual being equal 
    # to infinity, with the same sign. (If both are infinity 
    # with different signs, the error will be infinite.)
    res[(Y_true == Y_est) & np.isinf(Y_true)] = 0
    
    return res

@np.errstate(all='ignore')
def relative_error(Y_true, Y_est):
    """Return relative error.
    
    Special cases are handled.
    """
    res = np.abs(Y_true - Y_est) / np.abs(Y_true)

    # Handle the case of both the desired/actual being 
    # equal to NaN.
    res[np.isnan(Y_true) & np.isnan(Y_est)] = 0

    # Handle the case of one of the desired/actual being
    # equal to NaN, but not both.
    res[~np.isnan(Y_true) & np.isnan(Y_est)] = np.inf
    res[np.isnan(Y_true) & ~np.isnan(Y_est)] = np.inf

    # Handle the case of both the desired/actual being
    # equal to infinity, with the same sign.
    res[(Y_true == Y_est) & np.isinf(Y_true)] = 0

    # Handle the case of both the desired/actual being
    # equal to infinity, but with different signs.
    res[(Y_true != Y_est) & np.isinf(Y_true) & np.isinf(Y_est)] = np.inf

    # Handle the case of the desired being equal to infinity 
    # and the actual being finite. (If the reverse is true,
    # the relative error will already be infinite.)
    res[np.isinf(Y_true) & np.isfinite(Y_est)] = np.inf

    # Handle the case of the desired being equal to zero.
    res[(Y_true == 0) & (Y_est == 0)] = 0
    res[(Y_true == 0) & (Y_est != 0)] = np.inf

    return res

def round_sig(x, n):
    """Round to `n` significant digits."""
    res = x
    if np.isfinite(res) and res != 0:
        floor_ = np.floor(np.log10(res))
        if np.isfinite(floor_):
            res = round(res, -int(floor_) + (n - 1))
    return res

In [4]:
# Function opcode to name mapping.
function_names = {
    0 : 'add',
    1 : 'sub',
    2 : 'mul',
    3 : 'div',
    4 : 'sin',
    5 : 'cos',
    6 : 'exp',
    7 : 'log',
    8 : 'sqrt',
    9 : 'tanh',
}

# Primitive sets.
primitive_sets_ = (
    (primitive_sets.mad_16, 'mad_16'),
    (primitive_sets.mad_10, 'mad_10'),
    (primitive_sets.mad_04, 'mad_04'),
)

# Input domains.
low = [-10,]
high = [+10,]
# low = [-10, -10000, -31875756.0, -3.4028235e+38]
# high = [+10, +10000, +31875756.0, +3.4028235e+38]

# Total number of input cases.
n = 1000000
# n = 1000
# n = 10000

# Number of significant digits for numeric statistics.
d = 9

# Dictionary to contain statistics for each primitive set.
error_statistics = {name : {} for _, name in primitive_sets_}

# Apply the test cases to each function and compute statistics.
for primitive_set, ps_name in primitive_sets_:

    # Number of functions in the current primitive set.
    n_functions = len(primitive_set)

    # Initialize dictionaries within `error_statistics[ps_name]`,
    # which are to contain statistics for each function
    # of `primitive_set`.
    error_statistics[ps_name] = {name : {} for name in function_names.values()}

    # Evaluate test cases for each input domain.
    for low_, high_ in zip(low, high):
        # Initialize input arrays to all possible combinations
        # of special cases.
        special_inputs = [np.nan, -np.inf, +np.inf, -0, +0, -1, +1]
        X = np.asarray(special_inputs, dtype=np.float32)
        X0, X1 = np.meshgrid(X, X, indexing='ij')
        X0 = X0.ravel()
        X1 = X1.ravel()

        # Add uniformly random inputs to each input array.
        n_sample = n - len(X0)
        X0 = np.concatenate(
            (X0, (np.random.uniform(
                low=low_, high=high_, size=n_sample)).astype(np.float32)))
        X1 = np.concatenate(
            (X1, (np.random.uniform(
                low=low_, high=high_, size=n_sample)).astype(np.float32)))
        
        for sel in function_names.keys():
            # Extract the name of the current function.
            f_name = function_names[sel]
            _, arity = primitive_set[sel]

            # Compute desired/actual function outputs.
            Y_true = fpu(primitive_sets.standard, sel, [X0, X1])
            Y_est = fpu(primitive_set, sel, [X0, X1])

            # Absolute error.
            abs_error = absolute_error(Y_true, Y_est)

            # Statistics for absolute error.
            abs_error_argmax = np.argmax(abs_error)
            abs_error_max_test_case = (
                f'x=({X0[abs_error_argmax]}, {X1[abs_error_argmax]}), '
                f'y_true={Y_true[abs_error_argmax]}, '
                f'y_est={Y_est[abs_error_argmax]}')
            abs_error_max = round_sig(np.max(abs_error), d)
            abs_error_med = round_sig(np.median(abs_error), d)
            abs_error_mean = round_sig(np.mean(abs_error), d)
            abs_error_min = round_sig(np.min(abs_error), d)
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')
                abs_error_std = round_sig(np.std(abs_error), d)
                abs_error_iqr = round_sig(iqr(abs_error), d)

            # Relative error.
            rel_error = relative_error(Y_true, Y_est)

            # Convert relative error to percentage.
            rel_error *= 100

            # Statistics for relative error.
            rel_error_argmax = np.argmax(rel_error)
            rel_error_max_test_case = (
                f'x=({X0[rel_error_argmax]}, {X1[rel_error_argmax]}), '
                f'y_true={Y_true[rel_error_argmax]}, '
                f'y_est={Y_est[rel_error_argmax]}')
            rel_error_max = round_sig(np.max(rel_error), d)
            rel_error_med = round_sig(np.median(rel_error), d)
            rel_error_mean = round_sig(np.mean(rel_error), d)
            rel_error_min = round_sig(np.min(rel_error), d)
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')
                rel_error_std = round_sig(np.std(rel_error), d)
                rel_error_iqr = round_sig(iqr(rel_error), d)

            # Store statistics in dictionary.
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_ARGMAX', []).append(rel_error_argmax)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_MAX_TEST_CASE', []).append(rel_error_max_test_case)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_MAX', []).append(rel_error_max)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_MED', []).append(rel_error_med)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_MEAN', []).append(rel_error_mean)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_MIN', []).append(rel_error_mean)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_STD', []).append(rel_error_std)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_IQR', []).append(rel_error_iqr)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_ARGMAX', []).append(abs_error_argmax)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_MAX_TEST_CASE', []).append(abs_error_max_test_case)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_MAX', []).append(abs_error_max)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_MED', []).append(abs_error_med)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_MEAN', []).append(abs_error_mean)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_MIN', []).append(abs_error_mean)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_STD', []).append(abs_error_std)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_IQR', []).append(abs_error_iqr)
            

            # Compute desired/actual derivative outputs.
            Y_true_p = fpu(
                primitive_sets.standard, sel, [X0, X1],
                primal=Y_true, differential=0)
            Y_est_p = fpu(
                primitive_set, sel, [X0, X1],
                primal=Y_est, differential=0)
            
            if arity == 2:
                # Compute second partial derivative.
                Y_true_p_2 = fpu(
                    primitive_sets.standard, sel, [X0, X1],
                    primal=Y_true, differential=1)
                Y_est_p_2 = fpu(
                    primitive_set, sel, [X0, X1],
                    primal=Y_est, differential=1)

            # Absolute error.
            abs_error = absolute_error(Y_true_p, Y_est_p)

            if arity == 2:
                # Add independent errors from second partial derivative.
                abs_error += absolute_error(Y_true_p_2, Y_est_p_2)

            # Statistics for absolute error.
            abs_error_argmax = np.argmax(abs_error)
            abs_error_max_test_case = (
                f'x=({X0[abs_error_argmax]}, {X1[abs_error_argmax]}), '
                f'y_true={Y_true_p[abs_error_argmax]}, '
                f'y_est={Y_est_p[abs_error_argmax]}')
            abs_error_max = round_sig(np.max(abs_error), d)
            abs_error_med = round_sig(np.median(abs_error), d)
            abs_error_mean = round_sig(np.mean(abs_error), d)
            abs_error_min = round_sig(np.min(abs_error), d)
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')
                abs_error_std = round_sig(np.std(abs_error), d)
                abs_error_iqr = round_sig(iqr(abs_error), d)

            # Relative error.
            rel_error = relative_error(Y_true_p, Y_est_p)

            if arity == 2:
                # Add independent errors from second partial derivative.
                rel_error += relative_error(Y_true_p_2, Y_est_p_2)

            # Convert relative error to percentage.
            rel_error *= 100

            # Statistics for relative error.
            rel_error_argmax = np.argmax(rel_error)
            rel_error_max_test_case = (
                f'x=({X0[rel_error_argmax]}, {X1[rel_error_argmax]}), '
                f'y_true={Y_true_p[rel_error_argmax]}, '
                f'y_est={Y_est_p[rel_error_argmax]}')
            rel_error_max = round_sig(np.max(rel_error), d)
            rel_error_med = round_sig(np.median(rel_error), d)
            rel_error_mean = round_sig(np.mean(rel_error), d)
            rel_error_min = round_sig(np.min(rel_error), d)
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')
                rel_error_std = round_sig(np.std(rel_error), d)
                rel_error_iqr = round_sig(iqr(rel_error), d)

            # Store statistics in dictionary.
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_ARGMAX', []).append(rel_error_argmax)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_MAX_TEST_CASE', []).append(rel_error_max_test_case)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_MAX', []).append(rel_error_max)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_MED', []).append(rel_error_med)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_MEAN', []).append(rel_error_mean)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_MIN', []).append(rel_error_mean)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_STD', []).append(rel_error_std)
            error_statistics[ps_name][f_name].setdefault(
                'REL_ERROR_IQR', []).append(rel_error_iqr)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_ARGMAX', []).append(abs_error_argmax)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_MAX_TEST_CASE', []).append(abs_error_max_test_case)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_MAX', []).append(abs_error_max)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_MED', []).append(abs_error_med)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_MEAN', []).append(abs_error_mean)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_MIN', []).append(abs_error_mean)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_STD', []).append(abs_error_std)
            error_statistics[ps_name][f_name].setdefault(
                'ABS_ERROR_IQR', []).append(abs_error_iqr)


# Write the relevant statistics to a CSV file.
df = pd.DataFrame.from_dict(
    {(i,j): error_statistics[i][j] 
     for i in error_statistics.keys() for j in error_statistics[i].keys()},
    orient='index')
df.to_excel(f'{root_dir}/error_statistics.xlsx', na_rep='NaN')
df.to_csv(f'{root_dir}/error_statistics.csv', na_rep='NaN')