In [240]:
from itertools import izip
from time import time
import numpy as np
import astropy
from pearce.mocks.customHODModels import *
from pearce.mocks import cat_dict
from scipy.optimize import minimize

In [241]:
from SloppyJoes import lazy_wrapper

In [242]:
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set()

In [243]:
AB = True

In [244]:
PRIORS = {'f_c': (0, 1),
          'alpha': (0, 2),
          'logMmin':(10,14),
          'logM1': (10, 15),
          'logM0': (9,15),
          'sigma_logM': (0.3, 1.5),
          'logMcut': (9,15),
          'logMlin':(9,15),
          'f_cen': (0.0,1.0)}

cens_model = AssembiasRedMagicCens()
#cens_model = AssembiasReddick14Cens()
sats_model = AssembiasRedMagicSats(cens_model)
#sats_model = AssembiasReddick14Sats()

cosmo_params = {'simname':'chinchilla', 'Lbox':400.0, 'scale_factors':[0.658, 1.0]}

cat = cat_dict[cosmo_params['simname']](**cosmo_params)#construct the specified catalog!

cat.load(1.0, HOD=(cens_model, sats_model))

LBOX = 400.0

sats_model.modulate_with_cenocc = False

In [245]:
def resids(theta,params,cens_occs, sats_occs,mbc,mass_bins, sp_vals, sp_median):
      
    cens_model.param_dict['f_c'] = 1.0
    sats_model.param_dict['f_c'] = 1.0
    cat.model.param_dict['f_c'] = 1.0
    cens_model.param_dict.update({p:x for p, x in izip(params, theta)})
    sats_model.param_dict.update({p:x for p, x in izip(params, theta)})
    cat.model.param_dict.update({p:x for p, x in izip(params, theta)})

    arg1 = np.tile(mbc, sp_vals.shape[1])
    arg2 = sp_vals.reshape((-1,), order = 'F')
    arg3 = np.tile(sp_median, sp_vals.shape[1])

    cens_preds = cens_model.mean_occupation(prim_haloprop = arg1,\
                                            sec_haloprop = arg2,\
                                            sec_haloprop_percentile_values = arg3)
    sats_preds = sats_model.mean_occupation(prim_haloprop = arg1,\
                                            sec_haloprop = arg2,\
                                            sec_haloprop_percentile_values = arg3)

    #Weird edge cases can occur?
    cens_preds[cens_preds < 1e-9] = 0
    sats_preds[sats_preds < 1e-9] = 0

    cens_vars = cens_preds*(1-cens_preds)+1e-6
    sats_vars = sats_preds + 1e-6

    Ngal_pred = np.sum(cens_preds+sats_preds)
    Ngal_obs = np.sum(cens_occs+sats_occs)

    idx = sats_occ > 0
    log_sats_diff = (np.log10(sats_preds) - np.log10(sats_occs).reshape((-1,), order = 'F' ) )
    log_sats_diff[np.isnan(log_sats_diff)] = 0.0
    log_sats_diff[log_sats_diff == -np.inf] = 0.0
    log_sats_diff[log_sats_diff == np.inf] = 0.0
    
    return np.r_[ cens_preds-cens_occs.reshape((-1,), order = 'F'),log_sats_diff, (Ngal_pred-Ngal_obs) ]

    #return np.r_[cens_preds[0,:]-cens_occs[0,:], Ngal_pred-Ngal_obs]

In [246]:
catalog = astropy.table.Table.read('/u/ki/swmclau2/des/AB_tests/abmatched_halos.hdf5', format = 'hdf5')

In [247]:
mag_cut = -21
min_ptcl = 200
if AB:
    catalog = catalog[np.logical_and(catalog['halo_mvir'] > min_ptcl*cat.pmass, catalog['halo_vpeak_mag'] <=mag_cut)]
else:
    catalog = catalog[np.logical_and(catalog['halo_mvir'] > min_ptcl*cat.pmass, catalog['halo_vvir_mag'] <=mag_cut)]

In [248]:
if not AB:
    MAP = np.array([ 12.64539386,  14.15396837,   0.52641264,   0.22234201,
        14.34871275,   1.07989646,  12.81902682])
else:
    MAP = np.array([ 12.72747382,  14.24964974,   0.55068739,   0.18672767,
        14.00597843,   1.06836772,  12.88931659])
       
names = ['logMmin', 'logMlin', 'sigma_logM', 'f_cen', 'logM1', 'alpha', 'logMcut']
hod_params = dict(zip(names, MAP))

In [249]:
ab_params = {'mean_occupation_centrals_assembias_param1':0.4, 'mean_occupation_satellites_assembias_slope1':3,\
             'mean_occupation_satellites_assembias_param1':-0.5, 'mean_occupation_centrals_assembias_slope1':3,}

In [250]:
sats_model.param_dict.update(cens_model.param_dict)

In [251]:
param_dict = hod_params
param_dict.update(ab_params)
cens_model.param_dict.update(param_dict)
sats_model.param_dict.update(param_dict)

params = ab_params.keys()#sats_model.param_dict.keys()
########################
#params.remove('f_c')
#######################3
ndim = len(params)

In [252]:
halo_table = cat.halocat.halo_table[cat.halocat.halo_table['halo_mvir'] > min_ptcl*cat.pmass]

In [253]:
detected_central_ids = set(catalog[catalog['halo_upid']==-1]['halo_id'])

In [254]:
from collections import Counter
def compute_occupations(halo_table):
    #halo_table = cat.halocat.halo_table[cat.halocat.halo_table['halo_mvir'] > min_ptcl*cat.pmass]

    cens_occ = np.zeros((np.sum(halo_table['halo_upid'] == -1),))
    #cens_occ = np.zeros((len(halo_table),))
    sats_occ = np.zeros_like(cens_occ)
    detected_central_ids = set(catalog[catalog['halo_upid']==-1]['halo_id'])
    detected_satellite_upids = Counter(catalog[catalog['halo_upid']!=-1]['halo_upid'])

    for idx, row  in enumerate(halo_table[halo_table['halo_upid'] == -1]):
        cens_occ[idx] = 1.0 if row['halo_id'] in detected_central_ids else 0.0
        sats_occ[idx]+= detected_satellite_upids[row['halo_id']]

    return cens_occ, sats_occ

In [255]:
from halotools.utils.table_utils import compute_prim_haloprop_bins
def compute_hod(masses, centrals, satellites, mass_bins):
    mass_bin_idxs = compute_prim_haloprop_bins(prim_haloprop_bin_boundaries=mass_bins, prim_haloprop = masses)
    mass_bin_nos = set(mass_bin_idxs)

    cens_occ = np.zeros((mass_bins.shape[0]-1,))
    sats_occ = np.zeros_like(cens_occ)
    for mb in mass_bin_nos:
        indices_of_mb = np.where(mass_bin_idxs == mb)[0]
        denom = len(indices_of_mb)
        #TODO what to do about bout 0 mean std's?
        cens_occ[mb-1] = np.mean(centrals[indices_of_mb])
        sats_occ[mb-1] = np.mean(satellites[indices_of_mb])
    return cens_occ, sats_occ

In [None]:
mass_bin_range = (9,16)
mass_bin_size = 0.1
mass_bins = np.logspace(mass_bin_range[0], mass_bin_range[1], int( (mass_bin_range[1]-mass_bin_range[0])/mass_bin_size )+1 )
mbc = (mass_bins[1:]+mass_bins[:-1])/2

In [None]:
cens_occ, sats_occ = compute_occupations(halo_table )
mock_masses = halo_table[halo_table['halo_upid']==-1]['halo_mvir']
mock_concentrations = halo_table[halo_table['halo_upid']==-1]['halo_nfw_conc']

In [None]:
from halotools.utils.table_utils import compute_conditional_percentiles
mock_percentiles = compute_conditional_percentiles(prim_haloprop = mock_masses, sec_haloprop = mock_concentrations,
                                              prim_haloprop_bin_boundaries= mass_bins)

splits = np.arange(0,1.1,0.2)

In [None]:
cen_hod, sat_hod = compute_hod(mock_masses, cens_occ, sats_occ, mass_bins)

In [None]:
cens_occs, sats_occs = [],[]

for idx, p in enumerate(splits[:-1]):
    split_idxs = np.logical_and(p<= mock_percentiles, mock_percentiles < splits[idx+1])
    
    _cens_occ, _sats_occ = compute_hod(mock_masses[split_idxs], cens_occ[split_idxs], sats_occ[split_idxs], mass_bins)
    
    cens_occs.append(_cens_occ)
    sats_occs.append(_sats_occ)
    
    #mass_bin_idxs = compute_prim_haloprop_bins(prim_haloprop_bin_boundaries=mass_bins, prim_haloprop = mock_masses[split_idxs])
    #mass_bin_nos = set(mass_bin_idxs)

    #for mb in mass_bin_nos:
    #    indices_of_mb = np.where(mass_bin_idxs == mb)[0]
    #    haloprop_grid[mb-1, idx] = np.mean(mock_concentrations[split_idxs][indices_of_mb])

In [None]:
from halotools.utils.table_utils import compute_conditional_percentile_values
sp_values = np.zeros((len(mass_bins)-1, (len(splits)-1)))
spv_median = np.zeros((len(mass_bins)-1,))

mass_bin_idxs = compute_prim_haloprop_bins(prim_haloprop_bin_boundaries=mass_bins, prim_haloprop = mock_masses[split_idxs])
mass_bin_nos = set(mass_bin_idxs)

q = ((splits[1:]+splits[:-1])/2)*100

for mb in mass_bin_nos:
    indices_of_mb = np.where(mass_bin_idxs == mb)[0]
    sp_values[mb-1, :] = np.percentile(mock_concentrations[indices_of_mb], q)
    
    spv_median[mb-1] = np.percentile(mock_concentrations[indices_of_mb], 50)

In [None]:
for co, so, p in izip(cens_occs, sats_occs, splits[1:]):
    plt.plot(mbc, co, label =p )

    
plt.plot(mbc, cen_hod, lw = 2)
    
plt.legend(loc='best')
plt.loglog()
plt.xlim([1e11,1e16])
plt.ylim([1e-3,1.1])
plt.show();

In [None]:
cens_model.param_dict['mean_occupation_centals_assembias_slope1'] = 1.2
cens_model.param_dict['f_c'] = 1.0
sats_model.param_dict['f_c'] = 1.0
sats_model.param_dict['mean_occupation_satellites_assembias_slope1'] = 1.2

In [None]:
arg1 = np.tile(mbc, sp_values.shape[1])
arg2 = sp_values.reshape((-1,), order = 'F')
arg3 = np.tile(spv_median, sp_values.shape[1])

cens_preds = cens_model.mean_occupation(prim_haloprop = arg1,\
                                        sec_haloprop = arg2,\
                                        sec_haloprop_percentile_values = arg3)
sats_preds = sats_model.mean_occupation(prim_haloprop = arg1,\
                                        sec_haloprop = arg2,\
                                        sec_haloprop_percentile_values = arg3)

cens_preds = cens_preds.reshape((-1, sp_values.shape[1]), order = 'F')
sats_preds = sats_preds.reshape((-1, sp_values.shape[1]), order = 'F')

for p, cp, sp, co, so in zip(splits, cens_preds.T, sats_preds.T, cens_occs, sats_occs,):
    plt.plot(mbc, (cp+sp)/(co+so), label = p+0.25 )
    
    
plt.legend(loc='best')
plt.loglog()
plt.xlim([1e11,1e16])
plt.ylim([1e-3,20])
plt.show();

In [None]:
vals = np.array([param_dict[key] for key in params])
cens_idxs = halo_table['halo_upid'] == -1
args = (params, np.array(cens_occs).T, np.array(sats_occs).T,mbc,mass_bins, sp_values, spv_median)
print params

In [None]:
vals[2]= 2.0
vals[3] = 2.0

In [None]:
test = cens_model.mean_occupation(prim_haloprop = cat.halocat.halo_table['halo_mvir'][:100],\
                           sec_haloprop= cat.halocat.halo_table['halo_nfw_conc'][:100])
print np.mean(test)

In [None]:
np.array(cens_occs).shape

In [None]:
cens_model.baseline_mean_occupation

In [None]:
resids(vals, *args)

In [None]:
lazy_wrapper(resids, vals, func_args = args,maxfev = 500, print_level = 1, artol = 1e-6)

In [None]:
print params