In [1]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.lines as mlines
import random
import json
from math import ceil
%matplotlib inline
import os
import matplotlib as mpl
from PIL import Image
from io import BytesIO


from bayescmd.abc import import_actual_data
from bayescmd.abc import inputParse
import scipy.stats as stats
import statsmodels.api as sm


import os
from pathlib import Path
from distutils import dir_util
from pprint import pprint
import pickle

# BayesCMD packages 

from bayescmd.bcmdModel import ModelBCMD
from bayescmd.abc import import_actual_data
from bayescmd.abc import priors_creator

from bayescmd.results_handling import plot_repeated_outputs


# Google BigQuery
from google.cloud import bigquery
%load_ext google.cloud.bigquery


mpl.rc('figure', dpi=300, figsize=(7.5,8))

mpl.rcParams["xtick.labelsize"]= 8

mpl.rcParams["ytick.labelsize"]= 8
    
mpl.rcParams["axes.labelsize"]= 10

mpl.rcParams["figure.titlesize"] = 12

STARTING AT: /home/buck06191/repos/Github/BayesCMD/bayescmd
 Looking for: BayesCMD
STARTING AT: /home/buck06191/repos/Github/BayesCMD/bayescmd
 Looking for: BayesCMD
STARTING AT: /home/buck06191/repos/Github/BayesCMD/bayescmd
 Looking for: BayesCMD


In [2]:
def TIFF_exporter(fig, fname, fig_dir='.', extra_artists=()):
    """
    Parameters
    ----------
    fig: matplotlib figure
    """

    # save figure
    # (1) save the image in memory in PNG format
    png1 = BytesIO()
    fig.savefig(png1, format='png', bbox_inches='tight', bbox_extra_artists=extra_artists,
                dpi=300, transparent=False)

    # (2) load this image into PIL
    png2 = Image.open(png1)

    # (3) save as TIFF
    png2.save(os.path.join(fig_dir, '{}.tiff'.format(fname)),
              compression='tiff_deflate')
    png1.close()
    return True

In [3]:
# Explicitly use service account credentials by specifying the private
# key file. All clients in google-cloud-python have this helper.
client = bigquery.Client.from_service_account_json(
    "../../gcloud/hypothermia-auth.json"
)

In [4]:
def generate_posterior_query(project, dataset, model, distance, parameters, limit=50000):
    unpacked_params = ",\n".join(parameters)
    histogram_query = """
SELECT
    {unpacked_params},
    {distance},
    idx
FROM
  `{project}.{dataset}.{model}`
ORDER BY
  {distance} ASC
LIMIT
  {limit}
    """.format(project=project, dataset=dataset, model=model, unpacked_params=unpacked_params,distance=distance, limit=limit)
    return histogram_query

In [5]:
def load_configuration(model_version, dataset, verbose=False):
    current_file = Path(os.path.abspath(''))
    config_file = os.path.join(current_file.parents[2],
                              'config_files',
                               'abc',
                               'bp_hypothermia_{}'.format(model_version),
                               'bp_hypothermia_{}_config.json'.format(model_version)
                              )

    with open(config_file, 'r') as conf_f:
        conf = json.load(conf_f)

    params = conf['priors']

    input_path = os.path.join(current_file.parents[2],
                              'data',
                              'simulated_desat',
                              'sim_sao2_desat.csv')

    d0 = import_actual_data(input_path)

    targets = conf['targets']
    model_name = conf['model_name']
    inputs = ['SaO2sup', 'temp']

    config = {
        "model_name": model_name,
        "targets": targets,
        "times": d0['t'],
        "inputs": inputs,
        "parameters": params,
        "input_path": input_path,
        "zero_flag": conf['zero_flag'],
    }
    
    if verbose:
        pprint(config)
        
    return config, d0

In [6]:
labels = {"t": "Time (sec)", 
 "HbO2": "$\Delta$HbO2 $(\mu M)$",
 "HHb": "$\Delta$HHb $(\mu M)$",
 "CCO": "$\Delta$CCO $(\mu M)$",
         "SaO2sup": "SaO$_{2}$ (%)"}

signals=['HbO2', 'HHb','CCO']
ticker_step = [20, 10, 10, 10, 0.5]
colpal = sns.color_palette(n_colors=len(signals))

In [7]:
def run_model(model):
    """Run a BCMD Model.

    Parameters
    ----------
    model : :obj:`bayescmd.bcmdModel.ModelBCMD`
        An initialised instance of a ModelBCMD class.

    Returns
    -------
    output : :obj:`dict`
        Dictionary of parsed model output.

    """
    input_f = model.create_initialised_input()
    model.run_from_buffer()
    output = model.output_parse()
    return output

def get_output(model_name,
               p,
               times,
               input_data,
               d0,
               targets,
               distance='euclidean',
               zero_flag=None):
    """Generate model output and distances.

    Parameters
    ----------
    model_name : :obj:`str`
        Name of model
    p : :obj:`dict`
        Dict of form {'parameter': value} for which posteriors are being
        investigated.
    times : :obj:`list` of :obj:`float`
        List of times at which the data was collected.
    input_data : :obj:`dict`
        Dictionary of input data as generated by :obj:`abc.inputParse`.
    d0 : :obj:`dict`
        Dictionary of real data, as generated by :obj:`abc.import_actual_data`.
    targets : :obj:`list` of :obj:`str`
        List of model outputs against which the model is being optimised.
    distance : :obj:`str`
        Distance measure. One of 'euclidean', 'manhattan', 'MAE', 'MSE'.
    zero_flag : dict
        Dictionary of form target(:obj:`str`): bool, where bool indicates
        whether to zero that target.

        Note: zero_flag keys should match targets list.

    Returns
    -------
    :obj:`tuple`
        A tuple of (p, model output data).

    """
    model = ModelBCMD(
        model_name, inputs=input_data, params=p, times=times, outputs=targets)

    output = run_model(model)

    if zero_flag:
        for k, boolean in zero_flag.items():
            if boolean:
                output[k] = [x - output[k][0] for x in output[k]]
    return output

In [8]:
def get_runs(posterior, conf, n_repeats=50): 
    rand_selection = random.sample(range(posterior.shape[0]), n_repeats)
    outputs_list = []
    p_names = list(conf['parameters'].keys())
    posteriors = posterior[p_names].values
    d0 = import_actual_data(conf['input_path'])
    input_data = inputParse(d0, conf['inputs'])
    while len(outputs_list) < n_repeats:
        idx = rand_selection.pop()
        print("\tSample {}, idx:{}".format(len(outputs_list), idx))
        p = dict(zip(p_names, posteriors[idx]))
            
    
        output = get_output(
            conf['model_name'],
            p,
            conf['times'],
            input_data,
            d0,
            conf['targets'],
            distance="NRMSE",
            zero_flag=conf['zero_flag'])
        
        outputs_list.append(output)
    return outputs_list


In [9]:
def plot_desat(outputs_list, targets, times, title, labels):
    d = {}
    for target in targets:
        d[target] = [o[target] for o in outputs_list]

    fig, ax = plt.subplots(len(targets)+1, sharex=True,
                           dpi=250, figsize=(4, 5))
    if type(ax) != np.ndarray:
        ax = np.asarray([ax])

    ax[0].plot(times, outputs_list[0]['SaO2sup'], 'r-')
    ax[0].set_xlabel('')
    ax[0].set_ylabel(labels['SaO2sup'])
    ax[0].title.set_fontsize(11)
    for item in ([ax[0].xaxis.label, ax[0].yaxis.label] +
                 ax[0].get_xticklabels() + ax[0].get_yticklabels()):
        item.set_fontsize(11)

    for ii, target in enumerate(targets):
        ii+=1
        x = [j for j in times for n in range(len(d[target]))]
        y = np.array(d[target]).transpose().flatten()
        df = pd.DataFrame({"Time": x, "Posterior": y})
        sns.lineplot(
            y="Posterior",
            x="Time",
            data=df,
            estimator=np.median,
            ci=95,
            ax=ax[ii])
        paths = []
        bayes_line = mlines.Line2D(
            [], [], color=sns.color_palette()[0], label='Posterior Predictive')
        paths.append(bayes_line)
        ax[ii].set_ylabel(labels[target])
        ax[ii].set_xlabel('Time (sec)')
        ax[ii].title.set_fontsize(11)
        for item in ([ax[ii].xaxis.label, ax[ii].yaxis.label] +
                     ax[ii].get_xticklabels() + ax[ii].get_yticklabels()):
            item.set_fontsize(11)
    props = {"rotation" : 30}
    plt.setp(ax[ii].get_xticklabels(), **props)


    fig.suptitle(title, y=0.9)
    # plt.subplots_adjust(hspace=0.25, right=0.98, bottom=0, top=0.875)
    fig.tight_layout(rect=[0, 0, 1, 0.875])
    # if limit:
    #     fig.suptitle("Simulated output for {} repeats using\ntop {} parameter combinations\n".
    #                  format(n_repeats, limit))
    # elif frac:
    #     fig.suptitle("Simulated output for {} repeats using top {}% of data\n".
    #                  format(n_repeats, frac))
    return fig, ax

## Generating posterior predictive ##

We can sample directly from the posterior to generate our posterior predictive.We then generate a variety of potentially useful summary statistics as well as the residuals, autocorrelation of the signals and autocorrelation of the residuals for each signal.

We also generate each summary statistic for the observed data so as to compare this with the posterior predictive distribution of these statistics.

In [10]:
configuration = {}
model_data_combos = {"LWP475": ["2"],
                    "LWP479": ["2_1"]}
titles = ["Simulated desaturation in a mild HIE piglet",
          "Simulated desaturation in a severe HIE piglet"]
for ii, combo in enumerate([(m,d) for d, l in model_data_combos.items() for m in l]):
    print("Working on (bph{}, {})".format(*combo))
    model_number = combo[0]
    model_name = 'bph{}'.format(model_number)
    DATASET = combo[1]
    configuration[model_name] = {}
    
    configuration[model_name][DATASET] = {}
    config, d0 = load_configuration(model_number, DATASET)
    configuration[model_name][DATASET]['bayescmd_config'] = config
    configuration[model_name][DATASET]['original_data']= d0

    configuration[model_name][DATASET]['posterior_query'] = generate_posterior_query('hypothermia-bayescmd', 
                                                                                     DATASET, 
                                                                                     model_name, 
                                                                                     'NRMSE', 
                                                                                     list(config['parameters'].keys()),
                                                                                     limit=5000)
    figPath = "/home/buck06191/Dropbox/phd/hypothermia/insilico_desat/Figures/{}/{}/{}".format(model_name, DATASET, 'NRMSE')
    dir_util.mkpath(figPath)

    # Get posterior
    print("\tRunning SQL query")
    df_post = client.query(configuration[model_name][DATASET]['posterior_query']).to_dataframe()
    N=500
    print("\tSampling from the posterior {} times.".format(N))
    
    outputs_list = get_runs(df_post, config, n_repeats=N)
    
    print("\n")
    fig, ax = plot_desat(outputs_list, config['targets'], config['times'], titles[ii], labels)
    
  
    fig.savefig(
        os.path.join(figPath, 'posterior_predictive_{}_{}.png'
                     .format(model_name, DATASET)),
        bbox_inches='tight', dpi=250)
    plt.close('all')
    

Working on (bph2, LWP475)
	Running SQL query
	Sampling from the posterior 500 times.
	Sample 0, idx:3340
	Sample 1, idx:3575
	Sample 2, idx:4389
	Sample 3, idx:2136
	Sample 4, idx:1358
	Sample 5, idx:715
	Sample 6, idx:4800
	Sample 7, idx:3689
	Sample 8, idx:3542
	Sample 9, idx:2932
	Sample 10, idx:2729
	Sample 11, idx:335
	Sample 12, idx:3114
	Sample 13, idx:2265
	Sample 14, idx:4706
	Sample 15, idx:437
	Sample 16, idx:1311
	Sample 17, idx:1219
	Sample 18, idx:873
	Sample 19, idx:4605
	Sample 20, idx:1011
	Sample 21, idx:3432
	Sample 22, idx:872
	Sample 23, idx:2909
	Sample 24, idx:3039
	Sample 25, idx:1185
	Sample 26, idx:3268
	Sample 27, idx:4022
	Sample 28, idx:1759
	Sample 29, idx:1212
	Sample 30, idx:1983
	Sample 31, idx:4385
	Sample 32, idx:708
	Sample 33, idx:2742
	Sample 34, idx:1519
	Sample 35, idx:897
	Sample 36, idx:2803
	Sample 37, idx:1418
	Sample 38, idx:379
	Sample 39, idx:4413
	Sample 40, idx:3936
	Sample 41, idx:3456
	Sample 42, idx:1284
	Sample 43, idx:1000
	Sample 4

	Sample 378, idx:4511
	Sample 379, idx:828
	Sample 380, idx:1707
	Sample 381, idx:1654
	Sample 382, idx:2842
	Sample 383, idx:927
	Sample 384, idx:1681
	Sample 385, idx:82
	Sample 386, idx:3512
	Sample 387, idx:2790
	Sample 388, idx:4693
	Sample 389, idx:3032
	Sample 390, idx:3072
	Sample 391, idx:3734
	Sample 392, idx:4527
	Sample 393, idx:1909
	Sample 394, idx:1428
	Sample 395, idx:3608
	Sample 396, idx:1945
	Sample 397, idx:4462
	Sample 398, idx:3356
	Sample 399, idx:693
	Sample 400, idx:582
	Sample 401, idx:227
	Sample 402, idx:576
	Sample 403, idx:4192
	Sample 404, idx:2369
	Sample 405, idx:4127
	Sample 406, idx:2987
	Sample 407, idx:981
	Sample 408, idx:1191
	Sample 409, idx:2377
	Sample 410, idx:4731
	Sample 411, idx:4122
	Sample 412, idx:1206
	Sample 413, idx:4288
	Sample 414, idx:3367
	Sample 415, idx:637
	Sample 416, idx:4215
	Sample 417, idx:1886
	Sample 418, idx:3650
	Sample 419, idx:452
	Sample 420, idx:2888
	Sample 421, idx:2768
	Sample 422, idx:4388
	Sample 423, idx:1198

	Sample 256, idx:51
	Sample 257, idx:4430
	Sample 258, idx:4534
	Sample 259, idx:241
	Sample 260, idx:4731
	Sample 261, idx:4231
	Sample 262, idx:3178
	Sample 263, idx:4509
	Sample 264, idx:4821
	Sample 265, idx:1074
	Sample 266, idx:2897
	Sample 267, idx:790
	Sample 268, idx:3662
	Sample 269, idx:1145
	Sample 270, idx:4407
	Sample 271, idx:2024
	Sample 272, idx:728
	Sample 273, idx:1804
	Sample 274, idx:3195
	Sample 275, idx:4873
	Sample 276, idx:4265
	Sample 277, idx:1917
	Sample 278, idx:2309
	Sample 279, idx:4093
	Sample 280, idx:2308
	Sample 281, idx:4209
	Sample 282, idx:4285
	Sample 283, idx:3382
	Sample 284, idx:4732
	Sample 285, idx:1714
	Sample 286, idx:2430
	Sample 287, idx:952
	Sample 288, idx:3584
	Sample 289, idx:1429
	Sample 290, idx:3434
	Sample 291, idx:2914
	Sample 292, idx:1928
	Sample 293, idx:1211
	Sample 294, idx:944
	Sample 295, idx:4052
	Sample 296, idx:3098
	Sample 297, idx:874
	Sample 298, idx:2347
	Sample 299, idx:860
	Sample 300, idx:3163
	Sample 301, idx:42

In [11]:
get_runs(df_post, config, n_repeats=1)[0].keys()

	Sample 0, idx:391


dict_keys(['t', 'CCO', 'HbO2', 'HHb', 'SaO2sup', 'temp'])