# Data Cuts

This notebook selects and caches to file the object Id's of SDSS observed transients that have observations in at least two bands and at least one point before and after max.

In [None]:
from multiprocessing import Pool
from pathlib import Path

import numpy as np
import pandas as pd
from astropy.table import Table
from sndata.sdss import sako18
from tqdm.notebook import tqdm

sako18.download_module_data()
results_dir = Path('../results/').resolve()


First we read in the SDSS classifications.

In [None]:
sdss_master = sako18.load_table('master')
sako_classification = pd.DataFrame({
    'obj_id': sdss_master['CID'],
    'spec_class': sdss_master['Classification']
})
sako_classification.set_index('obj_id', inplace=True)
sako_classification['spec_class'].value_counts()


We drop any non transient objects.

In [None]:
non_sne = sako_classification[sako_classification['spec_class'].isin(('Variable', 'AGN'))]


We check observations for each object and see if they pass our prescribed data cuts. Results are cached to file.

In [None]:
def passes_cut(obj_id):
    """Return if an object has data in >=2 bands and at
    least one point before and after max.
    
    Return is false for all AGN and Variables
    
    Args:
        obj_id (str): The object Id
        
    Returns:
        A boolean
    """

    if obj_id in non_sne.index:
        return False

    obj_data = sako18.get_data_for_id(obj_id).to_pandas()
    obj_data = obj_data[obj_data.flux / obj_data.fluxerr >= 5]
    num_bands = len(obj_data.band.unique())
    salt2_fits = Table.read(results_dir / 'sdss_salt2_fits.ecsv').to_pandas(index='obj_id')

    try:
        salt_fit = salt2_fits.loc[obj_id]
        return all(
            (salt_fit.pre_max, salt_fit.post_max, (num_bands >= 2))
        )

    except KeyError:
        return False


def get_good_ids(cache_file=results_dir / 'good_ids.npy'):
    """Get SDSS object Ids for targets passing observation cuts
    
    Results are cached to file for performance.
    
    Args:
        cache_file (str): Path to cache return to
        
    Returns:
        An array of object Ids
    """

    if Path(cache_file).exists():
        return np.load(cache_file)

    sako_obj_ids = sako18.get_available_ids()
    with Pool() as p:
        is_good_id = list(tqdm(p.imap(passes_cut, sako_obj_ids), total=len(sako_obj_ids)))

    good_obj_ids = np.array(sako_obj_ids)[is_good_id].tolist()
    np.save(cache_file, good_obj_ids)
    return good_obj_ids


In [None]:
len(get_good_ids())
