# Impact of the Prior Parameters
This notebook investigates the impact that the use of the gamma conjugate prior
with informed parameters has on the evaluation of heat flow anomalies. To this
end, "sample locations" are drawn from $R=80\,\mathrm{km}$ disks with a straight
line fault splitting the disk in half.

The signature of the anomaly follows Lachenbruch & Sass (1980) with a linear
increase in heat production with depth. The length of the fault is
$160\,\mathrm{km}$ and the depth $14\,\mathrm{km}$ (currently hardcoded to the
backend C++ code `external/zeal2022hfpc/src/resilience.cpp`).

In [None]:
import json
import numpy as np
from pathlib import Path
from plotconfig import *
from cmcrameri.cm import *
from cache import cached_call
from itertools import product
from cmocean.tools import crop
import matplotlib.pyplot as plt
from scipy.spatial import KDTree
from matplotlib.lines import Line2D
from zeal2022hf import get_cm_colors
from reheatfunq import GammaConjugatePrior
from scipy.ndimage import gaussian_filter
from reheatfunq.resilience import test_performance_cython, \
                                  test_performance_mixture_cython

### Results from previous notebooks

In [None]:
PRIOR_P, PRIOR_S, PRIOR_N, PRIOR_V = np.loadtxt('results/05-GCP-Parameters.txt', skiprows=1, delimiter=',')

In [None]:
gcp = GammaConjugatePrior(PRIOR_P, PRIOR_S, PRIOR_N, PRIOR_V)

In [None]:
with open('results/03-gamma-conjugate-prior-results.json','r') as f:
    GCP_MLE_AB = np.array([x[2] for x in json.load(f)])

## Test the capabilities on a synthetic example

In [None]:
!ls results

In [None]:
rng = np.random.default_rng()

In [None]:
M = 1000
#M = 10
PTEST_MW = [10]
BETA  = np.geomspace(2e-2, 15.0, 100)
ALPHA = np.geomspace(1, 1000, 101)

In [None]:
def evaluate_prior_vs_flat(PTEST_MW, N, M, quantile, ALPHA, BETA, prior_p, prior_s, prior_n, prior_v,
                           gcp_significance_level=1e-7, seed=289817):
    """
    Evaluate the performance of the prior vs. the flat prior.
    """
    # Find the significant level of alpha and beta:
    Z = np.zeros((ALPHA.size, BETA.size))
    ag,bg = np.meshgrid(ALPHA, BETA, indexing='ij')
    assert Z.shape == ag.shape
    Z.flat = gcp.probability(ag.reshape(-1), bg.reshape(-1))
    gcp_significance_mask = Z >= gcp_significance_level
    
    ag,bg = np.meshgrid(ALPHA,BETA,indexing='ij')
    res = np.zeros((len(PTEST_MW), ALPHA.size, BETA.size, len(quantile), 2))
    rng = np.random.default_rng(seed)
    seeds = rng.integers(2**63, size=len(PTEST_MW) * ALPHA.size * BETA.size)
    k = -1
    for p,PMW in enumerate(PTEST_MW):
        print("    ---- P =",PMW,"MW ----")
        for i,j in product(range(ALPHA.size),range(BETA.size)):
            k += 1
            print("alpha:",ALPHA[i],",  beta:",BETA[j])
            if not gcp_significance_mask[i,j]:
                res[p,i,j,:,:] = np.NaN
                continue
            res_ij = cached_call(test_performance_cython, np.array([N]), M, PMW,
                                 ALPHA[i], 1.0/BETA[j], quantile,
                                 prior_p, prior_s, prior_n, prior_v,
                                 seed=seeds[k])[:,0,:,:]
            res[p,i,j,:,0] = np.median(res_ij[0,:], axis=1)
            res[p,i,j,:,1] = np.median(res_ij[1,:], axis=1)
    
    return res

In [None]:
QUANTILES_GRID = np.array([0.01])
N = 10

In [None]:
p_select = PTEST_MW.index(10)

In [None]:
Nset = np.unique(np.round(np.geomspace(10, 50, 25)).astype(int))
M2 = 100000

In [None]:
USECACHE = True

In [None]:
if USECACHE:
    if Path('intermediate/A3-prior-vs-flat.pickle').is_file():
        with open('intermediate/A3-prior-vs-flat.pickle', 'rb') as f:
            res = Unpickler(f).load()
    else:
        res = evaluate_prior_vs_flat(PTEST_MW, N, M, QUANTILES_GRID, ALPHA, BETA, PRIOR_P,
                                     PRIOR_S, PRIOR_N, PRIOR_V, 1e-6)

In [None]:
if USECACHE:
    k0 = 7.0
    t0 = 1.0/0.1
    k1 = 110
    t1 = 1.0/0.7
    quants = np.array([0.01])
    res_with_N_1 = cached_call(test_performance_cython, Nset, M2, PTEST_MW[p_select], k0, t0, quants,
                               PRIOR_P, PRIOR_S, PRIOR_N, PRIOR_V, seed=12409035)[:,:,0,:]

    res_with_N_2 = cached_call(test_performance_cython, Nset, M2, PTEST_MW[p_select], k1, t1, quants,
                               PRIOR_P, PRIOR_S, PRIOR_N, PRIOR_V, seed=187579, nthread=12)[:,:,0,:]


In [None]:
colors = get_cm_colors(vik, 7)
color0 = colors[0]
color1 = colors[1]
color2 = colors[4]
color3 = colors[5]

In [None]:
mask = ~np.isnan(res[p_select,:,:,0,0])

color = (  np.abs(res[p_select,:,:,0,0] - PTEST_MW[p_select]*1e6)
         - np.abs(res[p_select,:,:,0,1] - PTEST_MW[p_select]*1e6)) / (PTEST_MW[p_select]*1e6)*1e2

# Fill the contour field with dummy values to achieve continous contours:
ag, bg = np.meshgrid(ALPHA, BETA, indexing='ij')
assert ag.shape == color.shape
contour = np.abs(res[p_select,:,:,0,1] - PTEST_MW[p_select]*1e6) / (PTEST_MW[p_select]*1e6)*1e2
contour[~mask] = contour[mask][KDTree(np.stack((ag[mask], bg[mask]), axis=1))
                                  .query(np.stack((ag[~mask], bg[~mask]), axis=1), 1)[1]]
contour = gaussian_filter(contour, 1.1)
contour[~mask] = np.NaN

vmin = color[mask].min()
vmax = 100.

if vmin > -vmax:
    cmap = crop(broc, vmin, vmax, 0.0)
else:
    cmap = broc
    vmax = -vmin

    
with plt.rc_context({'axes.labelpad': 2.5, 'xtick.major.pad': 1.2, 'ytick.major.pad': 1.2}):
    fig = plt.figure(figsize=(5.8, 4.9))

    ax0 = fig.add_axes((0,0,1,1), zorder=3, facecolor='none')
    ax0.set_xlim(0,1)
    ax0.set_ylim(0,1)
    ax0.set_axis_off()

    #
    # The main plot:
    #
    ax = fig.add_axes((0.09, 0.08, 0.79, 0.915))
    cax = fig.add_axes((0.89, 0.2, 0.02, 0.6))
    h = ax.pcolormesh(ALPHA, BETA, color.T, vmin=vmin, vmax=vmax, cmap=cmap, rasterized=True)
    cntr = ax.contour(ALPHA, BETA, contour.T, levels=[100, 200, 500, 1000, 2000, 5000],
                      colors='k', linewidths=1)
    ax.set_xscale('log')
    ax.set_yscale('log')
    h5 = ax.scatter(*GCP_MLE_AB.T, s=5, facecolor=colors[4], edgecolor='k', linewidth=0.5, zorder=3)
    cbar = fig.colorbar(h, cax=cax, extend='max')
    cbar.set_label('1 % tail quantile change when choosing informed prior,\nrelative to true $P_H$ (%)')
    h4 = ax.clabel(cntr, manual=[(1.3, 0.15), (1.3, 0.25), (1.3, 0.05), (1.3, 0.03), (4, 0.08), (3, 0.03)],
                   inline=True, colors='k',
                   fmt = "%d %%")

    for lbl in cbar.ax.get_yticklabels():
        lbl.set_rotation(90)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('$\\alpha$')
    ax.set_ylabel('$\\beta$')
    
    # The locations of the two inlay parameters:
    ax.scatter(k1, 1.0/t1, s=20, c='w', marker='x')
    ax.scatter(k0, 1.0/t0, s=20, c='k', marker='x')
    
    
    # The two inlays showing the posterior as a function of N:
    with plt.rc_context({**{key : 'k' for key in ['axes.labelcolor','axes.edgecolor','xtick.color',
                                                  'ytick.color']},
                         **{'axes.facecolor' : 'w', 'font.size' : 6}}):
        ax1 = fig.add_axes((0.513, 0.135, 0.31, 0.27))
        h0 = ax1.plot(Nset, 1e-2*PTEST_MW[p_select] * np.ones_like(Nset), label='True', color='tab:red',
                      linewidth=1.0, linestyle=':')
        h1 = ax1.plot(Nset, 1e-8*np.median(res_with_N_1[0,:,:], axis=1), label='Informed 1%',
                      color=color0, linewidth=0.7)
        h2 = ax1.plot(Nset, 1e-8*np.median(res_with_N_1[1,:,:], axis=1), label='Flat 1%', linestyle='-',
                      color=color1, linewidth=0.7)
        i10 = int(np.argwhere(Nset == N).flat[0])
        ax1.plot((N,N), (1e-8*np.median(res_with_N_1[1,i10,:]), 1e-8*np.median(res_with_N_1[0,i10,:])),
                 marker='o', linewidth=0.7, markerfacecolor='none', markeredgecolor='k',
                 markersize=4, linestyle='-', color='k')
        ax1.set_xlabel('$N$', labelpad=1)
        twinx = ax1.twinx()
        h3 = twinx.plot(Nset,
                        100*(np.median(res_with_N_1[0,:,:], axis=1) - np.median(res_with_N_1[1,:,:], axis=1))
                             /(np.median(res_with_N_1[1,:,:], axis=1) - 1e6*PTEST_MW[p_select]),
                        linestyle='-', color=color3, linewidth=0.7)
        for lbl in twinx.get_yticklabels():
            lbl.set_rotation(90)
            lbl.set_va('center')
        j = 0
        ax1.tick_params(axis='y', colors='w')
        ax1.get_yticklabels()[0].set_color('k')

        ax1.set_ylabel('1 % t. q. ($100\,\mathrm{MW}$)', color='w')
        ax1.set_ylim(0, ax1.get_ylim()[1])
        ax1.text(30, 2.9, '(c)', fontsize=8, ha='center', va='center')
        twinx.set_ylabel('1 % t. q. change (%)')
        ax0.annotate('', (0.24, 0.943), (0.625, 0.57),
                  arrowprops=dict(arrowstyle='<|-',
                                  shrinkA=7,
                                  shrinkB=2,
                                  fc="k", ec="k",
                                  connectionstyle="arc3,rad=0.3",
                                  ),)
        
        ax2 = fig.add_axes((0.23, 0.765, 0.25, 0.22))
        ax2.plot(Nset, 1e-2*PTEST_MW[p_select] * np.ones_like(Nset), label='True', color='tab:red', linewidth=1.0,
                 linestyle=':')
        ax2.plot(Nset, np.median(1e-8*res_with_N_2[0,:,:], axis=1), label='Informed 1 %', linestyle='-',
                 color=color0, linewidth=0.7)
        ax2.plot(Nset, np.median(1e-8*res_with_N_2[1,:,:], axis=1), label='Flat 1 %', linestyle='-',
                 color=color1, linewidth=0.7)
        # Circular markers marking the N=8 level:
        ax2.plot((N,N), (1e-8*np.median(res_with_N_2[1,0,:]), 1e-8*np.median(res_with_N_2[0,0,:])),
                 marker='o', linewidth=0.7, markerfacecolor='none', markeredgecolor='k',
                 markersize=4, linestyle='-', color='k')
        ax2.set_ylabel('1 % tail quantile\n($100\,\mathrm{MW}$)')
        ax2.set_xlabel('$N$', labelpad=-2.5)
        ax2.set_ylim(0, ax2.get_ylim()[1])
        ax2.text(49, 3.6, '(b)', fontsize=8, ha='center', va='center')
        
        # Arrow 2:
        ax0.annotate('', (0.525, 0.397), (0.31, 0.30), #(0.545, 0.42), (0.5, 0.595),
                  arrowprops=dict(arrowstyle='<|-',
                                  shrinkA=7,
                                  shrinkB=2,
                                  fc="k", ec="k",
                                  connectionstyle="arc3,rad=-0.3",
                                  ),)

        ax.legend(handles=(h0[0], h1[0], h2[0], h3[0], Line2D([],[], color='k', linewidth=1.0), h5),
                   labels=('Anomaly', 'Informed 1 %', 'Flat 1 %', '1 % t.q. change\n(relative to uninformed)',
                           '1 % t. q. overestimate of\nanomaly when using\nuninformed prior',
                           'RGRDC MLE'),
                   loc='center right', framealpha=1.0, handlelength=1)
    
    ax0.text(, 3.6, '(b)', fontsize=8, ha='center', va='center')
    
    fig.savefig('figures/A3-Prior-Performance-vs-Uninformed-Synthetic-Gamma.pdf')

### References
>  Lachenbruch, A. H., and Sass, J. H. (1980), Heat flow and energetics of the San Andreas Fault Zone, J. Geophys. Res., 85( B11), 6185– 6222, [doi:10.1029/JB085iB11p06185](https://dx.doi.org/10.1029/JB085iB11p06185). 

### License
```
A notebook to evaluate impact that the default gamma conjugate
prior has onto the constraining of heat flow anomalies.

This file is part of the REHEATFUNQ model.

Author: Malte J. Ziebarth (ziebarth@gfz-potsdam.de)

Copyright © 2019-2022 Deutsches GeoForschungsZentrum Potsdam,
            2022 Malte J. Ziebarth
            

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
```