In [1]:
# standard library
import os
from os import path
import sys
_path = path.abspath('../pkg/')
if _path not in sys.path:
    sys.path.append(_path)
import pickle

# Third-party
import astropy.coordinates as coord
import astropy.units as u
from astropy.table import Table, join
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
%matplotlib inline
import corner
from schwimmbad import MultiPool, SerialPool

# Custom
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import UnitSystem

import emcee
from pyia import GaiaData

from chemtrails.potential import Sech2Potential, UniformPotential
from chemtrails.likelihood import Model
from chemtrails.data import (load_nominal_galah, load_nominal_apogee, 
                             get_abundance_data, get_label_from_abundancename)

In [2]:
usys = UnitSystem(u.pc, u.Myr, u.Msun, u.radian, u.km/u.s)

# GALAH

In [3]:
g, galcen = load_nominal_galah('/Users/adrian/data/GaiaDR2/GALAH-GaiaDR2-xmatch.fits')
len(g)

21733

### Quick model test:

In [5]:
model = Model(galcen, g, ['fe_h'], 
              frozen_pars=dict(sun_z=0., sun_vz=0., 
                               lnsigma=np.log(65), 
                               lnhz=np.log(250)),
              marginalize=False, metals_deg=3)
model([np.log(0.2**2), 0., 0., -0.1, 0])

-3689.1771340636783

---

## Emcee

In [6]:
p0 = np.array([np.log(65), np.log(250),
               np.log(0.2**2), 
               0., 0., -0.2, 0.05])

nwalkers = 64
ndim = len(p0)
p0 = emcee.utils.sample_ball(p0, std=[1e-2, 1e-2,
                                      1e-3, 
                                      1e-4, 1e-4, 1e-4, 1e-4], size=nwalkers)

In [7]:
all_elems = []
for col in g.data.colnames:
    if (col.endswith('_fe') and not col.startswith('e_') 
            and not col.startswith('flag') and not col.startswith('alpha')):
        col = '{}_h'.format(col.split('_')[0])
        all_elems.append(col)
        
all_elems = ['fe_h', 'alpha_fe'] + sorted(all_elems)

In [16]:
nburn = 1024
nsteps = 1024

# elem = 'fe_h'
for elem in all_elems:
    print(elem)
    mask = np.isfinite(get_abundance_data(g, elem))
    model = Model(galcen[mask], g[mask], [elem],
                  frozen_pars=dict(sun_z=0., sun_vz=0.),
                  marginalize=False, metals_deg=3)
    
    cache_filename = path.join('sampler-unmarginalized-{0}.pkl'.format(elem))
    if not path.exists(cache_filename):
        with SerialPool() as pool:
            sampler = emcee.EnsembleSampler(nwalkers, ndim, model, 
                                            pool=pool)

            print("burn-in")
            pos, prob, state = sampler.run_mcmc(p0.copy(), nburn, progress=True)

            if nsteps > 0:
                print("sampling")
                sampler.reset()
                _ = sampler.run_mcmc(pos, nsteps, progress=True)   
        
        with open(cache_filename, 'wb') as f:
            pickle.dump(sampler, f)
    
    with open(cache_filename, 'rb') as f:
        sampler = pickle.load(f)
        
    Sigma, hz, var = np.exp(sampler.flatchain[:, :3]).T
    fig = corner.corner(np.vstack((Sigma, hz, np.sqrt(var))).T, bins=128,
                        range=[(16, 256), (32, 512), (0.04, 0.5)],
                        labels=[r'$\Sigma$', r'$h_z$', r'$\sigma$'])
    fig.savefig('corner-unmarginalized-{0}.png'.format(elem), dpi=250)
    
    # ---------------
    
    alpha = np.mean(sampler.flatchain[:, 3:], axis=0)
    pot0 = Sech2Potential(Sigma=np.mean(Sigma)*u.Msun/u.pc**2,
                          hz=np.mean(hz)*u.pc, units=usys)
    Ez = (0.5*galcen.v_z**2 + pot0.energy(galcen.z[None])) / 1000.

    mu_func = np.poly1d(alpha)

    fig, ax = plt.subplots(1)
    
    x = np.log(Ez.value)
    x = x - np.mean(x)
    ax.plot(x, get_abundance_data(g, elem), 
            marker='o', ls='none', color='k', 
            alpha=0.25, ms=1, mew=0)
    
    grid = np.linspace(x.min(), x.max(), 128)
    ax.plot(grid, mu_func(grid), marker='')
    
    ax.set_xlabel(r'$\ln E_z-\rm{mean}(\ln E_z)$')
    ax.set_ylabel(get_label_from_abundancename(elem))
    ax.set_xlim(-4, 3)
    ax.set_ylim(-2, 2)
    fig.tight_layout()
    fig.savefig('elem-Ez-unmarginalized-{0}.png'.format(elem), dpi=250)
    
    plt.close('all')

fe_h


  0%|          | 0/1024 [00:00<?, ?it/s]

alpha_fe
burn-in


100%|██████████| 1024/1024 [00:49<00:00, 20.61it/s]
  0%|          | 2/1024 [00:00<00:59, 17.15it/s]

sampling


100%|██████████| 1024/1024 [00:48<00:00, 21.16it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

al_h
burn-in


100%|██████████| 1024/1024 [00:52<00:00, 19.65it/s]
  0%|          | 2/1024 [00:00<00:57, 17.91it/s]

sampling


100%|██████████| 1024/1024 [00:49<00:00, 20.59it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

ba_h
burn-in


100%|██████████| 1024/1024 [00:47<00:00, 21.94it/s]
  0%|          | 2/1024 [00:00<00:57, 17.66it/s]

sampling


100%|██████████| 1024/1024 [00:47<00:00, 22.15it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

c_h
burn-in


100%|██████████| 1024/1024 [00:54<00:00, 18.74it/s]
  0%|          | 2/1024 [00:00<01:02, 16.48it/s]

sampling


100%|██████████| 1024/1024 [00:54<00:00, 18.89it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

ca_h
burn-in


100%|██████████| 1024/1024 [00:58<00:00, 17.58it/s]
  0%|          | 2/1024 [00:00<00:56, 18.21it/s]

sampling


100%|██████████| 1024/1024 [00:49<00:00, 20.59it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

co_h
burn-in


100%|██████████| 1024/1024 [00:51<00:00, 19.29it/s]
  0%|          | 2/1024 [00:00<01:03, 16.16it/s]

sampling


100%|██████████| 1024/1024 [00:50<00:00, 20.29it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

cr_h
burn-in


100%|██████████| 1024/1024 [00:50<00:00, 20.47it/s]
  0%|          | 2/1024 [00:00<01:09, 14.80it/s]

sampling


100%|██████████| 1024/1024 [00:51<00:00, 20.06it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

cu_h
burn-in


100%|██████████| 1024/1024 [00:52<00:00, 19.51it/s]
  0%|          | 2/1024 [00:00<00:59, 17.21it/s]

sampling


100%|██████████| 1024/1024 [00:53<00:00, 19.23it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

eu_h
burn-in


100%|██████████| 1024/1024 [00:52<00:00, 19.59it/s]
  0%|          | 2/1024 [00:00<00:58, 17.43it/s]

sampling


100%|██████████| 1024/1024 [00:49<00:00, 21.42it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

k_h
burn-in


100%|██████████| 1024/1024 [00:48<00:00, 20.96it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

sampling


100%|██████████| 1024/1024 [00:47<00:00, 21.47it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

la_h
burn-in


100%|██████████| 1024/1024 [00:52<00:00, 19.61it/s]
  0%|          | 2/1024 [00:00<00:54, 18.86it/s]

sampling


100%|██████████| 1024/1024 [00:55<00:00, 18.54it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

li_h
burn-in


100%|██████████| 1024/1024 [00:50<00:00, 20.12it/s]
  0%|          | 2/1024 [00:00<00:51, 19.94it/s]

sampling


100%|██████████| 1024/1024 [00:47<00:00, 21.34it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

mg_h
burn-in


100%|██████████| 1024/1024 [00:45<00:00, 22.45it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

sampling


100%|██████████| 1024/1024 [00:44<00:00, 22.47it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

mn_h
burn-in


100%|██████████| 1024/1024 [01:03<00:00, 16.01it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

sampling


100%|██████████| 1024/1024 [00:58<00:00, 16.97it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

na_h
burn-in


100%|██████████| 1024/1024 [00:55<00:00, 19.23it/s]
  0%|          | 2/1024 [00:00<00:55, 18.48it/s]

sampling


100%|██████████| 1024/1024 [00:58<00:00, 16.69it/s]


ni_h
burn-in


100%|██████████| 1024/1024 [01:06<00:00, 17.18it/s]
  0%|          | 2/1024 [00:00<01:03, 16.22it/s]

sampling


100%|██████████| 1024/1024 [01:05<00:00, 16.21it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

o_h
burn-in


100%|██████████| 1024/1024 [00:53<00:00, 20.43it/s]
  0%|          | 3/1024 [00:00<00:44, 22.79it/s]

sampling


100%|██████████| 1024/1024 [00:47<00:00, 22.74it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

sc_h
burn-in


100%|██████████| 1024/1024 [00:52<00:00, 19.47it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

sampling


100%|██████████| 1024/1024 [00:52<00:00, 19.12it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

si_h
burn-in


100%|██████████| 1024/1024 [00:53<00:00, 19.26it/s]
  0%|          | 2/1024 [00:00<00:53, 19.06it/s]

sampling


100%|██████████| 1024/1024 [00:55<00:00, 18.57it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

ti_h
burn-in


100%|██████████| 1024/1024 [00:54<00:00, 18.96it/s]
  0%|          | 2/1024 [00:00<00:53, 19.23it/s]

sampling


100%|██████████| 1024/1024 [00:54<00:00, 18.61it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

v_h
burn-in


100%|██████████| 1024/1024 [00:50<00:00, 20.14it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

sampling


100%|██████████| 1024/1024 [00:50<00:00, 20.32it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

y_h
burn-in


100%|██████████| 1024/1024 [00:55<00:00, 18.56it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

sampling


100%|██████████| 1024/1024 [00:59<00:00, 18.61it/s]
  0%|          | 0/1024 [00:00<?, ?it/s]

zn_h
burn-in


100%|██████████| 1024/1024 [00:54<00:00, 16.80it/s]
  0%|          | 2/1024 [00:00<00:56, 18.03it/s]

sampling


100%|██████████| 1024/1024 [00:55<00:00, 18.51it/s]


In [44]:
# fig, axes = plt.subplots(ndim, 1, figsize=(8, 4*ndim),
#                          sharex=True)
# for k in range(ndim):
#     ax = axes[k]
#     for walker in sampler.chain[..., k]:
#         ax.plot(walker, marker='', drawstyle='steps-mid', 
#                 color='k', alpha=0.2)

In [27]:
lnsigma, lnhz, lnvar, *alpha = np.mean(sampler.flatchain, axis=0)
np.exp(lnsigma), np.exp(lnhz), np.sqrt(np.exp(lnvar))

(56.898550995024742, 89.578923683913203, 0.24710017182021057)