# Visual inspection of CSP spectra


In [None]:
import sys
from pathlib import Path

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sndata.csp import DR1, DR3

sys.path.insert(0, '../')
from scripts.run_csp import pre_process, get_csp_t0


In [None]:
dr1 = DR1()
dr1.download_module_data()

dr3 = DR3()
dr3.download_module_data()

# Output directory for figures
fig_dir = Path('./figs/classification')
fig_dir.mkdir(exist_ok=True, parents=True)
results_dir = Path('.').resolve().parent / 'results'


In [None]:
def read_in_pipeline_result(path):
    """Read pEW values from analysis pipline file
    
    Adds columns for Branch classifications determined by the
    measured pEW values and spectral subtypes determined from 
    CSP DR1.
    
    Args:
        path (str): Path of the file to read
        
    Returns:
        A pandas Dataframe indexed by feat_name and obj_id
    """
    
    df = pd.read_csv(path, index_col=['obj_id', 'feat_name'])

    # Add phases using CSP DR3 t0 values
    obj_id = df.index.get_level_values(0)
    phase = np.array([get_csp_t0(oid) for oid in obj_id])
    df['phase'] = phase - df.time

    # Add Branch style classifications
    pw = pd.DataFrame({
        'pW6': df.xs('pW6', level=1).pew, 
        'pW7': df.xs('pW7', level=1).pew}
    ).dropna()
    
    # Add spectral subtypes
    csp_table_2 = dr3.load_table(2)
    subtypes = pd.DataFrame({'spec_type': csp_table_2['Subtype1']}, index=csp_table_2['SN'])
    df = df.join(subtypes, on='obj_id')
    
    return df


In [None]:
ella = read_in_pipeline_result(results_dir / 'ella_csp.csv')
emily = read_in_pipeline_result(results_dir / 'emily_csp.csv')
anish = read_in_pipeline_result(results_dir / 'anish_csp.csv')

combined = ella.join(emily, lsuffix='_ella', rsuffix='_emily')
combined = combined.join(anish, rsuffix='_anish')
combined.head()


In [None]:
def subplot_feature_pew(wave, flux, axis, feat_start, feat_end, **kwargs):
    """Shade in the PEW of spectral properties

    Args:
        wave     (ndarray): The spectrum's wavelengths
        flux     (ndarray): The flux for each wavelength
        axis        (Axis): The axius to plot on
        feat_start (float): Where the feature starts
        feat_end   (float): Where the feature ends
        Any other kwargs for ``axis.fill_between``
    """

    idx_start = np.where(wave == feat_start)[0][0]
    idx_end = np.where(wave == feat_end)[0][0]
    feat_wave = wave[idx_start: idx_end + 1]
    feat_flux = flux[idx_start: idx_end + 1]

    continuum, norm_flux, pew = spec_class.feature_pew(feat_wave, feat_flux)
    axis.fill_between(feat_wave, feat_flux, continuum, alpha=.75, zorder=0, **kwargs)


In [None]:
def plot_measurements(obj_id, phase, results):
    data = dr1.get_data_for_id(obj_id)
    processed_data = pre_process(data)
    
    obj_results = results.loc[obj_id]
    obj_results = obj_results[obj_results.phase.round(2) == phase]
    
    # ...
