In [None]:
# ------------------------------------------------------------------------
#
# TITLE - fit_stellar_bulge_disk.ipynb
# AUTHOR - James Lane
# PROJECT - ges-mass
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Fit the stellar bulge and disk at z=0
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np
import sys, os, pdb, copy, time, dill as pickle
from tqdm.notebook import tqdm

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu

## Fitting
import emcee
import corner
import multiprocessing
import scipy.optimize

## galpy
from galpy import orbit
from galpy import potential

## Project-specific
sys.path.insert(0,'../../src/')
from tng_dfs import fitting as pfit
from tng_dfs import densprofile as pdens
from tng_dfs import util as putil
from tng_dfs import cutout as pcutout

### Scale parameters

ro = 8.
vo = 220
zo = 0.0208 # Bennett+ 2019

### Notebook setup

%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','RO','VO','ZO','LITTLE_H',
            'MW_MASS_RANGE']
data_dir,mw_analog_dir,ro,vo,zo,h,mw_mass_range = \
    putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
# fig_dir = './fig/stellar_bulge_disk'
epsen_fig_dir = '/epsen_data/scr/lane/projects/tng-dfs/figs/notebooks/'+\
    'fit_density_profiles/stellar_bulge_disk/'
epsen_fitting_dir = '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
    'density_profile/stellar_bulge_disk/'
# os.makedirs(fig_dir,exist_ok=True)
os.makedirs(epsen_fig_dir,exist_ok=True)
os.makedirs(epsen_fitting_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Functions

In [None]:
def plot_RZ_density(chain, R, z, masses, densfunc, other_params):

    fig = plt.figure(figsize=(6,15))
    ax1 = fig.add_subplot(311)
    ax2 = fig.add_subplot(312)
    ax3 = fig.add_subplot(313)

    x_range = [0,25]
    y_range = [-5,5]
    nbins = 20
    vmin, vmax = 4,10
    fontsize=15

    # ax1.scatter(R, z, s=0.1, color='Grey', alpha=0.5)
    # Make the numpy histogram
    H, xedges, yedges = np.histogram2d(R, z, weights=masses, bins=nbins, 
        range=[x_range,y_range])
    _Rs = 0.5*(xedges[1:] + xedges[:-1])
    _zs = 0.5*(yedges[1:] + yedges[:-1])
    Rs, zs = np.meshgrid(_Rs, _zs)
    # Divide the histogram by the volume that the cell represents
    vol = 2*np.pi * Rs * (xedges[1]-xedges[0]) * (yedges[1]-yedges[0])
    H /= vol

    img1 = ax1.imshow(np.log10(H.T), extent=[*x_range,*y_range], origin='lower', 
        cmap='Greys', vmin=vmin, vmax=vmax)
    ax1.set_xlabel(r'$R$ [kpc]')
    ax1.set_ylabel(r'$z$ [kpc]')
    ax1.set_xlim(*x_range)
    ax1.set_ylim(*y_range)
    ax1.set_aspect('auto')
    cbar1 = fig.colorbar(img1, ax=ax1, fraction=0.03, pad=0.03)
    cbar1.set_label(r'$\log_{10}(\rho)$')
    ax1.annotate(r'Data', xy=(0.3,0.95), xycoords='axes fraction', ha='left', 
        va='top', fontsize=fontsize)

    dens = densfunc(Rs.flatten(), 0, zs.flatten(), 
        params=np.median(chain, axis=0))
    dens = dens.reshape((nbins,nbins))
    img2 = ax2.imshow(np.log10(dens), extent=[*x_range,*y_range], origin='lower', 
        cmap='Greys', vmin=vmin, vmax=vmax)
    ax2.set_xlabel(r'$R$ [kpc]')
    ax2.set_ylabel(r'$z$ [kpc]')
    ax2.set_xlim(*x_range)
    ax2.set_ylim(*y_range)
    ax2.set_aspect('auto')
    cbar2 = fig.colorbar(img2, ax=ax2, fraction=0.03, pad=0.03)
    cbar2.set_label(r'$\log_{10}(\rho)$')
    ax2.annotate(r'Median Chain', xy=(0.3,0.95), xycoords='axes fraction', ha='left', 
        va='top', fontsize=fontsize)

    dens = densfunc(Rs.flatten(), 0, zs.flatten(), 
        params=other_params)
    dens = dens.reshape((nbins,nbins))
    img3 = ax3.imshow(np.log10(dens), extent=[*x_range,*y_range], origin='lower', 
        cmap='Greys', vmin=vmin, vmax=vmax)
    ax3.set_xlabel(r'$R$ [kpc]')
    ax3.set_ylabel(r'$z$ [kpc]')
    ax3.set_xlim(*x_range)
    ax3.set_ylim(*y_range)
    ax3.set_aspect('auto')
    cbar3 = fig.colorbar(img3, ax=ax3, fraction=0.03, pad=0.03)
    cbar3.set_label(r'$\log_{10}(\rho)$')

    return fig, [ax1,ax2,ax3]

def plot_density_profiles(chain, R, z, masses, densfunc, Rmin_fit, 
    Rmax_fit, zmax_fit, other_params):

    fig = plt.figure(figsize=(5,10))
    ax1 = fig.add_subplot(211)
    ax2 = fig.add_subplot(212)

    # Compute the density near the midplane
    zmax = 0.1
    nbin = 50
    log_Redges = np.linspace(np.log10(0.1), np.log10(30), nbin+1)
    log_Rcents = 0.5*(log_Redges[1:] + log_Redges[:-1])
    Redges = 10**log_Redges
    Rcents = 10**log_Rcents
    dens = np.zeros(nbin)
    mpdens = np.zeros(nbin)
    for i in range(nbin):
        vol = 2*np.pi*Rcents[i] * (Redges[i+1]-Redges[i]) * (2*zmax)
        mask = (R >= Redges[i]) & (R < Redges[i+1]) & (np.abs(z) < zmax)
        dens[i] = np.sum(masses[mask]) / vol
        mpdens[i] = np.sum(masses[mask & mask]) / vol

    ax1.plot(np.log10(Rcents), np.log10(mpdens), color='Grey', linewidth=4., 
        zorder=2)
    nt = 100
    indx = np.random.choice(len(chain), size=nt, replace=False)
    for k in range(nt):
        ax1.plot(np.log10(Rcents), np.log10(densfunc(Rcents, 0, 0, chain[indx[k]])), 
            color='Red', alpha=0.1, linewidth=1., zorder=3, 
            linestyle='solid')
    ax1.plot(np.log10(Rcents), np.log10(densfunc(Rcents, 0, 0, other_params)),
        color='Blue', linewidth=2., zorder=4, linestyle='solid')
    ax1.set_xlabel(r'$\log_{10}(R/\mathrm{kpc})$')
    ax1.set_ylabel(r'$\log_{10}\rho$')
    ax1.axvline(np.log10(Rmin_fit), color='Black', linestyle='dashed', linewidth=2.)
    ax1.axvline(np.log10(Rmax_fit), color='Black', linestyle='dashed', linewidth=2.)
    ax1.annotate(r'Data ($\vert z \vert < $'+str(zmax)+' kpc)', xy=(0.3,0.95),
        xycoords='axes fraction', ha='left', va='top')

    # Compute the density moving away from the midplane
    Rcents = np.array([1., 8., 16.])
    Rbinsize = 0.5
    nbin = 10
    z_edges = np.linspace(0, 5, nbin+1)
    z_cents = 0.5*(z_edges[1:] + z_edges[:-1])
    dens = np.zeros((len(Rcents),nbin))
    for i in range(len(Rcents)):
        for j in range(len(z_cents)):
            vol = 2*np.pi*Rcents[i]*Rbinsize * 2*(z_edges[j+1]-z_edges[j])
            mask = (R >= Rcents[i]-Rbinsize/2) & (R < Rcents[i]+Rbinsize/2) & \
                (np.abs(z) >= z_edges[j]) & (np.abs(z) < z_edges[j+1])
            dens[i,j] = np.sum(masses[mask]) / vol

    nt = 50
    indx = np.random.choice(len(chain), size=nt, replace=False)
    for i in range(len(Rcents)):
        ax2.plot(z_cents, np.log10(dens[i,:]), color='Grey', linewidth=4., 
            zorder=2)
        for k in range(nt):
            _Rcents = np.ones_like(z_cents) * Rcents[i]
            ax2.plot(z_cents, np.log10(densfunc(_Rcents, 0, z_cents, 
                chain[indx[k]])), color='Red', alpha=0.1, linewidth=1., zorder=3, 
                linestyle='solid')
        ax2.plot(z_cents, np.log10(densfunc(_Rcents, 0, z_cents, opt.x)), 
            color='Blue', linewidth=2., zorder=4, linestyle='solid')

    ax2.set_xlabel(r'$\vert z \vert$ [kpc]')
    ax2.set_ylabel(r'$\log_{10}\rho$')
    ax2.set_xlim(0,5)
    ax2.axvline(zmax_fit, color='Black', linestyle='dashed', linewidth=2.)
    ax2.annotate(r'$R_{cents} = '+str(Rcents)+'$\n $\Delta R = '+str(Rbinsize)+'$ [kpc]',
        xy=(0.5,0.95), fontsize=15, xycoords='axes fraction', ha='left', va='top')
    
    return fig, [ax1,ax2]

def plot_projected_density(orbs, masses, chain, comp_densfunc, other_params):
    fig = plt.figure(figsize=(15,10))
    axs = fig.subplots(2,3).T.flatten()

    x, y, z = orbs.x().to(apu.kpc).value, orbs.y().to(apu.kpc).value, orbs.z().to(apu.kpc).value

    drange = [-25,25]
    nbins = 50
    vmin, vmax = 4,10

    # Plot face 2d hist and then edge on 2d hist
    N, xedges, yedges = np.histogram2d(x, y, bins=50, weights=masses, range=[drange,drange])
    area = (xedges[1]-xedges[0]) * (yedges[1]-yedges[0])
    N /= area
    axs[0].imshow(np.log10(N.T), origin='lower', 
        cmap=plt.cm.gray_r, vmin=vmin, vmax=vmax, 
        extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])

    N, xedges, zedges = np.histogram2d(x, z, bins=50, weights=masses, range=[drange,drange])
    area = (xedges[1]-xedges[0]) * (zedges[1]-zedges[0])
    N /= area
    axs[1].imshow(np.log10(N.T), origin='lower', 
        cmap=plt.cm.gray_r, vmin=vmin, vmax=vmax, 
        extent=[xedges[0], xedges[-1], zedges[0], zedges[-1]])

    # Make a grid in x,y,z to evaluate the density
    nbins = 50
    x_range = [-25,25]
    y_range = [-25,25]
    z_range = [-25,25]
    xedges = np.linspace(*x_range, nbins+1)
    yedges = np.linspace(*y_range, nbins+1)
    zedges = np.linspace(*z_range, nbins+1)
    xcents = 0.5*(xedges[1:] + xedges[:-1])
    ycents = 0.5*(yedges[1:] + yedges[:-1])
    zcents = 0.5*(zedges[1:] + zedges[:-1])
    xs, ys, zs = np.meshgrid(xcents, ycents, zcents)

    # Evaluate the density
    Rs = np.sqrt(xs**2 + ys**2)
    phis = np.arctan2(ys, xs)
    dens = comp_densfunc(Rs.flatten(), phis.flatten(), zs.flatten(),
        params=np.median(chain, axis=0))
    dens = dens.reshape((nbins,nbins,nbins))

    # Plot the density
    xydens = np.sum(dens, axis=2)*(zedges[1]-zedges[0])
    xzdens = np.sum(dens, axis=1)*(yedges[1]-yedges[0])
    axs[2].imshow(np.log10(xydens).T, origin='lower', cmap=plt.cm.gray_r, 
        vmin=vmin, vmax=vmax, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])
    axs[3].imshow(np.log10(xzdens).T, origin='lower', cmap=plt.cm.gray_r, 
        vmin=vmin, vmax=vmax, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])

    # Evaluate the opt density
    odens = comp_densfunc(Rs.flatten(), phis.flatten(), zs.flatten(),
        params=other_params)
    odens = dens.reshape((nbins,nbins,nbins))

    # Plot the density
    xyodens = np.sum(odens, axis=2)*(zedges[1]-zedges[0])
    xzodens = np.sum(odens, axis=1)*(yedges[1]-yedges[0])
    axs[4].imshow(np.log10(xyodens).T, origin='lower', cmap=plt.cm.gray_r, 
        vmin=vmin, vmax=vmax, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])
    axs[5].imshow(np.log10(xzodens).T, origin='lower', cmap=plt.cm.gray_r, 
        vmin=vmin, vmax=vmax, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])

    axs[0].set_xlabel('X [kpc]')
    axs[0].set_ylabel('Y [kpc]')
    axs[1].set_xlabel('X [kpc]')
    axs[1].set_ylabel('Z [kpc]')
    axs[2].set_xlabel('X [kpc]')
    axs[2].set_ylabel('Y [kpc]')
    axs[3].set_xlabel('X [kpc]')
    axs[3].set_ylabel('Z [kpc]')
    axs[4].set_xlabel('X [kpc]')
    axs[4].set_ylabel('Y [kpc]')
    axs[5].set_xlabel('X [kpc]')
    axs[5].set_ylabel('Z [kpc]')

    axs[0].annotate(r'Data', xy=(0.05,0.95), xycoords='axes fraction', 
        ha='left', va='top')
    axs[2].annotate(r'Median Chain', xy=(0.05,0.95), xycoords='axes fraction',
        ha='left', va='top')

    fig.tight_layout()

    return fig, axs

### Fit the data with the Poisson likelihood and bulge/disk model

In [None]:
def usr_log_prior_ded_spcs(densfunc, params, rmax):
    '''
    '''
    assert isinstance(densfunc, pdens.CompositeDensityProfile)
    assert isinstance(densfunc.densprofiles[0], pdens.DoubleExponentialDisk)
    assert isinstance(densfunc.densprofiles[1], pdens.SinglePowerCutoffSpherical)
    # Unpack params
    hR, hz, damp = densfunc.densprofiles[0]._parse_params(params[:3])
    alpha, rc, bamp = densfunc.densprofiles[1]._parse_params(params[3:])
    # Check that a, hR, hz is not outside of the maximum radius
    if alpha >= 3.:
        return -np.inf
    if rc > rmax:
        return -np.inf
    elif hR > rmax:
        return -np.inf
    elif hz > rmax:
        return -np.inf
    else:
        return 0.

disk_densfunc = pdens.DoubleExponentialDisk()
bulge_densfunc = pdens.SinglePowerCutoffSpherical()
comp_densfunc = pdens.CompositeDensityProfile([disk_densfunc, bulge_densfunc])
verbose = True
show_plots = False
apply_fitting_mask = False
nwalkers = 50
nit = 4000
ncut = 2000
nprocs = 15

In [None]:
version_prefix = 'Enorm_JzJcirc_cut_dexp_disk_pswc_bulge'

for i in range(n_mw):
    # if i > 2: continue
    
    if verbose: print(f'Fitting stellar distribution of analog {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    print('z=0 SID: ', z0_sid)
    # n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()

    vels = co.get_velocities('PartType4', physical=True)
    pot = co.get_potential_energy('PartType4', physical=True)
    kin = 0.5*np.sum(np.square(vels),axis=1)
    E = (pot+kin)
    J,Jz,Jp = co.get_J_Jz_Jp('PartType4',physical=True)
    co.get_E_Jcirc_spline('PartType4',angmom='J')
    Jcirc = co.Jcirc(E)
    Enorm = E/np.abs(E).max()
    Jz_Jcirc = Jz / Jcirc

    Jz_Jcirc_spheroid_bound = 0.5
    Jz_Jcirc_disk_bound = 0.8
    Enorm_bulge_bound = -0.75

    bulge_mask = (Jz_Jcirc < Jz_Jcirc_spheroid_bound) &\
                (Enorm < Enorm_bulge_bound)
    halo_mask = (Jz_Jcirc < Jz_Jcirc_spheroid_bound) &\
                (Enorm > Enorm_bulge_bound)
    thin_mask = (Jz_Jcirc > Jz_Jcirc_disk_bound)
    thick_mask = (Jz_Jcirc < Jz_Jcirc_disk_bound) &\
                (Jz_Jcirc > Jz_Jcirc_spheroid_bound)

    _orbs = co.get_orbs('stars')
    orbs = _orbs[~halo_mask]
    _masses = co.get_masses('stars').to(apu.Msun).value
    masses = _masses[~halo_mask]
    r, R, phi, z = orbs.r().to(apu.kpc).value, orbs.R().to(apu.kpc).value, \
        orbs.phi().value, orbs.z().to(apu.kpc).value
    npart = len(r)

    # Randomly select 1e4 points
    npts = 1e4
    inds = np.random.choice(np.arange(len(masses),dtype=int), size=int(npts),
        replace=False)
    fac = len(masses)/npts
    r = r[inds]
    R = R[inds]
    phi = phi[inds]
    z = z[inds]
    orbs = orbs[inds]
    masses = masses[inds]
    masses *= fac # Increase the mass to compensate for the decrease in points

    # Effective volume and data parameters
    rmin = np.min(r)
    rmax = np.max(r)
    zmax = np.max(np.abs(z))
    effvol_params = [0, np.inf, np.inf]
    usr_log_prior_params = [rmax,]

    # Remove data outside the fitting range if requested
    Rmin_fit = 0.3
    Rmax_fit = 15.0
    zmax_fit = 5.0
    fitting_mask = (r >= Rmin_fit) & (r < Rmax_fit) & (np.abs(z) < zmax_fit)
    if apply_fitting_mask:
        r_fit = r[fitting_mask]
        R_fit = R[fitting_mask]
        phi_fit = phi[fitting_mask]
        z_fit = z[fitting_mask]
        masses_fit = masses[fitting_mask]
        effvol_params = [Rmin_fit, Rmax_fit, zmax_fit]
        usr_log_prior_params = [Rmax_fit,]
    else:
        r_fit = copy.deepcopy(r)
        R_fit = copy.deepcopy(R)
        phi_fit = copy.deepcopy(phi)
        z_fit = copy.deepcopy(z)
        masses_fit = copy.deepcopy(masses)
    
    # Guesses for initial parameters
    _hR = 4*apu.kpc
    _hz = 1*apu.kpc
    _alpha = 1.
    _rc = 1*apu.kpc
    _amp = np.sum(masses) / comp_densfunc.mass(r=np.inf, 
        params=[_hR,_hz,1.,_alpha,_rc,1.], zmax=np.inf)
    _amp_bulge = 1e9 # _amp/6
    _amp_disk = 1e9 # 5*_amp/6
    _init = [_hR.value, _hz.value, _amp_disk, _alpha, _rc.value, _amp_bulge]
    # init = np.array([_init+0.1*np.random.randn(len(_init)) for i in range(nwalkers)])

    # Optimize to get a starting point
    print('Optimizing...')
    opt_fn = lambda params: pfit.mloglike_dens(params, comp_densfunc, R, 
        phi, z, mass=masses, usr_log_prior=usr_log_prior_ded_spcs, 
        usr_log_prior_params=usr_log_prior_params, 
        effvol_params=effvol_params)
    opt = scipy.optimize.minimize(opt_fn, _init, method='Nelder-Mead',
        options={'maxiter':3000,})
    print(opt)

    # MCMC
    # Make the likelihood function
    def llfunc(params):
        return pfit.loglike_dens(params, comp_densfunc, R, phi, z,
            mass=masses, effvol_params=effvol_params, 
            usr_log_prior=usr_log_prior_ded_spcs, 
            usr_log_prior_params=usr_log_prior_params)
    mcmc_init = np.array([
        opt.x+0.1*np.random.randn(len(opt.x)) for i in range(nwalkers)
        ])
    with multiprocessing.Pool(processes=nprocs) as pool:
        sampler = emcee.EnsembleSampler(nwalkers, len(mcmc_init[0]), llfunc, 
            pool=pool)
        sampler.run_mcmc(mcmc_init, nit, progress=True)
    _chain = sampler.get_chain(flat=True, discard=ncut)
    chain = copy.deepcopy(_chain)
    chain[:,2] = np.log10(chain[:,2])
    chain[:,5] = np.log10(chain[:,5])

    # Calculate masses for the whole chain
    ms = np.zeros(len(chain))
    for k in range(len(chain)):
        ms[k] = comp_densfunc.effective_volume(_chain[k,:], *[0, np.inf, np.inf])
    chain = np.append(chain, np.log10(ms[:,None]), axis=1)
    mtot_bd = np.sum(masses)
    mtot_all = np.sum(_masses)

    # Save the results
    this_fitting_dir = os.path.join(epsen_fitting_dir,version_prefix,str(z0_sid))
    os.makedirs(this_fitting_dir,exist_ok=True)
    opt_filename = os.path.join(this_fitting_dir,'opt.pkl')
    with open(opt_filename,'wb') as handle:
        pickle.dump(opt,handle)
    sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
    with open(sampler_filename,'wb') as handle:
        pickle.dump(sampler,handle)
    chain_filename = os.path.join(this_fitting_dir,'chain.pkl')
    with open(chain_filename,'wb') as handle:
        pickle.dump([chain,['hR','hz','log10(amp_disk)','alpha','rc','log10(amp)','mtot']],handle)

    # Plot the results
    this_fig_dir = os.path.join(epsen_fig_dir,str(z0_sid))
    os.makedirs(this_fig_dir,exist_ok=True)

    # Density profile
    fig, axes = plot_density_profiles(_chain, R, z, masses, comp_densfunc,
        Rmin_fit, Rmax_fit, zmax_fit, opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_density_profile.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # R-z density profiles
    fig, axes = plot_RZ_density(_chain, R, z, masses, comp_densfunc, opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_Rz_density.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # Projected densities
    fig, axes = plot_projected_density(orbs, masses, _chain, comp_densfunc, 
        opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_projected_density.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # Corner plot
    fig = corner.corner(chain,
        labels=[r'$h_R$',r'$h_z$',r'$\log_{10}$(disk amp)',r'$\alpha$',
                r'$r_c$',r'$\log_{10}$(bulge amp)',r'$\log_{10}$(Mass)'],
        truths = [None,]*6+[np.log10(mtot_bd),],
        truth_color='Red', quantiles=[0.16,0.5,0.84])
    corner.overplot_lines(fig, [None,]*6+[np.log10(mtot_all),], 
        color='Red', linestyle='dashed')
    corner.overplot_lines(fig, np.median(chain, axis=0), color='Black')
    figname = os.path.join(this_fig_dir,version_prefix+'_corner.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

In [None]:
version_prefix = 'dexp_disk_pswc_bulge'

for i in range(n_mw):
    if i > 2: continue
    
    if verbose: print(f'Fitting stellar distribution of analog {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    print('z=0 SID: ', z0_sid)
    # n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()

    orbs = co.get_orbs('stars')
    masses = co.get_masses('stars').to(apu.Msun).value
    r, R, phi, z = orbs.r().to(apu.kpc).value, orbs.R().to(apu.kpc).value, \
        orbs.phi().value, orbs.z().to(apu.kpc).value
    npart = len(r)

    # Randomly select 1e4 points
    npts = 1e4
    inds = np.random.choice(np.arange(len(masses),dtype=int), size=int(npts),
        replace=False)
    fac = len(masses)/npts
    r = r[inds]
    R = R[inds]
    phi = phi[inds]
    z = z[inds]
    orbs = orbs[inds]
    masses = masses[inds]
    masses *= fac # Increase the mass to compensate for the decrease in points

    # Effective volume and data parameters
    rmin = np.min(r)
    rmax = np.max(r)
    zmax = np.max(np.abs(z))
    effvol_params = [0, np.inf, np.inf]
    usr_log_prior_params = [rmax,]

    # Remove data outside the fitting range if requested
    Rmin_fit = 0.3
    Rmax_fit = 15.0
    zmax_fit = 5.0
    fitting_mask = (r >= Rmin_fit) & (r < Rmax_fit) & (np.abs(z) < zmax_fit)
    if apply_fitting_mask:
        r_fit = r[fitting_mask]
        R_fit = R[fitting_mask]
        phi_fit = phi[fitting_mask]
        z_fit = z[fitting_mask]
        masses_fit = masses[fitting_mask]
        effvol_params = [Rmin_fit, Rmax_fit, zmax_fit]
        usr_log_prior_params = [Rmax_fit,]
    else:
        r_fit = copy.deepcopy(r)
        R_fit = copy.deepcopy(R)
        phi_fit = copy.deepcopy(phi)
        z_fit = copy.deepcopy(z)
        masses_fit = copy.deepcopy(masses)
    
    # Guesses for initial parameters
    _hR = 4*apu.kpc
    _hz = 1*apu.kpc
    _alpha = 1.
    _rc = 1*apu.kpc
    _amp = np.sum(masses) / comp_densfunc.mass(r=np.inf, 
        params=[_hR,_hz,1.,_alpha,_rc,1.], zmax=np.inf)
    _amp_bulge = 1e9 # _amp/6
    _amp_disk = 1e9 # 5*_amp/6
    _init = [_hR.value, _hz.value, _amp_disk, _alpha, _rc.value, _amp_bulge]
    # init = np.array([_init+0.1*np.random.randn(len(_init)) for i in range(nwalkers)])

    # Optimize to get a starting point
    print('Optimizing...')
    opt_fn = lambda params: pfit.mloglike_dens(params, comp_densfunc, R, 
        phi, z, mass=masses, usr_log_prior=usr_log_prior_ded_spcs, 
        usr_log_prior_params=usr_log_prior_params, 
        effvol_params=effvol_params)
    opt = scipy.optimize.minimize(opt_fn, _init, method='Nelder-Mead',
        options={'maxiter':3000,})
    print(opt)

    # MCMC
    # Make the likelihood function
    def llfunc(params):
        return pfit.loglike_dens(params, comp_densfunc, R, phi, z,
            mass=masses, effvol_params=effvol_params, 
            usr_log_prior=usr_log_prior_ded_spcs, 
            usr_log_prior_params=usr_log_prior_params)
    mcmc_init = np.array([
        opt.x+0.1*np.random.randn(len(opt.x)) for i in range(nwalkers)
        ])
    with multiprocessing.Pool(processes=nprocs) as pool:
        sampler = emcee.EnsembleSampler(nwalkers, len(mcmc_init[0]), llfunc, 
            pool=pool)
        sampler.run_mcmc(mcmc_init, nit, progress=True)
    _chain = sampler.get_chain(flat=True, discard=ncut)
    chain = copy.deepcopy(_chain)
    chain[:,2] = np.log10(chain[:,2])
    chain[:,5] = np.log10(chain[:,5])

    # Calculate masses for the whole chain
    ms = np.zeros(len(chain))
    for k in range(len(chain)):
        ms[k] = comp_densfunc.effective_volume(_chain[k,:], *[0, np.inf, np.inf])
    chain = np.append(chain, np.log10(ms[:,None]), axis=1)
    mtot_bd = np.sum(masses)
    mtot_all = np.sum(_masses)

    # Save the results
    this_fitting_dir = os.path.join(epsen_fitting_dir,version_prefix,str(z0_sid))
    os.makedirs(this_fitting_dir,exist_ok=True)
    opt_filename = os.path.join(this_fitting_dir,'opt.pkl')
    with open(opt_filename,'wb') as handle:
        pickle.dump(opt,handle)
    sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
    with open(sampler_filename,'wb') as handle:
        pickle.dump(sampler,handle)
    chain_filename = os.path.join(this_fitting_dir,'chain.pkl')
    with open(chain_filename,'wb') as handle:
        pickle.dump([chain,['hR','hz','log10(amp_disk)','alpha','rc','log10(amp)','mtot']],handle)

    # Plot the results
    this_fig_dir = os.path.join(epsen_fig_dir,str(z0_sid))
    os.makedirs(this_fig_dir,exist_ok=True)

    # Density profile
    fig, axes = plot_density_profiles(_chain, R, z, masses, comp_densfunc,
        Rmin_fit, Rmax_fit, zmax_fit, opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_density_profile.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # R-z density profiles
    fig, axes = plot_RZ_density(_chain, R, z, masses, comp_densfunc, opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_Rz_density.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # Projected densities
    fig, axes = plot_projected_density(orbs, masses, _chain, comp_densfunc, 
        opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_projected_density.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # Corner plot
    fig = corner.corner(chain,
        labels=[r'$h_R$',r'$h_z$',r'$\log_{10}$(disk amp)',r'$\alpha$',
                r'$r_c$',r'$\log_{10}$(bulge amp)',r'$\log_{10}$(Mass)'],
        truths = [None,]*6+[np.log10(mtot_bd),],
        truth_color='Red', quantiles=[0.16,0.5,0.84])
    corner.overplot_lines(fig, [None,]*6+[np.log10(mtot_all),], 
        color='Red', linestyle='dashed')
    corner.overplot_lines(fig, np.median(chain, axis=0), color='Black')
    figname = os.path.join(this_fig_dir,version_prefix+'_corner.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

### Fit the data with the Poisson likelihood and bulge/disk + halo model

In [None]:
def usr_log_prior_ded_spcs_tps(densfunc, params, rmax):
    '''
    '''
    assert isinstance(densfunc, pdens.CompositeDensityProfile)
    assert isinstance(densfunc.densprofiles[0], pdens.DoubleExponentialDisk)
    assert isinstance(densfunc.densprofiles[1], pdens.SinglePowerCutoffSpherical)
    assert isinstance(densfunc.densprofiles[2], pdens.TwoPowerSpherical)
    # Unpack params
    disk_hR, disk_hz, disk_amp = densfunc.densprofiles[0]._parse_params(params[:3])
    bulge_alpha, bulge_rc, bulge_amp = densfunc.densprofiles[1]._parse_params(params[3:6])
    halo_alpha, halo_beta, halo_a, halo_amp = densfunc.densprofiles[2]._parse_params(params[6:])
    # Check that a, hR, hz is not outside of the maximum radius
    if disk_hR > rmax:
        return -np.inf
    if disk_hz > rmax:
        return -np.inf
    if bulge_alpha >= 3.:
        return -np.inf
    if bulge_rc > rmax:
        return -np.inf
    if halo_alpha >= 3.:
        return -np.inf
    return 0.

disk_densfunc = pdens.DoubleExponentialDisk()
bulge_densfunc = pdens.SinglePowerCutoffSpherical()
halo_densfunc = pdens.TwoPowerSpherical()
comp_densfunc = pdens.CompositeDensityProfile([disk_densfunc, bulge_densfunc, halo_densfunc])
verbose = True
show_plots = False
apply_fitting_mask = False
nwalkers = 50
nit = 10000
ncut = 5000
nprocs = 15

In [None]:
version_prefix = 'dexp_disk_pswc_bulge_tps_halo'

for i in range(n_mw):
    # if i > 0: continue
    
    if verbose: print(f'Fitting stellar distribution of analog {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    print('z=0 SID: ', z0_sid)
    # n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()

    orbs = co.get_orbs('stars')
    masses = co.get_masses('stars').to(apu.Msun).value
    r, R, phi, z = orbs.r().to(apu.kpc).value, orbs.R().to(apu.kpc).value, \
        orbs.phi().value, orbs.z().to(apu.kpc).value
    npart = len(r)

    # Randomly select 1e4 points
    npts = 1e4
    inds = np.random.choice(np.arange(len(masses),dtype=int), size=int(npts),
        replace=False)
    fac = len(masses)/npts
    r = r[inds]
    R = R[inds]
    phi = phi[inds]
    z = z[inds]
    orbs = orbs[inds]
    masses = masses[inds]
    masses *= fac # Increase the mass to compensate for the decrease in points

    # Effective volume and data parameters
    rmin = np.min(r)
    rmax = np.max(r)
    zmax = np.max(np.abs(z))
    effvol_params = [0, 1e6, 1e6] # [Rmin, Rmax, zmax]
    usr_log_prior_params = [rmax,]

    # Remove data outside the fitting range if requested
    Rmin_fit = 0.
    Rmax_fit = 200.
    zmax_fit = 200.
    fitting_mask = (R >= Rmin_fit) & (R < Rmax_fit) # & (np.abs(z) < zmax_fit)
    if apply_fitting_mask:
        r_fit = r[fitting_mask]
        R_fit = R[fitting_mask]
        phi_fit = phi[fitting_mask]
        z_fit = z[fitting_mask]
        masses_fit = masses[fitting_mask]
        effvol_params = [Rmin_fit, Rmax_fit, zmax_fit]
        usr_log_prior_params = [Rmax_fit,]
    else:
        r_fit = copy.deepcopy(r)
        R_fit = copy.deepcopy(R)
        phi_fit = copy.deepcopy(phi)
        z_fit = copy.deepcopy(z)
        masses_fit = copy.deepcopy(masses)
    
    # Guesses for initial parameters
    _hR = 4*apu.kpc
    _hz = 1*apu.kpc
    _alpha_bulge = 1.
    _rc_bulge = 1*apu.kpc
    _alpha_halo = 1.
    _beta_halo = 4.
    _a_halo = 15*apu.kpc
    _amp = np.sum(masses) / comp_densfunc.mass(r=np.inf, 
        params=[_hR,_hz,1.,_alpha_bulge,_rc_bulge,1.,_alpha_halo,_beta_halo,
                _a_halo,1.], zmax=np.inf)
    _amp_bulge = 1e9 # _amp/6
    _amp_disk = 1e8 # 5*_amp/6
    _amp_halo = 1e5 # 
    _init = [_hR.value, _hz.value, _amp_disk, _alpha_bulge, _rc_bulge.value, 
        _amp_bulge, _alpha_halo, _beta_halo, _a_halo.value, _amp_halo]
    # init = np.array([_init+0.1*np.random.randn(len(_init)) for i in range(nwalkers)])

    # Optimize to get a starting point
    print('Optimizing...')
    opt_fn = lambda params: pfit.mloglike_dens(params, comp_densfunc, R, 
        phi, z, mass=masses, usr_log_prior=usr_log_prior_ded_spcs_tps, 
        usr_log_prior_params=usr_log_prior_params, 
        effvol_params=effvol_params)
    opt = scipy.optimize.minimize(opt_fn, _init, method='Nelder-Mead',
        options={'maxiter':3000,})
    print(opt)

    # continue

    # MCMC
    # Make the likelihood function
    def llfunc(params):
        return pfit.loglike_dens(params, comp_densfunc, R, phi, z,
            mass=masses, effvol_params=effvol_params, 
            usr_log_prior=usr_log_prior_ded_spcs_tps, 
            usr_log_prior_params=usr_log_prior_params)
    mcmc_init = np.array([
        opt.x+0.1*np.random.randn(len(opt.x)) for i in range(nwalkers)
        ])
    with multiprocessing.Pool(processes=nprocs) as pool:
        sampler = emcee.EnsembleSampler(nwalkers, len(mcmc_init[0]), llfunc, 
            pool=pool)
        sampler.run_mcmc(mcmc_init, nit, progress=True)
    _chain = sampler.get_chain(flat=True, discard=ncut)
    chain = copy.deepcopy(_chain)
    chain[:,2] = np.log10(chain[:,2])
    chain[:,5] = np.log10(chain[:,5])
    chain[:,9] = np.log10(chain[:,9])

    # Calculate masses for the whole chain
    ms = np.zeros(len(chain))
    for k in range(len(chain)):
        ms[k] = comp_densfunc.effective_volume(_chain[k,:], *[0, 1e6, 1e6])
    chain = np.append(chain, np.log10(ms[:,None]), axis=1)
    mtot_all = np.sum(masses)

    # Save the results
    this_fitting_dir = os.path.join(epsen_fitting_dir,version_prefix,str(z0_sid))
    os.makedirs(this_fitting_dir,exist_ok=True)
    opt_filename = os.path.join(this_fitting_dir,'opt.pkl')
    with open(opt_filename,'wb') as handle:
        pickle.dump(opt,handle)
    sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
    with open(sampler_filename,'wb') as handle:
        pickle.dump(sampler,handle)
    chain_filename = os.path.join(this_fitting_dir,'chain.pkl')
    with open(chain_filename,'wb') as handle:
        pickle.dump([chain,['hR','hz','log10(amp_disk)','alpha','rc','log10(amp)','mtot']],handle)

    # Plot the results
    this_fig_dir = os.path.join(epsen_fig_dir,str(z0_sid))
    os.makedirs(this_fig_dir,exist_ok=True)

    # Density profile
    fig, axes = plot_density_profiles(_chain, R, z, masses, comp_densfunc,
        Rmin_fit, Rmax_fit, zmax_fit, opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_density_profile.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # R-z density profiles
    fig, axes = plot_RZ_density(_chain, R, z, masses, comp_densfunc, opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_Rz_density.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # Projected densities
    fig, axes = plot_projected_density(orbs, masses, _chain, comp_densfunc, 
        opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_projected_density.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # Corner plot
    fig = corner.corner(chain,
        labels=[r'$h_R$',r'$h_z$',r'$\log_{10}$(disk amp)',
                r'$\alpha$',r'$r_c$',r'$\log_{10}$(bulge amp)',
                r'$\alpha$',r'$\beta$',r'$a$',r'$\log_{10}$(halo amp)',
                r'$\log_{10}$(Mass)'],
        truths = [None,]*10+[np.log10(mtot_all),],
        truth_color='Red', quantiles=[0.16,0.5,0.84])
    # corner.overplot_lines(fig, [None,]*10+[np.log10(mtot_all),], 
    #     color='Red', linestyle='dashed')
    corner.overplot_lines(fig, np.median(chain, axis=0), color='Black')
    figname = os.path.join(this_fig_dir,version_prefix+'_corner.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

### Fit using a Miyamoto-Nagai disk instead of double exponential

In [None]:
def usr_log_prior_mn_spcs_tps(densfunc, params, rmax):
    '''
    '''
    assert isinstance(densfunc, pdens.CompositeDensityProfile)
    assert isinstance(densfunc.densprofiles[0], pdens.MiyamotoNagaiDisk)
    assert isinstance(densfunc.densprofiles[1], pdens.SinglePowerCutoffSpherical)
    assert isinstance(densfunc.densprofiles[2], pdens.TwoPowerSpherical)
    # Unpack params
    disk_a, disk_b, disk_amp = densfunc.densprofiles[0]._parse_params(params[:3])
    bulge_alpha, bulge_rc, bulge_amp = densfunc.densprofiles[1]._parse_params(params[3:6])
    halo_alpha, halo_beta, halo_a, halo_amp = densfunc.densprofiles[2]._parse_params(params[6:])
    # Check that a, hR, hz is not outside of the maximum radius
    if disk_a > rmax:
        return -np.inf
    if disk_b > rmax:
        return -np.inf
    if bulge_alpha >= 3.:
        return -np.inf
    if bulge_rc > rmax:
        return -np.inf
    if halo_alpha >= 3.:
        return -np.inf
    if halo_a > rmax:
        return -np.inf
    return 0.

disk_densfunc = pdens.MiyamotoNagaiDisk()
bulge_densfunc = pdens.SinglePowerCutoffSpherical()
halo_densfunc = pdens.TwoPowerSpherical()
comp_densfunc = pdens.CompositeDensityProfile([disk_densfunc, bulge_densfunc, halo_densfunc])
verbose = True
show_plots = False
apply_fitting_mask = False
nwalkers = 50
nit = 10000
ncut = 5000
nprocs = 15

In [None]:
version_prefix = 'miyamoto_disk_pswc_bulge_tps_halo'
force_fits = False

for i in range(n_mw):
    # if i > 5: continue
    
    if verbose: print(f'Fitting stellar distribution of analog {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    # print('z=0 SID: ', z0_sid)

    this_fitting_dir = os.path.join(epsen_fitting_dir,version_prefix,str(z0_sid))
    if os.path.exists(os.path.join(this_fitting_dir,'sampler.pkl')) and not force_fits:
        print('Skipping fit for z=0 SID: ',str(z0_sid))
        continue
    
    # n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()

    orbs = co.get_orbs('stars')
    masses = co.get_masses('stars').to(apu.Msun).value
    r, R, phi, z = orbs.r().to(apu.kpc).value, orbs.R().to(apu.kpc).value, \
        orbs.phi().value, orbs.z().to(apu.kpc).value
    npart = len(r)

    # Randomly select 1e4 points
    npts = 1e4
    inds = np.random.choice(np.arange(len(masses),dtype=int), size=int(npts),
        replace=False)
    fac = len(masses)/npts
    r = r[inds]
    R = R[inds]
    phi = phi[inds]
    z = z[inds]
    orbs = orbs[inds]
    masses = masses[inds]
    masses *= fac # Increase the mass to compensate for the decrease in points

    # Effective volume and data parameters
    rmin = np.min(r)
    rmax = np.max(r)
    zmax = np.max(np.abs(z))
    effvol_params = [0, 1e6, 1e6] # [Rmin, Rmax, zmax]
    usr_log_prior_params = [rmax,]

    # Remove data outside the fitting range if requested
    Rmin_fit = 0.
    Rmax_fit = 200.
    zmax_fit = 200.
    fitting_mask = (R >= Rmin_fit) & (R < Rmax_fit) # & (np.abs(z) < zmax_fit)
    if apply_fitting_mask:
        r_fit = r[fitting_mask]
        R_fit = R[fitting_mask]
        phi_fit = phi[fitting_mask]
        z_fit = z[fitting_mask]
        masses_fit = masses[fitting_mask]
        effvol_params = [Rmin_fit, Rmax_fit, zmax_fit]
        usr_log_prior_params = [Rmax_fit,]
    else:
        r_fit = copy.deepcopy(r)
        R_fit = copy.deepcopy(R)
        phi_fit = copy.deepcopy(phi)
        z_fit = copy.deepcopy(z)
        masses_fit = copy.deepcopy(masses)
    
    # Guesses for initial parameters
    _hR = 4*apu.kpc
    _hz = 1*apu.kpc
    _alpha_bulge = 1.
    _rc_bulge = 1*apu.kpc
    _alpha_halo = 1.
    _beta_halo = 4.
    _a_halo = 15*apu.kpc
    _amp = np.sum(masses) / comp_densfunc.mass(r=np.inf, 
        params=[_hR,_hz,1.,_alpha_bulge,_rc_bulge,1.,_alpha_halo,_beta_halo,
                _a_halo,1.], zmax=np.inf)
    _amp_bulge = 1e9 # _amp/6
    _amp_disk = 5e10 # amplitude in Msol
    _amp_halo = 1e5 # 
    _init = [_hR.value, _hz.value, _amp_disk, _alpha_bulge, _rc_bulge.value, 
        _amp_bulge, _alpha_halo, _beta_halo, _a_halo.value, _amp_halo]
    # init = np.array([_init+0.1*np.random.randn(len(_init)) for i in range(nwalkers)])

    # Optimize to get a starting point
    print('Optimizing...')
    opt_fn = lambda params: pfit.mloglike_dens(params, comp_densfunc, R, 
        phi, z, mass=masses, usr_log_prior=usr_log_prior_mn_spcs_tps, 
        usr_log_prior_params=usr_log_prior_params, 
        effvol_params=effvol_params)
    opt = scipy.optimize.minimize(opt_fn, _init, method='Nelder-Mead',
        options={'maxiter':2000,})
    print(opt)

    # MCMC
    # Make the likelihood function
    def llfunc(params):
        return pfit.loglike_dens(params, comp_densfunc, R, phi, z,
            mass=masses, effvol_params=effvol_params, 
            usr_log_prior=usr_log_prior_mn_spcs_tps, 
            usr_log_prior_params=usr_log_prior_params)
    mcmc_init = np.array([
        opt.x+0.1*np.random.randn(len(opt.x)) for i in range(nwalkers)
        ])
    with multiprocessing.Pool(processes=nprocs) as pool:
        sampler = emcee.EnsembleSampler(nwalkers, len(mcmc_init[0]), llfunc, 
            pool=pool)
        sampler.run_mcmc(mcmc_init, nit, progress=True)
    _chain = sampler.get_chain(flat=True, discard=ncut)
    chain = copy.deepcopy(_chain)
    chain[:,2] = np.log10(chain[:,2])
    chain[:,5] = np.log10(chain[:,5])
    chain[:,9] = np.log10(chain[:,9])

    # Calculate masses for the whole chain
    ms = np.zeros(len(chain))
    for k in range(len(chain)):
        ms[k] = comp_densfunc.effective_volume(_chain[k,:], *[0, 1e6, 1e6])
    chain = np.append(chain, np.log10(ms[:,None]), axis=1)
    mtot_all = np.sum(masses)

    # Save the results
    os.makedirs(this_fitting_dir,exist_ok=True)
    opt_filename = os.path.join(this_fitting_dir,'opt.pkl')
    with open(opt_filename,'wb') as handle:
        pickle.dump(opt,handle)
    sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
    with open(sampler_filename,'wb') as handle:
        pickle.dump(sampler,handle)
    chain_filename = os.path.join(this_fitting_dir,'chain.pkl')
    with open(chain_filename,'wb') as handle:
        txt = ['a','b','log10(amp_disk)','alpha','rc','log10(amp)',
               'alpha','beta','a','mtot',
               ]
        pickle.dump([chain,txt],handle)

    # Plot the results
    this_fig_dir = os.path.join(epsen_fig_dir,str(z0_sid))
    os.makedirs(this_fig_dir,exist_ok=True)

    # Density profile
    fig, axes = plot_density_profiles(_chain, R, z, masses, comp_densfunc,
        Rmin_fit, Rmax_fit, zmax_fit, opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_density_profile.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # R-z density profiles
    fig, axes = plot_RZ_density(_chain, R, z, masses, comp_densfunc, opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_Rz_density.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # Projected densities
    fig, axes = plot_projected_density(orbs, masses, _chain, comp_densfunc, 
        opt.x)
    figname = os.path.join(this_fig_dir,version_prefix+'_projected_density.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # Corner plot
    fig = corner.corner(chain,
        labels=[r'$a$',r'$b$',r'$\log_{10}$(disk amp)',
                r'$\alpha$',r'$r_c$',r'$\log_{10}$(bulge amp)',
                r'$\alpha$',r'$\beta$',r'$a$',r'$\log_{10}$(halo amp)',
                r'$\log_{10}$(Mass)'],
        truths = [None,]*10+[np.log10(mtot_all),],
        truth_color='Red', quantiles=[0.16,0.5,0.84])
    # corner.overplot_lines(fig, [None,]*10+[np.log10(mtot_all),], 
    #     color='Red', linestyle='dashed')
    corner.overplot_lines(fig, np.median(chain, axis=0), color='Black')
    figname = os.path.join(this_fig_dir,version_prefix+'_corner.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)