# Spectrowcopically Classifying SDSS SNe

This notebook identifies SN 1991bg-like SNe and compares the results of photometric and spectroscopic classifications.

#### Table of Contents:
1. <a href='#reading_in_data'>Reading in the Data</a>: Reading in data from both the analysis pipeline and external publications.
1. <a href='#spectroscopic_classification'>Spectroscopic Classification</a>: Subtyping of spectroscopically observed targets.


In [1]:
import math
import sys
import warnings
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
from astropy.cosmology import WMAP9 as wmap9
from astropy.table import Table
from matplotlib import pyplot as plt
from matplotlib.ticker import MultipleLocator
from multiprocessing import Pool
from scipy import stats
from sklearn.utils import resample
from sndata.csp import dr1
from sndata.sdss import sako18, sako18spec
from tqdm import tqdm_notebook

sys.path.insert(0, '../')
from phot_class.classification import classify_targets
from phot_class import spectra as spec_class
from phot_class import fit_func_wraps
 
dr1.download_module_data()
sako18spec.download_module_data()
sako18.download_module_data()

# Output directory for figures
fig_dir = Path('./notebook_figs/classification')
fig_dir.mkdir(exist_ok=True, parents=True)
results_dir = Path('../results/').resolve()


Downloading data tables...
Fetching http://cdsarc.u-strasbg.fr/viz-bin/nph-Cat/tar.gz?J/ApJ/773/53


KeyboardInterrupt: 

## Reading in the Data <a id='reading_in_data'></a>

We read in spectroscopic measurements and classifications from external publications in addition to our own analysis piepline.


In [None]:
@np.vectorize
def calc_julian_date(date):
    """
    Convert a datetime object into julian float.
    
    Args:
        date (str): The date to convert in %Y-%m-%d format

    Returns:
        The Julian date as a float
    """

    
    date = datetime.strptime(date, '%Y-%m-%d')
    julian_datetime = (
        367 * date.year - 
        int((7 * (date.year + int((date.month + 9) / 12.0))) / 4.0) + 
        int((275 * date.month) / 9.0) + date.day + 
        1721013.5 + 
        (date.hour + date.minute / 60.0 + date.second / math.pow(60, 2)) / 24.0 - 
        0.5 * math.copysign(1, 100 * date.year + date.month - 190002.5) + 0.5
    )

    return julian_datetime


def read_spec_results(path):
    """Read in spectroscopic measurements from the analysis pipeline
    
    Args:
        path (str): The path of the ecsv file to read
        
    Returns:
        A Pandas DataFrame
    """

    # Read in pipeline results
    spec_class = Table.read(path).to_pandas()
    spec_class.set_index(['obj_id', 'feat_name'], inplace=True)
    spec_class['jd'] = calc_julian_date(spec_class.date)
    
    # Get time of peak brightness
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        master = sako18spec.load_table('master').to_pandas()
        master = master.rename(columns={'CID': 'obj_id'}).set_index('obj_id')

    # Calculate days since maximum
    peak_jd = master.MJDatPeakrmag + 2400000.5
    spec_class['days'] = spec_class.jd - peak_jd
    
    return spec_class


In [None]:
spec_results = read_spec_results(results_dir / 'spec_class/sdss_sako18spec_rv3_1_bin5.0_methgauss_step5.ecsv')
spec_results.head()


In [None]:
folatelli_13 = dr1.load_table(6).to_pandas()
branch_06 = pd.DataFrame({
    'obj_id': ['1981B', '1984A', '1986G', '1989B', '1990N', '1991M', '1991T', '1991bg', '1992A', '1994D', '1194ae', '1996X', '1997cn', '1998ag', '1998bu', '1999aw', '1999by', '1999ee', '2000cx', '2001ay', '2001el', '2002bf', '2002bo', '2002cx'],
    'pw7': [127, 204, 126, 124, 88, 137, 29, 92, 107, 96, 89, 87, 101, 78, 94, 58, 95, 82, 51, 150, 95, 171, 146, 18 ],
    'pw6': [17, 23, 33, 20, 12, 19, 0, 49, 19, 19, 7, 17, 45, 12, 16, 1, 46, 5, 2, 8, 16, 10, 11, 0]
})

branch_06.set_index('obj_id', inplace=True)
branch_06.head()


## Spectroscopic Classification <a id='spectroscopic_classification'></a>

This section assigns Branch et al. 2006 style subtypes to SDSS spectra. This classification scheme relies on the pseudo equivalent width (pEW) of Si ii at λ5972 vs λ6355 (here-in refered to as features pW6 and pW7)

#### Section Contents:
1. <a href='#spectral_binning'>Spectral Binning</a>: A reminder of some of the manipulations performed on each spectrum during the analysis.
1. <a href='#selecting_good_data'>Selecting the "Good" Data</a>: Selects only measurements near peak brightness and drops results from noisy spectra.
1. <a href='#branch'>Branch Classification Plot</a>: Plots Si ii λ5972 vs Si ii λ6355 and assigns subtypes to each SN.

### Spectral Binning <a id='spectral_binning'></a>

We pause for a moment and remind ourselves that the values we are working with in this notebook are calculated after each spectra is restframed, corrected for Milky Way extinction, and binned to a resolution of five (unless otherwise noted) angstroms. We choose to use a gaussian filter to reduce the resolution, although options are available for averaging and summing in each bin using the `method` argument.


In [None]:
data = sako18spec.get_data_for_id('5635')
wave = data['wavelength']
flux = data['flux']
z = data.meta['z']
ra = data.meta['ra']
dec = data.meta['dec']

# Bin the flux and calculate the average in each bin
bin_wave, bin_flux = spec_class.bin_spectrum(wave, flux, method='gauss')

# Correct for exctinction and shift to rest frame
rest_wave, rest_flux = spec_class.correct_extinction(bin_wave, bin_flux, ra, dec, z)

plt.figure(figsize=(16, 8))
plt.plot(wave, flux, linewidth=.5, label='Original')
plt.plot(bin_wave, bin_flux + 5e-17, linewidth=1, label='Binned Spectrum')
plt.plot(rest_wave, rest_flux + 1e-16, linewidth=.5, label='Rest Framed Spectrum')

plt.ylim(ymin=0)
plt.ylabel('Flux')
plt.xlabel(r'Wavelength ($\AA$)')
plt.legend()


### Selecting the "Good" Data <a id='selecting_good_data'></a>

We start by checking the reasons why some of our calculations failed. Note that range and index related errors are user generated errors during the analysis used to indicate cases where the spectrum was too noisy to identify the feature.



In [None]:
spec_results[spec_results.pew.isna()].msg.value_counts()


We also check the number of observations available for each target. The data release includes targets with multiple observations, however, depending on how the analysis pipeline was run, duplicate observations may have already been dropped.

In [None]:
# Divide by two since there are two feature measurments per object.
(spec_results.sid.value_counts() / 2).hist()
plt.title('Number of spectra per object (Not including host)')
plt.ylabel('Number of Targets')
plt.xlabel('Number of Spectra')
plt.show()


Moving forward we drop results from spectra that:
1. Are not measured from the spectrum taken closest to peak brightness.
1. Do not have pW7 measurements due to their spectral range.
1. Fail visual inspection

We start with the first two conditions:

In [None]:
def get_tmax_pew(spec_data):
    """Keep only pew measurements performed nearest tmax
    
    Args:
        spec_data (DataFrame): Measurements from the analysis pipeline
    
    Returns:
        A pandas DataFrame
    """
    
    # Determine what features were measured
    features = spec_data.index.get_level_values('feat_name').unique()

    data_frames = []
    for feat_name in features:
        feat_data = spec_data.xs(feat_name, level='feat_name')
        feat_data['feat_name'] = feat_name
        feat_data.set_index('feat_name', append=True, inplace=True)
        
        feat_data['sort'] = feat_data.days.abs()         
        feat_data = feat_data.sort_values(by='sort')
        feat_data = feat_data[~feat_data.index.duplicated()]
        feat_data = feat_data.drop(axis=1, labels='sort')
        data_frames.append(feat_data)
    
    all_data = pd.concat(data_frames)
    all_data.dropna(subset=['pew'], inplace=True)
    
    return all_data.loc[all_data.xs('pW7', level='feat_name').index]


In [None]:
spec_results_peak = get_tmax_pew(spec_results)
spec_results_peak.head()


Next we perform a visual inspection to drop results from any spectra that are particularly noisy. The analysis pipeline already requires the visual inspection of each spectrum, and many noisy spectra may have already been flagged and skipped over. However, the goal of the initial inspection is to be overly ambitious in what we can measure while acknowledging a secondary cut is required later on. We perform that cut here. We also drop spectra with H$\alpha$ lines, since that indicates a potential problem with the host galaxy subtraction.

The following few cells will plot each spectra and ask for an input. Valid inputs are as follows:
1. `<Enter>` indicates a good spectrum. 
1. A number replots the spectrum with a new upper bound
1. `n` Means the spectrum is too noisy.
1. `h` Indicates a strong H-alpha line
1. `r` Indicates the spectrum does not cover enough of the necessary wavelength range.

The `start_from` variable can be used to start the iteration process from a given object Id. Results generated by the notebook author are hardcoded below.

In [None]:
ha = []  # Significant H-alpha
noisy = []  # Noisy spectra
bad_range = []  # Incomplete wavelength range

# Set the object ID to start from. '99999' is higher than any object Id,
# thus skipping to the end.
start_from = '99999'  


In [None]:
from IPython.display import clear_output

indices = spec_results_peak.index.get_level_values('obj_id').unique()
for i, idx in enumerate(sorted(indices)):
    if idx < start_from:
        continue
        
    inp = '5e-17'    
    while inp:
        if inp == 'h':
            ha.append(idx)
            break
            
        if inp == 'n':
            noisy.append(idx)
            break
            
        if inp == 'r':
            bad_range.append(idx)
            break
            
        clear_output()
        f, a = plot_outliers(spec_results_peak.loc[[idx]])
        a[0].set_ylim(0, float(inp))
        plt.show()
        inp = input(f'{i}/{len(indices)} ({i / len(indices) * 100:.2f}%)')
    
    last_index = idx
    clear_output()
 

Here are results tabulated by the notebook author.

In [None]:
ha = [
    '10096', '10805', '12844', '12856', '12874', '12927', '12950', '12977', 
    '13025', '13070', '13099', '13152', '13174', '13254', '13354', '13467', 
    '13830', '14261', '14279', '15129', '15136', '15213', '15234', '15467', 
    '16099', '16099', '16578', '16637', '16692', '16776', '16847', '17176', 
    '17220', '17389', '17568', '17605', '18375', '18485', '18697', '18903', 
    '19003', '19008', '19207', '19353', '19626', '19775', '19969', '20142', 
    '20245', '20528', '21062', '2330', '2635', '2789', '3901', '4524', '5717', 
    '5751', '6057', '6108', '6249', '6773', '7876', '8921'
]

noisy = [
    '11067', '11300', '11452', '11557', '1166', '13610', '13655', '13689', 
    '13736', '14421', '15203', '15229', '15287', '15356', '15383', '15456', 
    '16072', '16350', '16414', '16789', '17048', '18749', '18855', '18890', 
    '18927', '18959', '19230', '20106', '20144', '20184', '20227', '20345', 
    '20432', '20581', '20768', '21034', '21510', '2372', '2533', '2689', '3080', 
    '3199', '3452', '4577', '4679', '6127', '6137', '744', '7475', '7512', '762', 
    '774', '7947', '8598', '9457', '7143'
]

bad_range = ['16116', '17171', '17208', '17215', '19616', '19658']


In [None]:
bad_obj_ids = set(ha + noisy + bad_range)
good_spectra = spec_results_peak.drop(bad_obj_ids)
print(f'{len(bad_obj_ids)} bad spectra out of {len(spec_results_peak.index.unique())}')


In [None]:
def plot_all_spectra(
        spec_measurements, xlim=(5000, 7000), space_scale=1.5, num_columns=2):
    """Plot all spectra used in our data sample with PEW above a given SNR

    Args:
        spec_measurements: 
        xlim: 
        space_scale: 
        num_columns: 

    Returns:
        A matplotlib figure
        A matplotlib axis
    """

    fig, axes = plt.subplots(1, num_columns, figsize=(8.5, 11), sharex=True)

    # Divide object ids into sperate collections for each figure columns
    spec_to_plot = list(spec_measurements.index.get_level_values('obj_id').unique())
    spec_per_col = int(np.ceil(len(spec_to_plot) / num_columns))
    spectra_cols = [spec_to_plot[i * spec_per_col: (i + 1) * spec_per_col] for i in range(num_columns)]
    
    for obj_ids_in_column, axis in zip(spectra_cols, axes.flatten()):
        
        yticks = []
        for i, obj_id in enumerate(obj_ids_in_column):
            target_data = sako18spec.get_data_for_id(obj_id)
            target_measurements = spec_measurements.loc[obj_id]
            offset = i * space_scale

            # Keep only the SN spectra used in the classification
            date = target_measurements.date[0]
            target_data = target_data[target_data['date'] == date]
            target_data = target_data[target_data['type'] != 'Gal']

            # Correct for extinction, shift to rest frame, and bin spectrum
            wave = target_data['wavelength']
            flux = target_data['flux']
            z = target_data.meta['z']
            ra = target_data.meta['ra']
            dec = target_data.meta['dec']

            bin_wave, bin_flux = spec_class.bin_spectrum(wave, flux)
            rest_wave, rest_flux = spec_class.correct_extinction(
                bin_wave, bin_flux, ra, dec, z)

            # Scale and offset spectra to same order of magnitude
            _, pw7_end = spec_class.guess_feature_bounds(
                rest_wave, rest_flux, spec_class.line_locations['pW7'])
            scale = rest_flux[np.where(rest_wave == pw7_end)[0][0]]
            rest_flux /= scale
            rest_flux += offset

            axis.plot(rest_wave, rest_flux, lw=1, color='k')
            
            # Add the object Id as a tickmark label
            min_idx = np.argmin(abs(rest_wave - xlim[0]))
            yticks.append(rest_flux[min_idx])

        axis.set_ylim(0, (i + 1) * space_scale)
        axis.set_yticks(yticks)
        axis.set_yticklabels(obj_ids_in_column)
        axis.set_xlim(xlim)
        axis.set_xlabel('Wavelength')

    axes[0].set_ylabel('Candidate Id')
    plt.tight_layout()


In [None]:
obj_ids = sorted(good_spectra.index.get_level_values('obj_id').unique())

# Break object Ids into chuncks and create multiple figures
num_figs = 5
for i in range(num_figs - 1):
    n = len(obj_ids) // num_figs
    fig_ids = obj_ids[i * n: (i + 1) * n]
    plot_all_spectra(spec_results_peak.loc[fig_ids])
    plt.show()

fig_ids = obj_ids[(i + 1) * n:]
plot_all_spectra(spec_results_peak.loc[fig_ids])
plt.show()


### Branch Classification Plot <a id='branch'></a>

We plot the pEW of Si ii at λ5972 vs λ6355. 


In [None]:
def get_colors(pw6, pw7):
    """Get the color of each point bassed on its coordinates
    
    Args:
        pw6 (ndarray): Array of EW measurements for feature 6
        pw7 (ndarray): Array of EW measurements for feature 7
        
    Returns:
        A 2d array of RGB values
    """
    
    color = np.empty(len(pw6), dtype='U10')
    color[:] = 'black'  # Default to black
    color[pw6 > 30] = 'blue'  # Blue
    color[(pw6 < 30) & (pw7 > 105)] = 'red'  # Red
    color[pw7 < 70] = 'green'  # Green

    return color

def subplot_published_classes(axis):
    """Plot Si ii pEW at λ5972 vs λ6355 from CSP and Branch 2006
    
    Args:
        axis (Axis): A matplotlib axis
    """

    axis.scatter(
        branch_06.pw7, 
        branch_06.pw6, 
        marker='D',
        facecolor='none', 
        edgecolor=get_colors(branch_06.pw6, branch_06.pw7),
        zorder=2,
        alpha=.5,
        label='Branch 2006'
    )

    axis.scatter(
        folatelli_13.pW7, 
        folatelli_13.pW6, 
        marker='v',
        facecolor='none', 
        edgecolor=get_colors(folatelli_13.pW6, folatelli_13.pW7),
        zorder=2,
        alpha=.5,
        label='Folatelli 2013'
    )
    
def get_pew_above_snr(df, snr):
    """Return a data frame with only measurements above a given SNR
    
    Args:
        df (DataFrame): Dataframe of spectroscopic pipeline results
        snr    (float): The signal to noise ratio
        
    Returns:
        A DataFrame
    """
    
    lt_7_days = df[df.days.abs() < 7]
    good_snr = lt_7_days[lt_7_days.pew / lt_7_days.pew_samperr  >= snr][['pew', 'pew_samperr']]
    pw6 = good_snr.xs('pW6', level='feat_name')
    pw7 = good_snr.xs('pW7', level='feat_name')
    pew = pd.merge(pw6, pw7, on='obj_id', suffixes=('_pw6', '_pw7'))
    pew = pew.join(df.date)
    return pew
    

def plot_si_ratio(spec_data, snr_ratios=(1, 2, 3), plot_external_data=True):
    """Plot the pW6 vs pW7 silicon pEw ratios for different SNR cutoffs
    
    Args:
        spec_data     (DataFrame): Measurements from the analysis pipeline
        snr_ratios        (tuple): SNR cutoffs to plot
        plot_external_data (bool): Whether to plot data from CSP and Branch 2006
        
    Returns:
        A matplotlib figure
        An array of matplotlib axes
    """
    
    # Keep only Type Ia spectra
    si_data = spec_data[spec_data.type.isin(['Ia', 'Ia-pec', 'Ia?'])].dropna(subset=['pew'])
    
    num_subplots = len(snr_ratios)
    fig, axes = plt.subplots(1, num_subplots, figsize=(num_subplots * 6, 6), sharex=True, sharey=True)  
    flat_ax = [axes] if num_subplots == 1 else axes.flatten()
    for snr, axis in zip(snr_ratios, flat_ax):
        
        # Keep only data with both pw6 and pw7 measurements greater than snr
        plot_data = get_pew_above_snr(si_data, snr)
        
        axis.errorbar(
            x=plot_data.pew_pw7, 
            y=plot_data.pew_pw6, 
            xerr=plot_data.pew_samperr_pw7, 
            yerr=plot_data.pew_samperr_pw6, 
            linestyle='', 
            ecolor='grey', 
            color='grey',
            alpha=.3, 
            zorder=0)
        
        if plot_external_data:
            subplot_published_classes(axis)
        
        # Isolate each subtype
        cl = plot_data[plot_data.pew_pw6 > 30]
        bl = plot_data[(plot_data.pew_pw6 < 30) & (plot_data.pew_pw7 > 105)]
        ss = plot_data[plot_data.pew_pw7 < 70]
        cn = plot_data[(plot_data.pew_pw6 <= 30) & (70 <= plot_data.pew_pw7) & (plot_data.pew_pw7 <= 105)]
        
        axis.scatter(cl.pew_pw7, cl.pew_pw6, color='C0', zorder=1, label='Cool')
        axis.scatter(bl.pew_pw7, bl.pew_pw6, color='C3', zorder=1, label='Broad Line')
        axis.scatter(ss.pew_pw7, ss.pew_pw6, color='C2', zorder=1, label='Shallow Silicon')
        axis.scatter(cn.pew_pw7, cn.pew_pw6, color='k', zorder=1, label='Core Normal')
        
        axis.set_xlabel(r'Si ii $\lambda$6355', fontsize=14)
        axis.set_title(rf'SNR $>$ {snr}')
        
    flat_ax[0].set_xlim(-10, 500)
    flat_ax[0].set_ylim(-10, 200)
    flat_ax[0].set_ylabel(r'Si ii $\lambda$5972', fontsize=14) 
    flat_ax[-1].legend()
    return fig, flat_ax


Plotting the subtypes, we note by using a Gaussian filter to reduce noise, the error bars get artificially small. Values are plotted against results from existing publications.

In [None]:
fig, axes = plot_si_ratio(good_spectra, (5, 10, 15, 20))
fig.suptitle(f'Branch Si II Subtypes')
plt.savefig(fig_dir / 'branch_subtypes.pdf')
plt.show()


For curiosity's sake, we investigate the furthest outliers in the first plot.

In [None]:
def get_outliers(spec_data, pw6_cutoff, pw7_cutoff, snr_cutoff=0):
    """Select meaurements with pw6 pew > pw6_cutoff and pw7 pew pw7_cutoff
    
    Args:
        spec_data (DataFrame): Data to select on
        pw6_cutoff    (float): The PEW cutoff for feature pw6
        pw7_cutoff    (float): The PEW cutoff for feature pw7
        snr_cutoff    (float): SNR cutoff to apply before selection
        
    Returns:
        A DataFrame with measurements matching the specified criteria
    """
    good_snr = spec_data[spec_data.pew / spec_data.pew_samperr > snr_cutoff]

    pw6 = good_snr.xs('pW6', level='feat_name')
    outliers = set(pw6[pw6.pew > pw6_cutoff].index)

    pw7 = good_snr.xs('pW7', level='feat_name')
    outliers = list(outliers.intersection(pw7[pw7.pew > pw7_cutoff].index))

    outlier_meas = good_snr.loc[outliers]
    pw6_indices = set(outlier_meas.xs('pW6', level='feat_name').index)
    pw7_indices = set(outlier_meas.xs('pW7', level='feat_name').index)
    outliers = sorted(pw6_indices.intersection(pw7_indices))

    return good_snr.loc[outliers]

def subplot_feature_pew(wave, flux, axis, feat_start, feat_end, **kwargs):
    """Shade in the PEW of spectral properties

    Args:
        wave     (ndarray): The spectrum's wavelengths
        flux     (ndarray): The flux for each wavelength
        axis        (Axis): The axius to plot on
        feat_start (float): Where the feature starts
        feat_end   (float): Where the feature ends
        Any other kwargs for ``axis.fill_between``
    """

    idx_start = np.where(wave == feat_start)[0][0]
    idx_end = np.where(wave == feat_end)[0][0]
    feat_wave = wave[idx_start: idx_end + 1]
    feat_flux = flux[idx_start: idx_end + 1]

    continuum, norm_flux, pew = spec_class.feature_pew(feat_wave, feat_flux)
    axis.fill_between(feat_wave, feat_flux, continuum, alpha=.75, zorder=0, **kwargs)
    

def plot_outliers(outlier_data):
    """Plot a collection of spectra

    Args:
        outlier_data (DataFrame): PEW Measurements for spectra to plot

    Returns:
        A matplotlib Figure
        An array of matplotlib axes
    """

    obj_ids = outlier_data.index.get_level_values('obj_id').unique()
    fig, axes = plt.subplots(len(obj_ids), 1, figsize=(15, 4 * len(obj_ids)), sharex=True)
    axes = np.atleast_1d(axes)

    for obj_id, axis in zip(obj_ids, axes):
        target_data = sako18spec.get_data_for_id(obj_id)
        target_data.sort('wavelength')

        # Keep only the SN spectra used in the classification
        date = outlier_data.loc[obj_id].date[0]
        pw6_pew = outlier_data.loc[obj_id, 'pW6'].pew
        pw7_pew = outlier_data.loc[obj_id, 'pW7'].pew

        target_data = target_data[target_data['date'] == date]
        target_data = target_data[target_data['type'] != 'Gal']

        # Correct for exctinction and shift to rest frame
        wave = target_data['wavelength']
        flux = target_data['flux']
        z = target_data.meta['z']
        ra = target_data.meta['ra']
        dec = target_data.meta['dec']
        rest_wave, rest_flux = spec_class.correct_extinction(wave, flux, ra, dec, z)
        bin_wave, bin_flux = spec_class.bin_spectrum(rest_wave, rest_flux, method='gauss')

        axis.plot(rest_wave, rest_flux, lw=1, label='Restframed Spectrum', color='grey', alpha=.75)
        for feat_name, row in outlier_data.loc[obj_id].iterrows():
            subplot_feature_pew(bin_wave, bin_flux, axis, row.feat_start, row.feat_end, label=feat_name)

        axis.plot(bin_wave, bin_flux, lw=2, label='Binned Spectrum', color='k')
        axis.set_xlabel('Wavelength')
        axis.set_ylabel('Flux')
        axis.set_ylim(0, 3e-17)
        axis.set_title(f'Object Id: {obj_id}')
        
    axis.legend()
    axis.set_xlim(3000, 7500)

    return fig, axes


In [None]:
outliers = get_outliers(good_spectra, 75, 0)
outliers


In [None]:
_ = plot_outliers(outliers)


We also plot a few of the dropped spectra

In [None]:
fig, axes = plot_outliers(spec_results_peak.loc[['6304', '7947', '7876']])
axes[0].set_ylim(0, .03)
axes[0].set_xlabel('')
axes[1].set_ylim(0, 6e-17)
axes[1].set_xlabel('')
axes[2].set_ylim(0, 4e-16)
axes[2].set_xlabel(r'Wavlength ($\AA$)')

plt.savefig(fig_dir / 'Dropped_spectra.pdf')
