In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pymc3
import pandas as pd
import os
%matplotlib inline

In [None]:
# stack lengths came from the pickaxe output file for this run
def make_chains(s, stack_length_list, n_burn, n_thin):
    thinned_lengths = [(l - n_burn)/n_thin for l in stack_length_list]
    thinned_lengths2 = [np.floor(l) for l in thinned_lengths]
    smallest_length = np.min(thinned_lengths2)
    cum_lengths = np.cumsum(thinned_lengths2)
    chains = []
    start = 0
    for end in cum_lengths:
        data = s[start:end]
        chains.append(data[:smallest_length - 1])
        #plt.plot(data)
        start = end
    chain_arr = np.stack(chains, axis = 1).T
    return(chain_arr)

def make_stat_table(
    samples, stat_func, stat_name,
    stack_length_list, n_burn, n_thin,
    param_key_list, layer_title_list, param_idx_list, param_title_list,
    fpath_out = None, table_version = 1,
):
    if table_version == 1:
        table = pd.DataFrame(0, index = layer_title_list, columns = param_title_list)
    elif table_version == 2:
        index = ['{} {}'.format(lt, pt) for lt in layer_title_list for pt in param_title_list]
        table = pd.DataFrame(0, index = index, columns = [stat_name])
    for samples_key, layer_title in zip(param_key_list, layer_title_list):
        print(samples_key, layer_title)
        for param_idx, param_title in zip(param_idx_list, param_title_list):
            print(param_idx, param_title)
            param = samples[samples_key][:, param_idx]
            chain_arr = make_chains(param, stack_length_list, n_burn, n_thin)
            name = '{}-param{}'.format(samples_key, param_idx)
            stat = stat_func(chain_arr)
            if table_version == 1:
                table.loc[layer_title, param_title] = stat
            elif table_version == 2:
                idx = '{} {}'.format(layer_title, param_title)
                table.loc[idx, stat_name] = np.around(stat, decimals = 4)
    if fpath_out:
        table.to_csv(fpath_out)
    return(table)
    
def plot_chain_geweke(chain_arr, name):
    score_list = []
    for chain_idx in range(chain_arr.shape[0]):
        scores = pymc3.diagnostics.geweke(chain_arr[chain_idx,:])
        score_list.append(scores)
    new_name = '{}.eps'.format(name, chain_idx)
    fpath_out = os.path.join(parent_dir, new_name)
    #f = plt.figure()
    #plt.plot(scores[:,0], scores[:,1])
    plot_geweke(score_list, fpath_out)
        
def plot_geweke(
    data,
    fpath_out = None,
    plot_title = None,
    thin_amount = None,
    lw = 0.5
):
    print(fpath_out)
    try:
        plt.rc('text', usetex = True)
        plt.rc('font', family = 'serif')
    except:
        pass
    x_label = 'MCMC iteration'
    y_label = 'Parameter value'
    f = plt.figure(figsize = (5, 5))
    ax = plt.gca()
    if isinstance(data, list):
        print(len(data))
        for data_i in data:
            _ = ax.plot(data_i[:,0], data_i[:,1], marker='o')
    else:
        _ = ax.plot(data[:,0], data[:,1], marker='o', color='black')

    """# x axis
    start = 0
    end = data.shape[0]
    middle = (end - start) / 2
    new_xt = [start, middle, end]
    ax.set_xticks(new_xt)
    ax.set_xticklabels([int(num) for num in new_xt])
    #ax.set_xlim([start, end])
    _ = plt.xticks(fontsize = fontsize)
    _ = plt.xlabel(x_label, fontsize = fontsize)

    # y axis
    start = data.min()
    end = data.max()
    middle = (end + start) / 2
    new_yt = [start, middle, end]
    new_yt = np.around(new_yt, decimals = 2)
    ax.set_yticks(new_yt)
    ax.set_yticklabels(new_yt)
    _ = plt.ylabel(y_label, fontsize = fontsize)"""

    #ax.set_ylim([start, end])
    plt.tight_layout()
    if plot_title: plt.title(plot_title, fontsize = fontsize)
    plt.tight_layout()
    plt.savefig(fpath_out)
    plt.clf()

def plot_heatmap(
    data,
    fig_width = 5, fig_height = 5,
    cmap = plt.get_cmap('gray'),
    fpath = None
):
    try:
        plt.rc('text', usetex = True)
        plt.rc('font', family = 'serif')
    except:
        pass
    fig = plt.figure(figsize=(fig_height, fig_width))
    ax = plt.gca()
    plt.imshow(data, cmap)
    plt.colorbar()

    # Major ticks
    n = data.shape[0]
    ax.set_xticks(np.arange(0, n, 1));
    ax.set_yticks(np.arange(0, n, 1));

    # Labels for major ticks
    ax.set_xticklabels(np.arange(1, n+1, 1));
    ax.set_yticklabels(np.arange(1, n+1, 1));

    # Minor ticks
    ax.set_xticks(np.arange(-.5, n, 1), minor=True);
    ax.set_yticks(np.arange(-.5, n, 1), minor=True);

    # Gridlines based on minor ticks
    ax.grid(which='minor', color='w', linestyle='-', linewidth=2)
    if fpath: plt.savefig(fpath)

In [None]:
# rock property config
#rp_layer_title_list = ['Layer 0', 'Layer 1']
rp_layer_title_list = ['Halfway Gneiss', 'Durlacher Supersuite']
rp_param_key_list = ['layer0rockProperties', 'layer1rockProperties']
rp_param_title_list = ['Rock Density', 'Log Susceptibility']
rp_param_idx_list = [0, 1]

In [None]:
# control point config
cp_layer_title_list = ['Layer 1']
cp_param_key_list = ['layer1ctrlPoints']

cp_param_idx_list = list(range(25))
cp_param_title_list = ['Control point {}'.format(num) for num in cp_param_idx_list]

In [None]:
samples.keys()

In [None]:
#stack_length_list = [1160161, 1154009, 1129264, 1111255]
stack_length_list = [1592406, 1551447, 1567431, 1555443, 1566432, 1593405]

table_func_list = [
    np.mean,
    lambda x: np.percentile(x, 5),
    lambda x: np.percentile(x, 95),
    pymc3.diagnostics.effective_n,
    pymc3.diagnostics.gelman_rubin,
]
table_name_list = [
    'mean',
    '5th percentile',
    '95th percentile',
    'effective_n',
    'rhat',
]
table_fancy_name_list = [
    'Mean',
    '5th Percentile',
    '95th Percentile',
    'Effective n',
    'Rhat',
]
plot_func_list = [
    plot_chain_geweke
]
plot_func_name_list = ['geweke']
n_burn = 1000
n_thin = 1000
parent_dir = '/Volumes/david_hd/obsidian/output/experiments/11_15_2018/01'
#parent_dir = '/Volumes/david_hd/obsidian/output/experiments/11_10_2018/01'
fontsize = 20

In [None]:
#fpath = '/Volumes/david_hd/obsidian/output/experiments/gascoyne_v5_run03/gascoyne_v5-rs-run03-thin10000.npz'
fpath = '/Volumes/david_hd/obsidian/output/experiments/11_15_2018/01/output0.npz'
#fpath = '/Volumes/david_hd/obsidian/output/experiments/11_15_2018/01/output.npz'

samples = np.load(fpath)
print(samples.keys())

In [None]:
#s1 = samples['layer0rockProperties'][:,0]
s2 = samples['layer1rockProperties'][:,1]
#plt.hist(s1)
plt.hist(s2)

In [None]:
s1 = samples['layer0rockProperties'][:,0]
s2 = samples['layer1rockProperties'][:,0]
plt.hist(s1)
plt.hist(s2)

In [None]:
(s1 == s2).sum()

In [None]:
plt.plot(s1)
plt.plot(s2)

# Geweke plot

In [None]:
samples_key

In [None]:
for plot_func, plot_name, in zip(plot_func_list, plot_func_name_list):
    for samples_key, layer_title in zip(rp_layer_key_list, rp_layer_title_list):
        for param_idx, param_title in zip(rp_param_idx_list, rp_param_title_list):
            param = samples[samples_key][:, param_idx]
            chain_arr = make_chains(param, stack_length_list, n_burn, n_thin)
            name = '{}-param{}'.format(samples_key, param_idx)
            plot_chain_geweke(chain_arr, name)

## rock property tables

In [None]:
parent_dir

In [None]:
table_version = 2
table_list = []
for table_func, table_name, fancy_name in zip(table_func_list, table_name_list, table_fancy_name_list):
    fpath_out = os.path.join(parent_dir, 'rp-' + table_name + '-table.csv')
    table = make_stat_table(
        samples, table_func, fancy_name,
        stack_length_list, n_burn, n_thin,
        rp_param_key_list, rp_layer_title_list, 
        rp_param_idx_list, rp_param_title_list,
        table_version = table_version, fpath_out = fpath_out
    )
    table_list.append(table)
big_table = pd.concat(table_list, axis = 1)
big_table_name = 'rp-diagnostics.csv'
big_table_fpath = os.path.join(parent_dir, big_table_name)
big_table.to_csv(big_table_fpath)

## control point tables

In [None]:
table_version = 2
table_list = []
for table_func, table_name, fancy_name in zip(table_func_list, table_name_list, table_fancy_name_list):
    fpath_out = os.path.join(parent_dir, 'cp-' + table_name + '-table.csv')
    table = make_stat_table(
        samples, table_func, fancy_name,
        stack_length_list, n_burn, n_thin,
        cp_param_key_list, cp_layer_title_list, 
        cp_param_idx_list, cp_param_title_list,
        table_version = table_version, fpath_out = fpath_out
    )
    table_list.append(table)
big_table = pd.concat(table_list, axis = 1)
big_table_name = 'cp-diagnostics.csv'
big_table_fpath = os.path.join(parent_dir, big_table_name)
big_table.to_csv(big_table_fpath)

## Assessing convergence of the control points

In [None]:
samples_key = 'layer1ctrlPoints'
data = samples[samples_key]
ctrlpoint_x = 5
ctrlpoint_y = 5
no_samples = data.shape[0]
cp = np.reshape(samples['layer1ctrlPoints'], (ctrlpoint_x, ctrlpoint_y, no_samples))
cp_mean = cp.mean(axis = 2)
cp_effective_n = np.reshape(big_table['Effective n'].values, (5,5,)) 

In [None]:
plot_heatmap(cp_mean)

In [None]:
plot_heatmap(cp_effective_n)