# Creating pW6 pW7 Comparison Plot for Final CSP Data

In [None]:
import sys

sys.path.insert(0, '..')


In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sndata.csp import DR3
from sndata.sdss import Sako18Spec

from scripts.run_csp import get_csp_t0
from scripts.run_sdss import get_sdss_t0


In [None]:
curr_dir = Path('.').resolve() 
results_dir = curr_dir.parent / 'results'

fig_dir = curr_dir / 'figures' / 'phase_evolution'
fig_dir.mkdir(parents=True, exist_ok=True)


In [None]:
@np.vectorize
def wrapped_get_csp_t0(obj_id):
    try:
        get_csp_t0
    
    except ValueError:
        return np.nan
    
    
@np.vectorize
def wrapped_get_sdss_t0(obj_id):
    try:
        get_sdss_t0
    
    except ValueError:
        return np.nan


In [None]:
def read_in_pipeline_result(path, survey, drop_flagged=False):
    """Read pEW values from analysis pipeline file
    
    Adds columns for Branch classifications determined by the
    measured pEW values and spectral subtypes determined from 
    CSP DR1.
    
    Args:
        path          (str): Path of the file to read
        survey        (str): Read in data for either `csp` or `sdss`
        drop_flagged (bool): Optionally drop flagged measurements / spectra

    Returns:
        A pandas Dataframe indexed by feat_name and obj_id
    """
    
    df = pd.read_csv(path, index_col=['feat_name', 'obj_id'])

    # Add phases using CSP DR3 t0 values
    obj_id = df.index.get_level_values(1)
    
    if survey == 'csp':
        df['phase'] = df.time - wrapped_get_csp_t0(obj_id)
        
        dr3 = DR3()
        csp_table_2 = dr3.load_table(2)
        subtypes = pd.DataFrame({'spec_type': csp_table_2['Subtype1']}, index=csp_table_2['SN'])
        df = df.join(subtypes, on='obj_id')

    elif survey == 'sdss':
        df['phase'] = df.time - wrapped_get_sdss_t0(obj_id)
        df['spec_type'] = 'unknown'
        
        sako_18_spec = Sako18Spec()
        sako_master = sako_18_spec.load_table('master').to_pandas()
        sako_master = sako_master.rename({'CID': 'obj_id'}, axis='columns')
        sako_master['obj_id'] = sako_master.obj_id.astype(int)
        sako_master = sako_master.set_index('obj_id')

        df = df.join(sako_master, how='inner')
        
    else:
        warn(f'Could not calculate phases for survey {survey}. Expected "csp" or "sdss".')
    
    if drop_flagged:
        df = df[(df.spec_flag != 1) & (df.feat_flag != 1)]
    
    # Label measurements that represent that were taken nearest peak brightness
    df['delta_t'] = df.phase.abs()
    df = df.sort_values('delta_t')
    df['is_peak'] = ~df.index.duplicated()
    
    df = df.join(branch_classification(df), on='obj_id')
    return df


In [None]:
def plot_branch_classifications(pipeline_data, phase_cutoff=7, figsize=(10, 10)):
    """Create a Branch classification figure
    
    Args:
        pipeline_data (DataFrame): Data that has been read from a pipeline output file
        phase_cutoff      (float): Only use measurements taken within so many days of peak brightness
        fig_size          (Tuple): Size of the figure in inches
    """
    
    fig, axis = plt.subplots(figsize=figsize)
    
    plot_args = {
        'CL': dict(label='Cool', color='blue', marker='s'),
        'BL': dict(label='Broad Line', color='red' , marker='^'),
        'SS': dict(label='Shallow Silicon', color='green' , marker= '*'),
        'CN': dict(label='Core Normal', color='black' , marker= '.')
    }
    peak_vals = pipeline_data[pipeline_data.is_peak]
    peak_vals = peak_vals[peak_vals.delta_t <= phase_cutoff]
    
    pw6_data = peak_vals[['pew', 'pew_samperr', 'branch_type']].loc['pW6']
    all_data = pw6_data.join(peak_vals[['pew', 'pew_samperr']].loc['pW7'], lsuffix='_6', rsuffix='_7').dropna()

    for branch_class, data in all_data.groupby('branch_type'):
        label = plot_args[branch_class].pop('label') + f' ({len(data)})'
        axis.errorbar(
            data.pew_7, 
            data.pew_6, 
            xerr=data.pew_samperr_7,
            yerr=data.pew_samperr_6,
            linestyle='',
            label=label,
            **plot_args.get(branch_class, dict()))
    
    plt.title('Strength of pW6 vs pW7')
    axis.legend(loc = ('upper left'))
    plt.xlabel('pew of pW7')
    plt.ylabel('pew of pW6')
    

In [None]:
# Load SDSS measurements and drop any flagged values
sdss_results = read_in_pipeline_result(results_dir / 'final_sdss.csv', 'sdss', drop_flagged=True)


In [None]:
plot_branch_classifications(sdss_results)


In [None]:
# Load SDSS measurements and drop any flagged values
csp_results = read_in_pipeline_result(results_dir / 'final_csp.csv', 'csp', drop_flagged=True)


In [None]:
plot_branch_classifications(csp_results)
