# analyze the trace obtained from notebook No. 4 for Table2 1993 QS  fe_13

## for the monte-carlo markov-chain bayesian inference model of pymc3 developed by Chris Fonnesbeck, Anand Patil, David Huard, John Salvatier

## the original analysis was done with pymc version 2

## pymc3 is what is available now

In [None]:
import os
import fnmatch
import json
import pickle
from datetime import date
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
import theano
from theano import tensor as tt
import arviz as az
import ChiantiPy.tools.util as chutil
import ChiantiPy.model.Maker as mm

In [None]:
autoreload 3

In [None]:
wd = os.getcwd()

In [None]:
wd

In [None]:
today = date.today()

In [None]:
thisday =today.strftime('%Y_%B_%d')

In [None]:
thisday

In [None]:
matplotlib qt

In [None]:
myIon = 'fe_13'

In [None]:
nameDict = chutil.convertName(myIon)

In [None]:
ls

## fe_13 data is in the .json file

In [None]:
jsonList = []
dirList = os.listdir(wd)
for fname in dirList:
    if fnmatch.fnmatch(fname,'*.json'):
        jsonList.append(fname)

In [None]:
for idx, fname in enumerate(jsonList):
    print(' idx:  %i   fname:  %s'%(idx,fname))

In [None]:
jsonName = jsonList[0]

In [None]:
jsonName

In [None]:
with open(jsonName, 'r') as inpt:
    specData = json.load(inpt)

In [None]:
specData.keys()

In [None]:
specData['ref']

## the weighted chi-squared = sum ((pred-obs)/(wghtFactor x obs))**2

In [None]:
wghtFactor = 0.2

## create the box of matches object

In [None]:
box = mm.maker(specData, wghtFactor = wghtFactor, verbose = True)

In [None]:
box.SpecData['filename']

In [None]:
ls

## open the pickled match file

In [None]:
matchList = []
dirList = os.listdir(wd)
for fname in dirList:
    if fnmatch.fnmatch(fname,'*.pkl'):
        matchList.append(fname)

In [None]:
for idx, fname in enumerate(matchList):
    print(' idx:  %i   fname:  %s'%(idx,fname))

In [None]:
matchName = matchList[0]

In [None]:
matchName

In [None]:
with open(matchName,'rb') as inpt:
    match = pickle.load(inpt)

In [None]:
match.keys()

In [None]:
match['match'][0].keys()

## temperature and density are in the pickled match file

In [None]:
match['EDensity'].shape

In [None]:
' %10.2e'%(match['EDensity'][0])

In [None]:
dens = match['EDensity']

In [None]:
' density range = %10.2e to %10.2e'%(dens.min(), dens.max())

## load the matches saved in the pickle file

In [None]:
box.loadMatch(matchName)

## begin set up for mcmc sampling model

## only need it to restore the stored trace

## it will not be run

In [None]:
nDens = match['EDensity'].size
print(' # of densities %5i'%(nDens))

## tune and samples may need to be increased after the first run and then rerun

## the no. of cores depends on your machine

In [None]:
tune = 2000
samples = 100000
cores = 4

## the predicted intensity matrix

In [None]:
pred = np.zeros((nDens, len(box.Match)), np.float64)
for iwvl, amatch in enumerate(box.Match):
    pred[0:,iwvl] += amatch['intensitySum']

In [None]:
' tune = %i  samples = %i cores = %i'%(tune, samples, cores)

## the observed intensities

In [None]:
nObs = len(box.Match)
intensity = np.zeros(nObs, np.float64)
for iwvl in range(nObs):
    intensity[iwvl] = box.Match[iwvl]['obsIntensity']

## create the MCMC model and perform the sampling

In [None]:
with pm.Model() as model:
    d0 = pm.DiscreteUniform('d0', lower = 0, upper = nDens - 1, dtype='int64')
    
    em = pm.Uniform('em')  #  sigma was 0.1


    xpred = theano.shared(pred, name='p0')

    idx0 = tt.as_tensor_variable(d0)
    
    predicted = xpred[idx0]*10.**em
    

    sigma = np.sqrt(wghtFactor*intensity)
    
    
    Y_obs = pm.Normal('Y_obs', mu=predicted, sigma=sigma, observed=intensity)

    
    pm.NUTS([em],target_accept=0.87)  # was 0.8 was 0.87
    pm.Metropolis([d0], target_accept=0.87)  # was 0.8 was 0.87
#    start = {'d0':Dindex}
#    start['em'] = emLog
#    trace = pm.sample(samples, tune=tune, cores=cores, start=start)
#    pm_data = az.from_pymc3(
#        trace=trace)
#pm_data


## the trace has already been saved for later analysis

## here, the trace will be used for analysis and not the pm_data

In [None]:
ls

## find the pickled results dict from the mcmc run in #4

In [None]:
pklList = []
dirList = os.listdir(wd)
for fname in dirList:
    if fnmatch.fnmatch(fname,'*.pkl'):
        pklList.append(fname)

In [None]:
for idx, fname in enumerate(pklList):
    print(' idx:  %i   fname:  %s'%(idx,fname))

In [None]:
resultsName = matchList[5]

In [None]:
resultsName

In [None]:
with open(resultsName,'rb') as inpt:
    resultsDict = pickle.load(inpt)

In [None]:
resultsDict.keys()

In [None]:
pth = resultsDict['pth']

In [None]:
pth

In [None]:
trace = pm.load_trace(pth, model)

## arviz can plot the traces and histograms, among other things

In [None]:
az.plot_trace(trace, var_names=["d0"], kind='trace')

In [None]:
az.plot_trace(trace, var_names=["em"], kind='trace')

In [None]:
trace.stat_names

In [None]:
trace.varnames

In [None]:
trace['d0'].shape

## get the mean values and std of the traces

In [None]:
d0Mean = trace['d0'].mean()
d0Std = trace['d0'].std()
' d0Mean = %i d0Std = %10.5f'%(d0Mean, d0Std)

In [None]:
plt.figure()
xyhist = plt.hist(trace['d0'])
plt.xlabel('Temperature Index', fontsize=14)
plt.ylabel('Frequency', fontsize=14)
plt.title('d0    mean = %8.3f  std = %8.5f \n tune = %i  samples = %i wF:  %6.3f'%(d0Mean, d0Std, tune, samples, wghtFactor), fontsize=14)
plt.tight_layout()

In [None]:
em0Mean = trace['em'].mean()
em0Std = trace['em'].std()
' em0Mean = %10.3f em0Std = %10.5f'%(em0Mean, em0Std)

In [None]:
plt.figure()
xyhist = plt.hist(trace['em'])
plt.xlabel('Log$_{10}$ Emission Measure', fontsize=14)
plt.ylabel('Frequency', fontsize=14)
plt.title('EM    mean = %8.3f  std = %8.5f \n tune = %i  samples = %i wF: %6.3f'%(em0Mean, em0Std, tune, samples, wghtFactor), fontsize=14)
plt.tight_layout()

## predict the intensities with the parameters derived from the MCMC sampling

In [None]:
newDindex = int(np.round(d0Mean))

In [None]:
newDindex

In [None]:
newEmLog = em0Mean

In [None]:
' new Emlog:  %10.3f '%(newEmLog)

In [None]:
box.emSetIndices([newDindex])
print('density  set to %10.2e '%(dens[newDindex]))

In [None]:
autoreload 3

In [None]:
box.EmIndices

In [None]:
box.emSet([newEmLog])

In [None]:
box.predict()

In [None]:
sort = 'wvl'

In [None]:
matchName

In [None]:
os.path.splitext(matchName)

In [None]:
printName = os.path.splitext(matchName)[0] + '_postPredictPrint_%i_%i_%s_%s.txt'%(tune, samples, today, sort)

In [None]:
printName

In [None]:
box.predictPrint(filename=printName, sort=sort)

In [None]:
diffName = os.path.splitext(matchName)[0] + '_postDiffPrint_%i_%i_%s_%s.txt'%(tune, samples, today, sort)

In [None]:
diffName

In [None]:
box.diffPrint(filename=diffName, sort=sort)

## lets plot the differences

In [None]:
wvl = box.Diff['wvl']
diff = box.Diff['diffOverInt']

In [None]:
diffMean = diff.mean()
diffStd = diff.std()

In [None]:
mytitle = 'diff Mean %10.3f  diff Std  %10.3f'%(diffMean, diffStd)

In [None]:
mytitle

In [None]:
plt.figure()
plt.plot(wvl, diff,'o')

In [None]:
plt.axhline(diffMean, color='k', lw=2, label='Mean')
plt.axhline(diffMean + diffStd, color='r', lw=2, linestyle='--', label='1 std')
plt.axhline(diffMean - diffStd, color='r', lw=2, linestyle='--')  #, label='1 std')
plt.axhline(diffMean + 2.*diffStd, color='b', lw=2, linestyle='dotted', label='2 std')
plt.axhline(diffMean - 2.*diffStd, color='b', lw=2, linestyle='dotted')  #, label='2 std')
plt.axhline(diffMean + 3.*diffStd, color='g', lw=2, linestyle='dotted', label='3 std')
plt.axhline(diffMean - 3.*diffStd, color='g', lw=2, linestyle='dotted')  #, label='3 std')

In [None]:
plt.xlabel('Wavelength ($\AA$)', fontsize=14)
plt.ylabel(r'(Obs - Pred)/(w $\times$ Obs)', fontsize=14)

In [None]:
plt.title(mytitle, fontsize=14)

In [None]:
plt.legend(loc='upper right', bbox_to_anchor=(0.99, 1.0), fontsize=12)

In [None]:
plt.tight_layout()

## this can also be done with diffPlot

In [None]:
box.diffPlot(title=True)

## the plot object are saved in the DiffPlot attribute

In [None]:
box.DiffPlot.keys()

In [None]:
fig = box.DiffPlot['fig']
ax =  box.DiffPlot['ax']

## so, can do things like the following

In [None]:
ax.set_xlim([310., 330.])