# ALMA-IMF+SPICY: SED Fitting

Workspace for importing the fits files generated through the SED Table Prep workspace, running the SED fitting, and saving the results to a fitinfo.py file. Search for "!filepath!" to locate files/filepath references if they need to be changed.

In [None]:
# basics
import os
import numpy as np
from tqdm.notebook import tqdm_notebook

# astropy
from astropy import table
from astropy.table import Table
from astropy import units as u
from astropy.io import fits
from astropy.modeling.models import BlackBody

from astroquery.svo_fps import SvoFps

# sed fitting
from sedfitter.filter import Filter
from sedfitter.extinction import Extinction
from dust_extinction.parameter_averages import F19
from dust_extinction.averages import CT06_MWLoc
from sedfitter.source import Source
from sedfitter import fit, Fitter
from sedfitter.sed import SEDCube

# analysis/plotting
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.image as mpimg

# writing to file
import pickle

# define geometries, per Robitaille paper
geometries = ['s-pbhmi', 's-pbsmi',
              'sp--h-i', 's-p-hmi', 
              'sp--hmi', 'sp--s-i', 
              's-p-smi', 'sp--smi', 
              'spubhmi', 'spubsmi', 
              'spu-hmi', 'spu-smi', 
              's---s-i', 's---smi', 
              's-ubhmi', 's-ubsmi', 
              's-u-hmi', 's-u-smi']

## Function setup

In [None]:
def get_filters(hemisphere='south'):
    if hemisphere == 'north':
        filternames = ['UKIRT/UKIDSS.J', 'UKIRT/UKIDSS.H', 'UKIRT/UKIDSS.K',
                   'Spitzer/IRAC.I1', 'Spitzer/IRAC.I2', 'Spitzer/IRAC.I3', 'Spitzer/IRAC.I4', 'Spitzer/MIPS.24mu',
                   'Herschel/Pacs.blue', 'Herschel/Pacs.red', 'Herschel/SPIRE.PSW', 'Herschel/SPIRE.PMW', 'Herschel/SPIRE.PLW'
                  ]
        # keep only the non "_ext" SPIRE filters (but we should look up which is more appropriate)
        spire_filters = SvoFps.get_filter_list(facility='Herschel', instrument='Spire')
        spire_filters = spire_filters[['_ext' not in fid for fid in spire_filters['filterID']]]
        
        filter_meta = table.vstack([SvoFps.get_filter_list(facility='UKIRT', instrument='WFCAM'),
                                SvoFps.get_filter_list(facility='Spitzer', instrument='IRAC'),
                                SvoFps.get_filter_list(facility='Spitzer', instrument='MIPS')[0],
                                SvoFps.get_filter_list(facility='Herschel', instrument='Pacs'),
                                spire_filters,
                               ])
        
    elif hemisphere == 'south':
        filternames = ['Paranal/VISTA.Y', 'Paranal/VISTA.Z', 'Paranal/VISTA.J', 'Paranal/VISTA.H', 'Paranal/VISTA.Ks',
                   'Spitzer/IRAC.I1', 'Spitzer/IRAC.I2', 'Spitzer/IRAC.I3', 'Spitzer/IRAC.I4', 'Spitzer/MIPS.24mu',
                   'Herschel/Pacs.blue', 'Herschel/Pacs.red', 'Herschel/SPIRE.PSW', 'Herschel/SPIRE.PMW', 'Herschel/SPIRE.PLW'
                  ]
        # keep only the non "_ext" SPIRE filters (but we should look up which is more appropriate)
        spire_filters = SvoFps.get_filter_list(facility='Herschel', instrument='Spire')
        spire_filters = spire_filters[['_ext' not in fid for fid in spire_filters['filterID']]]
        
        filter_meta = table.vstack([SvoFps.get_filter_list(facility='Paranal', instrument='VIRCAM'),
                                SvoFps.get_filter_list(facility='Spitzer', instrument='IRAC'),
                                SvoFps.get_filter_list(facility='Spitzer', instrument='MIPS')[0],
                                SvoFps.get_filter_list(facility='Herschel', instrument='Pacs'),
                                spire_filters,
                               ])

    zpts = {filtername: filter_meta[filter_meta['filterID']==filtername]['ZeroPoint'] for filtername in filternames}

    filtercurves = {filtername: SvoFps.get_transmission_data(filtername) for filtername in filternames}
    wavelengths = [np.average(filtercurves[filtername]['Wavelength'],
                              weights=filtercurves[filtername]['Transmission'])
                  for filtername in filternames]
    wavelength_dict = {filtername: np.average(filtercurves[filtername]['Wavelength'],
                                              weights=filtercurves[filtername]['Transmission'])*u.AA
                       for filtername in filternames}

    filterfreqs = {filtername: u.Quantity(filtercurves[filtername]['Wavelength'], u.AA).to(u.Hz, u.spectral()) for filtername in filternames}
    filtertrans = {filtername: np.array(filtercurves[filtername]['Transmission'])[np.argsort(filterfreqs[filtername])]
                  for filtername in filternames}
    filterfreqs = {filtername: np.sort(filterfreqs[filtername]) for filtername in filternames}

    sed_filters = [Filter(name=filtername,
                          central_wavelength=wl*u.AA,
                          nu=filterfreqs[filtername],
                          response=filtertrans[filtername])
                   for filtername, wl in zip(filternames, wavelengths)]


    # Add in the custom ALMA-IMF filters
    almaimf_bandends_1mm = [[216.10085679, 216.36181569],
                            [217.05104378, 217.31175857],
                            [219.90488464, 220.04866835],
                            [218.13102322, 218.39222624],
                            [219.51976276, 219.66379059],
                            [230.31532951, 230.81137113],
                            [231.06503709, 231.56181105],
                            [231.52507012, 233.42623749]]*u.GHz
    nu_1mm = np.linspace(almaimf_bandends_1mm.min(), almaimf_bandends_1mm.max(), 5000)
    response_1mm = np.zeros(nu_1mm.size, dtype='bool')
    for start, stop in almaimf_bandends_1mm:
        response_1mm |= (nu_1mm > start) & (nu_1mm < stop)
    sed_filters.append(Filter(name='ALMA-IMF_1mm',
                              central_wavelength=(228.15802*u.GHz).to(u.mm, u.spectral()),
                              nu=nu_1mm,
                              response=response_1mm.astype(float),
                             ))

    for filterfunc in sed_filters:
        filterfunc.normalize()


    almaimf_bandends_3mm = [[ 93.13410936,  93.25141259],
                            [ 91.75059068,  92.68755174],
                            [102.15273354, 103.0896946 ],
                            [104.55323851, 105.49019957]]*u.GHz
    nu_3mm = np.linspace(almaimf_bandends_3mm.min(), almaimf_bandends_3mm.max(), 5000)
    response_3mm = np.zeros(nu_3mm.size, dtype='bool')
    for start, stop in almaimf_bandends_3mm:
        response_3mm |= (nu_3mm > start) & (nu_3mm < stop)
    sed_filters.append(Filter(name='ALMA-IMF_3mm',
                              central_wavelength=(99.68314596*u.GHz).to(u.mm, u.spectral()),
                              nu=nu_3mm,
                              response=response_3mm.astype(float),
                             ))

    wavelength_dict['ALMA-IMF_1mm'] = (228.15802*u.GHz).to(u.um, u.spectral())
    wavelength_dict['ALMA-IMF_3mm'] = (99.68314596*u.GHz).to(u.um, u.spectral())

    return sed_filters, wavelength_dict, filternames, zpts

# extracts the rows of an SED table as a plottable entry
def getrow(tbl, rownum, keys):
    return np.array([tbl[rownum][key] for key in keys])

def get_data_to_fit(rownumber, tbl, filters):
    # remove all extraneous data from input table
    for key in filters:
        if key+"_flux" not in tbl.keys():
            tbl[key+"_flux"] = [np.nan for row in tbl]
            tbl[key+"_eflux"] = [np.nan for row in tbl]
            
    # extract fluxes and errors
    flx = getrow(tbl, rownumber, keys=[key+"_flux" for key in filters])
    error = getrow(tbl, rownumber, keys=[key+"_eflux" for key in filters])
    valid = np.zeros(flx.size, dtype='int')
    
    # set flags based on validity of data
    valid[(np.isfinite(flx) & np.isfinite(error))] = 1
        # both the flux and error are "valid": data is fitted directly
    valid[(~np.isfinite(flx) & ~np.isfinite(error))] = 0
        # neither the flux nor error are valid (nan or masked): data is discarded
    valid[(~np.isfinite(flx) & np.isfinite(error))] = 3
        # flux is not specified, but the error is: treated as upper limit
    
    # error-proofing: toss any data points which measure exactly 0
    valid[flx == 0] = 0
    valid[error == 0] = 0

    # set the "flux" to be the 3-sigma error wherever we're treating it as an upper limit
    flx[valid == 3] = error[valid == 3] * 3
    # then, set the confidence associated with that upper limit, AKA 3-sigma
    error[valid == 3] = 0.997
    
    return flx, error, valid

def get_fitter(geometry, aperture_size,
               distance_range,
               robitaille_modeldir,
               filters, extinction,
               av_range):

    # Define path to models
    model_dir = f'{robitaille_modeldir}/{geometry}'

    if len(aperture_size) == 1:
        apertures = u.Quantity([aperture_size]*len(filters))
    else:
        apertures = u.Quantity(aperture_size, u.arcsec)
        
    if isinstance(filters, list):
        filters = np.array(filters)

    fitter = Fitter(filter_names=filters,
                    apertures=apertures,
                    model_dir=model_dir,
                    extinction_law=extinction,
                    distance_range=distance_range,
                    av_range=av_range,
                    use_memmap=True
                   )

    return fitter

def make_extinction():
    # make an extinction law
    ext = F19(3.1)
    ext2 = CT06_MWLoc()

    # https://arxiv.org/abs/0903.2057
    # 1.34 is from memory
    guyver2009_avtocol = (2.21e21 * u.cm**-2 * (1.34*u.Da)).to(u.g/u.cm**2)
    ext_wav = np.sort((np.geomspace(0.301, 8.699, 1000)/u.um).to(u.um, u.spectral()))
    ext_vals = ext.evaluate(ext_wav, Rv=3.1)
    
    # extend the extinction curve out
    ext_wav2 = np.geomspace(ext_wav.max(), 27*u.um, 100)
    ext_vals2 = ext2.evaluate(ext_wav2)
        
    extinction = Extinction()
    extinction.wav = np.hstack([ext_wav, ext_wav2])
    extinction.chi = np.hstack([ext_vals, ext_vals2]) / guyver2009_avtocol

    return extinction

# fit a single source
def fit_a_source(data, error, valid,
                 geometry, robitaille_modeldir,
                 extinction, filters, aperture_size, 
                 distance_range, av_range, 
                 fitter=None, stash_to_mmap=False,
                ):

    source = Source()
    source.valid = valid
    
    # https://sedfitter.readthedocs.io/en/stable/data.html
    # this site specifies that the fitter expects flux in mJy
    # if the data are given as a Jy-equivalent, convert them to mJy
    source.flux = u.Quantity(data, u.mJy).value
    source.error =  u.Quantity(error, u.mJy).value

    if fitter is None:
        fitter = get_fitter(geometry=geometry, aperture_size=aperture_size,
                            distance_range=distance_range, av_range=av_range,
                            robitaille_modeldir=robitaille_modeldir,
                            filters=filters, extinction=extinction)

    # Run the fitting
    fitinfo = fitter.fit(source)

    if stash_to_mmap:
        from tempfile import mkdtemp
        import os.path as path
        filename = path.join(mkdtemp(), f'{geometry}.dat')
        fp = np.memmap(filename, dtype='float32', mode='w+', shape=fitinfo.model_fluxes.shape)
        fp[:] = fitinfo.model_fluxes[:]
        fp.flush()
        fitinfo.model_fluxes = fp
        print(f"Moved array with size {fitinfo.model_fluxes.shape} to {fp.filename}")

    return fitinfo

# nested function for convenience
def full_source_fit(rownum, filternames, apertures, robitaille_modeldir, extinction, distance_range, av_range):
    flx, error, valid = get_data_to_fit(rownum, tbl, filters=filternames+["ALMA-IMF_1mm", "ALMA-IMF_3mm"])
    ##optional: print out data points before fitting
    #datatable = Table([flx, error, valid])
    #print(datatable)
    
    fits = {geom:
            fit_a_source(data=flx, error=error, valid=valid,
                         geometry=geom, robitaille_modeldir=robitaille_modeldir,
                         extinction=extinction,
                         filters=filternames+["user_filters/ALMA-IMF_1mm", "user_filters/ALMA-IMF_3mm"],
                         aperture_size=apertures,
                         distance_range=distance_range,
                         av_range=av_range
                      )
            for geom in tqdm_notebook(geometries, desc = f'Fitting source {rownum+1}/{len(tbl)}')}
    return fits

## Fitting

In [None]:
# select ALMA-IMF field of interest
fieldid = 'W43MM1'

# get rid of ALMA points for comparison?
almaoverride = False

In [None]:
# load table from Table Prep notebook
tbl = Table.read(f'/blue/adamginsburg/adamginsburg/SPICY_ALMAIMF/BriceTingle/Region_tables/Unfitted/{fieldid}', format='fits') # !filepath!
# print("Field: "+fieldid)
# print("NIR data: "+str(list(tbl['NIR data'])))
print("Sources: ",len(tbl))

if almaoverride:
    for x in ['ALMA-IMF_3mm_flux','ALMA-IMF_3mm_eflux','ALMA-IMF_1mm_flux','ALMA-IMF_1mm_eflux']:
        tbl[x] = 0
        tbl[x] = np.nan
        tbl[x] = np.ma.masked

# define constant sizes, based on literature values
apertures = [2.4, 2.4, 2.4, 2.4, 6, 10, 13.5, 23, 30, 41, 3, 3]*u.arcsec

# determine filters and add aperture sizes based on whether we're using UKDISS or VIRAC data
ukidss_fields = ['G10','G12','W43MM1','W43MM2','W43MM3','W51-E','W51IRS2']
virac_fields = ['G008','G327','G328','G333','G337','G338','G351','G353']

if any(x in virac_fields for x in tbl['ALMAIMF_FIELDID']):
    print("Grabbing VIRAC filters")
    sed_filters, wavelength_dict, filternames, zpts = get_filters("south")
    apertures = apertures.insert(0,[1.415, 1.415, 1.415, 1.415, 1.415]*u.arcsec)
    hemisphere = "south"
elif any(x in ukidss_fields for x in tbl['ALMAIMF_FIELDID']):
    print("Grabbing UKIDSS filters")
    sed_filters, wavelength_dict, filternames, zpts = get_filters("north")
    apertures = apertures.insert(0,[2, 2, 2]*u.arcsec)
    hemisphere = "north"
print("Filters:"+str(filternames))
print("Apertures:"+str(apertures))

# read distance from table (in kpc)
regiondistance = np.nanmax(tbl['Distance'])
print("Region literature distance (kpc): "+str(regiondistance))

# infer allowable distance range (in kpc)
distance_range=[regiondistance-0.2, regiondistance+0.2]
print("Distance range (kpc):"+str(distance_range))

# determine appropriate extinction, using 2*region distance as a lower bound
av_range=[regiondistance*2,60]
print("Extinction range (kpc): "+str(av_range))

# make extinction law
extinction = make_extinction()

# run fitting on each source in region, store data to an array
region_fits = {}
for rownum, row in enumerate(tbl):
    fits = full_source_fit(rownum, filternames, apertures, 
                           '/blue/adamginsburg/adamginsburg/SPICY_ALMAIMF/BriceTingle/robitaille_models-1.2', 
                           extinction, distance_range*u.kpc, av_range) # !filepath!
    
    # append fit data to data array
    region_fits[tbl['SPICY'][n]] = fits

In [None]:
file = open('fitinfo.p', 'wb') # !filepath!
file.clear()
pickle.dump(fits_all, file)
file.close()