In [None]:
from copy import copy

import numpy as np
import sncosmo
from astropy.table import Table
from sndata.csp import dr1, dr3

dr1.download_module_data()
dr3.download_module_data()
dr3.register_filters(force=True)


In [None]:
# Get ids for targets included in both data releases
dr1_ids = dr1.get_available_ids()
dr3_ids = dr3.get_available_ids()
all_obj_ids = set(dr1_ids).union(dr3_ids)

# Get band pass objects for each DR3 filter
BAND_DICT = {b: sncosmo.get_bandpass(b) for b in dr3.band_names}


In [None]:
def match_obs_times(spectra, photometry, delay=1):

    obj_id = spectra_table.meta['obj_id']
    out_table = Table(
        data=[np.unique(spectra['date'])], 
        names=['spec_time'],
        masked=True
    )

    # Convert photometry from MJD to JD
    photometry = copy(photometry)
    photometry['time'] += 2400000.5
        
    out_table.meta['bands'] = []
    for band in set(photometry['band']):
        band_data = photometry[photometry['band'] == band]
        
        # Get nearest photometric observation
        phot_times = np.array(band_data['time'])
        spec_times = np.array(out_table['spec_time'])
        delta_t = np.abs(spec_times[:, None] - phot_times)
        min_indices = delta_t.argmin(axis=1)
        
        out_table[band] = band_data['flux'][min_indices]
        # out_table[band].mask = np.abs(delta_t[min_indices]) > delay
        if not all(out_table[band].mask):
            out_table.meta['bands'].append(band)
    
    return out_table


def calc_synthetic_photometry(spectra, photometry):
    
    match_table = match_obs_times(spectra, photometry)
    for band_name in match_table.meta['bands']:
        band = BAND_DICT[band_name]
        flux = spectra['flux'] * band(spectra['wavelength'])
    
    return match_table
    
spectra_table = dr1.get_data_for_id('2005kc')
phot_table = dr3.get_data_for_id('2005kc', format_sncosmo=True)
calc_synthetic_photometry(spectra_table, phot_table)
