In [1]:
import arviz as az
import json
import numpy as np
import pickle

from emulators import SpaceTimeKron

def dump_default(obj):
    '''
    Helps json.dumps figure out what to do with
    a Numpy array which is not natively serializable.
    '''
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    raise TypeError('Not serializable')

def average_point(trace):
    '''
    Compute posterior mean dict from Arviz InferenceData object.
    '''
    return {k: trace.posterior[k].mean(axis=(0,1)).to_numpy() for k in trace.posterior.keys()}

def at_least_2d(array):
    '''
    Forces a Numpy array to have at least two dimensions.
    '''
    if array.ndim == 1:
        return array[:, None]
    else:
        return array
    
def constrain_unit_cube(array, reduce_axis=0):
    '''
    Subtract minimum and divide by maximum
    to ensure returned values are constrained to the 
    unit cube.
    '''
    offset = array.min(axis=reduce_axis)
    unit_array = array - offset
    scale = unit_array.max(axis=reduce_axis)
    return unit_array / scale, offset, scale



In [2]:
input_filepath     = '../data/sir.json'
fit_method         = 'mcmc'
model_type         = 'time+process'
output_filepath    = f'../outputs/sir_{fit_method}_{model_type}.json'

def fit_model(input_filepath, output_filepath, model_type, fit_method, 
              response_transform='plus1log', split_char='+',vi_iter=100_000, mcmc_iter=2):
    '''
    Utility for determining type of surrogate model to construct, preprocessing input/response data, 
    and running a parameter estimation algorithm.
    '''
    
    if response_transform == 'plus1log':
        response_transform = lambda x: np.log(x + 1)
    elif response_transform == 'ihs':
        response_transform = lambda x: np.log(x + (x**2+1)**0.5)
    elif response_transform == 'none':
        response_transform = lambda x: x       
    else:
        raise NotImplementedError
        
    with open(input_filepath, 'r') as src:
        data = json.load(src)

    for k in data.keys():
        data[k] = np.asarray(data[k])

    coord_keys = model_type.split(split_char)

    kron_Xs = [at_least_2d(data[k]) for k in coord_keys]

    input_scales  = []
    input_offsets = []

    # Shift and rescale all inputs to unit cube
    for i, arr in enumerate(kron_Xs):
        unit_array, offset, scale = constrain_unit_cube(arr)

        input_scales  += [np.atleast_1d(scale).tolist()]
        input_offsets += [np.atleast_1d(offset).tolist()]

    # Select only training input points in process variable
    # space
    kron_Xs[-1] = kron_Xs[-1][data['train_indices']]

    # A power transform or similar preprocessing
    # technique may be applied at this stage before
    # scaling and subtracting an offset.
    response_array = response_transform(data[f'{model_type}_response'])
    response_axes  = tuple(np.arange(response_array.ndim))
    response_array, offset, scale = constrain_unit_cube(response_array, 
                                                        reduce_axis=response_axes)
    response_array = response_array[...,data['train_indices']]
    y = response_array.flatten()

    extra_bookkeeping = {
        'input_scales'    : input_scales,
        'input_offsets'   : input_offsets,
        'response_offset' : offset,
        'response_scale'  : scale
    }
        
    stpk = SpaceTimeKron()
    stpk.fit(kron_Xs, y, fit_method=fit_method, vi_iter=vi_iter, mcmc_iter=mcmc_iter)    
    
    # Merge dictionaries and cast to list type so that target
    # arrays are serializable and human-readable
    trace = stpk.trace.to_dict()
    
    for k, v in extra_bookkeeping.items():
        if k in trace.keys():
            raise KeyError 
        trace[k] = v
    
    with open(output_filepath, 'w') as outfile:
        json.dump(trace, outfile, default=dump_default)
    
    return trace