In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

In [13]:
# Standard modules
import pdb
import sys
import os
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
from lmfit import Parameters, minimize, fit_report
from scipy.ndimage.filters import gaussian_filter
from astropy.io import fits
from astropy.wcs import WCS

c = 299792458.0 # m/s

import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

sys.path.append("..")

from simstackwrapper import SimstackWrapper
from simstackresults import SimstackResults
from simstackplots import SimstackPlots
from simstacktoolbox import SimstackToolbox



In [None]:
L_sun = 3.839e26  # W
c = 299792458.0  # m/s
def fast_LIR(self, theta, zed, dzin=None):
    '''This calls graybody_fn instead of fast_sed'''
    wavelength_range = self.loggen(8, 1000, 1000)
    Lrf_array = np.zeros_like(zed)
    for iz, zin in enumerate(zed):
        theta_in = theta[0][iz], theta[1][iz]
        model_sed = self.graybody_fn(theta_in, wavelength_range)

        nu_in = c * 1.e6 / wavelength_range
        dnu = nu_in[:-1] - nu_in[1:]
        dnu = np.append(dnu[0], dnu)
        Lir = np.sum(model_sed * dnu, axis=1)
        conversion = 4.0 * np.pi * (
                    1.0E-13 * self.config_dict['cosmology_dict']['cosmology'].luminosity_distance(
                zin) * 3.08568025E22) ** 2.0 / L_sun  # 4 * pi * D_L^2    units are L_sun/(Jy x Hz)

        Lrf = (Lir * conversion.value)[0]  # Jy x Hz

        if dzin is not None:
            dLrf = np.zeros([2])
            for idz, dz in enumerate(dzin):
                conversion = 4.0 * np.pi * (
                        1.0E-13 * self.config_dict['cosmology_dict']['cosmology'].luminosity_distance(
                    dz) * 3.08568025E22) ** 2.0 / L_sun  # 4 * pi * D_L^2    units are L_sun/(Jy x Hz)
                dLrf[idz] = (Lir * conversion.value)[0]

            #return Lrf, dLrf
            Lrf_array[iz] = Lrf, dLrf
            
        Lrf_array[iz] = Lrf

    return Lrf_array

In [None]:
L_sun = 3.839e26  # W
c = 299792458.0  # m/s
T_hot = 80
T_cold = 18
def fast_LIR2(self, theta, zed, dzin=None):
    '''This calls graybody_fn instead of fast_sed'''
    wavelength_range = self.loggen(8, 1000, 1000)
    Lrf_array = np.zeros_like(zed)
    for iz, zin in enumerate(zed):
        A_cold, A_hot = theta[0][iz], theta[1][iz]
        theta_cold = A_cold, T_cold / (1+zin)
        theta_hot = A_hot, T_hot / (1+zin)
        model_cold_sed = self.graybody_fn(theta_cold, wavelength_range)
        model_hot_sed = self.graybody_fn(theta_hot, wavelength_range)

        nu_in = c * 1.e6 / wavelength_range
        dnu = nu_in[:-1] - nu_in[1:]
        dnu = np.append(dnu[0], dnu)
        Lir = np.sum((model_cold+model_hot) * dnu, axis=1)
        conversion = 4.0 * np.pi * (
                    1.0E-13 * self.config_dict['cosmology_dict']['cosmology'].luminosity_distance(
                zin) * 3.08568025E22) ** 2.0 / L_sun  # 4 * pi * D_L^2    units are L_sun/(Jy x Hz)

        Lrf = (Lir * conversion.value)[0]  # Jy x Hz

        if dzin is not None:
            dLrf = np.zeros([2])
            for idz, dz in enumerate(dzin):
                conversion = 4.0 * np.pi * (
                        1.0E-13 * self.config_dict['cosmology_dict']['cosmology'].luminosity_distance(
                    dz) * 3.08568025E22) ** 2.0 / L_sun  # 4 * pi * D_L^2    units are L_sun/(Jy x Hz)
                dLrf[idz] = (Lir * conversion.value)[0]

            #return Lrf, dLrf
            Lrf_array[iz] = Lrf, dLrf
            
        Lrf_array[iz] = Lrf

    return Lrf_array

In [None]:
def model_A_or_Tdust(params, X, qt=False):
    v = params.valuesdict().copy()
    if 'A_offset' in v:
        model = v.pop('A_offset')
    elif 'T_offset' in v:
        model = v.pop('T_offset')
    if 'A_offset_sf' in v:
        model = v.pop('A_offset_sf')
    elif 'T_offset_sf' in v:
        model = v.pop('T_offset_sf')
    if 'A_offset_qt' in v:
        model = v.pop('A_offset_qt')
    elif 'T_offset_qt' in v:
        model = v.pop('T_offset_qt')
    if 'A_offset_agn' in v:
        model = v.pop('A_offset_agn')
    elif 'T_offset_agn' in v:
        model = v.pop('T_offset_agn')
    for i, ival in enumerate(v):
        model+= X[i] * v[ival]
    return model

In [None]:
def direct_fit_A_Tdust(params, X, y):
    # y is a dict containing map, x, y
    vT = params.copy()
    vA = params.copy()
    Alist = [i for i, j in vA.items() if 'A_' in i]
    Tlist = [i for i, j in vT.items() if 'T_' in i]
    Apop=[vA.pop(i) for i in Tlist]
    Tpop=[vT.pop(i) for i in Alist]
    A_model = model_A_or_Tdust(vA, X)
    T_model = model_A_or_Tdust(vT, X)
    out_model = []
    for map_name in y:
        map_lambda = y[map_name]['wavelength']
        map_nu = c * 1.e6/map_lambda
        map_sky = y[map_name]['map']
        map_coords=y[map_name]['map_coords']
        S_model = get_map_flux_mJy(np.array([map_nu]),A_model,T_model)
        # should be putting these in a simmap[x,y] and convolving by the beam before differencing....
        out_model.extend(map_sky[map_coords[0],map_coords[1]]-S_model)
    return out_model

In [None]:
def direct_convolved_fit_A_Tdust(params, X, y):
    # y is a dict containing map, x, y
    vT = params.copy()
    vA = params.copy()
    Alist = [i for i, j in vA.items() if 'A_' in i]
    Tlist = [i for i, j in vT.items() if 'T_' in i]
    Apop=[vA.pop(i) for i in Tlist]
    Tpop=[vT.pop(i) for i in Alist]
    A_model = model_A_or_Tdust(vA, X)
    T_model = model_A_or_Tdust(vT, X)
    out_model = []
    for map_name in y:
        map_lambda = y[map_name]['wavelength']
        map_nu = c * 1.e6/map_lambda
        map_coords = y[map_name]['map_coords']
        hd = y[map_name]['header']
        map_sky = y[map_name]['map']
        #wmap= WCS(hd)
        #cms = np.shape(smap)
        #ra = ra_series.values 
        #dec = dec_series.values 
        #ty, tx = wmap.wcs_world2pix(ra, dec, 0)
        ## CHECK FOR SOURCES THAT FALL OUTSIDE MAP
        #ind_keep = np.where((tx >= 0) & (np.round(tx) < cms[0]) & (ty >= 0) & (np.round(ty) < cms[1]))
        #real_x = np.round(tx[ind_keep]).astype(int)
        #real_y = np.round(ty[ind_keep]).astype(int)

        map_model = np.zeros_like(map_sky)
        map_pixels= np.zeros_like(map_sky)
        S_model = get_map_flux_mJy(np.array([map_nu]),A_model,T_model)
        map_model[map_coords[0],map_coords[1]]=S_model
        map_pixels[map_coords[0],map_coords[1]]=1
        
        fwhm = y[map_name]['fwhm']
        pix = y[map_name]['pixel_size']
        kern = gauss_kern(fwhm, np.floor(fwhm * 10) / pix, pix)
        tmap = smooth_psf(map_model, kern)
        
        #idx_fit = map_sky != 0
        #idx_fit = map_model != 0
        idx_fit = map_pixels != 0
        diff = map_sky - tmap
        out_model.extend(np.ravel(diff[idx_fit]))
            
        #pdb.set_trace()    
    return out_model

In [None]:
def direct_convolved_fit_A_Tdust(params, X, y):

    v = params.valuesdict()#.copy()
    A_model = v.pop('A_offset')
    T_model = v.pop('T_offset')
    i=0
    for ival in v:
        #print(ival)
        if 'A_' in ival:
            A_model+= X[i] * v[ival]
        else:
            T_model+= X[i] * v[ival]
            i+=1
            
    out_model = []
    for map_name in y:
        map_lambda = y[map_name]['wavelength']
        map_nu = c * 1.e6/map_lambda
        map_coords = y[map_name]['map_coords']
        hd = y[map_name]['header']
        map_sky = y[map_name]['map']
        #wmap= WCS(hd)
        #cms = np.shape(smap)
        #ra = ra_series.values 
        #dec = dec_series.values 
        #ty, tx = wmap.wcs_world2pix(ra, dec, 0)
        ## CHECK FOR SOURCES THAT FALL OUTSIDE MAP
        #ind_keep = np.where((tx >= 0) & (np.round(tx) < cms[0]) & (ty >= 0) & (np.round(ty) < cms[1]))
        #real_x = np.round(tx[ind_keep]).astype(int)
        #real_y = np.round(ty[ind_keep]).astype(int)

        map_model = np.zeros_like(map_sky)
        map_pixels= np.zeros_like(map_sky)
        S_model = get_map_flux_mJy(np.array([map_nu]),A_model,T_model)
        map_model[map_coords[0],map_coords[1]]=S_model
        map_pixels[map_coords[0],map_coords[1]]=1
        
        fwhm = y[map_name]['fwhm']
        pix = y[map_name]['pixel_size']
        kern = gauss_kern(fwhm, np.floor(fwhm * 10) / pix, pix)
        tmap = smooth_psf(map_model, kern)
        
        #idx_fit = map_sky != 0
        #idx_fit = map_model != 0
        idx_fit = map_pixels != 0
        diff = map_sky - tmap
        out_model.extend(np.ravel(diff[idx_fit]))
            
        #pdb.set_trace()    
    return out_model

In [2]:
def model_A_Tdust(params, X, qt=False):
    v = params.valuesdict().copy()
    if 'offset' in v:
        model = v.pop('offset')
    else:
        model = 0 
    if 'slope_split_params' in v:
        pop = v.pop('slope_split_params')
        if qt:
            model+=X[-1]*pop
            
    for i, ival in enumerate(v):
        #print(ival)
        model+= X[i] * v[ival]
    return model

In [3]:
def slope_A_Tdust(params, X, y):
    model = model_A_Tdust(params, X)
    idx_good = model==model
    return (y[idx_good] - model[idx_good])

In [None]:
def black(nu_in, T):
    # h = 6.623e-34     ; Joule*s
    # k = 1.38e-23      ; Joule/K
    # c = 3e8           ; m/s
    # (2*h*nu_in^3/c^2)*(1/( exp(h*nu_in/k*T) - 1 )) * 10^29

    a0 = 1.4718e-21  # 2*h*10^29/c^2
    a1 = 4.7993e-11  # h/k

    num = a0 * nu_in ** 3.0
    den = np.exp(a1 * np.outer(1.0 / T, nu_in)) - 1.0
    ret = num / den

    return ret

In [5]:
def get_map_flux_mJy(nu_in, Ain, T, betain=1.8, alphain=2.0):
    
    A = 10**Ain
    ng = np.size(A)

    ns = len(nu_in)
    base = 2.0 * (6.626)**(-2.0 - betain - alphain) * (1.38)**(3. + betain + alphain) / (2.99792458)**2.0
    expo = 34.0 * (2.0 + betain + alphain) - 23.0 * (3.0 + betain + alphain) - 16.0 + 26.0
    K = base * 10.0**expo
    w_num = A * K * (T * (3.0 + betain + alphain))**(3.0 + betain + alphain)
    w_den = (np.exp(3.0 + betain + alphain) - 1.0)
    w_div = w_num/w_den
    nu_cut = (3.0 + betain + alphain) * 0.208367e11 * T

    graybody = A * nu_in**betain * black(nu_in, T)[:,0] / 1000. 
    powerlaw = w_div * nu_in**(-1.0 * alphain)
    ind_cut = nu_in > nu_cut
    #pdb.set_trace()
    if np.sum(ind_cut):
        graybody[ind_cut] = powerlaw[ind_cut]
    
    #pdb.set_trace()
    return graybody

In [5]:
def get_flux_mJy(self, nu_in, Ain, T, betain=1.8, alphain=2.0):
    
    A = 10**Ain
    ng = np.size(A)

    ns = len(nu_in)
    base = 2.0 * (6.626)**(-2.0 - betain - alphain) * (1.38)**(3. + betain + alphain) / (2.99792458)**2.0
    expo = 34.0 * (2.0 + betain + alphain) - 23.0 * (3.0 + betain + alphain) - 16.0 + 26.0
    K = base * 10.0**expo
    w_num = A * K * (T * (3.0 + betain + alphain))**(3.0 + betain + alphain)
    w_den = (np.exp(3.0 + betain + alphain) - 1.0)
    w_div = w_num/w_den
    nu_cut = (3.0 + betain + alphain) * 0.208367e11 * T

    graybody = A * nu_in**betain * self.black(nu_in, T)[:,0] / 1000. 
    powerlaw = w_div * nu_in**(-1.0 * alphain)
    ind_cut = nu_in > nu_cut
    #pdb.set_trace()
    if np.sum(ind_cut):
        graybody[ind_cut] = powerlaw[ind_cut]
    
    #pdb.set_trace()
    return graybody

In [None]:
def fast_LIR(self, theta, zin, dzin=None):
    '''This calls graybody_fn instead of fast_sed'''
    wavelength_range = self.loggen(8, 1000, 1000)
    model_sed = self.graybody_fn(theta, wavelength_range)

    nu_in = c * 1.e6 / wavelength_range
    dnu = nu_in[:-1] - nu_in[1:]
    dnu = np.append(dnu[0], dnu)
    Lir = np.sum(model_sed * dnu, axis=1)
    conversion = 4.0 * np.pi * (
                1.0E-13 * self.config_dict['cosmology_dict']['cosmology'].luminosity_distance(
            zin) * 3.08568025E22) ** 2.0 / L_sun  # 4 * pi * D_L^2    units are L_sun/(Jy x Hz)

    Lrf = (Lir * conversion.value)[0]  # Jy x Hz

    if dzin is not None:
        dLrf = np.zeros([2])
        for idz, dz in enumerate(dzin):
            conversion = 4.0 * np.pi * (
                    1.0E-13 * self.config_dict['cosmology_dict']['cosmology'].luminosity_distance(
                dz) * 3.08568025E22) ** 2.0 / L_sun  # 4 * pi * D_L^2    units are L_sun/(Jy x Hz)
            dLrf[idz] = (Lir * conversion.value)[0]

        return Lrf, dLrf

    return Lrf

In [6]:
def get_x_y_from_ra_dec(self, ra_series, dec_series, ind_src=None):

    smap = self['map']
    hd = self['header']
    wmap= WCS(hd)
    cms = np.shape(smap)
    
    if ind_src:
        ra = ra_series[ind_src].values
        dec = dec_series[ind_src].values
    else:
        ra = ra_series.values 
        dec = dec_series.values 
        
    # CONVERT FROM RA/DEC to X/Y
    #pdb.set_trace()
    ty, tx = wmap.wcs_world2pix(ra, dec, 0)
    # CHECK FOR SOURCES THAT FALL OUTSIDE MAP
    ind_keep = np.where((tx >= 0) & (np.round(tx) < cms[0]) & (ty >= 0) & (np.round(ty) < cms[1]))
    real_x = np.round(tx[ind_keep]).astype(int)
    real_y = np.round(ty[ind_keep]).astype(int)

    return real_x, real_y #, ind_keep[0]

In [7]:
def gauss_kern(fwhm, side, pixsize):
    ''' Create a 2D Gaussian (size= side x side)'''

    sig = fwhm / 2.355 / pixsize
    delt = np.zeros([int(side), int(side)])
    delt[0, 0] = 1.0
    ms = np.shape(delt)
    #delt = self.shift_twod(delt, ms[0] / 2, ms[1] / 2)
    delt = np.roll(np.roll(delt, int(ms[0] / 2), axis=1), int(ms[1] / 2), axis=0)
    kern = delt
    gaussian_filter(delt, sig, output=kern)
    kern /= np.max(kern)

    return kern

def smooth_psf(mapin, psfin):

    s = np.shape(mapin)
    mnx = s[0]
    mny = s[1]

    s = np.shape(psfin)
    pnx = s[0]
    pny = s[1]

    psf_x0 = pnx / 2
    psf_y0 = pny / 2
    psf = psfin
    px0 = psf_x0
    py0 = psf_y0

    # pad psf
    psfpad = np.zeros([mnx, mny])
    psfpad[0:pnx, 0:pny] = psf

    # shift psf so that centre is at (0,0)
    #psfpad = self.shift_twod(psfpad, -px0, -py0)
    psfpad = np.roll(np.roll(psfpad, int(-px0), axis=1), int(-py0), axis=0)
    smmap = np.real(np.fft.ifft2(np.fft.fft2(mapin) *
                                 np.fft.fft2(psfpad))
                    )

    return smmap

In [8]:
def write_fits(self,
               mapin,
               map_name,
               path_map=r'C:\Users\viero\Desktop', 
               prefix='model_',
               overwrite=True, 
               show=False):
    
    if show:
        plt.imshow(model_250)

    map_object = self.maps_dict[map_name]
    #path_map = r'D:\maps\cutouts\layers'
    name_map = prefix+str(map_name)+'.fits'

    hd = map_object['header']
    fwhm = map_object['fwhm']
    pix = map_object['pixel_size']
    kern = gauss_kern(fwhm, np.floor(fwhm * 10) / pix, pix)
    tmap = smooth_psf(mapin, kern)

    hdu = fits.PrimaryHDU(tmap, header=hd)
    hdul = fits.HDUList([hdu])
    hdul.writeto(os.path.join(path_map, name_map),overwrite=overwrite)
    print('{0} written to {1}'.format(name_map,path_map))

In [9]:
def get_fast_sed_dict(self, fluxes_dict, catalog_object):
    
    split_dict = self.config_dict['catalog']['classification']
    label_keys = list(split_dict.keys())
    catalog_keys = [split_dict[i]['id'] for i in label_keys]
    label_dict = self.config_dict['parameter_names']
    
    x = fluxes_dict['wavelengths']

    wv_array = self.loggen(8, 1000, 100)
    sed_params_dict = {}
    graybody_dict = {}
    lir_dict = {}
    return_dict = {'wv_array': wv_array,
                   'sed_params': sed_params_dict, 
                   'graybody': graybody_dict, 
                   'lir': lir_dict}
    
    for ilabel in label_keys:
        return_dict[ilabel] = {}

    for z, z_val in enumerate(label_dict[label_keys[0]]):
        if len(label_keys) > 1:
            for m, m_val in enumerate(label_dict[label_keys[1]]):
                if len(label_keys) > 2:
                    for j, j_val in enumerate(label_dict[label_keys[2]]):
                        if len(label_keys) > 3:
                            for k, k_val in enumerate(label_dict[label_keys[3]]):
                                if len(label_keys) > 4:
                                    for l, l_val in enumerate(label_dict[label_keys[4]]):

                                        label = "__".join([z_val, m_val, j_val, k_val, l_val]).replace('.', 'p')

                                        idx_bin = (self.catalog_dict['tables']['split_table'][label_keys[0]] == z) & \
                                                  (self.catalog_dict['tables']['split_table'][label_keys[1]] == m) & \
                                                  (self.catalog_dict['tables']['split_table'][label_keys[2]] == j) & \
                                                  (self.catalog_dict['tables']['split_table'][label_keys[3]] == k) & \
                                                  (self.catalog_dict['tables']['split_table'][label_keys[4]] == l)  

                                        for ilab, lab in enumerate(label_keys):
                                            return_dict[lab][label]=np.median(catalog_object.catalog_dict['tables']['full_table'].loc[idx_bin][catalog_keys[ilab]].dropna())

                                        y = fluxes_dict['flux'][:,z,m,j,k,l]
                                        yerr = np.sqrt(abs(fluxes_dict['flux'][:,z,m,j,k,l]))

                                        sed_params_dict[label] = self.fast_sed_fitter(x, y, yerr)
                                        graybody_dict[label] = self.fast_sed(sed_params_dict[label], wv_array)[0]
                                        theta = sed_params_dict[label]['A'].value, sed_params_dict[label]['T_observed'].value
                                        lir_dict[label] = self.fast_LIR(theta, return_dict[label_keys[0]][label])
                                else:
                                    label = "__".join([z_val, m_val, j_val, k_val]).replace('.', 'p')
                                    idx_bin = (self.catalog_dict['tables']['split_table'][label_keys[0]] == z) & \
                                              (self.catalog_dict['tables']['split_table'][label_keys[1]] == m) & \
                                              (self.catalog_dict['tables']['split_table'][label_keys[2]] == j) & \
                                              (self.catalog_dict['tables']['split_table'][label_keys[3]] == k) 
                                    for ilab, lab in enumerate(label_keys):
                                        return_dict[lab][label]=np.median(catalog_object.catalog_dict['tables']['full_table'].loc[idx_bin][catalog_keys[ilab]].dropna())

                                    y = fluxes_dict['flux'][:,z,m,j,k]
                                    yerr = np.sqrt(abs(fluxes_dict['flux'][:,z,m,j,k]))

                                    sed_params_dict[label] = self.fast_sed_fitter(x, y, yerr)
                                    graybody_dict[label] = self.fast_sed(sed_params_dict[label], wv_array)[0]
                                    theta = sed_params_dict[label]['A'].value, sed_params_dict[label]['T_observed'].value
                                    lir_dict[label] = self.fast_LIR(theta, return_dict[label_keys[0]][label])
                        else:
                            label = "__".join([z_val, m_val, j_val]).replace('.', 'p')
                            idx_bin = (self.catalog_dict['tables']['split_table'][label_keys[0]] == z) & \
                                      (self.catalog_dict['tables']['split_table'][label_keys[1]] == m) & \
                                      (self.catalog_dict['tables']['split_table'][label_keys[2]] == j) 
                            for ilab, lab in enumerate(label_keys):
                                return_dict[lab][label]=np.median(catalog_object.catalog_dict['tables']['full_table'].loc[idx_bin][catalog_keys[ilab]].dropna())

                            y = fluxes_dict['flux'][:,z,m,j]
                            yerr = np.sqrt(abs(fluxes_dict['flux'][:,z,m,j]))

                            sed_params_dict[label] = self.fast_sed_fitter(x, y, yerr)
                            graybody_dict[label] = self.fast_sed(sed_params_dict[label], wv_array)[0]
                            theta = sed_params_dict[label]['A'].value, sed_params_dict[label]['T_observed'].value
                            lir_dict[label] = self.fast_LIR(theta, return_dict[label_keys[0]][label])
                else:
                    label = "__".join([z_val, m_val]).replace('.', 'p')
                    idx_bin = (self.catalog_dict['tables']['split_table'][label_keys[0]] == z) & \
                              (self.catalog_dict['tables']['split_table'][label_keys[1]] == m) 
                    for ilab, lab in enumerate(label_keys):
                        return_dict[lab][label]=np.median(catalog_object.catalog_dict['tables']['full_table'].loc[idx_bin][catalog_keys[ilab]].dropna())

                    y = fluxes_dict['flux'][:,z,m]
                    yerr = np.sqrt(abs(fluxes_dict['flux'][:,z,m]))

                    sed_params_dict[label] = self.fast_sed_fitter(x, y, yerr)
                    graybody_dict[label] = self.fast_sed(sed_params_dict[label], wv_array)[0]
                    theta = sed_params_dict[label]['A'].value, sed_params_dict[label]['T_observed'].value
                    lir_dict[label] = self.fast_LIR(theta, return_dict[label_keys[0]][label])
        else:
            label = "__".join([z_val]).replace('.', 'p')
            idx_bin = (self.catalog_dict['tables']['split_table'][label_keys[0]] == z) 
            for ilab, lab in enumerate(label_keys):
                return_dict[lab][label]=np.median(catalog_object.catalog_dict['tables']['full_table'].loc[idx_bin][catalog_keys[ilab]].dropna())

            y = fluxes_dict['flux'][:,z]
            yerr = np.sqrt(abs(fluxes_dict['flux'][:,z]))

            sed_params_dict[label] = self.fast_sed_fitter(x, y, yerr)
            graybody_dict[label] = self.fast_sed(sed_params_dict[label], wv_array)[0]
            theta = sed_params_dict[label]['A'].value, sed_params_dict[label]['T_observed'].value
            lir_dict[label] = self.fast_LIR(theta, return_dict[label_keys[0]][label])

    return return_dict

In [9]:
def get_fast_two_seds_dict(self, fluxes_dict, catalog_object, T_cold_rf=18, T_hot_rf=60):
    
    split_dict = self.config_dict['catalog']['classification']
    label_keys = list(split_dict.keys())
    catalog_keys = [split_dict[i]['id'] for i in label_keys]
    label_dict = self.config_dict['parameter_names']
    
    x = fluxes_dict['wavelengths']

    wv_array = self.loggen(8, 1000, 100)
    sed_params_dict = {'hot':{}, 'cold':{}}
    graybody_dict = {'hot':{}, 'cold':{}}
    lir_dict = {}
    return_dict = {'wv_array': wv_array,
                   'sed_params': sed_params_dict, 
                   'graybody': graybody_dict, 
                   'lir': lir_dict}
    
    for ilabel in label_keys:
        return_dict[ilabel] = {}

    for z, z_val in enumerate(label_dict[label_keys[0]]):
        z_mean = np.mean([float(i) for i in z_val.split('_')[-2:]])
        T_cold= T_cold_rf / (1+z_mean)
        T_hot = T_hot_rf / (1+z_mean)
        if len(label_keys) > 1:
            for m, m_val in enumerate(label_dict[label_keys[1]]):
                if len(label_keys) > 2:
                    for j, j_val in enumerate(label_dict[label_keys[2]]):
                        if len(label_keys) > 3:
                            for k, k_val in enumerate(label_dict[label_keys[3]]):
                                if len(label_keys) > 4:
                                    for l, l_val in enumerate(label_dict[label_keys[4]]):

                                        label = "__".join([z_val, m_val, j_val, k_val, l_val]).replace('.', 'p')

                                        idx_bin = (self.catalog_dict['tables']['split_table'][label_keys[0]] == z) & \
                                                  (self.catalog_dict['tables']['split_table'][label_keys[1]] == m) & \
                                                  (self.catalog_dict['tables']['split_table'][label_keys[2]] == j) & \
                                                  (self.catalog_dict['tables']['split_table'][label_keys[3]] == k) & \
                                                  (self.catalog_dict['tables']['split_table'][label_keys[4]] == l)  

                                        for ilab, lab in enumerate(label_keys):
                                            return_dict[lab][label]=np.median(catalog_object.catalog_dict['tables']['full_table'].loc[idx_bin][catalog_keys[ilab]].dropna())

                                        y = fluxes_dict['flux'][:,z,m,j,k,l]
                                        yerr = np.sqrt(abs(fluxes_dict['flux'][:,z,m,j,k,l]))

                                        sed_params_dict['cold'][label] = self.forced_sed_fitter(x, y, yerr, T_cold)
                                        sed_params_dict['hot'][label] = self.forced_sed_fitter(x, y, yerr, T_hot)
                                        graybody_dict['cold'][label] = self.fast_sed(sed_params_dict['cold'][label], wv_array)[0]
                                        graybody_dict['hot'][label] = self.fast_sed(sed_params_dict['hot'][label], wv_array)[0]
                                        theta = sed_params_dict['cold'][label]['A'].value, sed_params_dict['hot'][label]['A'].value
                                        lir_dict[label] = self.fast_LIR2(theta, [T_cold_rf, T_hot_rf], return_dict[label_keys[0]][label])
                                else:
                                    label = "__".join([z_val, m_val, j_val, k_val]).replace('.', 'p')
                                    idx_bin = (self.catalog_dict['tables']['split_table'][label_keys[0]] == z) & \
                                              (self.catalog_dict['tables']['split_table'][label_keys[1]] == m) & \
                                              (self.catalog_dict['tables']['split_table'][label_keys[2]] == j) & \
                                              (self.catalog_dict['tables']['split_table'][label_keys[3]] == k) 
                                    for ilab, lab in enumerate(label_keys):
                                        return_dict[lab][label]=np.median(catalog_object.catalog_dict['tables']['full_table'].loc[idx_bin][catalog_keys[ilab]].dropna())

                                    y = fluxes_dict['flux'][:,z,m,j,k]
                                    yerr = np.sqrt(abs(fluxes_dict['flux'][:,z,m,j,k]))

                                    sed_params_dict['cold'][label] = self.forced_sed_fitter(x, y, yerr, T_cold)
                                    graybody_dict['cold'][label] = self.fast_sed(sed_params_dict['cold'][label], wv_array)[0]
                                    graybody_dict['hot'][label] = self.fast_sed(sed_params_dict['hot'][label], wv_array)[0]
                                    theta = sed_params_dict['cold'][label]['A'].value, sed_params_dict['hot'][label]['A'].value
                                    lir_dict[label] = self.fast_LIR2(theta, [T_cold_rf, T_hot_rf], return_dict[label_keys[0]][label])
                        else:
                            label = "__".join([z_val, m_val, j_val]).replace('.', 'p')
                            idx_bin = (self.catalog_dict['tables']['split_table'][label_keys[0]] == z) & \
                                      (self.catalog_dict['tables']['split_table'][label_keys[1]] == m) & \
                                      (self.catalog_dict['tables']['split_table'][label_keys[2]] == j) 
                            for ilab, lab in enumerate(label_keys):
                                return_dict[lab][label]=np.median(catalog_object.catalog_dict['tables']['full_table'].loc[idx_bin][catalog_keys[ilab]].dropna())

                            y = fluxes_dict['flux'][:,z,m,j]
                            yerr = np.sqrt(abs(fluxes_dict['flux'][:,z,m,j]))

                            sed_params_dict['cold'][label] = self.forced_sed_fitter(x, y, yerr, T_cold)
                            sed_params_dict['hot'][label] = self.forced_sed_fitter(x, y, yerr, T_hot)
                            graybody_dict['cold'][label] = self.fast_sed(sed_params_dict['cold'][label], wv_array)[0]
                            graybody_dict['hot'][label] = self.fast_sed(sed_params_dict['hot'][label], wv_array)[0]
                            theta = sed_params_dict['cold'][label]['A'].value, sed_params_dict['hot'][label]['A'].value
                            lir_dict[label] = self.fast_LIR2(theta, [T_cold_rf, T_hot_rf], return_dict[label_keys[0]][label])
                else:
                    label = "__".join([z_val, m_val]).replace('.', 'p')
                    idx_bin = (self.catalog_dict['tables']['split_table'][label_keys[0]] == z) & \
                              (self.catalog_dict['tables']['split_table'][label_keys[1]] == m) 
                    for ilab, lab in enumerate(label_keys):
                        return_dict[lab][label]=np.median(catalog_object.catalog_dict['tables']['full_table'].loc[idx_bin][catalog_keys[ilab]].dropna())

                    y = fluxes_dict['flux'][:,z,m]
                    yerr = np.sqrt(abs(fluxes_dict['flux'][:,z,m]))

                    sed_params_dict['cold'][label] = self.forced_sed_fitter(x, y, yerr, T_cold)
                    sed_params_dict['hot'][label] = self.forced_sed_fitter(x, y, yerr, T_hot)
                    graybody_dict['cold'][label] = self.fast_sed(sed_params_dict['cold'][label], wv_array)[0]
                    graybody_dict['hot'][label] = self.fast_sed(sed_params_dict['hot'][label], wv_array)[0]
                    theta = sed_params_dict['cold'][label]['A'].value, sed_params_dict['hot'][label]['A'].value
                    lir_dict[label] = self.fast_LIR2(theta, [T_cold_rf, T_hot_rf], return_dict[label_keys[0]][label])
        else:
            label = "__".join([z_val]).replace('.', 'p')
            idx_bin = (self.catalog_dict['tables']['split_table'][label_keys[0]] == z) 
            for ilab, lab in enumerate(label_keys):
                return_dict[lab][label]=np.median(catalog_object.catalog_dict['tables']['full_table'].loc[idx_bin][catalog_keys[ilab]].dropna())

            y = fluxes_dict['flux'][:,z]
            yerr = np.sqrt(abs(fluxes_dict['flux'][:,z]))

            sed_params_dict['cold'][label] = self.forced_sed_fitter(x, y, yerr, T_cold)
            sed_params_dict['hot'][label] = self.forced_sed_fitter(x, y, yerr, T_hot)
            graybody_dict['cold'][label] = self.fast_sed(sed_params_dict['cold'][label], wv_array)[0]
            graybody_dict['hot'][label] = self.fast_sed(sed_params_dict['hot'][label], wv_array)[0]
            theta = sed_params_dict['cold'][label]['A'].value, sed_params_dict['hot'][label]['A'].value
            lir_dict[label] = self.fast_LIR2(theta, [T_cold_rf, T_hot_rf], return_dict[label_keys[0]][label])

    return return_dict

In [11]:
def parse_fluxes(self):

    wavelength_keys = list(self.results_dict['band_results_dict'].keys())
    wavelengths = []
    split_dict = self.config_dict['catalog']['classification']
    # split_type = split_dict.pop('split_type')
    label_keys = list(split_dict.keys())
    label_dict = self.config_dict['parameter_names']
    ds = [len(label_dict[k]) for k in label_dict]

    sed_flux_array = np.zeros([len(wavelength_keys), *ds])
    sed_error_array = np.zeros([len(wavelength_keys), *ds])

    for ik, key in enumerate(wavelength_keys):
        self.results_dict['band_results_dict'][key]['raw_fluxes_dict'] = {}

        wavelengths.append(self.config_dict['maps'][key]['wavelength'])

        len_results_dict_keys = np.sum(['flux_densities' in i for i in self.results_dict['band_results_dict'][key].keys()])
        flux_array = np.zeros([len_results_dict_keys, *ds])
        outlier_array = np.zeros([len_results_dict_keys, *ds])
        error_array = np.zeros(ds)

        #pdb.set_trace()
        for iboot in np.arange(len_results_dict_keys):
            if not iboot:
                boot_label = 'stacked_flux_densities'
            else:
                boot_label = 'bootstrap_flux_densities_' + str(int(iboot))

            results_object = self.results_dict['band_results_dict'][key][boot_label]

            for z, zval in enumerate(self.config_dict['catalog']['distance_labels']):
                for i, ival in enumerate(label_dict[label_keys[1]]):
                    if len(label_keys) > 2:
                        for j, jval in enumerate(label_dict[label_keys[2]]):
                            if len(label_keys) > 3:
                                for k, kval in enumerate(label_dict[label_keys[3]]):
                                    if len(label_keys) > 4:
                                        for l, lval in enumerate(label_dict[label_keys[4]]):
                                            label = "__".join([zval, ival, jval, kval, lval]).replace('.', 'p')
                                            # CHECK THAT LABEL EXISTS FIRST
                                            if label in results_object:
                                                if label+'__bootstrap2' in results_object:
                                                    outlier_array[iboot, z, i, j, k, l] = results_object[label].value
                                                    flux_array[iboot, z, i, j, k, l] = results_object[label + '__bootstrap2'].value
                                                else:
                                                    flux_array[iboot, z, i, j, k, l] = results_object[label].value
                                                    #print("{0} = {1:0.2e}".format(label, results_object[label].value))

                                                if len_results_dict_keys == 1:
                                                    error_array[z, i, j, k, l] = results_object[label].stderr
                                                    
                                    label = "__".join([zval, ival, jval, kval]).replace('.', 'p')
                                    # CHECK THAT LABEL EXISTS FIRST
                                    if label in results_object:
                                        if label+'__bootstrap2' in results_object:
                                            outlier_array[iboot, z, i, j, k] = results_object[label].value
                                            flux_array[iboot, z, i, j, k] = results_object[label + '__bootstrap2'].value
                                        else:
                                            flux_array[iboot, z, i, j, k] = results_object[label].value
                                            #print("{0} = {1:0.2e}".format(label, results_object[label].value))

                                        if len_results_dict_keys == 1:
                                            error_array[z, i, j, k] = results_object[label].stderr
                            else:
                                label = "__".join([zval, ival, jval]).replace('.', 'p')
                                # CHECK THAT LABEL EXISTS FIRST
                                if label in results_object:
                                    if label+'__bootstrap2' in results_object:
                                        outlier_array[iboot, z, i, j] = results_object[label].value
                                        flux_array[iboot, z, i, j] = results_object[label + '__bootstrap2'].value
                                    else:
                                        flux_array[iboot, z, i, j] = results_object[label].value

                                    if len_results_dict_keys == 1:
                                        error_array[z, i, j] = results_object[label].stderr
                    else:
                        label = "__".join([zval, ival]).replace('.', 'p')
                        if label in results_object:
                            if label + '__bootstrap2' in results_object:
                                outlier_array[iboot, z, i] = results_object[label].value
                                flux_array[iboot, z, i] = results_object[label+'__bootstrap2'].value
                            else:
                                flux_array[iboot, z, i] = results_object[label].value

                            if len_results_dict_keys == 1:
                                error_array[z, i] = results_object[label].stderr

        #pdb.set_trace()
        sed_flux_array[ik] = flux_array[0]
        sed_error_array[ik] = np.std(flux_array, axis=0)

    return {'flux': sed_flux_array, 'error': sed_error_array, 'wavelengths': wavelengths}

In [None]:
def plot_pops(self, sed_dict=None, sed_model_params=None):
    fluxes_dict = self.parse_fluxes()
    #fluxes_dict = parse_fluxes(self)
    wavelengths = fluxes_dict['wavelengths'] #[24, 100, 160, 250, 350, 500, 850]
    wavelength_keys = list(self.results_dict['band_results_dict'].keys())
    split_dict = self.config_dict['catalog']['classification']
    # split_type = split_dict.pop('split_type')
    label_keys = list(split_dict.keys())
    label_dict = self.config_dict['parameter_names']
    ds = [len(label_dict[k]) for k in label_dict]

    if len(ds) > 3:
        fig, axs = plt.subplots(ds[-2], ds[0], figsize=(25,9))
    else:
        fig, axs = plt.subplots(ds[-1], ds[0], figsize=(25,6))
    
    ls = [ ':','-','--']
    pop = ['qt', 'sf']
    color = ['r', 'b', 'g', 'k']
    for z, zval in enumerate(self.config_dict['catalog']['distance_labels']):
        if len(label_keys) == 5:
            for l, lval in enumerate(label_dict[label_keys[4]]):
                for iagn, agn_val in enumerate(label_dict[label_keys[3]]):
                    for isb, sb_val in enumerate(label_dict[label_keys[2]]):
                        for imass, m_val in enumerate(label_dict[label_keys[1]]):
                            id_label = "__".join([zval, m_val, sb_val, agn_val, lval]).replace('.', 'p')

                            label = None
                            if l and not iagn and z == len(label_dict[label_keys[0]])-1:
                                label = "$M_{"+pop[l]+"}=$"+"-".join(m_val.split('_')[-2:])+", F24="+"-".join(sb_val.split('_')[-2:])
                            if l and sed_dict is None:
                                axs[iagn, z].plot(wavelengths, 1e3*fluxes_dict['flux'][:,z,imass,isb,iagn,l], ls=ls[l], color=color[isb], label=label)
                                
                            if l and (sed_dict is not None):
                                sed_params = sed_dict['sed_params'][id_label]
                                lir_mod = np.log10(sed_dict['lir'][id_label])
                                wv_mod = sed_dict['wv_array']
                                gb_mod = sed_dict['graybody'][id_label]
                                axs[iagn, z].scatter(wavelengths, 1e3*fluxes_dict['flux'][:,z,imass,isb, iagn,l], color=color[isb], label=label)
                                if sed_model_params is None:
                                    axs[iagn, z].plot(wv_mod, 1e3*gb_mod, ls=ls[l], color=color[isb])

                            if l and (sed_model_params is not None):
                                wv_mod = sed_dict['wv_array']
                                model_cube = np.array([sed_dict['redshift'][id_label],
                                                       sed_dict['stellar_mass'][id_label],
                                                       sed_dict['agn_fraction'][id_label],
                                                       sed_dict['starburst'][id_label]]).T
                                try:
                                    A_model = model_A_Tdust(sed_model_params['A'], model_cube)
                                    T_model = model_A_Tdust(sed_model_params['Tdust'], model_cube)
                                except:
                                    A_model = model_A_or_Tdust(sed_model_params['A'], model_cube)
                                    T_model = model_A_or_Tdust(sed_model_params['Tdust'], model_cube)
                                theta_model = A_model, T_model
                                gb_model = self.graybody_fn(theta_model, wv_mod)
                                axs[iagn, z].plot(wv_mod, 1e3*gb_model[0], ls=ls[isb], color=color[isb])
                                
                            axs[iagn, z].set_ylim([5e-2, 5e1])
                            axs[iagn, z].set_xscale('log')
                            axs[iagn, z].set_yscale('log')
                            if z:
                                axs[iagn, z].set_yticklabels([])
                            if not iagn:
                                axs[iagn, z].set_title(zval)
                            if z == len(label_dict[label_keys[0]])-1:
                                ylabel = "Ah="+"-".join(agn_val.split('_')[-2:])
                                axs[iagn, z].set_ylabel(ylabel)
                                axs[iagn, z].yaxis.set_label_position("right")

                            if l and not iagn and z == len(label_dict[label_keys[0]])-1:
                                axs[iagn, z].legend(bbox_to_anchor=(1.1,1), loc="upper left")  
                                
        elif len(label_keys) == 4:
            for l, lval in enumerate(label_dict[label_keys[3]]):
                for iagn, agn_val in enumerate(label_dict[label_keys[2]]):
                    for imass, m_val in enumerate(label_dict[label_keys[1]]):
                        id_label = "__".join([zval, m_val, agn_val, lval]).replace('.', 'p')
                        #print(label)

                        label = None
                        if l and not iagn and z == len(label_dict[label_keys[0]])-1:
                            label = "$M_{"+pop[l]+"}=$"+"-".join(m_val.split('_')[-2:])
                        if l and sed_dict is None:
                            axs[iagn, z].plot(wavelengths, 1e3*fluxes_dict['flux'][:,z,imass,iagn,l], ls=ls[l], label=label)
                            
                        if l and (sed_dict is not None):
                            sed_params = sed_dict['sed_params'][id_label]
                            lir_mod = np.log10(sed_dict['lir'][id_label])
                            wv_mod = sed_dict['wv_array']
                            gb_mod = sed_dict['graybody'][id_label]
                            axs[iagn, z].scatter(wavelengths, 1e3*fluxes_dict['flux'][:,z,imass,iagn,l], label=label)
                            if sed_model_params is None:
                                axs[iagn, z].plot(wv_mod, 1e3*gb_mod, ls=ls[l])
                            
                        if l and (sed_model_params is not None):
                            wv_mod = sed_dict['wv_array']
                            if 'agn' in label_keys[2]:
                                model_cube = np.array([sed_dict['redshift'][id_label],sed_dict['stellar_mass'][id_label],sed_dict['agn_fraction'][id_label]]).T
                            else:
                                model_cube = np.array([sed_dict['redshift'][id_label],sed_dict['stellar_mass'][id_label],sed_dict['starburst'][id_label]]).T
                            try:
                                A_model = model_A_Tdust(sed_model_params['A'], model_cube)
                                T_model = model_A_Tdust(sed_model_params['Tdust'], model_cube)
                            except:
                                A_model = model_A_or_Tdust(sed_model_params['A'], model_cube)
                                T_model = model_A_or_Tdust(sed_model_params['Tdust'], model_cube)
                            theta_model = A_model, T_model
                            gb_model = self.graybody_fn(theta_model, wv_mod)
                            axs[iagn, z].plot(wv_mod, 1e3*gb_model[0], ls='--')

                        axs[iagn, z].set_ylim([5e-2, 5e1])
                        axs[iagn, z].set_xscale('log')
                        axs[iagn, z].set_yscale('log')
                        if z:
                            axs[iagn, z].set_yticklabels([])
                        if not iagn:
                            axs[iagn, z].set_title(zval)
                        if z == len(label_dict[label_keys[0]])-1:
                            ylabel = agn_val.split('_')[0]+"="+"-".join(agn_val.split('_')[-2:])
                            axs[iagn, z].set_ylabel(ylabel)
                            axs[iagn, z].yaxis.set_label_position("right")

                        if l and not iagn and z == len(label_dict[label_keys[0]])-1:
                            axs[iagn, z].legend(bbox_to_anchor=(1.1,1), loc="upper left")
                            
        elif len(label_keys) == 3:
            for l, lval in enumerate(label_dict[label_keys[2]]):
                for imass, m_val in enumerate(label_dict[label_keys[1]]):
                    id_label = "__".join([zval, m_val, lval]).replace('.', 'p')

                    label = None
                    if z == len(label_dict[label_keys[0]])-1:
                        label = "$M_{"+pop[l]+"}=$"+"-".join(m_val.split('_')[-2:])
                    if sed_dict is None:
                        axs[l, z].plot(wavelengths, 1e3*fluxes_dict['flux'][:,z,imass,l], ls=ls[l], label=label)
                    
                    if sed_dict is not None:
                        sed_params = sed_dict['sed_params'][id_label]
                        lir_mod = np.log10(sed_dict['lir'][id_label])
                        wv_mod = sed_dict['wv_array']
                        gb_mod = sed_dict['graybody'][id_label]
                        axs[l, z].scatter(wavelengths, 1e3*fluxes_dict['flux'][:,z,imass,l], label=label)
                        if sed_model_params is None:
                            axs[l, z].plot(wv_mod, 1e3*gb_mod, ls=ls[l])
                        
                    if (sed_model_params is not None):
                        wv_mod = sed_dict['wv_array']
                        model_cube = np.array([sed_dict['redshift'][id_label],sed_dict['stellar_mass'][id_label]]).T
                        try:
                            A_model = model_A_Tdust(sed_model_params['A'], model_cube)
                            T_model = model_A_Tdust(sed_model_params['Tdust'], model_cube)
                        except:
                            A_model = model_A_or_Tdust(sed_model_params['A'], model_cube)
                            T_model = model_A_or_Tdust(sed_model_params['Tdust'], model_cube)
                        theta_model = A_model, T_model
                        gb_model = self.graybody_fn(theta_model, wv_mod)
                        axs[l, z].plot(wv_mod, 1e3*gb_model[0], ls='--')

                    axs[l, z].set_ylim([5e-2, 5e1])
                    axs[l, z].set_xscale('log')
                    axs[l, z].set_yscale('log')
                    if z:
                        axs[l, z].set_yticklabels([])
                    if l:
                        axs[l, z].set_title(zval)
                    if z == len(label_dict[label_keys[0]])-1:
                        ylabel = pop[l]
                        axs[l, z].set_ylabel(ylabel)
                        axs[l, z].yaxis.set_label_position("right")

                    if z == len(label_dict[label_keys[0]])-1:
                        axs[l, z].legend(bbox_to_anchor=(1.1,1), loc="upper left")