In [None]:
import sys
import os
_path = os.path.abspath('../')
if _path not in sys.path:
    sys.path.insert(0, _path)
print('done adding path')
    
import astropy.coordinates as coord
import numpy as np
from gala.units import UnitSystem
from yellowcard.model import TimingArgumentModel 
from yellowcard.model_select import *
from yellowcard.coordinates import fiducial_m31_c,LocalGroupHalocentric
import astropy.units as u
from scipy.optimize import minimize
from yellowcard.keplerianPlane import LGKepler
from numpy.linalg import norm
import emcee
import arviz as az
from schwimmbad import MultiPool 
import warnings
warnings.filterwarnings('ignore')
import corner

import matplotlib.pyplot as plt


In [None]:
# modelChoice = None
# modelChoice = "vdm_rad"
modelChoice = "vdm"
# modelChoice = "fiducial"

In [None]:
###################################################
# testing VDM 2012 with zero transverse velocity! #
###################################################
if modelChoice == "vdm_rad":
    model = model_vdm_rad()
    title = "van der Marel 2012 - Radial Only"
elif modelChoice == "vdm":
    model = model_vdm()
    title = "van der Marel 2012"
elif modelChoice == "fiducial":
    model = model_fid()
    title = "Fiducial"
else:
    model = model_testing()
    title = "Testing model - small dist err"
    print('Warning: test model!!')
    

In [None]:
galcen_m31     = fiducial_m31_c.transform_to(model.galcen_frame)
galcen_m31_pos = galcen_m31.data.without_differentials()
galcen_m31_vel = galcen_m31.velocity
galcen_m31_L   = galcen_m31_pos.cross(galcen_m31_vel)
galcen_m31_L   = galcen_m31_L / galcen_m31_L.norm()

In [None]:
e_init   = 0.9
eta_init = 5*u.rad
alpha_init = 0*u.rad
init_par = {}

init_par['lnr'] = np.log(fiducial_m31_c.distance.value)
init_par['eParam'] = -3
init_par['coseta'] = np.cos(eta_init)
init_par['sineta'] = np.sin(eta_init)
init_par['lnM'] = np.log((4e12*u.Msun).decompose(model.unit_system).value)
# init_par['Lhatlg'] = galcen_m31_L.xyz
init_par['cosalpha'] = np.cos(alpha_init)
init_par['sinalpha'] = np.sin(alpha_init)

In [None]:
init_par

---
## creating first minimization of MCMC 

In [None]:
result = minimize( lambda *args: -model(*args), model.pack_pars(init_par), method='Powell')
# result

____
# MCMC

In [None]:
nwalkers = 10*len(result.x)
sampler_x0 = np.random.normal(result.x, 1e-1, size=(nwalkers,len(result.x)))

In [None]:
with MultiPool() as pool:
    sampler = emcee.EnsembleSampler(nwalkers = nwalkers, 
                                    ndim=len(result.x), 
                                    log_prob_fn=model, 
                                    pool=pool)
    state = sampler.run_mcmc(sampler_x0, 
                             nsteps = 1000, 
                             progress=True)  # burn in
    sampler.reset()
    state = sampler.run_mcmc(state, 
                             nsteps = 2000, 
                             progress=True)

----
## creating list of means of each parameter

In [None]:
everyHundo = np.vstack(sampler.chain[:, ::10])

lil = {}
col_names = ["lnr","eParam", "coseta", "sineta", "lnM", "cosalpha", "sinalpha"]
for i in range(len(col_names)):
    lil[col_names[i]] = everyHundo[:,i]

means = model.whats_this_mean(lil)
means

---
## figures for full sampled set of data

In [None]:
tulips = az.from_emcee(sampler,
                       var_names=["ln r","ln(1-e)", "coseta", "sineta", "ln M", "cosalpha", "sinalpha"])
lookout = az.convert_to_inference_data(model.whats_this(lil))
looks = az.convert_to_dataset(lookout).to_array()

In [None]:
az.plot_trace(tulips);
az.plot_pair(tulips);

In [None]:
fig = corner.corner(np.vstack(sampler.chain[:, ::10]),
                    labels=["ln r","ln(1-e)", "coseta", "sineta", "ln M", "cosalpha", "sinalpha"],
                    show_titles=True)
ii = 0
fig.text(0.8, 0.63,"Means",fontsize=20)
for key, val in means.items():
    try:
        val = val.value
    except AttributeError:
        val
    fig.text(0.8,0.6-0.02*ii,key+"=%.2f" % val,fontsize=20)
    ii+=1
fig.suptitle(title,fontsize=20)
plt.show()

___
## the plots below are for the transformed variables

In [None]:
az.plot_trace(lookout);
az.plot_pair(lookout);

----
## some extra testing stuff

In [None]:
# meanies = np.mean(np.vstack(sampler.chain[:, ::100]), axis=0)
# whats_this_mean(model.unpack_pars(meanies))

In [None]:
# whats_this(par_dict)

In [None]:
# # def whats_this_mean(par_dict):
# #         ''' you can tell that i hard coded this function :) '''
# #         what_dict = {}
# #         what_dict['r'] = np.mean(np.exp(par_dict['lnr']))*u.kpc
# #         what_dict['e'] = np.mean(1 - np.exp(par_dict['eParam']))
# #         etta = np.arctan2(par_dict['sineta'],par_dict['coseta']) # *u.rad
# #         what_dict['eta'] = np.mean(etta%(2*np.pi))
# #         what_dict['M'] = np.mean(np.exp(par_dict['lnM'])*model.unit_system['mass'])
# #         allpha = np.arctan2(par_dict['sinalpha'],par_dict['cosalpha']) # *u.rad
# #         what_dict['alpha'] = np.mean(allpha%(2*np.pi))
# #         return what_dict



# def whats_this(par_dict):
#         ''' you can tell that i hard coded this function :) '''
#         what_dict = {}
#         what_dict['r'] = np.exp(par_dict['lnr'])
#         what_dict['e'] = 1 - np.exp(par_dict['eParam'])
#         etta = np.arctan2(par_dict['sineta'],par_dict['coseta']) # *u.rad
#         what_dict['eta'] = etta%(2*np.pi)
#         what_dict['M'] = np.exp(par_dict['lnM'])
#         allpha = np.arctan2(par_dict['sinalpha'],par_dict['cosalpha']) # *u.rad
#         what_dict['alpha'] = allpha%(2*np.pi)
#         return what_dict

In [None]:
# lil

In [None]:
# whats_this(lil)

In [None]:
# last_pars = model.unpack_pars(sampler.chain[:, -1].T)
# last_Lhatlg = coord.CartesianRepresentation(last_pars['Lhatlg'])
# last_Lhatlg_sph = last_Lhatlg.represent_as(coord.UnitSphericalRepresentation)
# plt.scatter(last_Lhatlg_sph.lon, last_Lhatlg_sph.lat)
# plt.xlim(0, 2*np.pi)
# plt.ylim(-np.pi, np.pi)