Skip to content

Commit

Permalink
Add validation plot after Bastos and O'Hagan (2009), along with examp…
Browse files Browse the repository at this point in the history
…le (#20)

* Add validation plot after Bastos and O'Hagan (2009), along with example

* Generalise the reshaping of the input parameters

* Delete Validation_emulating_using_GPs.ipynb

Remove example notebook in favour of adding the plot to the existing notebook

* Implement Duncan's suggested changes

Check for pd.DataFrame when setting labels in the validation plot
Let validation_plot_bastos use validation_plot to minimize redundant
code

* Reformat docstring

Co-authored-by: Duncan Watson-Parris <duncan@watson-parris.co.uk>
Co-authored-by: Duncan Watson-Parris <duncan.watson-parris@physics.ox.ac.uk>
  • Loading branch information
3 people committed Sep 30, 2021
1 parent 181d7ca commit 98de068
Showing 1 changed file with 85 additions and 2 deletions.
87 changes: 85 additions & 2 deletions esem/utils.py
@@ -1,6 +1,7 @@
import numpy as np
import tensorflow as tf
from tqdm.auto import tqdm
import pandas as pd


def add_121_line(ax):
Expand Down Expand Up @@ -54,9 +55,10 @@ def prediction_within_ci(test_mean, pred_mean, pred_var, ci=0.95):
return lower, upper, within


def validation_plot(test_mean, pred_mean, pred_var, figsize=(7, 7), minx=None, miny=None, maxx=None, maxy=None):
def validation_plot(test_mean, pred_mean, pred_var, figsize=(7, 7), minx=None, miny=None, maxx=None, maxy=None, ax=None):
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, figsize=figsize)
if ax is None:
fig, ax = plt.subplots(1, figsize=figsize)

# Deal with input arrays that might be masked
if isinstance(test_mean, np.ma.MaskedArray):
Expand Down Expand Up @@ -90,6 +92,87 @@ def validation_plot(test_mean, pred_mean, pred_var, figsize=(7, 7), minx=None, m
ax.set_ylim([miny, maxy])


def validation_plot_bastos(X_test, Y_test, m_test, v_test):
"""
Validation plot following Bastos and O'Hagan (2009)
Source:
Bastos and O'Hagan (2009): Diagnostics for Gaussian Process Emulators, Technometrics,
51, 425-438. https://doi.org/10.1198/TECH.2009.08019
Code for the lower-right plot adapted from the ESEm package validation_plot() (see above)
Author:
Ulrike Proske (ulrike.proske@env.ethz.ch)
Parameters
----------
X_test : array-like of shape (n_samples, n_features)
Input data
Y_test : array-like of shape (n_samples,)
Simulated output
m_test : array-like of shape (n_samples, n_features)
Emulator output
v_test : array-like of shape (n_samples,)
Variance of emulator
"""

import matplotlib.pyplot as plt
from statsmodels.compat.python import lzip
import statsmodels.api as sm
from scipy import stats

# Namelist
c_black = 'black'
c_blue = '#1f78b4'
c_green = '#33a02c'
c_orange = '#ff7f00'
c_purple = '#6a3d9a'
colors = [c_blue, c_green, c_orange, c_purple]
alpha = 0.75

# Start plotting
_, axs = plt.subplots(nrows=2, ncols=2, figsize=(4.5,4.5),
gridspec_kw={'hspace': 0.35, 'wspace': 0.75})
errors_std = (Y_test - m_test)/np.sqrt(v_test) # standardized errors
axs[0, 0].scatter(m_test, errors_std, c=c_black, marker='.', alpha=alpha)
axs[0, 0].set_xlabel(r'$Y_{\mathrm{emu}}$')
axs[0, 0].set_ylabel(r'$({Y_{\mathrm{sim}} - Y_{\mathrm{emu}})}/{\sqrt{V}}$')
# customize qq plot
pp = sm.ProbPlot(errors_std.ravel(), stats.t, fit=True)
qq_plot = pp.qqplot(marker='.', markerfacecolor='k',
markeredgecolor='k', alpha=alpha, ax=axs[1, 0])
end_pts = lzip(axs[1, 0].get_xlim(), axs[1, 0].get_ylim())
sm.qqline(qq_plot.axes[2], line='45', fmt='k--')
axs[1, 0].set_xlim([end_pts[0][0], end_pts[1][0]])
axs[1, 0].set_ylim([end_pts[0][1], end_pts[1][1]])
axs[1, 0].set_ylabel('Standardized quantiles')

for i in range(0, np.shape(X_test)[1]):
# Slightly convoluted way to expand the parameters to match the shape of the outputs
expanded_params = np.broadcast_to(np.expand_dims(X_test.to_numpy()[:, i], axis=[i for i in range(1, len(errors_std.shape))]),
errors_std.shape)
if isinstance(X_test, pd.DataFrame):
axs[0, 1].scatter(expanded_params, errors_std, c=colors[i], label=X_test.columns[i], marker='.', alpha=alpha)
else:
axs[0, 1].scatter(expanded_params, errors_std, c=colors[i], label=str(i), marker='.', alpha=alpha)

axs[0, 1].legend()
axs[0, 1].set_xlabel(r'$\eta_i$')
axs[0, 1].set_ylabel(r'$({Y_{\mathrm{sim}} - Y_{\mathrm{emu}})}/{\sqrt{V}}$')

# add hlines
axs[0, 1].axhline(y=-2, c=c_black, linestyle='--')
axs[0, 1].axhline(y=2, c=c_black, linestyle='--')
axs[0, 1].axhline(y=0, c=c_black, linestyle='--')
axs[0, 0].axhline(y=-2, c=c_black, linestyle='--')
axs[0, 0].axhline(y=2, c=c_black, linestyle='--')
axs[0, 0].axhline(y=0, c=c_black, linestyle='--')

validation_plot(Y_test, m_test, v_test, ax=axs[1, 1])


def plot_parameter_space(df, nbins=100, target_df=None, smooth=True,
xmins=None, xmaxs=None, fig_size=(8, 6)):
from itertools import repeat
Expand Down

0 comments on commit 98de068

Please sign in to comment.