In [None]:
%run ../tools/autoipy.py

import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import collections
import numpy as np
import math
import emcee
import pickle
import os
import time
from pathlib import Path
import matplotlib.cm as cmx
import matplotlib.colors as colors
from matplotlib import patches
# from scipy.stats import circmean, circstd
import matplotlib as mpl
from matplotlib import rc
from matplotlib.font_manager import FontProperties
inline_rc = dict(mpl.rcParams)

cwd = os.getcwd()
# os.chdir('../falsepos')
# import random_planets as plz
# os.chdir(cwd)

# FIXED VALUES

mass_sun=1 # solar mass
mass_pl=1 # earth mass

mass_sun_kg = 1.989e30
mass_earth_kg = 5.972e24
G_const = 6.67408e-11
mu_earth = G_const * (mass_sun_kg + mass_earth_kg)
AU2m = 1.496e11

D_tel = 10 # m
wl = 0.5e-6 # micron
IWA = 3 * wl/D_tel * 206265 # arcsec
OWA = 10 * wl/D_tel * 206265 # arcsec

In [1]:
# Read from file
def read_results(fname):
    f = open(fname, 'rb')
    samples_epochs = pickle.load(f)  # variables come out in the order you put them in
    f.close()
    return samples_epochs

def read_truths(direc):
    cwd = os.getcwd()
    f = open(str(cwd)+'/'+direc+'/truths_iters.dat', 'rb')
    s = pickle.load(f)  # variables come out in the order you put them in
    f.close()
    return s

def read_obvs(direc):
    cwd = os.getcwd()
    f = open(str(cwd)+'/'+direc+'/obvs_iters.dat', 'rb')
    s = pickle.load(f)  # variables come out in the order you put them in
    f.close()
    return s

def read_retrievals(direc, n_ep, n_it, ndims, d, iter_start=0, sigma=1, prints=False):
    print('reading', sigma, 'sigma result from', direc)
    cwd = os.getcwd()
#     s=np.zeros((3, n_ep, 3, n_it))
    s=[]
    for ii in range(iter_start, iter_start+n_it):
        try:
            if sigma==1:
                f = open(str(cwd)+'/'+direc+'/d'+str(d)+'fitted'+str(ii)+'.dat', 'rb')
            elif sigma==2:
                f = open(str(cwd)+'/'+direc+'/d'+str(d)+'fitted'+str(ii)+'_95.dat', 'rb')
            y_data = pickle.load(f)  # variables come out in the order you put them in
#             print('     ', ii, 'y_data loaded', y_data)
            f.close()
            s.append(y_data)
            if prints and (ii==n_it-1):
                print('\ni=',ii,'/n retrieved a=', y_data[0])
        except FileNotFoundError:
            print(ii, 'not found, moving on -- also truths may no longer correspond')
            
            
 

    
#     print('s', np.shape(s))
   
    s = np.reshape(s, (n_it, ndims, n_ep, 3)) #s (10, 6, 1, 5, 3) (10,0,5,3) # (18, 6, 1, 5, 3)

    return s # shape (n_iter, ndims, n_epoch, 3)

In [7]:

def maxlikelihood(chain, lnprob, x_true=None, y_true=None, aproj_true=None, s=None, cadence=None, n=None, 
                  i_det=None, i_nondet=None, a_IWA=None, steps=None, returnmore=False, multiple=False):
    # chain is flattened, shape (steps*nwalkers, ndim)
#     samples_lnprob = read_results(direc+'/d'+str(d)+'samples_epochs'+str(i_it)+'_lnprob.dat')
    from mc_orbit_good import lnlike
    
    if multiple:
        llshh = []
        i_max = np.argwhere(lnprob == np.amax(lnprob))
#         print('chain', np.shape(chain))
        for ii in i_max:
            th = chain[ii]
            theta = th[0]
#             print(ii,': theta', theta)
            ll = lnlike(theta, x_true, y_true, aproj_true, s, cadence, n+1, i_det, i_nondet, a_IWA)
#             print('       lnlike=', ll)
            llshh.append(ll)
        maxl = np.argmax(llshh)
#         print('max ll', llshh[maxl], ' @ ', i_max[maxl])
        max_prob = chain[i_max[maxl]][0]
        val = np.amax(lnprob)
        i_max_prob = i_max[maxl]
        
        
    
    else:
        i_max_prob = np.argmax(lnprob)
        max_prob = chain[i_max_prob]
        val = lnprob[i_max_prob]
    
    if not returnmore:
        return max_prob # all dimensions
    else:
        return max_prob, val, i_max_prob

In [8]:
def get_orbits(direc, ndim, n_it, n_ep, d, ML=False, quantiles=[16, 50, 84], ep_start=0,
                    fname='f.dat', samples_epochs=None, i_it=0):
    # load samples of walkers from each iteration and retrieve per-epoch fit

    fname_base='d'+str(d)+'samples_epochs'
    
    if samples_epochs is None:
        chains_epochs = read_results(direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'.dat')
        print('chains_epochs', np.shape(chains_epochs))
        samples_epochs=[]
        for i_ep in range(n_ep):
            chains = chains_epochs[i_ep]
            samples = chains[:, 1000:, :].reshape((-1, ndim)) # discard first 1000 steps
            samples_epochs.append(samples)

    y_data_dims=[]
    for idim in range(ndim):   
        y_data = retrieve_fits(n_it, n_ep, idim, direc, fname_base, d, fend='.dat', ep_start=ep_start,
                          results=samples_epochs, s_percentile=quantiles, ML=ML, it_start=i_it, fname=fname)
            
        y_data_dims.append(y_data)
    save_fits(y_data_dims, direc, d, fname)
    # save the [32, 50, 68]th percentiles on retrieved accuracy to file for this dim, cadence, ecc, for all epochs, iterations

    return y_data_dims # shape (ndims, n_iter, n_epoch, 3)

def save_fits(data, direc, d, name):
    cwd = os.getcwd()
    fname = str(cwd)+'/'+direc+'/d'+str(d)+name
    print('pickling to ', fname)
    with open(fname, 'wb') as f:
        pickle.dump(data, f)   
        
def retrieve_fits(n_iters, n_epochs, idim, direc, fname_base, d, ML=False, fend='.dat', steps=None,
                      results=None, s_percentile=[16, 50, 84], it_start=0, ep_start=0, fname=''):
    # get median and percentiles of fit for all iterations, per epoch
    # returns data in reverted parameterizations (a, i)
    
    ndims=6
    if steps is None:
        steps = slice(0, -1)
        
    y_data = []
#     std_data = []

    for ii in range(it_start, n_iters+it_start):
#         print('ii', ii)
        if results is None:
            # extract data
            print('reading from ',direc+'/'+fname_base+str(ii)+fend)
            samples_epochs = read_results(direc+'/'+fname_base+str(ii)+fend)
        else:
            samples_epochs=results
            print('results', np.shape(results))

      


        y_ep = []
#             std_ep = []
#             xo = list(map(np.ndarray, samples_epochs))
        samples_1dim = extract1dim(samples_epochs, idim, ep_start, n_epochs)
#         print(' samples_1dim',np.shape(samples_1dim))
     
        try:
            if idim==0:
                samples_1dim = np.exp(samples_1dim) # because mc works in ln(a)
            elif idim==1: # invert cosine of i
                samples_1dim = np.arccos(samples_1dim)
            elif (idim==3) or (idim==5): # fold by 180 deg
                samples_1dim = samples_1dim % np.pi
        except:
#             print('samples_1dim', type(samples_1dim), np.shape(samples_1dim))
            samples_1dim2 = []
            #for i_ep in range(ep_start, n_epochs):
            for i_ep in range(n_epochs-ep_start):
                ep = samples_1dim[i_ep]
#                 print('ep', i_ep)
#                 print('samples_1dim[i_ep]', np.shape(samples_1dim[i_ep]))
                if idim==0:
                    ep = np.exp(ep) # because mc works in ln(a)
                elif idim==1: # invert cosine of i
                    ep = np.arccos(ep)
                elif (idim==3) or (idim==5): # fold by 180 deg
                    ep = ep % np.pi
                samples_1dim2.append(ep) 
            samples_1dim = samples_1dim2
                
                

        for iep in range(n_epochs-ep_start):
#             print('iep', iep)
#             if results[iep][0] == -99:
#                 f = open(str(cwd)+'/'+direc+'/d'+str(d)+fname, 'rb')
#                 fitted_saved = pickle.load(f)  # variables come out in the order you put them in
#                 f.close()
#                 print('fitted_saved', np.shape(fitted_saved))
#                 q = fitted_saved[iep]
#             else:
            q = np.percentile(samples_1dim[iep][steps], q=s_percentile, axis=0)
#             print('q', q)
            if ML:
                fname = direc+'/d'+str(d)+'chains_epochs'+str(ii)+'_lnprob.dat'
                lnprob = read_results(fname)[iep]
                fname = direc+'/d'+str(d)+'chains_epochs'+str(ii)+'.dat'
                chains = read_results(fname)[iep]
                flatchain = chains.reshape((-1, ndims))
                flatprob = lnprob.flatten()

                maxprob = maxlikelihood(flatchain, flatprob, idim=idim)
                print('maxprob', maxprob)
                if idim==0:
                    maxprob = np.exp(maxprob)
                elif idim==1:
                    maxprob = np.arccos(maxprob)
                q[1] = maxprob


            y_ep.append(q)
#                 std_ep.append(np.std(samples_1dim[iep][steps]))
        y_data.append(y_ep)
#     print('y_data', np.shape(y_data), y_data)
        
#         std_data.append(std_ep)
    return np.reshape(np.array(y_data), (n_iters, n_epochs-ep_start, len(s_percentile)))


In [None]:
def semianalytic_hist(direc):

    pdf = read_results(direc+'/d10a_posterior_nd0.dat')

    # PDF of aproj for a given a
    from scipy.stats import gaussian_kde
    from scipy import stats  
    from matplotlib import rc
    rc('font',**{'family':'serif','serif':['Computer Modern Roman']})
    rc('text', usetex=True)

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    axissize = 20
    legsize=18
    labelsize = 22


    a_true = 1
    nbins = 70

    # generated a density class

    datasets = (  aproj_list_uni,aproj_list_Beta, aproj_list)
    # datasets = (aproj_list_Beta,)
    # datasets= [aproj_list]
    colors = ( 'xkcd:charcoal', 'xkcd:forest green', 'xkcd:lavender')
    labels=( 'uniform', 'beta',  'circular')
    for i, data in enumerate(datasets):
    #     density = gaussian_kde(data)
    #     # set the covariance_factor, lower means more detail
    #     density.covariance_factor = lambda : .2
    #     density._compute_covariance()

    # #     # generate a fake range of x values
    #     xs = np.linspace(0,2)

    # #     # fill y values using density class
    #     pdf = density(xs)

    # #     params = stats.argus.fit(data)  
    # #     pdf = stats.argus.pdf(xs, *params)  

    # # #     plt.plot(xs, pdf_gamma, label="Gamma")

    # #     l = ax.plot(xs, pdf_gamma, antialiased=True, linewidth=2, color=colors[i+1])
    #     l = ax.fill_between(xs, pdf, alpha=.5, zorder=5, antialiased=True, color=colors[i+1])


        ax.hist(data, lw=1, alpha=0.4, density=True, bins=nbins, histtype='stepfilled', color=colors[i], label=labels[i], zorder=2)
        ax.hist(data, lw=1, alpha=1, density=True, bins=nbins, histtype='step', color=colors[i], zorder=2)

        print('std =', np.std(data))
        print('mean = ', np.mean(data))

    ax.axvline(x=1, lw=1, color='xkcd:black', ls='--', label=r'true $a$', zorder=5)
    ax.legend(frameon=False,                fontsize=legsize)
    ax.set_xlabel('Projected separation (AU)', fontsize=labelsize)
    ax.set_ylabel('Probability density function', fontsize=labelsize)
    # ax.set_title(r'Probabilty density of $a_{\rm proj}$ given $a$', fontsize=12)
    # plt.yticks([]) 
    ax.xaxis.set_tick_params(labelsize=axissize)
    ax.yaxis.set_tick_params(labelsize=axissize)
    ax.set_ylim(0, 4)
    ax.set_xlim(0, 2)
    fig.savefig('aproj_dist.pdf', bbox_inches='tight')

In [None]:
def semianalytic_compare(direc, i_it, d, idim=0, i_ep=2, nbins=80, thin_mc=40000, ndims=6, ylim=3,
                        stitle='Posterior after 3 epochs'):
    # compare a posterior from semi-analytic and MCMC
    rc('font',**{'family':'serif','serif':['Computer Modern Roman']})
    rc('text', usetex=True)
    axissize = 20
    legsize=18
    labelsize = 22

    # fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # print
    fig, ax = plt.subplots(1, 1, figsize=(7, 7))

    truth = np.exp(read_truths(direc)[i_it][idim])
        
    # get samples from MCMC
    f = direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'.dat'
    print('reading file', f)
    chains_epochs = read_results(f)
#     print('chains_epochs', np.shape(chains_epochs))
    samples_epochs=[]
    y_data_dims=[]

    chains = chains_epochs[i_ep]
#     print('chains', np.shape(chains))
    samples = chains[:, 1000::thin_mc, :].reshape((-1, ndims)) # discard first 1000 steps
#     print('samples', np.shape(samples))
    samples_epochs.append(samples)

#     print('samples_epochs', np.shape(samples_epochs))
    y_data = samples_epochs[0]
#     print('y_data', np.shape(y_data))
    print(y_data)
    y_data_dims = np.swapaxes(y_data, 0, 1)[idim] # extract dim
#     print('y_data_dims', np.shape(y_data_dims))

    sma_MCMC = np.exp(y_data_dims)
    print('sma_MCMC', np.shape(sma_MCMC))

    # get data from semianalytic
    
    sma_semianal = read_results(direc+'/d'+str(d)+'a_posterior_ep'+str(i_ep)+'_'+str(i_it)+'.dat')
    print('sma_semianal', np.shape(sma_semianal))

    datasets = (  sma_semianal, sma_MCMC)
    colors = ( 'xkcd:forest green', 'xkcd:water blue')
    labels=( 'semi-analytic',  'MCMC')
    # hatch=( '//', '-')
    for i, data in enumerate(datasets):
        ax.hist(data, lw=1, alpha=0.4, density=True, bins=nbins, #hatch=hatch[i], 
                range=(0.4, 4.4),
                histtype='stepfilled', color=colors[i], label=labels[i], zorder=2)
        ax.hist(data, lw=1, alpha=1, density=True, bins=nbins, range=(0.4, 4.4),
                histtype='step', color=colors[i], zorder=2)

    # truth
    ax.axvline(x=truth, color='xkcd:red', ls='-', lw=2, zorder=5, label='truth')
    leg = ax.legend(frameon=False, fontsize=legsize)
    plt.setp(leg.get_title(),fontsize=legsize)
    ax.set_xlabel('Semi-major axis (AU)', fontsize=labelsize)
    ax.set_ylabel('Probability density function', fontsize=labelsize)

    ax.xaxis.set_tick_params(labelsize=axissize)
    ax.yaxis.set_tick_params(labelsize=axissize)
    ax.set_ylim(0, ylim)
    ax.set_xlim(0.4, 4.4)
    ax.set_title(stitle, fontsize=labelsize)

    return fig

In [9]:
def median_accuracy(direc, n_ep, n_it, idim, d, y_data=None, percentile=[16, 50, 84], form='percentiles',
                       truths=None, ep_start=0, iter_start=0, maxprob=False, read=False, 
                    a_fixed=None,
                   i_good=[]):
    # get median accuracy across all iterations  (inferred-truth)/truth
    # form can be raw percentiles or matplotlib errorbar format
    # y data is single dimension shape (n_it, n_ep, 3)

    
    if (y_data is None) and (not maxprob):
        # load
        y_data_dims = read_retrievals(direc, n_ep, n_it, d)
        y_data = np.swapaxes(y_data_dims, 0, 1)[idim]

    # return accuracy
    if truths is None:
        truths = read_truths(direc)
#     truth = np.swapaxes(truths, 0, 1)[idim][iter_start:n_it]

    truth = np.swapaxes(truths, 0, 1)[idim][:]
    

    if idim==0:
        truth = np.exp(truth)
    elif idim==1:
        truth = np.arccos(truth)
    elif (idim==3) or (idim==5): # fold by 180 deg
        truth= truth % np.pi
        
     
    
    yerr_data = np.zeros((2, n_ep))
    y50_data = np.zeros((n_ep))
    acc = np.zeros((3, n_ep))
    
    
    if (read):
        # if a=1 will just need to get the sma result
        y_ep_median = np.zeros((len(i_good),n_ep))

        for ii, val in enumerate(i_good):

            fname = direc+'/d'+str(d)+'acc'+str(val)+'.dat'
            results = read_results(fname)
            if (a_fixed==1):
                results = results+1
            y_ep_median[ii][:] = results
            
        
        
    
    else:

        if not maxprob:
            
            y_ep = np.swapaxes(y_data, 0, 1) # shape(n_ep, n_it, 3)
            y_ep_median = np.swapaxes(y_ep, 0, 2)[1] # shape (n_it, n_ep)

        else:
            print('maxprob')
            #y_ep_median = []
            y_ep_median = np.zeros((len(i_good),n_ep))
            for ii, val in enumerate(i_good):
                fname = direc+'/d'+str(d)+'ML'+str(val)+'.dat'
                MLs = read_results(fname) # shape (n_ep, ndim)
                y_ep_median_thisit = np.swapaxes(MLs, 0, 1)[idim]
                y_ep_median[ii][:] = y_ep_median_thisit
#                 y_ep_median.append(y_ep_median_thisit)


    #     y_ep_median = ((y_ep_median)-truth[:, np.newaxis])/truth[:, np.newaxis]
        if a_fixed is None:
            y_ep_median = ((y_ep_median)-truth[:n_it, np.newaxis])
   
        
    for jj in range(ep_start, n_ep):
 
     
        y = y_ep_median.T[jj] # shape (n_ep,)
        
         #extract points
        # calculate posterior on rel. error
        error = np.percentile(y, percentile, axis=0).flatten()

        
        acc[0][jj] = error[0]
        acc[1][jj] = error[1]
        acc[2][jj] = error[2]
        
        y50_data[jj] = error[1]
        yerr_data[0][jj] = (error[1]-error[0])
        yerr_data[1][jj] = (error[2]-error[1])

    if form is 'percentiles':
        return acc
    elif form is 'errorbars':
        return y50_data, yerr_data

In [10]:
def plot_medians(direc, n_ep, n_it, idim, d, i_ep, ndims=6, percentile=[16, 50, 84], form='percentiles',
                       a_fixed=None,):
    fig, ax = plt.subplots(1, 1)
    truths=read_truths(direc)
    y_data_dims = read_retrievals(direc,n_ep, n_it, ndims, d, sigma=1)# (n_iter, ndims, n_epoch, 3)
    y_data = np.swapaxes(y_data_dims, 0, 1)[idim]
    x = median_accuracy(direc, 1, n_it, idim, d,  y_data=y_data, percentile=percentile, form='percentiles',
                        truths=truths, ep_start=i_ep)
#     n, bins , _ = ax.hist(x[1], bins=10, alpha=0.5, range=(0.75, 0.9))
#     print(bins)
    ax.plot(x = range(n_it), y=x[1].sort())
    print(x[1].sort())
    ax.axvline(x=x[0])
    ax.axvline(x=x[2])
    return fig

In [11]:
# analysis 

def posterior(samples, nbins=100):
    n, bins = np.histogram(samples, bins=nbins, weights=None, density=True)
    return(n, bins)

def extract1dim(samples_arr, idim, ep_start=0, n_ep=5):
#     print('samples_arr', np.shape(samples_arr))
    try:
        return np.swapaxes(np.array(samples_arr), 0, 2)[idim].T # extract samples across epochs for this parameter
    except ValueError:
        # problem could be different walker numbers for diff epochs
#         print('ValueError in extract1dim')
#         n_ep = np.shape(samples_arr)[0]
        ret = []
#         print('n_ep',n_ep)
        for ii in range(ep_start, n_ep):
#             print('ii',ii)
            
            ep = np.array(samples_arr[ii])
#             print('ep array', np.shape(ep))
#             print('ep array 0', type(ep[0]), np.shape(ep[0]))
#             try:
#                 ep_dim = np.swapaxes(ep, 0, 1)[idim]
#             except:
            ep_dim = ep[:, idim]
#             print('ep_dim', np.shape(ep_dim))
            ret.append(ep_dim)
#         floats = map(float, ret)
#         print('ret', np.shape(ret))
        [l.tolist() for l in ret]
        return ret

def normalizeHist(n, bins):
    # integral over x-range of bins is 1
    dx = bins[1]-bins[0]
    C = 1/(dx * np.sum(n))
    return C*n, bins

def integrateHist(n, bins):
    dx = bins[1]-bins[0]
    return dx*sum(n)

def plotPDF(ax, bins, pdf, truth=None, q_data=None, c='k', fc='xkcd:greyish purple', xlabel='', stitle=''):
    if truth is not None:
        ax.axvline(x=truth, c='xkcd:gold', lw=3, label='truth')
        
    if (q_data is not None) and (len(q_data)==3):
        ax.axvline(x=q_data[1], c='xkcd:cerulean', lw=2, label=r'1$\sigma$') # 50%
        ax.axvline(x=q_data[0], c='xkcd:cerulean', lw=1) 
        ax.axvline(x=q_data[2], c='xkcd:cerulean', lw=1) 
        
    ax.legend(frameon=False, fontsize=14)
        
    ax.bar(bins[1:], pdf, width=(bins[1]-bins[0]), align='center', ec=list(c)*len(pdf), fc=fc, alpha=0.9)
    ax.set_xlabel(xlabel, fontsize=14)
    #ax.set_ylabel('Probability density')
    ax.set_title(stitle, fontsize=14)
    return ax
    
def getPDF_multiepoch(samples_arr, idim, nbins=100, i_ep=None):
    # return pdf after n epochs, default to final epoch
    # this edition: each epoch is fitted simultaneously, and data stored as such, therefore don't need to do anything to it here
    if i_ep is None:
        i_ep = np.shape(samples_arr)[0] - 1 # last epoch

    samples_thisdim = extract1dim(samples_arr, idim) # extract samples across epochs for this parameter
    pdf_at_epoch, bins = posterior(samples_thisdim[i_ep], nbins)  # save final posterior pdf
        
    return pdf_at_epoch, bins

def plotPDFs(direc, idim, i_iter, nbins=100, quantiles=None, n_ep=None, start_ep=0, cadence=30,
             fname_base='samples_epochs', fend='.dat', fc='xkcd:greyish purple',slides=False,
             fig=None, axes=None, xlim=[0, 5], ylim=[0, 4], ndim=6):
    
    if slides is False:
        mpl.rcParams.update(inline_rc)
    
    samples_arr = read_results(direc+'/'+fname_base+str(i_iter)+fend)
    if n_ep is None:
        n_ep = np.shape(samples_arr)[0]    
    truth = read_truths(direc)[i_iter][idim]
#     if ndim==6:
#         lna, cosi, ecc, omega_p, xi_0, lan = truths
#     elif ndim==5:
#         lna, cosi, ecc, omega_p, xi_0 = truths
# #     x_true, y_true, aproj_true = observables_at_epoch(ecc, np.exp(lna), np.arccos(cosi), omega_p, xi_0, lan, range(n_ep), cadence=cadence)
#     truth=truths[idim]
    
    samples_thisdim = extract1dim(samples_arr, idim) # extract samples across epochs for this parameter

    if idim==0: # sma un-log
        samples_thisdim = np.exp(samples_thisdim)
        truth = np.exp(truth)
        
    elif (idim==5) or (idim==3): # omegas
        samples_thisdim = samples_thisdim % np.pi # fold
        truth = truth % np.pi
    
    yerr_lower, y, yerr_upper = epochs_sigma(samples_thisdim, s_percentile=[16, 50, 84])
    ymin = y-yerr_lower
    ymax = y+yerr_upper


    if (fig is None) and (axes is None):
        fig, axes = plt.subplots(int(np.ceil((n_ep)/2)),2, figsize=(10, 10)) # (10, 10) for 2x3  (10, 6) for 2x2
    
    xlabel=get_label(idim, form='None')

    for iep in range(start_ep, n_ep):
        pdf_at_ep, bins = posterior(samples_thisdim[iep], nbins)
        ax = axes.flatten()[iep]
        plotPDF(ax, bins, pdf_at_ep, truth=truth, q_data=[ymin[iep], y[iep], ymax[iep]], c='k', fc=fc, xlabel=xlabel, 
                stitle='Epoch '+str(iep+1))
        median = np.median(samples_thisdim[iep])
        std = np.std(samples_thisdim[iep])
        zscore = (truth - median) / std
        accuracy = (truth - median) / truth
        
        ax.text(0.95, 0.2, 
                'z-score = {0:.2f}'.format(zscore)+'\n accuracy = {0:.2f} AU'.format(accuracy),
                ha='right', va='bottom', fontsize=14, transform=ax.transAxes)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)


    if n_ep%2==1:
        axes.flatten()[-1].set_visible(False) # remove extra subplot if odd number
    plt.tight_layout()

    return fig, axes
  

In [12]:
import os
def concatenate_truths(direc):
    cwd = os.getcwd()
    newdir = str(cwd)+'/'+direc[0]+'_joined'
    if not os.path.exists(newdir):
        os.makedirs(newdir)
    truths_append = read_truths(direc[0])
    print('truths shape',np.shape(truths_append))
    for i_dir, dirval in enumerate(direc):
        newtruths = read_truths(dirval)
        print('truths', i_dir, ' shape',np.shape(newtruths))
        if i_dir>0:
            truths_append = np.vstack((truths_append, newtruths))
    
    print('new truths shape', np.shape(truths_append))
    with open(newdir+'/truths_iters.dat', 'wb') as f:
        pickle.dump(truths_append, f)  
                 
from shutil import copyfile                
def concatenate_samples(direc, n_iters, stem='samples_epochs',fend='.dat'):
    cwd = os.getcwd()
    count_tot=0
    for i_dir in range(len(direc)):
        flag=True
        for count_it in range(n_iters[i_dir]):
            try:
                fname_new = str(cwd)+'/'+direc[0]+'_joined/'+stem+str(count_tot)+fend
                src = str(cwd)+'/'+direc[i_dir]+'/'+stem+str(count_it)+fend
                copyfile(src, fname_new)
            except FileNotFoundError:
                if stem.find("ep2") == -1:
                    print('file', src, 'not found, skipping...')
            count_tot=count_tot+1

In [13]:

def overplot_orbit_fits(fig, ax, direc, i_iter, epochs_to_show, d, epochs_tot=None, lw=1, cmap='autumn', 
                        y_data=None, ndims=6, showerr=False, s=5, maxprob=False, showobvs=False, cadence=None):
    
    from mc_orbit_good import EuclideanRotation, obv_at_epoch
    
    s_AU = s*1e-3 * d
    print('s_AU', s_AU)
    
    if epochs_tot is None:
        epochs_tot = len(epochs_to_show)
    
#     color = colorize(range(epochs_tot), cmap=cmap)[0]
    color = colorize(range(epochs_tot), cmap=cmap, vmin=-1, vmax=epochs_tot)[0]
    
    x_obv, y_obv, _ = read_obvs(direc)[i_iter]
    
    truths = read_truths(direc)[i_iter]
    a_true, i_true, ecc_true, omega_true, xi_true, lan_true = truths
    a_true = np.exp(a_true)
    i_true = np.arccos(i_true)
   
    print('truths: ', a_true, i_true, ecc_true, omega_true, xi_true, lan_true)
    
    
    add_to_omega=0
    if omega_true > np.pi:
        print('adding pi to omega')
        add_to_omega = np.pi
    
    add_to_lan=0
    if lan_true > np.pi:
        print('adding pi to lan')
        add_to_lan = np.pi
        
        
    if not maxprob: # read in percentiles
        if y_data is None:
            #(direc, n_ep, n_it, ndims, d, iter_start=0, sigma=1)
            y_data = read_retrievals(direc, epochs_tot, 1, ndims, d, iter_start=i_iter)[0] # (n_iter, ndims, n_epoch, 3)
            # already in actual format
        y_data = np.swapaxes(y_data, 0, 2)
        y_data = np.swapaxes(y_data, 1, 2) # want (1, n_epoch, ndims)
        medians = y_data[1]
        medians = np.swapaxes(medians, 0, 1)
        # this is already re-parameterized
    else:
        fname = direc+'/d'+str(d)+'ML'+str(i_iter)+'.dat'
        MLs = read_results(fname)
 


        
    for i_epoch in epochs_to_show:
        
        if maxprob:
            y = MLs[i_epoch]

        else:
            y = medians[i_epoch]
            
        if ndims==6:
            a_mcmc, i_mcmc, ecc_mcmc, omega_mcmc, xi_mcmc, lan_mcmc = y
#             lna_mcmc, cosi_mcmc, ecc_mcmc, omega_mcmc, xi_mcmc, lan_mcmc = np.percentile(samples, 50, axis=0)
        elif ndims==4:
#             lna_mcmc, cosi_mcmc, ecc_mcmc, xi_mcmc = np.percentile(samples, 50, axis=0)
            a_mcmc, i_mcmc, ecc_mcmc, xi_mcmc = y
            lan_mcmc=0
            omega_mcmc=0
        
        print('  ',i_epoch, 'fits: ', a_mcmc, i_mcmc, ecc_mcmc, omega_mcmc+add_to_omega, xi_mcmc, lan_mcmc+add_to_lan)
        a = a_mcmc
        b = a*np.sqrt(1-ecc_mcmc**2)
        alpha = lan_mcmc + add_to_lan
        beta = i_mcmc
        gamma = omega_mcmc + add_to_omega
       
        xs, ys, _ = EuclideanRotation(alpha, beta, gamma, a, b, np.linspace(start=0, stop=2*np.pi, num=50))
        ax.plot(xs, ys, '-', lw=lw, c=color[i_epoch], zorder=50)
        
            
        if showobvs:
            x_clean, y_clean, _ = obv_at_epoch(truths, range(epochs_tot), cadence=cadence, noise=False)
            x = x_obv[i_epoch]
            y = y_obv[i_epoch]
            
#             print('(x_obv,y_obv)', (x,y))

#             print('(x_clean, y_clean)', (x_clean[i_epoch],y_clean[i_epoch]))
            ax.plot(x, y, '.', c='xkcd:sea green', zorder=100)
    
    if showerr:
        font = {'family': 'sans-serif',
#                     'color':  'darkred',
                'weight': 'light',
                    'size': 14,
                }
 
#         ax.errorbar(x=x_clean[1], y=y_clean[1], yerr=s_AU, xerr=s_AU, color='xkcd:purple', lw=1.5, capsize=1, capthick=0.5)
        
        ax.errorbar(x=1.1, y=1.4, yerr=s_AU/np.sqrt(2), xerr=s_AU/np.sqrt(2), color='k', lw=0.5, capsize=3, capthick=0.5)
        ax.text(x=1.1, y=1.4-0.15, s='astrometric\nerror', va='top', ha='center', 
                fontdict=font
#                     bbox=dict(boxstyle='square,pad=1', fc='none', ec='none')
               )
            
        
    return fig, ax
       
from matplotlib import patches    
def show_orbit(truths, epochs_to_show, cadence, rotations=2, d=10, s_mas=5, ax=None, fig=None,
               showplane=True,
               slides=False,lims=[-1.5, 1.5],cmap='autumn',
               showIWA=True, savefig=False, showaxes=False, showxi=False, plane='xy',n_planets=None):
    
    mpl.rcParams.update(inline_rc)
    if slides:
        # matplotlib params for presentations
        plt.style.use('dark_background')
        axissize = 16
        legsize=18
        labelsize = 18
#         patch_width = 500*t # 10 points

    else:

        axissize = 20
        legsize=18
        labelsize = 22
#         patch_width = 500*t # 10 points
        rc('font',**{'family':'serif','serif':['Computer Modern Roman']})
        rc('text', usetex=True)
        
    
    from mc_orbit_good import getNextPhase, EuclideanRotation
    
    ndim=len(truths)

    if ndim==4:
        lna, cosi, omega_p, xi_0 = truths
        lan=0
        ecc=0
    else:
        lna, cosi, ecc, omega_p, xi_0, lan = truths
    a_IWA = d*IWA
    a_OWA = d*OWA
#     color = colorize(epochs_to_show, cmap=cmap)[0]
    color = colorize(epochs_to_show, cmap=cmap, vmin=-1, vmax=len(epochs_to_show))[0]
        
    a = np.exp(lna)
    eccentricity=ecc
    i = np.arccos(cosi)

    # xi is a list of phases to plot
    if n_planets is None:
        n_planets = len(epochs_to_show)
    
    if ax is None:
        fig1 = plt.figure(figsize=(5,5))
        ax1 = fig1.add_subplot(111, aspect='equal')
    
    else:
        ax1=ax
        ax1.axis('equal')
    
    if showIWA:
        if plane=='yz':
            print('Must be xy plane to show IWA')
        else:
#             ax1.add_patch(patches.Circle(
#                  (0, 0),   # (x,y)
#                  a_OWA,          # radius
#                  fc='ivory' # visible region colour
#                  )
#                  )

            covered='xkcd:light grey' #'khaki'
            ax1.add_patch(
                patches.Circle(
                    (0, 0),   # (x,y)
                    a_IWA,          # radius
                    fc=covered
                )
            )
#         ax1.set_facecolor(covered)
#         ax1.set_facecolor('xkcd:ivory')
    else:
        ax1.set_facecolor('xkcd:ivory')
    
        
    b = a*np.sqrt(1-eccentricity**2)
    c = np.sqrt(a**2 - b**2) # ellipse

    # plot different phases of planet
#     r_illus = a*0.08# d*(s_mas*1e-3) # same as astrometric error  #
    r_illus = s_mas*1e-3 * d 
#     phases_illus = np.linspace(0, np.pi, num=n_planets, endpoint=False)
    phases_illus = getNextPhase(epochs_to_show, cadence, a, xi_0)

    alpha = lan
    beta = i
    gamma = omega_p
    xs, ys, zs = EuclideanRotation(alpha, beta, gamma, a, b, phases_illus)
    aproj = np.sqrt(xs**2 + ys**2)

    for count in range(n_planets):

        # if obscured, plot as hollow circle:
        if aproj[count] <= a_IWA:
            fc_pl = 'none'
            ec='k' #'r'
            
        else:
            #fc_pl = str(phi[count])
            fc_pl = '0.4'
            ec='k'
        
        ax1.add_patch(
        patches.Circle(
            (xs[count], ys[count]),   # (x,y)
           r_illus,          # radius
           fc=fc_pl,
            ec=ec, zorder=15
            )
        )
        ax1.text(xs[count]+r_illus, ys[count]+r_illus+0.1, epochs_to_show[count]+1, color=color[count], fontsize=16, zorder=16, va='top', ha='left')


    plt.scatter(0, 0, c='xkcd:green blue', marker='*', s=300, zorder=25) # sun

    # show orbital plane
    from matplotlib import cm 
    num=9000
    x = np.zeros(num)
    y = np.zeros(num)
    z = np.zeros(num)
    t = np.linspace(0*np.pi, 100*np.pi, num)


    if showplane:
        r = np.linspace(0.97, 1.1, num)
        for ii in np.arange(0, num):
            x[ii], y[ii], z[ii] = EuclideanRotation(lan, i, omega_p, a*r[ii], b*r[ii], t[ii])
        if plane=='xy':    
            ax1.scatter(x, y, c=z, s = 0.5, cmap='PuBuGn_r',alpha=0.3, zorder=5)
        elif plane=='yz':
            ax1.scatter(z, y, c=x, s = 50, cmap='PuBuGn_r',alpha=0.3, zorder=5)
    else:
        for ii in np.arange(0, num):
            x[ii], y[ii], z[ii] = EuclideanRotation(lan, i, omega_p, a, b, t[ii])
        ax1.plot(x, y, c='k', lw=2, zorder=5)
        
        
    if showaxes:

        # show orbital plane axes
        corners_t = [0, np.pi/2, np.pi, 3*np.pi/2]
        corners_theta=[]
        for ii in range(4):
            corners_theta.append(plz.TrueAnom(eccentricity, corners_t[ii], 5))
            
        xhat, yhat, zhat = EuclideanRotation(lan, i, omega_p, 1.8, 1.8, corners_t) 
        xax = [xhat[0], xhat[2]]
        xax_y = [yhat[0], yhat[2]]
        xax_z = [zhat[0], zhat[2]]
        yax_x = [xhat[1], xhat[3]]
        yax_z = [zhat[1], zhat[3]]
        yax = [yhat[1], yhat[3]]

        # rotation angles
        x_ang, y_ang, z_ang = EuclideanRotation(lan, i, omega_p, 1.2, 1.2, corners_t) 
        
        # omega_p
        barlength=0.3
        
        if plane=='xy':
            x_angf, y_angf, z_angf = EuclideanRotation(lan, i, omega_p, 1.5, 1.5, corners_t) 
            # show prime axes
            plt.plot(xax, xax_y, 'r-')
            plt.plot(yax_x, yax, 'r-')
            
            # label
            plt.text(xax[0], xax_y[0], r'+x$^\prime$', color='r', ha='left', va='bottom', fontsize=12, zorder=26)
            plt.text(yax_x[0], yax[0], r'+y$^\prime$', color='r', ha='left', va='bottom', fontsize=12, zorder=26)

            
            # show arm of angle parallel to axes
            plt.plot([x_ang[0], x_ang[0]+barlength], [y_ang[0], y_ang[0]], 
                     '-', c="0.1", lw=1, zorder=0)
            

            ax1.annotate('',
                xytext=( x_ang[0]+barlength, y_ang[0]), xycoords='data',
                xy=(x_angf[0], y_angf[0]), textcoords='data',
                arrowprops=dict(arrowstyle="wedge,tail_width=0.4,shrink_factor=0.6",
                                fc="0.1", ec="none",
                                connectionstyle="angle3", shrinkA=0.09),
                )
            ax1.text(x_ang[0], y_ang[0], r"$\omega_p$  ", fontsize=18, va='bottom', ha='right')
            
        elif plane=='yz':
            x_angf, y_angf, z_angf = EuclideanRotation(lan, i, omega_p, 1.6, 1.6, corners_t) 
            
            # show prime y axis only
            plt.plot(xax_z, xax_y, 'r-') # x-axis -- cheat because this is longer
            #plt.plot(yax_z, yax, 'r-') # y-axis
            
            #label
            plt.text(xax_z[0], xax_y[0], r'+y$^\prime$', color='r', ha='left', va='bottom', fontsize=12, zorder=26)
            #plt.text(yax_z[0], yax[0], r'+y$^\prime$', color='r', ha='left', va='bottom', fontsize=12, zorder=26)

            
            # show arm of angle parallel to axes
            plt.plot([z_ang[0], z_ang[0]], [y_ang[0]+0.01, y_ang[0]+1.5*barlength], 
                     '-', c="0.1", lw=1, zorder=0)
            
            ax1.annotate('',
                xy=( z_angf[0], y_angf[0]), xycoords='data',
                xytext=(z_ang[0], y_ang[0]+1.5*barlength), textcoords='data',
                arrowprops=dict(arrowstyle="wedge,tail_width=0.2,shrink_factor=1",
                            fc="0.1", ec="none",
                            connectionstyle="angle3,angleA=0,angleB=90", shrinkA=0),zorder=0
                    )
            ax1.text(z_ang[0], y_ang[0], r"$i$   ", fontsize=18, va='bottom', ha='right')
            
    
            bbox_x = dict(boxstyle="larrow,pad=0.2", fc="xkcd:ice", ec="xkcd:marine", lw=1)
            ax1.text(0, 0.87, '    telescope', ha="left", va="top", color='k',  transform=ax1.transAxes,
                        size=16, bbox=bbox_x, zorder=20)
    
    
    ax1.set_xlim(lims[0], lims[1])
    ax1.set_ylim(lims[0], lims[1])
    
    # axes labels
    if plane=='xy':
        ax1.set_xlabel('x (AU)', fontsize=labelsize)
        ax1.set_ylabel('y (AU)', fontsize=labelsize)
    elif plane=='yz':
        ax1.set_xlabel('z (AU)', fontsize=labelsize, labelpad=1)
        ax1.set_ylabel('y (AU)', fontsize=labelsize, labelpad=1)
        
#     plt.yticks([])  
#     plt.xticks([])
   
    ax1.xaxis.set_tick_params(labelsize=axissize)
    ax1.yaxis.set_tick_params(labelsize=axissize)
    
    if savefig:
        fig1.savefig('orbit_demo.pdf')
        
    try:
        return fig1, ax1
    except:
        return ax1


In [14]:
def multiply_pdfs(n0, n1, bins, q = [0.16, 0.5, 0.84]):
    mids = 0.5*(bins[1:] + bins[:-1])
    j = n0*n1
    dx = bins[1] - bins[0]
    I = np.sum(dx*j)
    C = 1/I
    pdf = C*j
    cum = np.cumsum(pdf*dx)
    i_16 = np.where(cum < q[0])[0][-1]
    i_50 = np.where(cum < q[1])[0][-1]
    i_84 = np.where(cum < q[2])[0][-1]
    quants = (mids[i_16], mids[i_50], mids[i_84])
    return pdf, quants

In [15]:
def semianalytic_error(ep, a_proj, a_true, ecc_dist, nbins=100, sigma=1, store=True, n_pl=int(1e3),
                       i_it=None, direc=None, d=None, aproj_range=None, 
                      plot=False, name_ow=None, n1=None,
                      name='unknown', s_mas=5):
    #n1 is the cumulative binned hist of the previous epochs 
    
    cwd = os.getcwd()
    os.chdir('../falsepos')
    import random_planets_fullorbit as plz
    os.chdir(cwd)
    
    s_AU = d*(s_mas*1e-3) # convert mas error to AU
    a_IWA = d*IWA
    
    if sigma==1:
        q = [0.16, 0.5, 0.84]
    elif sigma==2:
        q = [0.025, 0.5, 0.975]
    n_ep = len(a_proj)
    if a_proj[ep] <= a_IWA:
        aproj_range=(0, a_IWA)
        name = 'd'+str(d)+'a_posterior_nd0'
        print('          nondetection ep '+str(ep)+' @ it', i_it)
    else:
        name = 'd'+str(d)+'a_posterior_ep'+str(ep)+'_'+str(i_it)
    try:
        sma_list = read_results(str(cwd)+'/'+direc+'/'+name+'.dat')
    except:
        print('it', i_it, ':    calculating semi-analytic posterior for epoch ',ep,'...')
        epsilon_lo=1e-20
        epsilon_hi=1
        tol=0.01
        if (aproj_range is None):
            aproj_lo = a_proj[ep]-tol
            aproj_hi = a_proj[ep]+tol
        else:
            aproj_lo=aproj_range[0]
            aproj_hi=aproj_range[1]
        planetlist=plz.start(d=d, n_pl=n_pl, ecc=ecc_dist, mode='observables', 
                                epsilon_lo=epsilon_lo, epsilon_hi=epsilon_hi,
                               aproj_lo=aproj_lo, aproj_hi=aproj_hi)

        print('          done!')
        aproj_list = [o.separation for o in planetlist]
        sma_list = [o.a for o in planetlist]

        if store:
            if name_ow is not None:
                fname = str(cwd)+'/'+direc+'/'+name_ow+'.dat'
            else:
                fname = str(cwd)+'/'+direc+'/'+name+'.dat'
            print('          pickling to ', fname)
            with open(fname, 'wb') as f:
                pickle.dump(sma_list, f) 
        
    # calc
    n, bins = np.histogram(sma_list, bins=nbins, density=True, range=(0, 5))
    mids = 0.5*(bins[1:] + bins[:-1])
    
    if n1 is None:
        a_mc = np.array(sma_list)
        lower, med, upper = np.percentile(a_mc, [16, 50, 84])
        prec = (upper-lower)/2
        acc = med - a_true
        pdf = n

    else:
        pdf, quants = multiply_pdfs(n, n1, bins, q=q)
        prec = (quants[2] - quants[0])/2
        acc=quants[1] - a_true

    if plot:
        fig = plt.figure()
        ax = plt.gca()
        ax.plot(mids, pdf, alpha=0.3)   
        ax.set_xlabel('semi-major axis (AU)')
        ax.set_title('Epoch '+str(ep))
        ax.set_ylim(0, 5)
        c = colorize(np.arange(4),cmap='autumn')[0]
        for ii, ap in enumerate(a_proj):
            ax.axvline(ap, color=c[ii], linestyle='--', lw=0.9)
        if n1 is not None:
            ax.plot(mids, n, 'k', alpha=0.3, lw=0.5)
        plt.show()
    return prec, acc, pdf


In [17]:
def precision(direc, y_data, ecc_dist, n_it, d, aproj_range=None, ML=False, save=True,
                  plot=False, it_start=0, s_mas=5, i_hide=[], n_det_needed=3):
    
    y_ep = np.swapaxes(y_data, 0, 1) # shape(n_ep, n_it, 3)
    ymin = np.swapaxes(y_ep, 0, 2)[0] # shape (n_it, n_ep)
    ymed = np.swapaxes(y_ep, 0, 2)[1]
    ymax = np.swapaxes(y_ep, 0, 2)[2] # shape (n_it, n_ep)

    s_AU = d*(s_mas*1e-3) # convert mas error to AU
    a_IWA = d*IWA

    prec_offish=[]
    acc_offish=[]
    for i_it in range(it_start,n_it):
        if i_it in i_hide:
            continue
        print('\n \\(',i_it,')////////////\\ ( • ᴗ•)  /////////////////////////////////\\\\\\\\\\\\\\\\\\(•ᴗ • )/')
        prec=[]
        acc=[]
        x, y, a_proj = read_obvs(direc)[i_it]
        lna_true, _, _, _, _, _ = read_truths(direc)[i_it]
        n_det = 0
        if a_proj[0] > a_IWA: # next one ahead
            n_det = n_det+1
            print('-->detection epoch 0')
        i_ep = 0
        n_old=None
        while n_det < n_det_needed:
            print('|  semi-analytic, i_ep=', i_ep)
            print('  n_det', n_det)
            p, a, n_old = semianalytic_error(i_ep, a_proj[:i_ep+1], np.exp(lna_true), ecc_dist=ecc_dist, 
                                             nbins=500, i_it=i_it, direc=direc, d=d, aproj_range=aproj_range,
                                             plot=plot, n1=n_old, s_mas=s_mas)
            prec.append(p)
            acc.append(a)
            i_ep = i_ep+1
            if a_proj[i_ep] > a_IWA: # next one ahead
                n_det = n_det+1
                print('-->detection epoch', i_ep)
        # end

        # prec
        prec_offish = (ymax[i_it]-ymin[i_it])/2
        prec_offish[:i_ep] = prec

        #acc
        if ML:
            fname = direc+'/d'+str(d)+'ML'+str(i_it)+'.dat'
            MLs = read_results(fname)
            ML = np.swapaxes(MLs, 0, 1)
            y_median = ML[0] # 0th dim
        else:
            y_median = ymed[i_it]
            
        lna_true, _, _, _, _, _ = read_truths(direc)[i_it]
        acc_offish = y_median - np.exp(lna_true)
        acc_offish[:i_ep] = acc

        # save
        if save:
            cwd = os.getcwd()
            fname = str(cwd)+'/'+direc+'/d'+str(d)+'prec'+str(i_it)+'.dat'
            print('pickling to ', fname)
            with open(fname, 'wb') as f:
                pickle.dump(prec_offish, f) 
            fname = str(cwd)+'/'+direc+'/d'+str(d)+'acc'+str(i_it)+'.dat'
            print('pickling to ', fname)
            with open(fname, 'wb') as f:
                pickle.dump(acc_offish, f) 
        else:
            print('prec for iter ', i_it, prec_offish)
            print('acc for iter ', i_it, acc_offish)

In [18]:
def plot_diminishing_returns(y_data, n_epochs, n_iters, idim, s=5, d=10, direc=None,
                        title='',alpha=1,lw=0.5, n_datasets=1,
                        cmap=None,labels=None, ylog=False, read=False,
                        ylim=[0.5, 1.5], legtitle=[],
                        showguide=False,showlegend=True,
                        figsize=(5,8),ax=None,showxlabel=True,
                        slides=False,iter_start=0, i_hide=[]):
    
    if slides:
        # matplotlib params for presentations
        plt.style.use('dark_background')
        axissize = 20
        legsize=18
        labelsize = 24
        medcolor="w"
        ms=7
       
    else:
        axissize = 20
        legsize=18
        labelsize = 22
        medcolor="k"
        ms=8
        rc('font',**{'family':'serif','serif':['Computer Modern Roman']})
        rc('text', usetex=True)
    

        
    s_AU = d*(s*1e-3) # convert mas error to AU
    
    
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
    else:
        fig = None
        
    if ylog:
        ax.set_yscale('log')    
        
    if showxlabel:
        ax.set_xlabel('Number of epochs', fontsize=labelsize)
    ax.set_ylabel(get_label(idim, form='precision'), fontsize=labelsize)
    ax.set_xticks(np.arange(1, n_epochs+1))
    ax.set_title(title, fontsize=axissize)
    ax.xaxis.set_tick_params(labelsize=axissize)
    ax.yaxis.set_tick_params(labelsize=axissize)
    
    if (idim==0):
        c = 'xkcd:battleship grey'
        c = 'k'
        s_AU = s*1e-3*d 
        ax.axhline(y=s_AU,  ls='-', c=c, lw=0.5, zorder=0)
        ax.text(1, s_AU, r'$\sigma_{xy}$', va='bottom', ha='left', fontsize=legsize, color=c)
    
    
    xs = np.arange(1, n_epochs+1)
    

  
    
    if n_datasets==1: # plot median and all iterations for one directory
        
        i_good = [x for x in range(iter_start, iter_start+n_iters) if x not in i_hide]
        color = colorize(np.arange(start=0, stop=len(i_good)+1), cmap="gist_ncar")[0]
    
        y_ep = np.swapaxes(y_data[i_good], 0, 1) # shape(n_ep, n_it, 3)
        ymin = np.swapaxes(y_ep, 0, 2)[0] # shape (n_it, n_ep)
        ymax = np.swapaxes(y_ep, 0, 2)[2] # shape (n_it, n_ep)
        
        # plot each iteration
        prec_iters = []
        for i_it in range(len(i_good)):

            label=''

            if showguide:
                label=i_good[i_it]

    
            if read:
                fname = direc+'/d'+str(d)+'prec'+str(i_good[i_it])+'.dat'
                prec = read_results(fname)
            
            else:
                prec = (ymax[i_it]-ymin[i_it])/2
    #             print('prec', prec)
#                 if analytic:
#                     a_proj0 = read_obvs(direc)[i_good[i_it]][2][0] # 1st epoch
#                     sigma0 = 0.53804 * a_proj0 # for beta
#                     prec[0] = sigma0
                
#                 a_proj1 = read_obvs(direc)[i_good[i_it]][2][1] # second epoch
#                 sigma1 = 0.53804 * a_proj1 # for beta
#                 prec[1] = sigma1
                
            prec_iters.append(prec)
            ax.plot(xs, prec, lw=lw, alpha=alpha, label=label, 
                        color=color[i_it])

                

                
#                 ax.plot(1, sigma, 'v', color=color[i_it], alpha=alpha)
#                 ax.errorbar(1, y=sigma, yerr=s_AU, fmt='^', capsize=2, ecolor=color[i_it], alpha=alpha)
            
        median = np.median(np.array(prec_iters), axis=0)    
#         ax.plot(xs, median, lw=2, label='median', c=medcolor, ls='--')
        print('median prec', median)
        mean = np.mean(np.array(prec_iters), axis=0)    
        ax.plot(xs, mean, lw=3, label='mean', c=medcolor)


    else: # plot medians for multiple directories
        
        
        datacolor = colorize(range(n_datasets), cmap=cmap, vmin=-0.5)[0]
        for ii_direc, val in enumerate(direc):
            n_it = n_iters[ii_direc]
            i_hide_this = i_hide[ii_direc]
            i_good = [x for x in range(iter_start, iter_start+n_it) if x not in i_hide_this]
            color = colorize(np.arange(start=0, stop=len(i_good)+1), cmap="gist_ncar")[0]
            
            y_datai=y_data[ii_direc][i_good]
            y_ep = np.swapaxes(y_datai, 0, 1) # shape(n_ep, n_it, 3)
            ymin = np.swapaxes(y_ep, 0, 2)[0] # shape (n_it, n_ep)
            ymax = np.swapaxes(y_ep, 0, 2)[2] # shape (n_it, n_ep)
                           
                                   
            prec_iters = []
            for i_it in range(len(i_good)):

                label=''

                if showguide:
                    label=i_good[i_it]


                if read:
                    fname = val+'/d'+str(d)+'prec'+str(i_good[i_it])+'.dat'
                    prec = read_results(fname)
                    
                    

                else:
                    prec = (ymax[i_it]-ymin[i_it])/2
        #             print('prec', prec)
    #                 if analytic:
    #                     a_proj0 = read_obvs(direc)[i_good[i_it]][2][0] # 1st epoch
    #                     sigma0 = 0.53804 * a_proj0 # for beta
    #                     prec[0] = sigma0

    #                 a_proj1 = read_obvs(direc)[i_good[i_it]][2][1] # second epoch
    #                 sigma1 = 0.53804 * a_proj1 # for beta
    #                 prec[1] = sigma1

                prec_iters.append(prec)
            median = np.median(np.array(prec_iters), axis=0)    
            ax.plot(xs, median, lw=3, label=labels[ii_direc], c=datacolor[ii_direc], alpha=1)
            
        
        
        
    
        
    if showguide:
        l = ax.legend(fontsize=legsize, loc=2, borderaxespad=0.5,
                  frameon=False, bbox_to_anchor=(1.05, 1), ncol=3)
    elif showlegend:
        l = ax.legend(fontsize=legsize, loc=1, borderaxespad=0.5, title=legtitle,
                  frameon=False)
        plt.setp(l.get_title(),fontsize=legsize)
    ax.set_ylim(ylim)
    xlim = ax.get_xlim()
   
    
    # hatched area over underconstrained epochs
    bad = patches.Rectangle((xlim[0], ylim[0]), 2.5-xlim[0], ylim[1]-ylim[0], hatch='.', 
                            edgecolor='xkcd:light grey', facecolor='none', zorder=0)  
    ax.add_patch(bad)
    
    return fig, ax


In [19]:
def plot_diminishing_std(n_epochs, n_iters, idim, s=5, d=10, direc=None,
                        title='',alpha=1,n_datasets=1,
                        cmap=None,labels=None,
                        ylim=[0.5, 1.5], legtitle=[],
                        showguide=False,
                        figsize=(5,8),
                        slides=False,iter_start=0, i_hide=[]):
    
    if slides:
        # matplotlib params for presentations
        plt.style.use('dark_background')
        axissize = 20
        legsize=18
        labelsize = 24
        medcolor="w"
        ms=7
        lw=2
    else:
        axissize = 16
        legsize=14
        labelsize = 18
        medcolor="k"
        ms=8
        lw=2
        
    color = colorize(np.arange(start=iter_start, stop=n_iters), cmap="gist_ncar")[0]
        
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.set_xlabel('Number of epochs', fontsize=labelsize)
    ax.set_ylabel(get_label(idim, form='uncertainty'), fontsize=labelsize)
    ax.set_xticks(np.arange(1, n_epochs+1))
    ax.set_title(title, fontsize=axissize)
    plt.tick_params(labelsize=axissize)
    
    if (idim==0) and (n_datasets==1):
        s_AU = s*1e-3*d 
        ax.axhline(y=s_AU,  ls=':', c='xkcd:silver', lw=2, zorder=0, label='measurement error')
    
#     truths = read_truths(direc)  
    
    xs = np.arange(1, n_epochs+1)
    

    
    if n_datasets==1: # plot median and all iterations for one directory
        
        std_all=[]
        # plot each iteration
        for i_it in range(iter_start, iter_start+n_iters):
#             print(i_it)
#             print('  ymin', ymin[i_it])
#             print('  ymax', ymax[i_it])
            label=''
            std=read_results(direc+'/d'+str(d)+'std'+str(i_it)+'.dat')
            print('std', np.shape(std))
            std_dim = np.swapaxes(std, 0, 1)[idim]
            std_all.append(std_dim)
            print('std_dim', np.shape(std_dim))
            if showguide:
                label=i_it
            if i_it not in i_hide:
                ax.plot(xs, std_dim, lw=1, alpha=alpha, label=label, 
                        color=color[i_it])
            
        median = np.median(std_all, axis=0)    
        ax.plot(xs, median, lw=3, label='median uncertainty', c=medcolor)

    else: # plot medians for multiple directories
        datacolor = colorize(range(n_datasets), cmap=cmap, vmin=-0.5)[0]
        for ii_direc, val in enumerate(direc):
            y_datai=y_data[ii_direc]
            y_ep = np.swapaxes(y_datai, 0, 1) # shape(n_ep, n_it, 3)
            ymin = np.swapaxes(y_ep, 0, 2)[0] # shape (n_it, n_ep)
            ymax = np.swapaxes(y_ep, 0, 2)[2] # shape (n_it, n_ep)
            median = np.median((ymax-ymin)/2, axis=0)    
            ax.plot(xs, median, lw=3, label=labels[ii_direc], c=datacolor[ii_direc], alpha=0.8)
            
        
        
        
        
        
    if showguide:
        l = ax.legend(fontsize=legsize, loc=2, borderaxespad=0.5,
                  frameon=False, bbox_to_anchor=(1.05, 1))
    else:
        l = ax.legend(fontsize=legsize, loc=1, borderaxespad=0.5, title=legtitle,
                  frameon=False)
    plt.setp(l.get_title(),fontsize=legsize)
    
    
    return fig

In [20]:
def get_label(idim, a_fixed=None, form='accuracy'):
    
    if form=='relaccuracy':
    
        if (idim==0) and (a_fixed==1):
            label=r'retrieved semi-major axis (AU)'
        elif idim==0:
            label=r'$a_{\rm retrieved}/a_{\rm true}$'
        elif idim==1:
            label=r'$i_{\rm retrieved}/i_{\rm true}$'
        elif idim==2:
            label=r'$e_{\rm retrieved}/e_{\rm true}$'
        elif idim==3:
            label=r'$\omega_{p,{\rm retrieved}}/\omega_{p,{\rm true}}$'
        elif idim==4:
            label=r'$M_{0,{\rm retrieved}}/M_{0,{\rm true}}$'
        elif idim==5:
            label=r'$\Omega_{\rm retrieved}/\Omega_{\rm true}$'
        elif idim==-1:
            label='mean parameter accuracy'
    
    else:
        if idim==0:
            label=r'Semi-major axis'
        elif idim==1:
            label=r'Inclination'
        elif idim==2:
            label=r'Eccentricity'
        elif idim==3:
            label=r'$\omega_{p}$'
        elif idim==4:
            label=r'Mean anomaly'
        elif idim==5:
            label=r'$\Omega$'
        elif idim==-1:
            label='mean parameter'
        
        if form=='uncertainty':
            label=label+' uncertainty'
        elif form=='accuracy':
            label=label+' accuracy'    
        elif form=='precision':
            label=label+' precision' 
            

        if idim==0:
            # add units
            label = label+' (AU)'
            
    
    
    return label

In [21]:
def plot_scatter_epochs(y_data, n_epochs, n_iters, idim, truths, d=10, direc=None,sigma=1,maxprob=True,
                        title='', n_datasets=1,fname_base='samples_epochs',alpha=1,alphaHZ=1, showguide=False,
                        a_fixed=None, HZ_cmap="PuRd", ylim=[0.5, 1.5], legtitle=[],i_hide=[],read=False,
                        HZ_lines=None, HZ_inner=None, HZ_outer=None, HZ_labels=None,showxlabel=True,showylabel=True,
                        datalabel=None,conc_direcs=False,figsize=(5,8),iter_start=0,precisionbounds=False,
                        plotIWA=False, showlegend=True, slides=False, offset=None, cmap='Wistia', ax=None):
    # get guess on parameter and truth for each epoch for each iteration
    # incoming y data shape: (n, n_ep, 3)
    
    
    if sigma==1:
        percentiles=[16, 50, 84]
    elif sigma==2:
        percentiles=[2.5, 50, 97.5]
    
    mpl.rcParams.update(inline_rc)
    # set up plot
    if slides:
        # matplotlib params for presentations
        plt.style.use('dark_background')
        axissize = 20
        legsize=18
        labelsize = 24
        datacolor="w"
        ms=7
        lw=2
        err_fmt='ws'
        c_1line = 'w'
    else:
        axissize = 20
        legsize=18
        labelsize = 22
        datacolor="k"
        ms=8
        lw=2
        err_fmt='ks'
        c_1line = 'k'
        
        #rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
        rc('font',**{'family':'serif','serif':['Computer Modern Roman']})
        rc('text', usetex=True)
        
    labelpad=10
    
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
    else:
        fig = None
    if showxlabel:
        ax.set_xlabel('Number of epochs', fontsize=labelsize)
    ylabel=get_label(idim, a_fixed, 'accuracy')
    if a_fixed==1:
        ylabel='Retrieved semi-major axis (AU)'
    if showylabel:
        ax.set_ylabel(ylabel, fontsize=labelsize)
    
    
    # store list of that value for each epoch
    x_data = []
    
    
    # color code planets by iteration???
    color_fixed = 'xkcd:bluey grey'
    mark = ['s', 'v', 'o', 'D']
    
    # get data
    if n_datasets>1: # plot errorbars for multiple directories:

        datacolor = colorize(range(n_datasets), cmap=cmap, vmin=-0.5)[0]

        # plot multiple datasets on these axes
        for ii_direc, val in enumerate(direc):
            
            n_iters_this = n_iters[ii_direc]
            i_hides = i_hide[ii_direc]
            i_good = [x for x in range(iter_start, n_iters_this) if x not in i_hides]
            color = colorize(np.arange(start=0, stop=len(i_good)+1), cmap="gist_ncar")[0]
        
            truths = read_truths(val)     
            truths_good = np.array(truths)[i_good]
            y50_data, yerr_data = median_accuracy(val, n_epochs, n_iters_this, idim, d=d, 
                                                  y_data=y_data[ii_direc][i_good], 
                                                  percentile=percentiles, form='errorbars',
                                                  iter_start=iter_start,
                                                  truths=truths_good, maxprob=maxprob, read=read, 
                                                  i_good=i_good)
            
            
            # for each epoch, overplot uncertainty
            xs = np.arange(1, n_epochs+1)
            q = np.arange(-n_datasets/2, n_datasets/2)
            if offset is not None:
                xs = xs + offset*q[ii_direc]
            markers, caps, bars = ax.errorbar(xs, y50_data, yerr=yerr_data, capsize=8, capthick=2,
                                              elinewidth=lw,alpha=1,
                                              lw=0, color=datacolor[ii_direc], marker=mark[ii_direc], ms=ms, 
                                              label=datalabel[ii_direc])

    else: # plot one dataset showing scatter
      # get accuracies

        i_good = [x for x in range(iter_start, n_iters) if x not in i_hide]
        color = colorize(np.arange(start=0, stop=len(i_good)+1), cmap="gist_ncar")[0]
        truths = read_truths(direc)
        

        y_data_good = y_data[i_good]
       
        
        truths_good = np.array(truths)[i_good]
        y50_data, yerr_data = median_accuracy(direc, n_epochs, n_iters, idim, d=d, 
                                              y_data=y_data_good, i_good=i_good,
                              percentile=percentiles, form='errorbars',iter_start=iter_start,
                              truths=truths_good, maxprob=maxprob, read=read, a_fixed=a_fixed)
      
        # for each iteration plot dots

        truth = np.swapaxes(truths_good, 0, 1)[idim]
#         truth = truths[idim]
    

        if idim==0:
            truth=np.exp(truth)
        elif idim==1:
            truth = np.arccos(truth)
        elif (idim==3) or (idim==5):
            truth = truth % np.pi
            
        x_data = list([np.arange(n_epochs)])*len(i_good)        
        
        for ii in range(len(i_good)):
           
            if read:
                y = read_results(direc+'/d'+str(d)+'acc'+str(i_good[ii])+'.dat')
                if a_fixed==1:
                    y = y+1 # reconvert to sma
                
            else:
                if not maxprob:
                    y_it = y_data_good[ii]
                    y_median = y_it.T[1]
                else:
                    fname = direc+'/d'+str(d)+'ML'+str(i_good[ii])+'.dat'
                    MLs = read_results(fname)
                    ML = np.swapaxes(MLs, 0, 1)
                    y_median = ML[idim]

                if a_fixed==1:
                    y = y_median # plot raw
                else:
    #                 y = ( y_median - truth[ii])/truth[ii]
                    y = ( y_median - truth[ii])


            if HZ_inner is None and HZ_lines is None:
                c = color[ii]
            else:
                c = color_fixed
            label=''
            if showguide:
                label=i_good[ii]
            ax.scatter(x_data[ii]+1, y, c=c, s=80, alpha=alpha, zorder=1, label=label)
#             print('y', y)

            

        # for each epoch, overplot errorbars
        markers, caps, bars = ax.errorbar(np.arange(1, n_epochs+1), y50_data, yerr=yerr_data, capsize=8, 
                                          capthick=2,
                    lw=lw, fmt=err_fmt, ms=ms, alpha=1, label=str(sigma)+r'$\sigma$ confidence interval')
        print('median acc',  y50_data)
#     # loop through bars and caps and set the alpha value
#     [bar.set_alpha(alpha) for bar in bars]
#     [cap.set_alpha(alpha) for cap in caps]
    
    xlim = ax.get_xlim()
    xvec = np.linspace(xlim[0], xlim[1], num=2)
   
    # show HZ boundaries?
    if a_fixed is not None:
        if HZ_lines is not None:
            color2=colorize (np.arange(len(HZ_lines)), cmap=HZ_cmap, vmin=-1, vmax=len(HZ_lines)+1 )[0]
            for ii, val in enumerate(HZ_lines):
                ax.axhline(y=val, label=HZ_labels[ii], c=color2[ii], lw=3, ls='-', zorder=0, alpha=1)
                

        elif (HZ_outer is not None) and (HZ_inner is not None):
#             hatches=['//','\\', '-', '/']
            color2=colorize(np.arange(len(HZ_inner)), cmap=HZ_cmap )[0]
            
            
            ax.fill_between(x=xvec, y1=HZ_inner[0]*np.ones(len(xvec)), y2=HZ_outer[0]*np.ones(len(xvec)),
                            label=HZ_labels[0], facecolor=color2[0],  
                            lw=1, zorder=0,
                            alpha=alphaHZ
                           ) # center
          
            
            for ii in range(len(HZ_inner)-1):  
                ax.fill_between(x=xvec, y2=HZ_inner[ii]*np.ones(len(xvec)), y1=HZ_inner[ii+1]*np.ones(len(xvec)),
                                label=HZ_labels[ii+1], 
                                facecolor=color2[ii+1],  
                                lw=1, zorder=0,
                                alpha=alphaHZ
                               ) # lower rect
                ax.fill_between(x=xvec, y2=HZ_outer[ii+1]*np.ones(len(xvec)), y1=HZ_outer[ii]*np.ones(len(xvec)),
                                facecolor=color2[ii+1], 
                                lw=1, zorder=0,
                                alpha=alphaHZ
                               ) # upper rect


     
    if plotIWA:
        # show LUVOIR IWA for 10 pc, 500 nm
        IWA=3*0.5e-6/10*206265
        d = 10 # pc
        a_IWA = d*IWA
        ax.fill_between(x=xvec, y1=ylim[0], y2=a_IWA, #label=r'IWA @ 10 pc, 0.5 $\mu$m',
                        facecolor='xkcd:greyish', edgecolor='xkcd:greyish', alpha=0.2, hatch='//')
        ax.text(5, a_IWA-0.05, 'IWA', fontsize=20, va='top', ha='right', alpha=0.4)
    
    
    if precisionbounds:
        xvec=np.arange(10)
        ax.fill_between(x=xvec, y1=-0.05*np.ones(len(xvec)), y2=0.05*np.ones(len(xvec)),
                facecolor='xkcd:grey',  
                lw=1, zorder=0,
                alpha=0.8
               ) # center
#         ax.axhline(y=0.01)
#         ax.axhline(y=0.05)
#         ax.axhline(y=0.1)
#         ax.axhline(y=-0.01)
#         ax.axhline(y=-0.05)
#         ax.axhline(y=-0.1)
        
    # set 1:1 line 
    if HZ_inner is None:
#         ax.axhline(y=1, ls=':', c='xkcd:silver', lw=2, zorder=0)
        ax.axhline(y=0, ls='-', c=c_1line, lw=0.5, zorder=0)
    
    ax.set_ylim(ylim)
    ax.set_xlim(xlim)
    ax.set_xticks(np.arange(1, n_epochs+1))
    ax.set_title(title, fontsize=axissize)
    ax.xaxis.set_tick_params(labelsize=axissize)
    ax.yaxis.set_tick_params(labelsize=axissize)
#     plt.tick_params(labelsize=axissize)
    
    
    # hatched area over underconstrained epochs
    bad = patches.Rectangle((xlim[0], ylim[0]), 2.5-xlim[0], ylim[1]-ylim[0], hatch='.', 
                            edgecolor='xkcd:light grey', facecolor='none', zorder=0)  
    ax.add_patch(bad)
    
    
    if showguide:
        l = ax.legend(fontsize=legsize, loc=2, borderaxespad=0.5,
                  frameon=False, bbox_to_anchor=(1.05, 1), ncol=3)
    elif ((HZ_inner is not None) or (HZ_lines is not None)) and showlegend:
        l = ax.legend(frameon=False, fontsize=legsize, bbox_to_anchor=(1.05, 1), loc=2, title=legtitle, 
                  borderaxespad=0.)
        plt.setp(l.get_title(),fontsize=legsize+2)
    elif showlegend:
        l = ax.legend(fontsize=legsize, loc=1, borderaxespad=0.5,title=legtitle,
                  frameon=False, markerfirst=False)
        plt.setp(l.get_title(),fontsize=legsize+2)
#     # additional legend    
#     # Shrink current axis by 20%
#     box = ax.get_position()
#     ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])

#     # Put a legend to the right of the current axis
#     print(labels[0:2])
#     print(hzlines)
#     legend1 = ax.legend(handles=[hzlines[0:3]], labels=[labels[0:3]], frameon=False, fontsize=labelsize, 
#                         loc='center left', bbox_to_anchor=(1, 0.5))  
#     ax.add_artist(legend1)
        
    return fig, ax

In [22]:
# # show earth twins with different HZ limits


# direc=['data/cadence15', 'data/cadence30', 'data/cadence180']
# datalabel=['15 days', '30 days', '180 days']

# #HZ_labels=['Moist greenhouse', 'Runaway greenhouse', 'Recent Venus', 'Early Mars', 'Maximum greenhouse']
# #HZ_cmap='PiYG'
# HZ_lines = [0.99, 0.97, 0.75, 1.77, 1.7] # Kopperapu 2013

# a_fixed= 1


# HZ_scale=np.array([0.01, 0.1, 0.25, 0.5])
# HZ_labels=['1% HZ','10% HZ', '25% HZ', '50% HZ']
# HZ_cmap='spring'

# HZ_inner = a_fixed-a_fixed*HZ_scale
# HZ_outer= a_fixed+a_fixed*HZ_scale

# n_iters = 20
# idim = 0
# n_epochs=5
# stem='d10samples_epochs'

# # plot 
# fig3, y_data4, epochs_2_constrain = plot_scatter_epochs(n_iters, n_epochs, 0, None, direc, 
#                                                             fname_base=stem, datalabel=datalabel,
# #                                                             a_fixed=1, HZ_lines=None, 
# #                                                             HZ_inner=HZ_inner, HZ_outer=HZ_outer, 
# #                                                             HZ_labels=HZ_labels,HZ_cmap=HZ_cmap,
#                                                             plotIWA=False, slides=True, offset=0.09,
#                                                             ylim=[0.4, 1.6], alpha=0.5,figsize=(8,8),
# #                                                             y_data=ydata3
#                                                                )
# fig3.savefig('earths_hypotheticalHZ'+str(idim)+'.png', bbox_inches='tight', transparent=True)
# # fig3.savefig('earths_hypotheticalHZ_multi_0-4.png', bbox_inches='tight', transparent=True)


In [23]:
# # show earth twins with different HZ limits

# direc='data/cadence30'

# #HZ_labels=['Moist greenhouse', 'Runaway greenhouse', 'Recent Venus', 'Early Mars', 'Maximum greenhouse']
# #HZ_cmap='PiYG'
# HZ_lines = [0.99, 0.97, 0.75, 1.77, 1.7] # Kopperapu 2013

# a_fixed= 1


# HZ_scale=np.array([0.5, 0.25, 0.1, 0.01])
# HZ_labels=['50% HZ', '25% HZ','10% HZ', '1% HZ']
# # HZ_cmap='spring'

# HZ_inner = a_fixed-a_fixed*HZ_scale
# HZ_outer= a_fixed+a_fixed*HZ_scale

# n_iters = 20
# idim = 0
# n_epochs=5
# truths_et = read_truths(direc)
# stem='d10samples_epochs'


# # plot 
# fig3, ydata3, epochs_2_constrain = plot_scatter_epochs(n_iters, n_epochs, idim, truths_et, direc, 
#                                                             fname_base=stem, 
# #                                                             title='', a_fixed=1, HZ_lines=None, 
# #                                                             HZ_inner=HZ_inner, HZ_outer=HZ_outer, 
# #                                                             HZ_labels=HZ_labels, HZ_cmap='PuRd_r',
#                                                             plotIWA=False, slides=True,
#                                                             ylim=[0.4, 1.8], alpha=0.5,
# #                                                              y_data=ydata3
#                                                                )
# fig3.savefig('earths_hypotheticalHZ_30_'+str(idim)+'.png', bbox_inches='tight')


In [24]:
def epochs_sigma(samples_1dim, s_percentile=[16, 50, 84], n_epochs=None):
    if n_epochs is None:
        n_epochs = np.shape(samples_1dim)[0]
    y_err = np.zeros((n_epochs, 3))
    for iep in range(n_epochs):
        y_err[iep, :] = np.percentile(samples_1dim[iep], q=s_percentile, axis=0)
    y = y_err[:,1]
    yerr_lower = (y-y_err[:,0])
    yerr_upper = (y_err[:,2]-y)
    return yerr_lower, y, yerr_upper


def medians_all_iters(n_iters, n_epochs, idim, truths, direc, fname_base, a_fixed=None, fend='.dat', 
                      results=None):
    # get 50th percentile of fit for all iterations, per epoch
    y_data = []
    for ii in range(n_iters):
        if results is None:
            # extract data
            print('reading from ',direc+'/'+fname_base+str(ii)+fend)
            samples_epochs = read_results(direc+'/'+fname_base+str(ii)+fend)
        else:
            samples_epochs=results
        if idim > -1:
            samples_1dim = extract1dim(samples_epochs, idim)
            median = epochs_sigma(samples_1dim, s_percentile=[50], n_epochs=n_epochs)[1] # list of medians at each epoch
            truth = truths[ii][idim]
            if idim==0:
                median = np.exp(median) # because mc works in ln(a)
                truth = np.exp(truth)
            elif idim==1: # invert cosine of i
                median = np.arccos(median)
                truth = np.arccos(truth)
            if a_fixed is None:
                y = median/truth
            else:
                y = median/a_fixed # shape: (n_ep,)
        else: # average over all dims
            y_temp = []
            for ii_dim in range(6):
                samples_1dim = extract1dim(samples_epochs, ii_dim)
                median = epochs_sigma(samples_1dim, s_percentile=[50], n_epochs=n_epochs)[1] # list of medians at each epoch
                truth = truths[ii][ii_dim]
                if ii_dim==0:
                    median = np.exp(median) # because mc works in ln(a)
                    truth = np.exp(truth)
                elif ii_dim==1: # invert cosine of iin
                    median = np.arccos(median)
                    truth = np.arccos(truth)
                if a_fixed is None:
                    y = median/truth
                else:
                    y = median/a_fixed
                y_temp.append(np.array(y))
            y = np.mean(y_temp, axis=0)
        y_data.append(np.array(y))
        
    return y_data
       
def quantiles_all_iters(n_epochs, y_data, quants=[16, 50, 84], returnquants=None):
    # given median of fit, calculate statistics across iterations
    print('y_data', y_data)
    yerr_data = np.zeros((2, n_epochs))
    y50_data = np.zeros((n_epochs))
    means = np.zeros((3, n_epochs))
    for jj in range(n_epochs):
        #extract points
        y_ep = np.array(y_data).T[jj] 
        # calculate posterior on rel. error
        error = np.percentile(y_ep, quants, axis=0).flatten()
        means[0][jj] = error[0]
        means[1][jj] = error[1]
        means[2][jj] = error[2]
        
        y50_data[jj] = error[1]
        yerr_data[0][jj] = (error[1]-error[0])
        yerr_data[1][jj] = (error[2]-error[1])
        
    if returnquants is True:
        return means
    return y50_data, yerr_data

# BROKEN vvv
def min_epoch(y50_data, yerr_data, min_a, max_a):
   
    # meant for use with ydat from plot_scatter_epochs and various a_min and a_max
    # sigma: sigma level for error bars

#     if sigma==1:
#         q = [100-68, 50, 68]
#     elif sigma==2:
#         q = [100-95, 50, 95]
#     elif sigma==3:
#         q = [100-99.73, 50, 99.73]
#     else:
#         print('error: sigma must be 1, 2, or 3')
    numepochs = []
    
    for ii in range(len(min_a)):
        # find epoch where lower error is equal to a_min
        arg = np.argmax((yerr_data[0]>min_a[ii]) and yerr_data[1]<max_a[ii])
        print('yerr lower=',yerr_data[0])
        print('yerr upper=',yerr_data[1])
        print('min_a', min_a[ii])
        print('argmin', arg)
        numepochs.append(arg)
    
    
    return(numepochs)

In [26]:
from matplotlib.ticker import ScalarFormatter
def plot_accuracy_distance(y_data, distances, n_iters_list, n_epochs, idim, direcs, fend='.dat', slides=False, 
                           ylims=[0.5,1.5], read=False, i_hide_list=[], iter_start=0,
                           a_fixed=None, percentile=[16, 50, 84], ep_start=0, figsize=(8,8), 
                           alpha=1, cadence=None, title='', maxprob=True, grid=False):
    
    fig, ax = plt.subplots(1, 1, figsize=figsize)
#     ax.set_xscale('log')
    xmax=np.max(distances)+5
#     ax.set_yscale('log')
    ax.set_xlim([1, xmax])
    ax.set_ylim(ylims)
    if grid:
        ax.yaxis.grid() # horizontal lines
    
    mpl.rcParams.update(inline_rc)
    if slides:
        # matplotlib params for presentations
        plt.style.use('dark_background')
        axissize = 20
        legsize=18
        labelsize = 24
        alpha_min = 0.35
#         patch_width = 500*t # 10 points

    else:
        alpha_min = 0.2
#         patch_width = 500*t # 10 points
        axissize = 20
        legsize=18
        labelsize = 22
        datacolor="k"
        ms=8
        lw=2
        err_fmt='ks'
        
        #rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
        rc('font',**{'family':'serif','serif':['Computer Modern Roman']})
        rc('text', usetex=True)
    
    print('n_iters', np.shape(n_iters_list))
    print('i_hide', np.shape(i_hide_list))
    print('i_hide[0]', i_hide_list[0])
    
    alpha = np.linspace(alpha_min, 1, num=n_epochs)
    fc = colorize(range(n_epochs), cmap='RdPu_r', vmin=-0.5+ep_start, vmax=n_epochs+0.5)[0]
 
    widths = np.linspace(100, 1000, num=n_epochs)
#     widths = [1000]*n_epochs

    
    
    for i_d, dist in enumerate(distances):
        print(i_d)
        n_iters = n_iters_list[i_d]
        print('n_iters', n_iters)
        i_hide = i_hide_list[i_d]
        print('i_hide' , i_hide)
        i_good = [x for x in range(iter_start, iter_start+n_iters) if x not in i_hide]    
        direc = direcs[i_d]
        print('direc', direc)
        truths = read_truths(direc)
        label=''
        acc = median_accuracy(direc, n_epochs, n_iters, idim, d=dist, y_data=y_data[i_d], 
                          percentile=percentile, form='percentiles',maxprob=maxprob,
                          truths=truths, read=read, i_good=i_good)
        
        for i_ep in range(ep_start, n_epochs):
            if i_d == len(distances)-1:
                label=str(i_ep+1)
            
            # get transform to convert pixels to data units for width
            t = ax.transAxes.transform([(0,0), (1,1)])
            t = fig.get_dpi() / (t[1,1] - t[0,1]) / 72
            patch_width = widths[i_ep]*t
            

            xy = (dist-patch_width/2, acc[0][i_ep])
            height = acc[2][i_ep] - acc[0][i_ep]
            
            rect = patches.Rectangle(xy,patch_width,height,zorder=10,
                                     label=label,
                                     linewidth=1,
                                     facecolor=fc[i_ep],
                                     alpha=0.7,
#                                      facecolor='xkcd:dark pink',
#                                      alpha=alpha[i_ep]
                                    )

            ax.add_patch(rect)
    
    
    # set 1:1 line 
    ax.axhline(y=0, ls='-', c='k', lw=0.5, zorder=0)
    l = plt.legend(frameon=False, fontsize=legsize, title='Epochs', ncol=1,
                   bbox_to_anchor=(0, 1), loc=1, borderaxespad=5 
                  )

    plt.setp(l.get_title(),fontsize=legsize+2)
    
    ax.set_title(title, fontsize=labelsize)

    ax.set_ylabel(get_label(idim, a_fixed), fontsize=labelsize)
    ax.set_xlabel('Distance (pc)', fontsize=labelsize)
#         plt.setp(l.get_title(),fontsize=legsize+2)

    plt.tick_params(labelsize=axissize)
    
#     ax.set_xticks(np.logspace(np.log(1), np.log(52), num=5))
    ax.xaxis.set_major_formatter(ScalarFormatter())
    
    f_obscured = get_obscured(n_iters, n_epochs, cadence, direc, distances)[1]
    print('f obscured:', f_obscured)
    
    
    ax2 = ax.twinx()
    # add endpoinds
    f_obscured.insert(0, 0)
    distances.insert(0, 0)
    f_obscured.append(1)
    distances.append(54.94549891321035)
    ax2.plot(distances, f_obscured, '-', c='xkcd:greenish', zorder=0)
    ax2.set_ylabel('Fraction obscured', color='xkcd:greenish', fontsize=labelsize)
    ax2.tick_params('y', labelsize=axissize,  colors='xkcd:greenish')
    ax2.set_ylim(0, 1)
    plt.setp(ax2.get_yticklabels()[0], visible=False)  
    
    
    return fig

In [27]:
def get_obscured(n_it, n_ep, cadence, direc, distances, IWA=3*850e-9/10*206265):
    from mc_orbit_good import obv_at_epoch
    # get number of obscured observations
    n_obscured=[]
    f_obscured=[]
    
    for i_d, dist in enumerate(distances): 
        count=0
        tot=0
        a_IWA = dist*IWA
        print('a_IWA',a_IWA)
        for ii in range(n_it):
            truths = read_truths(direc)[ii]
            _, _, aproj = obv_at_epoch(truths, range(n_ep), cadence=cadence)
            for i_ep in range(n_ep):
                tot = tot+1
                if aproj[i_ep] < a_IWA:
                    count=count+1


        n_obscured.append(count)
        f_obscured.append(count/tot)
        
    return n_obscured, f_obscured   


In [28]:
import matplotlib.patches as patches
import matplotlib.path as path
import matplotlib.animation as animation
from mpl_toolkits.axes_grid1 import make_axes_locatable

def animatePDFs(direc, idim, i_iter, finterval, d=10, fname_base='samples_epochs', fend='.dat', slides=True,
                nbins=100, start_ep=0, n_ep=None, ymax=None, cadence=30, lims=[-1.5, 1.5]):

    samples_arr_raw = read_results(direc+'/'+fname_base+str(i_iter)+fend)
    samples_arr = chains.reshape((-1, ndim),  order='F')
    
    truths = read_truths(direc)[i_iter]
    truth = truths[idim]

    
    if n_ep is None:
        n_ep = np.shape(samples_arr)[0] # number of epochs/frames
    samples = extract1dim(samples_arr, idim) # extract samples across epochs for this parameter
    
    
    xlabel=''
    if idim==0:
        xlabel='retrieved semi-major axis (AU)'
#         samples = np.exp(samples)
#         truth = np.exp(truth)
    elif idim==1:
        xlabel=r'$\cos(i)$'
    elif idim==2:
        xlabel=r'$e$'
    elif idim==3:
        xlabel=r'$\omega_p$'
    elif idim==4:
        xlabel=r'$\xi_0$'
    elif idim==5:
        xlabel=r'$\Omega$'
    
    mpl.rcParams.update(inline_rc)
    plt.style.use('dark_background')
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10,5))
    
    cmap = 'autumn'

#     Z = [[0,0],[0,0]]
#     levels = range(n_ep)
#     CS3 = plt.contourf(Z, levels, cmap=cmap)
#     plt.clf()
    
    # histogram our data with numpy
    n, bins = posterior(samples[0], nbins)

    # get the corners of the rectangles for the histogram
    left = np.array(bins[:-1])
    right = np.array(bins[1:])
    bottom = np.zeros(len(left))
    top = bottom + n
    nrects = len(left)

    # here comes the tricky part -- we have to set up the vertex and path
    # codes arrays using moveto, lineto and closepoly

    # for each rect: 1 for the MOVETO, 3 for the LINETO, 1 for the
    # CLOSEPOLY; the vert for the closepoly is ignored but we still need
    # it to keep the codes aligned with the vertices
    nverts = nrects*(1 + 3 + 1)
    verts = np.zeros((nverts, 2))
    codes = np.ones(nverts, int) * path.Path.LINETO
    codes[0::5] = path.Path.MOVETO
    codes[4::5] = path.Path.CLOSEPOLY
    verts[0::5, 0] = left
    verts[0::5, 1] = bottom
    verts[1::5, 0] = left
    verts[1::5, 1] = top
    verts[2::5, 0] = right
    verts[2::5, 1] = top
    verts[3::5, 0] = right
    verts[3::5, 1] = bottom

    barpath = path.Path(verts, codes)
    patch = patches.PathPatch(
        barpath, facecolor='xkcd:beige', edgecolor='xkcd:beige')
    ax1.add_patch(patch)

    if ymax is None:
        ymax = np.amax(samples)
        
    ax1.set_xlim(left[0], right[-1])
    ax1.set_ylim(0, ymax)
    ax1.set_xlabel(xlabel, fontsize=18)
    ax1.xaxis.set_tick_params(labelsize=16)
    ax1.set_yticks([])  
    
    ax2.set_xlim(lims[0], lims[1])
    ax2.set_xlim(lims[0], lims[1])
    
    yerr_lower, y, yerr_upper = epochs_sigma(samples, s_percentile=[32, 50, 68])
    yerrmin = y-yerr_lower
    yerrmax = y+yerr_upper
    
#     # show retrival
#     line50, = ax1.plot([0, 0], [0, ymax], c='xkcd:green blue', lw=2, label='fit')
#     linemin, = ax1.plot([0, 0], [0, ymax], c='xkcd:green blue', lw=0.5, ls='dashed')
#     linemax, = ax1.plot([0, 0], [0, ymax], c='xkcd:green blue', lw=0.5, ls='dashed')
        
    #truth line
    if idim==0:
        a = np.exp(truth)
        ax1.axvline(x=a, c='xkcd:green blue', lw=3, label='truth', zorder=10)
    else:
        ax1.axvline(x=truth, c='xkcd:green blue', lw=3, label='truth', zorder=10)
    ax1.legend(frameon=False, fontsize=20)

    time_text = ax1.text(0.5, 1.1, '', fontsize=20, color='k', ha='center', va='top', transform = ax1.transAxes)
    err_text = ax1.text(0.95, 0.2, '', ha='right', va='bottom', fontsize=20, color='k', transform = ax1.transAxes)
    
    
    color = colorize(range(n_ep), cmap='autumn')[0]
    
    def animate(i):
        # simulate new data coming in
        if i==0:
            patch.set_visible(False)
            show_orbit(truths, range(n_ep), cadence, ax=ax2, d=d, showIWA=True,showaxes=False,slides=slides,
            showplane=False,n_planets=i, lims=lims)
        else:
            patch.set_visible(True)
            n, bins = posterior(samples[i-1], nbins)

            top = bottom + n
            verts[1::5, 1] = top
            verts[2::5, 1] = top
            patch.set_facecolor(color[i-1]), 

#             line50.set_xdata([y[i-1], y[i-1]])
#             linemin.set_xdata([yerrmin[i-1], yerrmin[i-1]])
#             linemax.set_xdata([yerrmax[i-1], yerrmax[i-1]])

            median = np.median(samples[i-1])
            std = np.std(samples[i-1])
            zscore = (truth - median) / std
            accuracy = (truth - median) / truth

#             err_text.set_text('z-score = {0:.2f}     '.format(zscore)+'\n accuracy = {0:.2f} AU'.format(accuracy))

            if i==1:
                append=''
            else:
                append='s'
    #         time_text.set_text('Posterior PDF: '+str(i+1)+' epoch'+append)

            show_orbit(truths, range(n_ep), cadence, ax=ax2, d=d, showIWA=True,showaxes=False,slides=slides,
                       showplane=False,n_planets=i,lims=lims)
            overplot_orbit_fits(ax2, direc, fname_base, i_iter, epochs_to_show=range(i), lw=1,
                                samples_epochs=samples_arr, epochs_tot=n_ep)    

#         ax.set_xlim(xlim)
        return [patch, ]
    
    

    ani = animation.FuncAnimation(fig, animate, frames=range(n_ep+1), interval=finterval, repeat=True, blit=True)
    plt.tight_layout()
    #plt.show()
    return ani

In [29]:
import corner
def plot_corner(direc, d, c_truth, c_quants, lw_truth=5, alpha_truth=0.5, save=False, slides=False, labelsize=22, i_epoch=0, i_iter=0):

    if slides:
        plt.rcParams.update({
        "lines.color": "white",
        "patch.edgecolor": "white",
        "text.color": "black",
        "axes.facecolor": "white",
        "axes.edgecolor": "lightgray",
        "axes.labelcolor": "white",
        "xtick.color": "white",
        "ytick.color": "white",
        "figure.facecolor": "black",
        "figure.edgecolor": "black",
        "savefig.facecolor": "black",
        "savefig.edgecolor": "black"})
    
    ndim = 6
    fname = direc+'/d'+str(d)+'chains_epochs'+str(i_iter)+'.dat'
    chains_epochs = read_results(fname)[i_epoch]
#     print('chains_epochs', np.shape(chains_epochs))
    
    samples = chains_epochs.reshape((-1, ndim)) 
    truths = read_truths(direc)[i_iter]
    

    if ndim==6:
        lna_mcmc, cosi_mcmc, ecc_mcmc, omega_mcmc, xi_mcmc, lan_mcmc = map(lambda v: (v[1], v[2]-v[1], v[1]-v[0]),
                                     zip(*np.percentile(samples, [32, 50, 68],
                                                        axis=0)))

        quants = np.vstack((lna_mcmc, cosi_mcmc, ecc_mcmc, omega_mcmc, xi_mcmc, lan_mcmc))
    #     dum = np.ones(100)
    #     priors = [lna_true*dum, cosi_true*dum, ecc_true*dum, omega_p_true*dum, xi_0_true*dum] 
        labels = [r"$\ln(a)$", r"$\cos(i)$", r"$e$", r"$\omega_p$", r"$\xi_0$", r"$\Omega$"]

    elif ndim==5:
        lna_mcmc, cosi_mcmc, ecc_mcmc, omega_mcmc, xi_mcmc = map(lambda v: (v[1], v[2]-v[1], v[1]-v[0]),
                                     zip(*np.percentile(samples, [32, 50, 68],
                                                        axis=0)))

        quants = np.vstack((lna_mcmc, cosi_mcmc, ecc_mcmc, omega_mcmc, xi_mcmc))
    #     dum = np.ones(100)
    #     priors = [lna_true*dum, cosi_true*dum, ecc_true*dum, omega_p_true*dum, xi_0_true*dum] 
        labels = [r"$\ln(a)$", r"$\cos(i)$", r"$e$", r"$\omega_p$", r"$\xi_0$"]
        
    fig = corner.corner(samples, labels=labels,
                          #truths=truths,
                          scale_hist=False, verbose=True, hist_kwargs={'density':True},
                        label_kwargs ={'size':labelsize}
                       )

    # Extract the axes
    axes = np.array(fig.axes).reshape((ndim, ndim))

    
    # Loop over the diagonal
    for i in range(ndim):
        ax = axes[i, i]
        ax.axvline(truths[i], color=c_truth, lw=lw_truth, alpha=alpha_truth) # truth
        if i<3: # don't show median for bimodal angles..
            ax.axvline(quants[i, 0], color=c_quants) # mean
            ax.axvline(quants[i, 0]+quants[i, 1], color=c_quants, ls='dashed') # 68% upper
            ax.axvline(quants[i, 0]-quants[i, 2], color=c_quants, ls='dashed') # 32% lower

    # Loop over the histograms
    for yi in range(ndim):
        for xi in range(yi):
            ax = axes[yi, xi]
            ax.axvline(truths[xi], color=c_truth,lw=lw_truth, alpha=alpha_truth) # truth
            ax.axhline(truths[yi], color=c_truth, lw=lw_truth, alpha=alpha_truth) # truth
            ax.plot(truths[xi], truths[yi], marker='s', color=c_truth) # truth
            ax.axvline(quants[xi, 0], color=c_quants) # mean
            ax.axhline(quants[yi, 0], color=c_quants) # mean
            ax.plot(quants[xi, 0], quants[yi, 0], marker='s', color=c_quants)

    if save:
        fig.savefig("corner-epoch"+str(i_epoch)+".png")

In [30]:
def verifylp(theta_ML, direc, d, i_iter, n_ep, cadence, s_mas=5, wl=500e-9, D_tel=10, ndims=6):
    
    from mc_orbit_good import lnprob_Beta, lnlike, lnprior_Beta
    # look at likelihoods

    IWA = 3 * wl/D_tel * 206265 # arcsec
    
    s = d*(s_mas*1e-3) # convert mas error to AU
    a_IWA = d*IWA
    n_iter = i_iter+1

    obvs = read_obvs(direc)[i_iter]
    x, y, aproj = obvs
    
    i_det = []
    i_nondet=[]
    for ii in range(n_ep):
        if aproj[ii] <= a_IWA:
            i_nondet.append(ii)
        else:
            i_det.append(ii)


    lnprob_val = lnprob_Beta(theta_ML, x[:n_ep], y[:n_ep], aproj[:n_ep], s, 
                             cadence, n_ep, i_det, i_nondet, a_IWA)
    lnlike_val = lnlike(theta_ML, x[:n_ep], y[:n_ep], aproj[:n_ep], s, 
                        cadence, n_ep, i_det, i_nondet, a_IWA)
    lnprior_val = lnprior_Beta(theta_ML)
    
#     print('lnprob inputs', theta_ML, x[:n_ep], y[:n_ep], aproj[:n_ep], s, 
#                              cadence, n_ep, i_det, i_nondet, a_IWA)
    
    print('\nlnprob', lnprob_val)
    print('     lnlike', lnlike_val)
    print('     lnprior', lnprior_val)

In [31]:
import collections

def plot_chain(direc, i_it, i_ep, idim, d, i_walker, steps=None, title='', sigma=1, showML=False,
               percentiles=None, lw=0.1, alpha=0.9, nwalkers=None, showIWA=True, showfit=True, 
               nsteps=None, ylim=None, fig=None, ax=None, n_ep=5):
    
    ndim=6

    
    print('i_ep', i_ep)
    
    labelsize=16
    if steps is None:
        steps = slice(-1)
        
    if (fig is None) and (ax is None):
        fig, ax = plt.subplots(1,1, figsize=(16, 4))
    fname = direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'.dat'
    chains_epochs = read_results(fname)[i_ep]
    
    samples_epochs = chains_epochs.reshape((-1, ndim))
 
   
    

    if isinstance(i_walker, collections.Iterable):
        for ii, k in enumerate(i_walker):
            
            chain_epochsk = chains_epochs[k]


            samples_1dim = np.swapaxes(chain_epochsk, 0, 1)[idim]
            samples = samples_1dim
          
            samples = samples[steps]
            nsteps = len(samples)
            
            

            ax.plot(range(nsteps), samples, c='xkcd:black', lw=lw, alpha=alpha)

    
    else:    
        chain_epochs = chains_epochs[i_walker]

        print('chain_epochs', np.shape(chain_epochs))

        samples_1dim = np.swapaxes(chain_epochs, 0, 1)[idim]
        samples = samples_1dim
        pos0 = samples[0]
#         ax.axhline(y=pos0, c='xkcd:baby pink', lw=3, label='initialization')
        print('nsteps available:', len(samples))

        samples = samples[steps]
        nsteps = len(samples)

        ax.plot(range(nsteps), samples, c='xkcd:light grey', lw=lw, alpha=alpha)

    
    
    print('nsteps', nsteps)
    print('samples 1dim', np.shape(samples) )
    ax.set_xlabel("Step", fontsize=labelsize)
    if idim==0:
        label='ln(a)'

    elif idim==1:
        label='cos(i)'
    else:
        label=get_label(idim, form=None)
    ax.set_ylabel(label, fontsize=labelsize)
    ax.set_xlim(0, nsteps)

    
    try:
        truth = read_truths(direc)[i_it][idim]
        ax.axhline(y=truth, c='xkcd:mint', lw=3, label='truth', zorder=20)
    except IndexError:
        print('lost truths')
    
    ax.set_title(title, fontsize=labelsize)
    
    
    ax.set_xlim(0, nsteps)
    
    # show retrievals
    if showfit:
        if percentiles is not None:
            percentile = percentiles
        elif sigma==1:
            percentile = [16, 50, 84]
        elif sigma==2:
            percentile = [2.5, 50, 97.5]



        flatchain = chains_epochs[:, steps, :].reshape((-1, ndim))
#         print('flatchain', np.shape(flatchain))
        samples_1dim = flatchain[1000:,idim] # remove burn in
#         print('samples_1dim', np.shape(samples_1dim))
        fit = np.percentile(samples_1dim, q=percentile, axis=0)

        ax.axhline(y=fit[2], lw=1, c='xkcd:magenta', ls='--', label= '%d sigma upper = ln(%05.3f)' % (sigma, np.exp(fit[2])))
        ax.axhline(y=fit[1], lw=1, c='xkcd:magenta', label='median = ln(%07.5f)' % np.exp(fit[1]))
        ax.axhline(y=fit[0], lw=1, c='xkcd:magenta', ls='--', label= '%d sigma lower = ln(%05.3f)' % (sigma, np.exp(fit[0])))

        print('1-sigma half-width in linear-space=',(np.exp(fit[2])-np.exp(fit[0]))/2)
        
    
    if ylim is not None:
        ax.set_ylim(ylim[0], ylim[1])
        
    ymin, ymax = ax.get_ylim()    
    
    if showIWA:
        IWA = 3 * wl/D_tel * 206265 # mas
        lna_IWA = np.log(d*IWA)
#         ax.axhline(y=lna_IWA, c='xkcd:poo', lw=1, label='IWA')
        ax.fill_between(range(nsteps), y1=ymin, y2=lna_IWA, edgecolor='xkcd:poo', facecolor=None, alpha=0,
                        label='IWA', hatch='/')
        
    
    if showML:
        fname = direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'_lnprob.dat'
        lnprob = read_results(fname)[i_ep]
        flatprob = lnprob.flatten()
        i_ML = np.argmax(lnprob)
        pos_ML = np.unravel_index(i_ML, (nwalkers, nsteps))
        print('best step', pos_ML[1])
        print('@ chain',pos_ML[0])
        ax.axvline(x=pos_ML[1], label='i_ML')
        
    
    
    
    
#     ax2 = ax.twinx()
#     nticks =len(ax.get_yticks())
#     ax2.set_yticks(np.linspace(ylim[0], ylim[1], num=nticks))
#     ticklabs = np.exp(ax2.get_yticks())
#     print(ticklabs)
#     lab = ticklabs.astype(str)

#     ax2.set_yticklabels(lab)
    

#     ax2.yaxis.set_tick_params(labelsize=14)
#     ax2.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    ax.xaxis.set_tick_params(labelsize=14)
    ax.yaxis.set_tick_params(labelsize=14)
    
    ax.legend(fontsize=labelsize-2, frameon=False, ncol=2)
    
    return fig, ax

In [32]:
def plot_chain_dist(direc, i_it, i_ep, d, idim, n_ep, nbins=100, fig=None, ax=None, ylim=[-1, 1],
                   showfit=False, sigma=1,  percentiles=None, showML=False, cadence=None):
    
    ndim=6

    print('truths', read_truths(direc)[i_it])
    
    try:
        fname = direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'flat.dat'
        flatchains = read_results(fname)
        print('reading flat chains; shape', np.shape(flatchains))
        flatchain = flatchains[i_ep]
        
        
# #         print('\n\n flatchain', np.shape(flatchain), 'len', len(flatchain), type(flatchain))
#         print(flatchain)
#         print('\n\n flatchain[100]', np.shape(flatchain[100]), 'len', len(flatchain[100]), type(flatchain[100]))
#         print(flatchain[100])
#         print('\n\n')
    except FileNotFoundError:
        fname = direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'.dat'
        chains = read_results(fname)[i_ep]
        flatchains = chains.reshape((-1, ndim))
        flatchain = np.array(flatchains)
        print('reading unflattened chain')
#     print('flatchain', np.shape(flatchain))
#     print('flatchain[1]', flatchain[1])
#         
    a=np.swapaxes(flatchain, 0, 1)[idim]

        
    # get obvs
    color = colorize(range(i_ep+1), cmap='Greys', vmin=-2)[0]
    for ep in range(i_ep+1):
        obvs = read_obvs(direc)[i_it][2][ep]
        ax.axhline(np.log(obvs), c=color[ep], zorder=0, lw=0.5, label=r'$a_{\rm proj}$ '+str(ep)) # show aproj
    
    if (ax is None):
        fig, ax = plt.subplots()
        
     # show retrievals
    if showfit:
        if percentiles is not None:
            percentile = percentiles
        elif sigma==1:
            percentile = [16, 50, 84]
        elif sigma==2:
            percentile = [2.5, 50, 97.5]


    #     y_data_dim=retrieve_fits(1, 1, idim, direc, None, d, steps=steps,
    #                       results=samples_epochs, s_percentile=percentile, it_start=i_it, ep_start=i_ep) #(n_iters, n_ep, 3)

        y_data_dims = read_retrievals(direc, n_ep, 1, ndim, d, iter_start=i_it, sigma=sigma)[0] # all dims
#         print('y_data_dims', np.shape(y_data_dims)) # y_data_dims (6, 5, 3)

        y_data_dim = y_data_dims[idim]
#         print('y_data_dim', np.shape(y_data_dim)) # y_data_dim (6, 5, 3)

        y_data = y_data_dim[i_ep]

        fit = np.reshape(y_data, (3,1))
        if idim==0: # to match with samples
            fit = np.log(fit)
        elif idim==1:
            fit = np.cos(fit)

        ax.axhline(y=fit[2], lw=1, c='xkcd:magenta', ls='--')
        ax.axhline(y=fit[1], lw=1, c='xkcd:magenta')
        ax.axhline(y=fit[0], lw=1, c='xkcd:magenta', ls='--')
        
    if showML:
        fname = direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'_lnprob.dat'
        lnprob = read_results(fname)[i_ep]
        flatprob = lnprob.flatten()
        
        print('\n overall ML')
        ML_theta, ML_val, i_ML = maxlikelihood(flatchain, flatprob, steps=None, returnmore=True)
        print('ML value', ML_val, '@', i_ML)
        ax.axhline(y=ML_theta[idim], label='ML', lw=1, c='xkcd:goldenrod')
        print('ML theta', ML_theta)
        
        matches = np.argwhere(flatchain==ML_theta)[0]
#         print('matches', matches)
#         print('flatchain[matches]', flatchain[matches])
        
        verifylp(ML_theta, direc, d, i_it, i_ep+1, cadence)
        
        # get second peak
        chain_dim = np.swapaxes(flatchain, 0, 1)[idim]
        
        print('\nother peak')
        if ML_theta[idim] > fit[1]:
            i_subpeak = np.where(chain_dim < fit[1])
        elif ML_theta[idim] < fit[1]:
            i_subpeak = np.where(chain_dim > fit[1])
        flatprob_subpeak = flatprob[i_subpeak]
        chain_subpeak = chain_dim[i_subpeak]
        i_ML_subpeak = np.argmax(flatprob_subpeak)
        ML_subpeak = chain_subpeak[i_ML_subpeak]
        ML_val_subpeak = flatprob_subpeak[i_ML_subpeak]
        print('ML value', ML_val_subpeak)
#         ax.axhline(y=ML_subpeak, label='ML sub', lw=1, c='xkcd:orange')
        
        # full orbit of this peak?
        print('flatchain subpeak', np.shape(flatchain[i_subpeak]))
        ML_theta_sub = maxlikelihood(flatchain[i_subpeak], flatprob_subpeak)
        verifylp(ML_theta_sub, direc, d, i_it, i_ep+1, cadence)
        print('ML theta sub', ML_theta_sub)
        
        print('\nread')
        fname = direc+'/d'+str(d)+'ML'+str(i_it)+'.dat'
        ML_read = read_results(fname)[i_ep]
        ML_read[0] = np.log(ML_read[0])
        ML_read[1] = np.cos(ML_read[1])
        print('ML theta read', ML_read)
        ax.axhline(y=ML_read[0], label='ML read', lw=1, c='xkcd:red')
        
        verifylp(ML_read, direc, d, i_it, i_ep+1, cadence)

    # draw from priors
#     from mc_orbit_good import lnprior_Beta
#     # separate peaks
#     tol = 0.3
#     iii=np.where(abs(chain_dim - ML_theta_sub[idim]) < tol)[0]
# #     print('iii', iii)
#     lp_val=[]
#     for iiii in iii:
# #         print('iiii', iiii)
#         the = flatchain[iiii]
# #         print('the', the)
#         lp_val.append(lnprior_Beta(the))
# #     print('lp_val max', lp_val.max())
    
    
#     lp_val = []
#     i_rand = np.random.randint(1000, 9000, size=500)# generate random numbers
#     for iii in range(500):
#         the = flatchain[iii]
#         lp_val.append(lnprior_Beta(the))
        
#     # plot priors
#     if idim==0:
# #         ys = np.linspace(ylim[0], ylim[1], num=100)
#         pdf = 1/(np.log(50) - np.log(0.1) )
#         ax.axvline(x=pdf, color='xkcd:charcoal', label='prior', ls='--', lw=1)    
        
    ax.hist(a, nbins, density=True, histtype='step', orientation='horizontal', color='k')
    ax.set_ylabel(r'$\ln (a)$', fontsize=18)
    ax.xaxis.set_tick_params(labelsize=14)
    ax.yaxis.set_tick_params(labelsize=14)
    ax.yaxis.tick_right()
    ax.xaxis.tick_top()
    ax.yaxis.set_label_position("right")
    if ylim is not None:
        ax.set_ylim(ylim[0], ylim[1])
        
    ax.legend(fontsize=14, frameon=False)
    
#     print('flatchain[1]', flatchain[1])
    return fig, ax

In [33]:
def comparepriors(direc, i_it, i_ep, d, idim, n_ep, nbins=100, fig=None, ax=None, ylim=[-1, 1],
                   showfit=False, sigma=1,  percentiles=None, showML=False, cadence=None):
    
    ndim=6

    try:
        fname = direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'flat.dat'
        flatchains = read_results(fname)
        print('reading flat chains; shape', np.shape(flatchains))
        flatchain = flatchains[i_ep]
        
        
# #         print('\n\n flatchain', np.shape(flatchain), 'len', len(flatchain), type(flatchain))
#         print(flatchain)
#         print('\n\n flatchain[100]', np.shape(flatchain[100]), 'len', len(flatchain[100]), type(flatchain[100]))
#         print(flatchain[100])
#         print('\n\n')
    except FileNotFoundError:
        fname = direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'.dat'
        chains = read_results(fname)[i_ep]
        flatchains = chains.reshape((-1, ndim))
        flatchain = np.array(flatchains)
        print('reading unflattened chain')
#     print('flatchain', np.shape(flatchain))
#     print('flatchain[1]', flatchain[1])
#         
    a=np.swapaxes(flatchain, 0, 1)[idim]
   
    if (ax is None):
        fig, ax = plt.subplots()
        
     # show retrievals
    if showfit:
        if percentiles is not None:
            percentile = percentiles
        elif sigma==1:
            percentile = [16, 50, 84]
        elif sigma==2:
            percentile = [2.5, 50, 97.5]


    #     y_data_dim=retrieve_fits(1, 1, idim, direc, None, d, steps=steps,
    #                       results=samples_epochs, s_percentile=percentile, it_start=i_it, ep_start=i_ep) #(n_iters, n_ep, 3)

        y_data_dims = read_retrievals(direc, n_ep, 1, ndim, d, iter_start=i_it, sigma=sigma)[0] # all dims
#         print('y_data_dims', np.shape(y_data_dims)) # y_data_dims (6, 5, 3)

        y_data_dim = y_data_dims[idim]
#         print('y_data_dim', np.shape(y_data_dim)) # y_data_dim (6, 5, 3)

        y_data = y_data_dim[i_ep]

        fit = np.reshape(y_data, (3,1))
        if idim==0: # to match with samples
            fit = np.log(fit)
        elif idim==1:
            fit = np.cos(fit)

        ax.axhline(y=fit[2], lw=1, c='xkcd:magenta', ls='--')
        ax.axhline(y=fit[1], lw=1, c='xkcd:magenta')
        ax.axhline(y=fit[0], lw=1, c='xkcd:magenta', ls='--')
        
    if showML:
        fname = direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'_lnprob.dat'
        lnprob = read_results(fname)[i_ep]
        flatprob = lnprob.flatten()
        
        print('\n overall ML')
        ML_theta, ML_val, i_ML = maxlikelihood(flatchain, flatprob, steps=None, returnmore=True)
        print('ML value', ML_val, '@', i_ML)
        ax.axhline(y=ML_theta[idim], label='ML', lw=1, c='xkcd:goldenrod')
        print('ML theta', ML_theta)
        
        matches = np.argwhere(flatchain==ML_theta)[0]
#         print('matches', matches)
#         print('flatchain[matches]', flatchain[matches])
        
        verifylp(ML_theta, direc, d, i_it, i_ep+1, cadence)
        
        # get second peak
        chain_dim = np.swapaxes(flatchain, 0, 1)[idim]
        
        print('\nother peak')
        if ML_theta[idim] > fit[1]:
            i_subpeak = np.where(chain_dim < fit[1])
        elif ML_theta[idim] < fit[1]:
            i_subpeak = np.where(chain_dim > fit[1])
        flatprob_subpeak = flatprob[i_subpeak]
        chain_subpeak = chain_dim[i_subpeak]
        i_ML_subpeak = np.argmax(flatprob_subpeak)
        ML_subpeak = chain_subpeak[i_ML_subpeak]
        ML_val_subpeak = flatprob_subpeak[i_ML_subpeak]
        print('ML value', ML_val_subpeak)
#         ax.axhline(y=ML_subpeak, label='ML sub', lw=1, c='xkcd:orange')
        
        # full orbit of this peak?
        print('flatchain subpeak', np.shape(flatchain[i_subpeak]))
        ML_theta_sub = maxlikelihood(flatchain[i_subpeak], flatprob_subpeak)
        verifylp(ML_theta_sub, direc, d, i_it, i_ep+1, cadence)
        print('ML theta sub', ML_theta_sub)
        
        print('\nread')
        fname = direc+'/d'+str(d)+'ML'+str(i_it)+'.dat'
        ML_read = read_results(fname)[i_ep]
        ML_read[0] = np.log(ML_read[0])
        ML_read[1] = np.cos(ML_read[1])
        print('ML theta read', ML_read)
        ax.axhline(y=ML_read[0], label='ML read', lw=1, c='xkcd:red')
        
        verifylp(ML_read, direc, d, i_it, i_ep+1, cadence)
        matches2 = np.argwhere(flatchain==ML_read)[0]
#         print('flatchain.index', list(flatchain).index(ML_theta))
#         matches = (x for x in flatchain if x == ML_theta)
        print('matches2', matches2)
        print('flatchain[matches2]', flatchain[matches2])
  
    # draw from priors
    from mc_orbit_good import lnprior_Beta
    # separate peaks
    tol = 0.3
    iii=np.where(abs(chain_dim - ML_theta_sub[idim]) < tol)[0]
    print('iii', iii)
    lp_val_sub=[]
    lna_sub = []
    for iiii in iii:
        print('iiii', iiii)
        the = flatchain[iiii]
        print('the', the)
        lna_sub.append(the[0])
        lp_val_sub.append(lnprior_Beta(the))
    print('lp_val_sub max', lp_val_sub.max())
    
    ax.hist(lna_sub, nbins, density=True, histtype='step', orientation='horizontal', color='r')
    
    
    iii=np.where(abs(chain_dim - ML_theta[idim]) < tol)[0]
#     print('iii', iii)
    lp_val=[]
    lna = []
    for iiii in iii:
#         print('iiii', iiii)
        the = flatchain[iiii]
#         print('the', the)
        lna.append(the[0])
        lp_val.append(lnprior_Beta(the))
    print('lp_val max', lp_val.max())
    
    ax.hist(lna, nbins, density=True, histtype='step', orientation='horizontal', color='b')
    
    
#     lp_val = []
#     i_rand = np.random.randint(1000, 9000, size=500)# generate random numbers
#     for iii in range(500):
#         the = flatchain[iii]
#         lp_val.append(lnprior_Beta(the))
        
#     # plot priors
#     if idim==0:
# #         ys = np.linspace(ylim[0], ylim[1], num=100)
#         pdf = 1/(np.log(50) - np.log(0.1) )
#         ax.axvline(x=pdf, color='xkcd:charcoal', label='prior', ls='--', lw=1)    
        
    
    ax.set_ylabel(r'$\ln (a)$', fontsize=18)
    ax.xaxis.set_tick_params(labelsize=14)
    ax.yaxis.set_tick_params(labelsize=14)
    ax.yaxis.tick_right()
    ax.xaxis.tick_top()
    ax.yaxis.set_label_position("right")
    if ylim is not None:
        ax.set_ylim(ylim[0], ylim[1])
        
    ax.legend(fontsize=14, frameon=False)
    
#     print('flatchain[1]', flatchain[1])
    return fig, ax

In [34]:
def plot_lnprob(direc, i_it, i_ep, idim, d, i_walker, steps=None, title='', showML=False, 
                nwalkers=None, nsteps=None, lw=0.1, alpha=0.9, ndim=6):
    
    labelsize=16
    if steps is None:
        steps = slice(0, -1)
    fig, ax = plt.subplots(1,1, figsize=(16, 4))
    fname = direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'_lnprob.dat'
    lnprob = read_results(fname)[i_ep]
    
    lnprob_walker = np.swapaxes(lnprob, 0, 1)[i_walker]
    
#     print('lnprob_walker', np.shape(lnprob_walker))
#     print(lnprob_walker)
    nsteps = len(lnprob)
    
    fname = direc+'/d'+str(d)+'chains_epochs'+str(i_it)+'.dat'
    chains = read_results(fname)[i_ep]
    
#     print('chains', np.shape(chains))
    
    print('nsteps available:', len(lnprob))
    
    flatchain = chains.reshape((-1, ndim))
    flatprob = lnprob.flatten()
#     print('flatchain', np.shape(flatchain))
#     print('flatprob', np.shape(flatprob))
    max_prob = maxlikelihood(flatchain, flatprob)
    
    ax.axhline(y=np.max(lnprob_walker), c='xkcd:teal', lw=2, label='max likelihood')
      
    stepvec = np.arange(nsteps)[steps]    
    ax.plot(stepvec, lnprob_walker[steps], c='xkcd:black', lw=lw, alpha=alpha)
    
    ax.set_xlabel("Step", fontsize=labelsize)
    ax.set_ylabel('ln prob', fontsize=labelsize)
    ax.set_xlim(0, nsteps)

    ax.set_title(title, fontsize=labelsize)
    ax.set_xlim(stepvec[0], stepvec[-1])
    
    if showML:

        i_ML = np.argmax(flatprob)
        pos_ML = np.unravel_index(i_ML, (nwalkers, nsteps))
        print(pos_ML[1])
        ax.axvline(x=pos_ML[1], label='i_ML', ls='--', alpha=0.5)

    ax.legend(fontsize=labelsize-2)
    ax.set_yscale('log')
    return fig, ax, max_prob

In [35]:
def get_zscore(direc, n_it, i_ep, idim, d, maxprob=True, iter_start=0, y_data=None, n_epochs=5, sigma=1,
              i_hide=[]):
    
    # for given epoch: 
    
    # z score is for single iteration is (a_true - a_fitted)/std - where std is for this chain
    
    # want to get mean z-score across iterations - i.e. different "experiment" runs
    ndims=6
    truths_iters = read_truths(direc)
    samples=[]
    zscores=[]
    
    
    if y_data is None:
        #(direc, n_ep, n_it, ndims, d, iter_start=0, sigma=1)
        y_read = read_retrievals(direc, n_epochs, n_it, ndims, d, iter_start=iter_start, sigma=sigma) # (n_iter, ndims, n_epoch, 3)
        # already in actual format
#             y_data = np.swapaxes(y_data, 0, 1)
#             y_data = np.swapaxes(y_data, 1, 2) # want (1, n_epoch, ndims)

    else:
        y_read = y_data
        
    i_good = [x for x in range(n_it) if x not in i_hide]

    

    for ii in range(len(i_good)):
    
        truth = truths_iters[i_good[ii]][idim]
     

        y_data = y_read[i_good[ii]][idim][i_ep]
#         print('\n',ii,'\n\n', y_data)


        if not maxprob: # read in percentiles
            medians = y_data[1]
            medians = np.swapaxes(medians, 0, 1)
            print('medians', np.shape(medians))
            # this is already re-parameterized
            y=medians
        else:
            fname = direc+'/d'+str(d)+'ML'+str(i_good[ii])+'.dat'
            MLs = read_results(fname)
#             print(ii, 'MLs', np.shape(MLs))
            y=MLs[i_ep][idim]
            
        

        if idim==0:
            truth = np.exp(truth)
        elif idim==1:
            truth = np.arccos(truth)
        elif ((idim==3) or (idim==5)) and (truth > np.pi):
            print('adding pi to fit')
            y = y + np.pi

        fit = y
#         print(ii, 'fit', fit)
        lower = y_data[0]
        upper = y_data[2]
        if y > truth:
            sig = fit - lower
        elif y < truth:
            sig = upper - fit
            
        zscore = (truth - fit) / sig
        zscores.append(zscore)
        
#     if i_ep == n_epochs-1:
#         print('y_read', y_read) # y_data (3, 6, 5)
    
    zscore_mean = np.mean(zscores)
    zscore_std = np.std(zscores)
    print('n_it=',len(zscores))
        
    return zscores, zscore_mean, zscore_std
    

In [36]:
# make awesome diagnostic plot
from matplotlib import gridspec

def diagnostic_plots(direc, i_it, ecc_dist, d, cadence, n_it_tot, save=False, maxprob=False, showML=True,
                     nwalkers=30, nsteps=1e5):
    
    truths = read_truths(direc)[i_it]
    obvs = read_obvs(direc)[i_it]
    n_ep=5
    ndims=6
    idim=0
    

    y_data_dims = read_retrievals(direc, n_ep, n_it_tot, ndims, d, iter_start=0, sigma=1, prints=True) # (n_iter, ndims, n_epoch, 3)
    print('\n\n', np.shape(y_data_dims), '\n\n')
    y_data_thisdim = np.swapaxes(y_data_dims, 0, 1)[idim]

    ####### ORBIT
    lna, cosi, ecc, omega_p, xi_0, lan = truths
    fig, ax = show_orbit(truths=[lna, cosi, ecc, omega_p, xi_0, lan], epochs_to_show=range(n_ep), 
                              cadence=cadence, d=d, s_mas=5, 
                              showplane=False,#lims=[-3, 3],
                              lims=[-1.7, 1.7],
                               showIWA=True, savefig=False, showaxes=False, cmap='autumn')
    fig, ax = overplot_orbit_fits(fig, ax, direc, showerr=True, epochs_tot=n_ep, i_iter=i_it, d=d, 
                                       epochs_to_show=range(2, n_ep), lw=1, cmap='autumn', 
                                       maxprob=True, cadence=cadence, showobvs=True)
    

    #########################################################
    ########### empirical posteriors
    #########################################################
    
    _, _, a_proj = read_obvs(direc)[i_it]
    a_IWA = d*IWA
    print('a_proj', a_proj)
    
    i_ep=0
    if a_proj[0] > a_IWA:
        fname = (str(cwd)+'/'+direc+'/d'+str(d)+'a_posterior_ep0_'+str(i_it)+'.dat')
        text=''
    else:
        fname = (str(cwd)+'/'+direc+'/d'+str(d)+'a_posterior_nd0.dat')
        text=' (nondetection)'
    while (os.path.isfile(fname)) and (i_ep<n_ep):   
        print('reading', fname)
        fig, ax = plt.subplots(1, 1, figsize=(2,2))
        sma_list = read_results(fname)
        n, bins = np.histogram(sma_list, bins=50, density=True, range=(0, 5))
        mids = 0.5*(bins[1:] + bins[:-1])
        if i_ep>0:
            pdf, quants = multiply_pdfs(n, pdf, bins)
        else:
            pdf=n
            quants = np.percentile(sma_list, [16, 50, 84])
        ax.plot(mids, pdf, 'k-')
        ax.set_xlabel('semi-major axis')
        ax.set_title('Epoch '+str(i_ep)+text)
        lower, fit, upper = quants
        ax.axvline(x=lower, color='xkcd:magenta', ls='--')
        ax.axvline(x=upper, color='xkcd:magenta', ls='--')
        ax.axvline(x=fit, color='xkcd:magenta', label='median')
        ax.axvline(x=np.exp(lna), color='xkcd:aquamarine', label='truth')
        ax.text(4, 1, 'p='+str((upper-lower)/2), ha='right')
        ax.text(4, 0.5, 'a='+str(fit-np.exp(lna)), ha='right')
        ax.legend(frameon=False)
        i_ep = i_ep+1
        if i_ep==n_ep:
            break
        if a_proj[i_ep] > a_IWA:
            fname = (str(cwd)+'/'+direc+'/d'+str(d)+'a_posterior_ep'+str(i_ep)+'_'+str(i_it)+'.dat')
            text=''
        else:
            fname = (str(cwd)+'/'+direc+'/d'+str(d)+'a_posterior_nd0.dat')
            text=' (nondetection)'

    #########################################################    
    ########################### chains
    #########################################################
    
    try:
        eps = [2, 3, 4]
        for i_ep in eps:

            fig = plt.figure(figsize=(16, 4)) 
            gs = gridspec.GridSpec(1, 2, width_ratios=[7, 1]) 
            ax0 = plt.subplot(gs[0])
            ax1 = plt.subplot(gs[1])

            if i_ep >= 3:
                k = nwalkers
            else:
                k=2*nwalkers
            steps = slice(0, int(nsteps))


            fig, ax0 = plot_chain(direc, i_ep=i_ep, i_it=i_it, i_walker=range(k), idim=0, 
                                       d=d, showML=False, nwalkers=k,  ylim=[-2, 3],
                                       steps=None, showIWA=True, showfit=True,
                                       alpha=0.5, lw=0.5,
                                       title=str(i_ep+1)+' Epoch', fig=fig, ax=ax0)
            fig, ax1 = plot_chain_dist(direc, i_it, i_ep, d=d, idim=0, n_ep=5, nbins=25, ylim=[-2, 3], 
                                            fig=fig, ax=ax1, showML=showML, cadence=cadence,
                                           showfit=True)
            plt.tight_layout()
    except Exception as e:
        print(e)
        print('(chains lost!?)')


    ############# precison
    if i_it==0:
        i_hide=[]
    else:
        i_hide = np.arange(i_it)
    fig, ax = plot_diminishing_returns(y_data_thisdim, n_ep, i_it+1, idim=0, s=5, d=d, direc=direc,
                    alpha=1, lw=0.5, showguide=False,
                                    iter_start=0, read=True,
                                ylim=[0, 1],
                                legtitle=[],
                                    n_datasets=1,
                                    showlegend=False,
                            i_hide=i_hide,
                        figsize=(2.5,4),                             
                    )
    
    ########## accuracy
    fig, ax = plot_scatter_epochs(y_data_thisdim, n_ep, i_it+1, idim, truths, d=d, direc=direc,
                alpha=1, read=True,  n_datasets=1, maxprob=maxprob,
                                    showlegend=False,
                            i_hide=i_hide, ylim=[-1, 1],
                        figsize=(2.5,4),                             
                    )
    
    ########print 
#     precision(direc, y_data_thisdim, ecc_dist=ecc_dist, n_it=i_it+1, d=d, aproj_range=None, ML=False, save=False,
#                   plot=False, it_start=i_it, s_mas=5, i_hide=[], n_det_needed=3)
    
#     retrieve_fits(n_iters, n_epochs, idim, direc, fname_base, d=d, ML=False, fend='.dat', steps=None,
#                       results=None, s_percentile=[16, 50, 84], it_start=0)
    read_retrievals(direc, n_ep, n_it=i_it+1, ndims=ndims, d=d, iter_start=i_it, sigma=1, prints=True)

In [90]:
def inclination_v_error(direcs, n_it, d):
    # plot 3rd epoch precision versus inclination
    inc = []
    prec3 = []
    
    for direc in direcs:
        truths_iters = read_truths(direc)
        for i_it in range(n_it):
            _, cosi, _, _, _, _ = truths_iters[i_it]
            i = np.arccos(cosi)
            inc.append(i)
            try:
                fname = direc+'/d5prec'+str(i_it)+'.dat'
                p3 = read_results(fname)[2]
            except:
                try:
                    fname = direc+'/d10prec'+str(i_it)+'.dat'
                    p3 = read_results(fname)[2]
                except:
                    fname = direc+'/d20prec'+str(i_it)+'.dat'
                    p3 = read_results(fname)[2]
            prec3.append(p3)
        
    print('inc', np.shape(inc))    
    print('prec3', np.shape(prec3))   
    fig, ax = plt.subplots(1, 1)
    ax.scatter(inc, prec3)
    ax.set_xlabel('Inclination')
    ax.set_ylabel('Precision at 3rd epoch (AU)')
        

In [91]:


def get_trueseparation_accuracy(direc, n_it, i_ep, idim, d, n_epochs, cadence, maxprob=True, iter_start=0, s=5):
    from mc_orbit_good import getNextPhase
    truths_iters = read_truths(direc)
    acc = []
    for ii in range(n_it):
        fname = direc+'/d'+str(d)+'ML'+str(ii)+'.dat'
        MLs = read_results(fname)[i_ep]
        truths = truths_iters[ii]
        
        lna, cosi, ecc, omega_p, xi_0, lan = truths
        a_true = np.exp(lna)
        ecc_true = ecc
        M_true = getNextPhase(i_ep, cadence, a_true, xi_0)
        
#         print('M_true', M_true)
        
        a_model = MLs[0]
        ecc_model = MLs[2]
        xi_model = MLs[4]
        M_model = getNextPhase(i_ep, cadence, a_model, xi_model)
#         print('M_model', M_model)
        
        nu_true = TrueAnom(ecc_true,M_true,5)
        nu_model = TrueAnom(ecc_model,M_model,5)
        
        
        
        sep_true = a_true * (1-ecc_true**2) / (1+ecc_true*np.cos(nu_true))
        sep_model= a_model * (1-ecc_model**2) / (1+ecc_model*np.cos(nu_model))
        
        acc.append(abs((sep_model - sep_true)/sep_true))
        
    return np.mean(acc)
        
        

In [92]:
from astropy import units as u
from astropy import constants as c
from astropy.coordinates import Angle


def get_pix_accuracy(direc, n_it,  idim, d, n_epochs, cadence, maxprob=True, iter_start=0, s=5):
    from mc_orbit_good import obv_at_epoch
    
    s_xy = d*(s*1e-3) # convert mas error to AU

    truths_iters = read_truths(direc)

    n_pxs = []
    
    wl = 0.5 * u.micron
    D = 10 * u.m
    s_px = (wl/D/2).to(u.mas, equivalencies=u.dimensionless_angles()) # pixel scale in mas
    GM = c.GM_sun
    r = (d * u.pc)
    
    i_ep = n_epochs-1
    
    for ii in range(n_it):
        fname = direc+'/d'+str(d)+'ML'+str(ii)+'.dat'
        MLs = read_results(fname)
    
        truths = truths_iters[ii]
        
        fits = MLs[i_ep]
        fits[0] = np.log(fits[0])
        fits[1] = np.cos(fits[1])
        
        x_model, y_model, aproj_model = obv_at_epoch(fits, i_ep+1, cadence=cadence, s_AU=s_xy, 
                                                     noise=False)
        x_true, y_true, aproj_true = obv_at_epoch(truths, i_ep+1, cadence=cadence, s_AU=s_xy, 
                                                     noise=False)
    
    
        # how many pixels separate (x,y)_model and (x,y)_true ?
        dx = abs(x_model - x_true)
        dy = abs(y_model - y_true)
        ds_AU = abs(aproj_model - aproj_true)
        
        # convert to mas
        
        ds = ( (ds_AU*u.AU) / (d*u.pc) ).to(u.mas, equivalencies=u.dimensionless_angles())
    
        d_px = ds/s_px # how many pixels away
       
#         a = (a_true* u.AU).to(u.m)
#         v = np.sqrt(GM/a)


#     #     print('s_px=',s_px)

        
#         w = (v/r).to(u.mas / u.s, equivalencies=u.dimensionless_angles())
#     #     print('\nw(r=10)=',w)

#         tau = (cadence*u.day).to(u.s)
#         s_planet = (w*tau)
#     #     print('s_planet(r=10)=',s_planet)

#         n_px_true = s_planet/s_px # how much planet has actually moved
        
        
        
    #     print('n_px(r=10)=',n_px)
        n_pxs.append(d_px)
        
    return np.mean(n_pxs)




In [93]:
# -*- coding: utf-8 -*-
"""
Created on Thu Mar  3 21:44:34 2016

@author: jlustigy
"""


def colorize(vector,cmap='plasma', vmin=None, vmax=None):
    """Convert a vector to RGBA colors.

    Parameters
    ----------
    vector : array
        Array of values to be represented by relative colors     
    cmap : str (optional)
        Matplotlib Colormap name
    vmin : float (optional)
        Minimum value for color normalization. Defaults to np.min(vector)
    vmax : float (optional)
        Maximum value for color normalization. Defaults to np.max(vector)
        
    Returns
    -------
    vcolors : np.ndarray
        Array of RGBA colors
    scalarmap : matplotlib.cm.ScalarMappable
        ScalerMap to convert values to colors
    cNorm : matplotlib.colors.Normalize
        Color normalization
    """
    
    if vmin is None: vmin = np.min(vector)
    if vmax is None: vmax = np.max(vector)    
    
    cm = plt.get_cmap(cmap)
    cNorm  = colors.Normalize(vmin=vmin, vmax=vmax)
    scalarmap = cmx.ScalarMappable(norm=cNorm, cmap=cm)
    vcolors = scalarmap.to_rgba(vector)
    
    return vcolors,scalarmap,cNorm
        

In [94]:
def custom_colormap(C):
    return mpl.colors.ListedColormap(C/255.0)


In [95]:
def get_xy_accuracy(direc, n_it,  idim, d, cadence, n_epochs, maxprob=True, iter_start=0, s=5):
    from mc_orbit_good import obv_at_epoch
    
    s_xy = d*(s*1e-3) # convert mas error to AU
    
    truths_iters = read_truths(direc)
    x_acc = []
    y_acc=[]

    i_ep = n_epochs-1
    for ii in range(iter_start, n_it):
        fname = direc+'/d'+str(d)+'ML'+str(ii)+'.dat'
        MLs = read_results(fname)
#         print('MLs', np.shape(MLs))
        truths = truths_iters[ii]
    
        fits = MLs[i_ep]
        fits[0] = np.log(fits[0])
        fits[1] = np.cos(fits[1])
        x_model, y_model, aproj_model = obv_at_epoch(fits, i_ep+1, cadence=cadence, s_AU=s_xy, 
                                                     noise=False)
        x_true, y_true, aproj_true = obv_at_epoch(truths, i_ep+1, cadence=cadence, s_AU=s_xy, 
                                                     noise=False)
        
#         print('x_true', x_true)
#         print('x_model', x_model)
#         print('x acc', (x_model - x_true)/x_true)
        x_acc.append(abs((x_model - x_true)/x_true))
        y_acc.append(abs((y_model - y_true)/y_true))

        
#     print('x_acc', np.shape(x_acc))
    
    return np.mean(x_acc), np.mean(y_acc)
        
    
    # return mean accuracy, idea is to run for different n_epochs

In [96]:
def EccAnom(ec,m,dp):
    # arguments:
    # ec=eccentricity, m=mean anomaly,
    # dp=number of decimal places
    if ec==0:
        return m
    else:
        maxIter=30
        i=0
        delta=10**-dp
        E=np.pi

        if (ec<0.8):
            E=m
        F = E - ec*np.sin(m) - m
        while ((np.abs(F)>delta) & (i<maxIter)):
            E = E - F/(1.0-ec*np.cos(E))
            F = E - ec*np.sin(E) - m
            i = i + 1
        return E

def TrueAnom(ec,m,dp):
    E=EccAnom(ec,m,dp)
    S=np.sin(E)
    C=np.cos(E)
    fak=np.sqrt(1.0-ec*ec)
    phi=math.atan2(fak*S,C-ec)
    return phi



In [97]:
# d=10
# fact = 500e-9/10 * 206265 * d
# x0 = 2.5*fact
# b = 1
# c = 0.0000000000000001

# x = np.arange(5*fact, step=1e-5)
# y = 1 / (1 + b*c**(x0-x))

# fix, ax=plt.subplots(1, 1)
# ax.plot(x, y, 'k')
# print(x0)
# ax.axvline(x=1.5*fact, label='1.5L/D', c='r')
# ax.axvline(x=3.5*fact, label='3.5L/D', c='b')
# ax.legend()