# ParSNIP Limitation Testing

This notebook is for testing the limitations of ParSNIP by modifying datasets, making predictions and classifications, and comparing the results to those of unmodified data sets. Currently, it tests how the accuracy of ParSNIP degrades as the number of observations decreases.

Written by John Delker (jfla@uw.edu)

### Important Notes
* To easily run this, I recommend using the kernel called "John Delker's Python 3.10".
* 'cutoff' variables refer to a particular max number of data points (observations). So if the cutoff is 32, then only 32 total observations are used per bandpass.
* This notebook can be used with some non-plasticc datasets such as PS1, but a few settings will need to be changed, such as the names of bandpasses. I recommend making sure it runs with plasticc before trying to change the dataset used.

## Initial Setup

In [251]:
# Load dependencies
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from astropy.table import Table, vstack
import time
import parsnip
import lcdata
from ipywidgets import interact, interactive, interactive_output, fixed, interact_manual
import ipywidgets as widgets
from collections import namedtuple
from IPython.display import display
import pandas
import os

# Hide a few warnings that would otherwise fill the page
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings("ignore", message="'verbose' argument is deprecated and will be removed in a future release of LightGBM.")

## Settings

In [173]:
all_bands = ["lsstu", "lsstg", "lsstr", "lssti", "lsstz", "lssty"] # Names of the bands used in the dataset
band_colors = ["blue", "green", "red", "purple", "brown", "black"] # Colors used for the above bands when plotting
model = parsnip.load_model("plasticc") # Options: "ps1", "plasticc" - If changed, the bands and below dataset will need to be changed.
Curve = namedtuple('Curve', 'time flux results')

# Path to the public data folder so that this notebook can be used anywhere on epyc
parsnip_data = "/epyc/data/parsnip_tests/"

# Path to the dataset you want to train the model on for classifications
# plasticc_test_slim is the first 5000 entries of the full platicc dataset.
# You can alternatively use any dataset or subset of data you choose.
dataset_path = os.path.join(parsnip_data, "data/plasticc_test_slim.h5")

# Path to predictions from the model
predictions_path = os.path.join(parsnip_data, "predictions/parsnip_predictions_plasticc_train_aug_100.h5")

In [59]:
# Prepares the dataset, removing any objects whose type ParSNIP can not explicitly classify
# NOTE: If not using Plasticc, some of this may need tweaking
plasticc_dataset = parsnip.load_dataset(dataset_path)
plasticc_dataset = plasticc_dataset[~(plasticc_dataset.meta['type'] == 'CaRT')]
plasticc_dataset = plasticc_dataset[~(plasticc_dataset.meta['type'] == 'ILOT')]

Parsing 'plasticc_test_slim.h5' as PLAsTiCC dataset...
Rejecting 0 non-supernova-like light curves.
Dataset contains 5000 light curves.


In [4]:
# Train a classifier
classifier = parsnip.Classifier()
training_predictions = Table.read(predictions_path)
training_classifications = classifier.train(training_predictions)

Training classifier with keys:
    color
    color_error
    s1
    s1_error
    s2
    s2_error
    s3
    s3_error
    luminosity
    luminosity_error
    reference_time_error
[100]	valid_0's multi_logloss: 0.495647
[100]	valid_0's multi_logloss: 0.509514
[100]	valid_0's multi_logloss: 0.536781
[100]	valid_0's multi_logloss: 0.531477
[100]	valid_0's multi_logloss: 0.469776
[100]	valid_0's multi_logloss: 0.518461
[100]	valid_0's multi_logloss: 0.532923
[100]	valid_0's multi_logloss: 0.488277
[100]	valid_0's multi_logloss: 0.491106
[100]	valid_0's multi_logloss: 0.549328


## Modifying the datasets

In [5]:
def restrict_bands(dataset, bands):
    '''For each sample in the dataset, removes all points that are not part of the given bands'''
    
    # Makes a copy of the dataset so that the original is left unmodified
    modified_dataset = lcdata.Dataset(dataset.meta.copy(), dataset.light_curves.copy())
    
    # Group the light curves by band
    for i in range(0, len(modified_dataset)):
        light_curve = dataset.light_curves[i].group_by('band')
        group_count = len(light_curve.groups.indices) - 1

        # Mask out all points that are part of any band that is being removed for this test
        mask = np.ones(len(light_curve), dtype=bool)
        for n in range(0, group_count):
            if light_curve.groups.keys["band"][n] in bands: continue
            mask[light_curve.groups.indices[n] : light_curve.groups.indices[n + 1]] = False
            
        # Apply the mask and return to the original sorting method to prevent possible issues
        modified_dataset.light_curves[i] = light_curve[mask]
        modified_dataset.light_curves[i].sort("time")
    
    return modified_dataset 

In [6]:
def restrict_observation_count(dataset, max_observations_per_band):
    '''Removes observations past the max number of observations specified.'''
    
    # Makes a copy of the dataset so that the original is left unmodified
    modified_dataset = lcdata.Dataset(dataset.meta.copy(), dataset.light_curves.copy())
    
    # For each object in the dataset, limit the number of observations in each band
    for i in range(0, len(modified_dataset)):

        # Group the light curves by band so that each band can be manipulated individually
        light_curve = dataset.light_curves[i].group_by('band')
        group_count = len(light_curve.groups.indices) - 1

        # Create a mask to remove points from each group/band
        mask = np.ones(len(light_curve), dtype=bool)
        for n in range(0, group_count):
            lower_bound = light_curve.groups.indices[n]
            upper_bound = light_curve.groups.indices[n + 1]

            # Otherwise, only mask out the portion that would be above the point cutoff
            # Note: A cutoff of "None" retains all original observations
            if max_observations_per_band is not None and lower_bound + max_observations_per_band < upper_bound:
                mask[lower_bound + max_observations_per_band : upper_bound] = False

        # Apply the mask and return to the original sorting method to prevent possible issues
        modified_dataset.light_curves[i] = light_curve[mask]
        modified_dataset.light_curves[i].sort("time")
    
    return modified_dataset

## Classifying modified data

In [217]:
def limitation_test(sample_size, bands, observation_cutoffs = [], object_index = None):
    '''
    Takes a sample of the given dataset and returns the predictions and classifications for each sample
    along with predictions and classifications made using degraded versions of those sample.
    '''
    
    global model, plasticc_dataset
    
    # If only looking at one object, grab that part of the dataset; otherwise, grab the first X objects
    if not(object_index is None): dataset = plasticc_dataset[object_index]
    else: dataset = plasticc_dataset[0:sample_size]
        
    # Remove unused bands from the dataset
    dataset = restrict_bands(dataset, bands)
    
    # Classify the dataset with points removed for each cutoff given, and once without points removed as a control
    observation_cutoffs.insert(0, None)
    classified_data = { "used_bands": bands, "used_cutoffs": observation_cutoffs }
    for cutoff in observation_cutoffs:
    
        # Modifies the dataset with a cutoff for observations, then generates predictions and classifications
        modified_dataset = restrict_observation_count(dataset, cutoff)
        predictions = model.predict_dataset(modified_dataset)
        classifications = classifier.classify(predictions)

        # For every object in the dataset, store the classification and relevant info
        for index in range(0, len(predictions)):
            
            # Stores on a per-object basis to easily compare how an individual object is affected by modifications
            object_id = classifications["object_id"][index]
            object_info = classified_data.get(object_id)
            if object_info is None: object_info = []
            
            # Find the classification considered most likely by the classifier
            top_prediction = None
            for c in classifications.colnames:
                if c == "object_id": continue
                if top_prediction == None or classifications[c][index] > classifications[top_prediction][index]:
                    top_prediction = c
            
            # Uses ParSNIP to predict the full light curve based on modified samples
            predicted_curve = model.predict_light_curve(modified_dataset.light_curves[index], False)
            
            # Collects all the most important info from dataset, predictions, and classifications into one place
            object_info.append({
                "cutoff": cutoff, 
                "truth": modified_dataset.meta['type'][index],
                "prediction": top_prediction,
                "predictions": predictions[index],
                "light_curve": modified_dataset.light_curves[index], 
                "predicted_curve": Curve(predicted_curve[0], predicted_curve[1], predicted_curve[2]),
                "SNIa": classifications["SNIa"][index], 
                "SNII": classifications["SNII"][index], 
                "SLSN-I": classifications["SLSN-I"][index], 
                "SNIa-91bg": classifications["SNIa-91bg"][index], 
                "SNIax": classifications["SNIax"][index], 
                "SNIbc": classifications["SNIbc"][index], 
                "TDE": classifications["TDE"][index], 
                "KN": classifications["KN"][index]
            })
            
            classified_data.update({ object_id: object_info })
        
    return classified_data

## Analysis & plotting functions

In [260]:
def plot_curve(dataset, cutoffs, object_index, cutoff, used_bands, shown_bands, show_scatter, isolate_variability = True):
    '''
    Predicts the light curves for an object, for each modification made to that object, and plots
    those light curves with scatter, error bars, and the unmodified light curve for comparison.
    '''
    
    global model
    
    if "ALL" in used_bands: used_bands = all_bands
    if "ALL" in shown_bands: shown_bands = all_bands
        
    cutoff_value = cutoffs[cutoff - 1] if cutoff > 0 else None
    used_cutoffs = [cutoff_value] if cutoff > 0 else []
    cutoff_index = 1 if cutoff > 0 else 0
        
    # Make classifications for the given settings
    data = limitation_test(1, used_bands, used_cutoffs, object_index)
    object_id = list(data.keys())[2]
    true_class = data[object_id][0]['truth']
    
    fig, ax = plt.subplots(1, 1, figsize = (12, 5))
        
    primary_light_curve = data[object_id][0]["light_curve"]
    primary_predicted_curve = data[object_id][0]["predicted_curve"]
    start_time = np.min(primary_predicted_curve.time)
    
    ref_time = data[object_id][0]['predictions']['reference_time'] - start_time
    ref_error = data[object_id][0]['predictions']['reference_time_error']
    
    # TODO: Figure out why these values are so incorrect
    #parsnip_scale = data[object_id][0]['predictions']['parsnip_scale']
    #amplitude = data[object_id][0]['predictions']['amplitude'] * parsnip_scale
    #amplitude_error = data[object_id][0]['predictions']['amplitude_error'] * parsnip_scale
    
    # Determining the vertical axis limits to show as much as possible without losing any data
    if show_scatter != "off": min_y = min(primary_light_curve["flux"])
    else: min_y = min(primary_predicted_curve.flux.flatten())
    max_y = max(primary_predicted_curve.flux.flatten())
    
    for band in shown_bands:
        if not(band in used_bands): continue
        band_index = all_bands.index(band)
        
        # Plots the predicted light curve for the unmodified sample in the given band
        ax.plot((primary_predicted_curve.time - np.min(start_time)), primary_predicted_curve.flux[0][band_index], 
                            label = f"{len(primary_light_curve)} points (original)", 
                            linestyle = "--", linewidth = 1, color = band_colors[band_index])
        
        light_curve = data[object_id][cutoff_index]["light_curve"]
        light_curve_mask = light_curve["band"] == band

        # Plot the individual observations and their flux error in the given band
        if show_scatter != "off":
            error = light_curve[light_curve_mask]['fluxerr'] if show_scatter == "with error" else np.zeros(len(light_curve[light_curve_mask]['fluxerr']))
            ax.errorbar(light_curve[light_curve_mask]['time'] - start_time, light_curve[light_curve_mask]['flux'], 
                                    yerr = error, fmt = '.', label = band, color = band_colors[band_index])

        # Plot a vertical bar indicating the reference time and it's error
        ax.axvspan(ref_time - ref_error, ref_time + ref_error, alpha=0.2, color='red')
        
        # TODO: Also plot a horizontal line for the amplitude. This code already works, but
        #       the value for the amplitude seems very incorrect so the units may be off.
        #ax.axhspan(amplitude - amplitude_error, amplitude + amplitude_error, alpha=0.2, color='green')

        # Ensure the unmodified light curves are always plotted.
        if cutoff_value != None:
            predicted_curve = data[object_id][cutoff_index]["predicted_curve"]
            ax.plot(predicted_curve.time - start_time, predicted_curve.flux[0][band_index], 
                                label = f"{len(light_curve)} points", linewidth = 1.5, 
                                color = band_colors[band_index])

            max_y = max(max(predicted_curve.flux.flatten()), max_y)
        
    # Restrict the viewed area to the variable portion of the light curve
    if isolate_variability:
        max_time = max(primary_predicted_curve.time.flatten()) - start_time
        ax.set_xlim(left = max(ref_time - 80, 0), right = min(ref_time + 300, max_time))
    ax.set_ylim(top = max_y + 2, bottom = min_y - 2)

    # What is the predicted class and it's associated probability?
    predicted_class = data[object_id][cutoff_index]['prediction']
    prediction_probability = data[object_id][cutoff_index][predicted_class]
    title = f"{(prediction_probability * 100):.2f}% {predicted_class}"

    # If the predicted class is incorrect, what is the predicted probability of the true class?
    if predicted_class != true_class: 
        truth_probability = data[object_id][cutoff_index][true_class]
        title = f"{title} ({(truth_probability * 100):.2f}% {true_class})"

    # Make predicted and true class probabilities the title of each plot
    ax.set_title(title)
    fig.suptitle(object_id)
    #plt.legend()

    plt.show();

In [259]:
def plot_interactable_curve(dataset, cutoffs):
    ''' Creates an interactable plot that compares predicted light curves. '''
    
    object_index = widgets.BoundedIntText(min = 0, max = len(plasticc_dataset) - 1, step = 1, value = 0, description = "Object Index:", continuous_update = False)
    cutoff = widgets.IntSlider(min = 0, max = len(cutoffs), step = 1, value = 0, description = "Cutoff:", continuous_update = False)
    used_bands = widgets.SelectMultiple(options = ["ALL"] + all_bands, rows = 3, value = ["ALL"], description = "Use Bands:")
    shown_bands = widgets.SelectMultiple(options = ["ALL"] + all_bands, rows = 3, value = ["ALL"], description = "See Bands:")
    show_scatter = widgets.RadioButtons(options = ['with error', 'no error', 'off'], description = "Scatter:")
    
    # TODO: Figure out why checkboxes aren't showing up with ipywidgets
    #isolate_variability = widgets.Checkbox(value=True, description="Isolate Variability")
    
    ui = widgets.HBox([
        widgets.VBox([object_index, cutoff, show_scatter]),#, isolate_variability]), 
        widgets.VBox([used_bands, shown_bands])
    ])
    
    output = widgets.interactive_output(
        plot_curve, {
            "dataset": fixed(dataset), 
            "cutoffs": fixed(cutoffs), 
            "object_index": object_index,
            "cutoff": cutoff, 
            "used_bands": used_bands, 
            "shown_bands": shown_bands, 
            "show_scatter": show_scatter
            #"isolate_variability": isolate_variability
        });
    
    display(output, ui)

In [171]:
def plot_class_grid(data, class_analyzed = "ALL", max_objects = -1):
    '''
    Creates a table of data where each column is a modification (point cutoff) to the data, and each row is the
    probability of an object being that type. When multiple objects are used, their values are stacked to get an
    average probability for that type.
    
    The probabilities for a particular column relative to the first column (which has no modifications) tells us
    how much the accuracy of ParSNIP degrades with those modifications.
    
    Parameters
    ----------
    data : dict
        A dictionary of objects with their classifications and predictions for a number of modifications.
        
    class_analyzed: string
        If set to "ALL", then all classes will be individually analyzed and tables will be made for each.
        Otherwise, this can be set to the name of the class for which you want a single analysis table made.
        
    max_objects: int
        The maximum number of objects to include in the analysis. The values of each object will be stacked.
        The default of -1 will have no maximum.
    '''
    
    # NOTE: Plasticc dataset has no KN objects, but ParSNIP can still predict them. So it must be an included row.
    predictable_classes = { "SNIa" : 0., "SNII": 0, "SLSN-I": 0, "SNIa-91bg": 0, "SNIax": 0, "SNIbc": 0, "TDE": 0., "KN": 0. }
    
    classes = []
    if class_analyzed == "ALL": classes = ["SNIa", "SNII", "SLSN-I", "SNIa-91bg", "SNIax", "SNIbc", "TDE"]
    else: classes = [class_analyzed]
    
    for class_type in classes:
        
        data_table = { }
        object_count = 0

        for object_id in data:
            if object_id == "used_bands" or object_id == "used_cutoffs": continue
            if not(data[object_id][0]['truth'] == class_type): continue
            for index in range(0, len(data[object_id])):
                for key in data[object_id][index]:
                    if not(key in ["cutoff", "light_curve", "truth", "prediction", "predicted_curve"]): 
                        cutoff = str(data[object_id][index]["cutoff"])
                        if not(cutoff in data_table):
                            data_table[cutoff] = predictable_classes.copy()
                        data_table[cutoff][key] += data[object_id][index][key]

            object_count += 1
            if max_objects > 0 and object_count >= max_objects: break

        # Divides each value by the total number of objects included so that its a percentage
        for cutoff in data_table:
            for key in data_table[cutoff]:
                data_table[cutoff][key] /= object_count

        data_frame = pandas.DataFrame(data_table)
        
        # TODO: Make it instead show a grid plot for each object (WIP code below)
        display(data_frame)
        print(f"OBJECTS INCLUDED: {object_count} {class_type}\n")
        
        '''
        nrows = 8
        ncols = len(dataFrame.columns)
        print(ncols)

        Cellid = [2, 4 ,5, 11]
        Cellval = [20, 45 ,55, 77]

        data = np.zeros(nrows*ncols)
        data[Cellid] = Cellval
        data = np.ma.array(data.reshape((nrows, ncols)), mask=data==0)

        fig, ax = plt.subplots()
        ax.imshow(data, cmap="Greens", origin="lower", vmin=0, vmax=1)

        # optionally add grid
        ax.set_xticks(range(len(dataFrame.columns))-0.5, dataFrame.columns, minor=True)
        ax.set_xticks(range(8)-0.5, dataFrame, minor=False)
        ax.set_yticks(np.arange(nrows+1)-0.5, minor=True)
        ax.grid(which="minor")
        ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        plt.show()
        '''

## Testing: Comparing predicted light curves after modifications

In [261]:
plot_interactable_curve(plasticc_dataset, [32, 16, 8, 4])

Output()

HBox(children=(VBox(children=(BoundedIntText(value=0, description='Object Index:', max=4965), IntSlider(value=…

## Testing: Comparing classifications after modifications

In [218]:
test_classifications = limitation_test(4500, all_bands, [32, 16, 8, 4])

In [172]:
plot_class_grid(test_classifications, class_analyzed = "ALL", max_objects = -1)

Unnamed: 0,None,32,16,8,4
SNIa,0.373396,0.262883,0.148789,0.109685,0.101304
SNII,0.2901,0.212229,0.12821,0.094215,0.071499
SLSN-I,0.013942,0.014538,0.018662,0.014913,0.01712
SNIa-91bg,0.058301,0.067831,0.074259,0.067694,0.066807
SNIax,0.074113,0.065721,0.049596,0.046059,0.044959
SNIbc,0.036103,0.057061,0.079636,0.078237,0.078638
TDE,0.151222,0.124843,0.105606,0.088925,0.082332
KN,0.002824,0.194894,0.395245,0.500272,0.537341


OBJECTS INCLUDED: 1777 SNIa



Unnamed: 0,None,32,16,8,4
SNIa,0.019665,0.022209,0.026992,0.028916,0.030445
SNII,0.424987,0.280862,0.149879,0.090232,0.068656
SLSN-I,0.010255,0.010607,0.011993,0.010109,0.012017
SNIa-91bg,0.014194,0.037145,0.049199,0.05129,0.055715
SNIax,0.053725,0.046301,0.037349,0.035233,0.036625
SNIbc,0.089532,0.096431,0.100019,0.094207,0.093563
TDE,0.378479,0.272728,0.186689,0.125985,0.10316
KN,0.009162,0.233716,0.43788,0.564029,0.599819


OBJECTS INCLUDED: 2226 SNII



Unnamed: 0,None,32,16,8,4
SNIa,0.019865,0.023194,0.073519,0.09697,0.057734
SNII,0.316798,0.149664,0.11678,0.143971,0.122273
SLSN-I,0.587105,0.35115,0.240171,0.194944,0.168885
SNIa-91bg,0.015156,0.05861,0.026937,0.041209,0.034974
SNIax,0.003318,0.019589,0.02149,0.030853,0.024956
SNIbc,0.047714,0.091086,0.044185,0.065584,0.06931
TDE,0.008927,0.045866,0.138003,0.082448,0.073536
KN,0.001117,0.260841,0.338914,0.344021,0.448332


OBJECTS INCLUDED: 10 SLSN-I



Unnamed: 0,None,32,16,8,4
SNIa,0.021027,0.024924,0.03163,0.027464,0.026558
SNII,0.062647,0.055962,0.052045,0.039352,0.041293
SLSN-I,0.028994,0.027661,0.014225,0.015236,0.01242
SNIa-91bg,0.78448,0.541001,0.355135,0.210204,0.116424
SNIax,0.030416,0.03629,0.026911,0.035333,0.031128
SNIbc,0.057197,0.073712,0.071515,0.080167,0.076255
TDE,0.002533,0.023401,0.038904,0.068132,0.066326
KN,0.012706,0.217048,0.409635,0.524112,0.629597


OBJECTS INCLUDED: 48 SNIa-91bg



Unnamed: 0,None,32,16,8,4
SNIa,0.195395,0.125914,0.085214,0.075638,0.050834
SNII,0.323909,0.210098,0.162257,0.107248,0.088763
SLSN-I,0.002607,0.005017,0.006171,0.005103,0.011947
SNIa-91bg,0.120818,0.103,0.0774,0.066723,0.065309
SNIax,0.150116,0.110156,0.071892,0.059043,0.052459
SNIbc,0.142996,0.129403,0.099378,0.102755,0.090524
TDE,0.044377,0.054389,0.062376,0.060609,0.063948
KN,0.019781,0.262023,0.435312,0.522881,0.576217


OBJECTS INCLUDED: 121 SNIax



Unnamed: 0,None,32,16,8,4
SNIa,0.032178,0.033423,0.027377,0.026716,0.029491
SNII,0.267445,0.171444,0.114952,0.074556,0.063913
SLSN-I,0.003551,0.004659,0.006142,0.007498,0.008175
SNIa-91bg,0.203366,0.154774,0.08934,0.085222,0.074816
SNIax,0.07057,0.054235,0.039305,0.035588,0.037505
SNIbc,0.364203,0.259038,0.174473,0.133267,0.106743
TDE,0.021438,0.036677,0.050147,0.053012,0.070769
KN,0.03725,0.285751,0.498265,0.58414,0.608588


OBJECTS INCLUDED: 311 SNIbc



Unnamed: 0,None,32,16,8,4
SNIa,0.004164,0.005109,0.024813,0.02084,0.016331
SNII,0.023545,0.089726,0.029436,0.023396,0.029107
SLSN-I,0.003433,0.005508,0.006276,0.066368,0.006052
SNIa-91bg,0.00231,0.002701,0.04168,0.032104,0.024473
SNIax,0.004474,0.004772,0.024973,0.016611,0.044397
SNIbc,0.016086,0.016229,0.063564,0.058197,0.055669
TDE,0.425996,0.552339,0.119647,0.144034,0.169801
KN,0.519991,0.323617,0.689612,0.63845,0.654171


OBJECTS INCLUDED: 7 TDE



## Easy references to data

In [263]:
# Just an easy way to look at what kind of data is stored in predictions
for object_id in test_classifications:
    if object_id == "used_bands" or object_id == "used_cutoffs": continue
    print(test_classifications[object_id][0]["predictions"].columns)
    break

<TableColumns names=('object_id','ra','dec','type','redshift','ddf_bool','hostgal_specz','hostgal_photoz','hostgal_photoz_err','distmod','mwebv','target','true_submodel','true_distmod','true_lensdmu','true_vpec','true_rv','true_av','true_peakmjd','libid_cadence','tflux_u','tflux_g','tflux_r','tflux_i','tflux_z','tflux_y','parsnip_reference_time','parsnip_scale','reference_time','reference_time_error','color','color_error','amplitude','amplitude_error','s1','s1_error','s2','s2_error','s3','s3_error','total_s2n','count','count_s2n_3','count_s2n_5','count_s2n_3_pre','count_s2n_3_rise','count_s2n_3_fall','count_s2n_3_post','model_chisq','model_dof','luminosity','luminosity_error')>
