# ParSNIP Limitation Testing

This notebook is for testing the limitations of ParSNIP by modifying datasets, making predictions and classifications, and comparing the results to those of unmodified data sets. Currently, it tests how the accuracy of ParSNIP degrades as the number of observations decreases.

Written by John Delker (jfla@uw.edu)

## Initial Setup

In [51]:
# Load dependencies
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from astropy.table import Table, vstack
import time
import parsnip
import lcdata
from ipywidgets import interact, interactive, interactive_output, fixed, interact_manual
import ipywidgets as widgets
from collections import namedtuple
from IPython.display import display
import pandas
import os

# Hide a few warnings that would otherwise fill the page
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings("ignore", message="'verbose' argument is deprecated and will be removed in a future release of LightGBM.")

## Settings

In [58]:
all_bands = ["lsstu", "lsstg", "lsstr", "lssti", "lsstz", "lssty"] # Names of the bands used in the dataset
band_colors = ["blue", "green", "red", "purple", "brown", "black"] # Colors used for the above bands when plotting
model = parsnip.load_model("plasticc") # Options: "ps1", "plasticc" - If changed, the bands and below dataset will need to be changed.
Curve = namedtuple('Curve', 'time flux results')

parsnip_data = "/epyc/data/parsnip_tests/"

# Path to the dataset you want to train the model on for classifications
# plasticc_test_slim is the first 5000 entries of the full platicc dataset.
# You can alternatively use any dataset or subset of data you choose.
dataset_path = os.path.join(parsnip_data, "data/plasticc_test_slim.h5")

# Path to predictions from the model
predictions_path = os.path.join(parsnip_data, "predictions/parsnip_predictions_plasticc_train_aug_100.h5")

In [59]:
# Prepares the dataset, removing any objects whose type ParSNIP can not explicitly classify
# NOTE: If not using Plasticc, some of this may need tweaking
plasticc_dataset = parsnip.load_dataset(dataset_path)
plasticc_dataset = plasticc_dataset[~(plasticc_dataset.meta['type'] == 'CaRT')]
plasticc_dataset = plasticc_dataset[~(plasticc_dataset.meta['type'] == 'ILOT')]

Parsing 'plasticc_test_slim.h5' as PLAsTiCC dataset...
Rejecting 0 non-supernova-like light curves.
Dataset contains 5000 light curves.


In [4]:
# Train a classifier
classifier = parsnip.Classifier()
training_predictions = Table.read(predictions_path)
training_classifications = classifier.train(training_predictions)

Training classifier with keys:
    color
    color_error
    s1
    s1_error
    s2
    s2_error
    s3
    s3_error
    luminosity
    luminosity_error
    reference_time_error
[100]	valid_0's multi_logloss: 0.495647
[100]	valid_0's multi_logloss: 0.509514
[100]	valid_0's multi_logloss: 0.536781
[100]	valid_0's multi_logloss: 0.531477
[100]	valid_0's multi_logloss: 0.469776
[100]	valid_0's multi_logloss: 0.518461
[100]	valid_0's multi_logloss: 0.532923
[100]	valid_0's multi_logloss: 0.488277
[100]	valid_0's multi_logloss: 0.491106
[100]	valid_0's multi_logloss: 0.549328


## Modifying the datasets

In [5]:
def restrict_bands(dataset, bands):
    '''For each sample in the dataset, removes all points that are not part of the given bands'''
    
    # Makes a copy of the dataset so that the original is left unmodified
    modified_dataset = lcdata.Dataset(dataset.meta.copy(), dataset.light_curves.copy())
    
    # Group the light curves by band
    for i in range(0, len(modified_dataset)):
        light_curve = dataset.light_curves[i].group_by('band')
        group_count = len(light_curve.groups.indices) - 1

        # Mask out all points that are part of any band that is being removed for this test
        mask = np.ones(len(light_curve), dtype=bool)
        for n in range(0, group_count):
            if light_curve.groups.keys["band"][n] in bands: continue
            mask[light_curve.groups.indices[n] : light_curve.groups.indices[n + 1]] = False
            
        # Apply the mask and return to the original sorting method to prevent possible issues
        modified_dataset.light_curves[i] = light_curve[mask]
        modified_dataset.light_curves[i].sort("time")
    
    return modified_dataset 

In [6]:
def restrict_observation_count(dataset, max_observations_per_band):
    '''Removes observations past the max number of observations specified.'''
    
    # Makes a copy of the dataset so that the original is left unmodified
    modified_dataset = lcdata.Dataset(dataset.meta.copy(), dataset.light_curves.copy())
    
    # For each object in the dataset, limit the number of observations in each band
    for i in range(0, len(modified_dataset)):

        # Group the light curves by band so that each band can be manipulated individually
        light_curve = dataset.light_curves[i].group_by('band')
        group_count = len(light_curve.groups.indices) - 1

        # Create a mask to remove points from each group/band
        mask = np.ones(len(light_curve), dtype=bool)
        for n in range(0, group_count):
            lower_bound = light_curve.groups.indices[n]
            upper_bound = light_curve.groups.indices[n + 1]

            # Otherwise, only mask out the portion that would be above the point cutoff
            # Note: A cutoff of "None" retains all original observations
            if max_observations_per_band is not None and lower_bound + max_observations_per_band < upper_bound:
                mask[lower_bound + max_observations_per_band : upper_bound] = False

        # Apply the mask and return to the original sorting method to prevent possible issues
        modified_dataset.light_curves[i] = light_curve[mask]
        modified_dataset.light_curves[i].sort("time")
    
    return modified_dataset

## Classifying modified data

In [7]:
def limitation_test(sample_size, bands, observation_cutoffs = [], object_index = None):
    '''
    Takes a sample of the given dataset and returns the predictions and classifications for each sample
    along with predictions and classifications made using degraded versions of those sample.
    '''
    
    global model, plasticc_dataset
    
    # If only looking at one object, grab that part of the dataset; otherwise, grab the first X objects
    if not(object_index is None): dataset = plasticc_dataset[object_index]
    else: dataset = plasticc_dataset[0:sample_size]
        
    # Remove unused bands from the dataset
    dataset = restrict_bands(dataset, bands)
    
    # Classify the dataset with points removed for each cutoff given, and once without points removed as a control
    observation_cutoffs.insert(0, None)
    classified_data = { "used_bands": bands, "used_cutoffs": observation_cutoffs }
    for cutoff in observation_cutoffs:
    
        # Modifies the dataset with a cutoff for observations, then generates predictions and classifications
        modified_dataset = restrict_observation_count(dataset, cutoff)
        predictions = model.predict_dataset(modified_dataset)
        classifications = classifier.classify(predictions)

        # For every object in the dataset, store the classification and relevant info
        for index in range(0, len(predictions)):
            
            # Stores on a per-object basis to easily compare how an individual object is affected by modifications
            object_id = classifications["object_id"][index]
            object_info = classified_data.get(object_id)
            if object_info is None: object_info = []
            
            # Find the classification considered most likely by the classifier
            top_prediction = None
            for c in classifications.colnames:
                if c == "object_id": continue
                if top_prediction == None or classifications[c][index] > classifications[top_prediction][index]:
                    top_prediction = c
            
            # Uses ParSNIP to predict the full light curve based on modified samples
            predicted_curve = model.predict_light_curve(modified_dataset.light_curves[index], False)
            
            # Collects all the most important info from dataset, predictions, and classifications into one place
            object_info.append({
                "cutoff": cutoff, 
                "truth": modified_dataset.meta['type'][index],
                "prediction": top_prediction,
                "light_curve": modified_dataset.light_curves[index], 
                "predicted_curve": Curve(predicted_curve[0], predicted_curve[1], predicted_curve[2]),
                "SNIa": classifications["SNIa"][index], 
                "SNII": classifications["SNII"][index], 
                "SLSN-I": classifications["SLSN-I"][index], 
                "SNIa-91bg": classifications["SNIa-91bg"][index], 
                "SNIax": classifications["SNIax"][index], 
                "SNIbc": classifications["SNIbc"][index], 
                "TDE": classifications["TDE"][index], 
                "KN": classifications["KN"][index]
            })
            
            classified_data.update({ object_id: object_info })
        
    return classified_data

## Analyzing the Data

In [8]:
def plot_curve(dataset, cutoffs, object_index, cutoff, used_bands, shown_bands, show_scatter):
    '''
    Predicts the light curves for an object, for each modification made to that object, and plots
    those light curves with scatter, error bars, and the unmodified light curve for comparison.
    '''
    
    global model
    
    if "ALL" in used_bands: used_bands = all_bands
    if "ALL" in shown_bands: shown_bands = all_bands
        
    cutoff_value = cutoffs[cutoff - 1] if cutoff > 0 else None
    used_cutoffs = [cutoff_value] if cutoff > 0 else []
    cutoff_index = 1 if cutoff > 0 else 0
        
    # Make classifications for the given settings
    data = limitation_test(1, used_bands, used_cutoffs, object_index)
    object_id = list(data.keys())[2]
    true_class = data[object_id][0]['truth']
    
    fig, ax = plt.subplots(1, 1, figsize = (12, 5))
        
    primary_light_curve = data[object_id][0]["light_curve"]
    primary_predicted_curve = data[object_id][0]["predicted_curve"]
    start_time = np.min(primary_predicted_curve.time)
    
    # Determining the vertical axis limits to show as much as possible without losing any data
    if show_scatter != "off": min_y = min(primary_light_curve["flux"])
    else: min_y = min(primary_predicted_curve.flux.flatten())
    max_y = max(primary_predicted_curve.flux.flatten())
    
    for band in shown_bands:
        if not(band in used_bands): continue
        band_index = all_bands.index(band)
        
        # Plots the predicted light curve for the unmodified sample in the given band
        ax.plot((primary_predicted_curve.time - np.min(start_time)), primary_predicted_curve.flux[0][band_index], 
                            label = f"{len(primary_light_curve)} points (original)", 
                            linestyle = "--", linewidth = 1, color = band_colors[band_index])
        
        light_curve = data[object_id][cutoff_index]["light_curve"]
        light_curve_mask = light_curve["band"] == band

        # Plot the individual observations and their flux error in the given band
        if show_scatter != "off":
            error = light_curve[light_curve_mask]['fluxerr'] if show_scatter == "with error" else np.zeros(len(light_curve[light_curve_mask]['fluxerr']))
            ax.errorbar(light_curve[light_curve_mask]['time'] - start_time, light_curve[light_curve_mask]['flux'], 
                                    yerr = error, fmt = '.', label = band, color = band_colors[band_index])

        if cutoff_value != None:
            predicted_curve = data[object_id][cutoff_index]["predicted_curve"]
            ax.plot(predicted_curve.time - start_time, predicted_curve.flux[0][band_index], 
                                label = f"{len(light_curve)} points", linewidth = 1.5, 
                                color = band_colors[band_index])

            max_y = max(max(predicted_curve.flux.flatten()), max_y)
        
    ax.set_ylim(top = max_y + 2, bottom = min_y - 2)

    # What is the predicted class and it's associated probability?
    predicted_class = data[object_id][cutoff_index]['prediction']
    prediction_probability = data[object_id][cutoff_index][predicted_class]
    title = f"{(prediction_probability * 100):.2f}% {predicted_class}"

    # If the predicted class is incorrect, what is the predicted probability of the true class?
    if predicted_class != true_class: 
        truth_probability = data[object_id][cutoff_index][true_class]
        title = f"{title} ({(truth_probability * 100):.2f}% {true_class})"

    # Make predicted and true class probabilities the title of each plot
    ax.set_title(title)
    fig.suptitle(object_id)
    #plt.legend()

    plt.show();

In [9]:
def plot_interactable_curve(dataset, cutoffs):
    ''' Creates an interactable plot that compares predicted light curves. '''
    
    object_index = widgets.BoundedIntText(min = 0, max = len(plasticc_dataset) - 1, step = 1, value = 0, description = "Object Index:", continuous_update = False)
    cutoff = widgets.IntSlider(min = 0, max = len(cutoffs), step = 1, value = 0, description = "Cutoff:", continuous_update = False)
    used_bands = widgets.SelectMultiple(options = ["ALL"] + all_bands, rows = 3, value = ["ALL"], description = "Use Bands:")
    shown_bands = widgets.SelectMultiple(options = ["ALL"] + all_bands, rows = 3, value = ["ALL"], description = "See Bands:")
    show_scatter = widgets.RadioButtons(options = ['with error', 'no error', 'off'], description = "Scatter:")

    ui = widgets.HBox([
        widgets.VBox([object_index, cutoff, show_scatter]), 
        widgets.VBox([used_bands, shown_bands])
    ])
    
    output = widgets.interactive_output(
        plot_curve, {
            "dataset": fixed(dataset), 
            "cutoffs": fixed(cutoffs), 
            "object_index": object_index,
            "cutoff": cutoff, 
            "used_bands": used_bands, 
            "shown_bands": shown_bands, 
            "show_scatter": show_scatter
        });
    
    display(output, ui)

In [10]:
def print_classifications(data):
    ''' A debug function to quickly check what probabilities are being assigned to each class for each object and modification.'''
    for object_id in data:
        if object_id == "used_bands" or object_id == "used_cutoffs": continue
        print("==============================")
        print(f"OBJECT {object_id}, TRUTH - {data[object_id][0]['truth']}")
        print("==============================")
        for index in range(0, len(data[object_id])):
            print("------------------------------")
            print(f"{len(data[object_id][index]['light_curve'])} POINTS USED")
            print("------------------------------")
            for key in data[object_id][index]:
                if key != "cutoff" and key != "light_curve" and key != "truth" and key != "predicted_curve": 
                    text = f"{key}: {data[object_id][index][key]}"
                    print(text)

In [48]:
def print_accuracy(data):
    '''
    ~~FUNCTION IS A WORK IN PROGRESS AND NOT VERY HELPFUL AT THE MOMENT~~
    Prints a table that shows us how accurate ParSNIP's predictions are to true classifications.
    Each column is the true type. Each row is a possible predicted type.
    ParSNIP assigns each possible predicted type a percentage chance that it is the true type.
    The value in this table is found by summing all those percentages, then dividing by the total number of
    objects that are the true type of that column (the "Count"). 
    Values closer to 100 mean ParSNIP generally leans towards that type.
    '''
    
    # Table structure - Plasticc dataset has no KN objects, but ParSNIP can still predict them.
    table = pandas.DataFrame({
        "SNIa": { "SNIa" : 0., "SNII": 0, "SLSN-I": 0, "SNIa-91bg": 0, "SNIax": 0, "SNIbc": 0, "TDE": 0., "KN": 0., "Count": 0 },
        "SNII": { "SNIa" : 0., "SNII": 0, "SLSN-I": 0, "SNIa-91bg": 0, "SNIax": 0, "SNIbc": 0, "TDE": 0., "KN": 0., "Count": 0 },
        "SLSN-I": { "SNIa" : 0., "SNII": 0, "SLSN-I": 0, "SNIa-91bg": 0, "SNIax": 0, "SNIbc": 0, "TDE": 0., "KN": 0., "Count": 0 },
        "SNIa-91bg": { "SNIa" : 0., "SNII": 0, "SLSN-I": 0, "SNIa-91bg": 0, "SNIax": 0, "SNIbc": 0, "TDE": 0., "KN": 0., "Count": 0 },
        "SNIax": { "SNIa" : 0., "SNII": 0, "SLSN-I": 0, "SNIa-91bg": 0, "SNIax": 0, "SNIbc": 0, "TDE": 0., "KN": 0., "Count": 0 },
        "SNIbc": { "SNIa" : 0., "SNII": 0, "SLSN-I": 0, "SNIa-91bg": 0, "SNIax": 0, "SNIbc": 0, "TDE": 0., "KN": 0., "Count": 0 },
        "TDE": { "SNIa" : 0., "SNII": 0, "SLSN-I": 0, "SNIa-91bg": 0, "SNIax": 0, "SNIbc": 0, "TDE": 0., "KN": 0., "Count": 0 },
        #"KN": { "SNIa" : 0., "SNII": 0, "SLSN-I": 0, "SNIa-91bg": 0, "SNIax": 0, "SNIbc": 0, "TDE": 0., "KN": 0., "Count": 0 }
    })
    
    # For each object, add its percentage chance for each type to their row
    for object_id in data:
        if object_id == "used_bands" or object_id == "used_cutoffs": continue
        truth = data[object_id][0]['truth']
        table.loc["Count", truth] += 1
        #for index in range(0, len(data[object_id])):
        index = 0
        for key in data[object_id][index]:
            if not(key in ["cutoff", "light_curve", "truth", "prediction", "predicted_curve"]): 
                #if data[object_id][index][key] > 0.9:
                    #print(f"{key}, {truth} += {data[object_id][index][key]} ")
                table.loc[key, truth] += data[object_id][index][key]
          
    # Go through each column of true values and turn them into a percentage of the total objects that are truly this type
    for col in table:
        for index, row in table.iterrows():
            if index == "Count": continue
            table.loc[index, col] = (table.loc[index, col] * 100) / table.loc["Count", col]
        
    print(table)

## Testing

In [12]:
plot_interactable_curve(plasticc_dataset, [32, 16, 8, 4])

Output()

HBox(children=(VBox(children=(BoundedIntText(value=0, description='Object Index:', max=4965), IntSlider(value=…

In [13]:
# TODO: Make a plot where x axis is the point cutoff, y axis is the classifications and it is chopped into grids. 
#       Each block should either have a value or be colored based on how likely that classification is for that cutoff
#       And possibly plot a stacked version for every sample combined.

In [16]:
test_classifications = limitation_test(4500, all_bands)

Preprocessing dataset: 100%|██████████| 4500/4500 [00:09<00:00, 470.37it/s]
Preprocessing dataset: 100%|██████████| 4500/4500 [00:09<00:00, 464.65it/s]


In [47]:
print_accuracy(test_classifications)

                  SNIa         SNII     SLSN-I  SNIa-91bg       SNIax  \
SNIa         37.339582     1.966538   1.986523   2.102705   19.539480   
SNII         29.009998    42.498728  31.679757   6.264664   32.390906   
SLSN-I        1.394237     1.025496  58.710483   2.899390    0.260742   
SNIa-91bg     5.830061     1.419360   1.515628  78.447973   12.081801   
SNIax         7.411273     5.372542   0.331822   3.041586   15.011631   
SNIbc         3.610319     8.953219   4.771415   5.719727   14.299634   
TDE          15.122160    37.847889   0.892675   0.253343    4.437745   
KN            0.282370     0.916228   0.111698   1.270612    1.978060   
Count      1777.000000  2226.000000  10.000000  48.000000  121.000000   

                SNIbc        TDE  
SNIa         3.217834   0.416387  
SNII        26.744472   2.354541  
SLSN-I       0.355136   0.343325  
SNIa-91bg   20.336642   0.230993  
SNIax        7.056954   0.447421  
SNIbc       36.420253   1.608646  
TDE          2.143752  4