In [1]:
import os 
import numpy as np
import pandas as pd
import pickle 
from astropy.table import Table

In [2]:
train_data_head_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/train_final_master_HEAD.FITS'
train_data_phot_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/train_final_master_PHOT.FITS'
test_data_head_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/test_final_master_HEAD.FITS'
test_data_phot_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/test_final_master_PHOT.FITS'

In [3]:
df_header = Table.read(train_data_head_path, format='fits').to_pandas()
df_phot = Table.read(train_data_phot_path, format='fits').to_pandas()

In [4]:
from kndetect.utils import load_pcs, get_event_type, get_data_dir_path
from kndetect.features import extract_features_all_lightcurves
from kndetect.training import append_y_true_col

In [5]:
pcs = load_pcs()
data_dir = get_data_dir_path()

In [6]:
use_already_trained_features = False
use_already_trained_models = False
mimic_alerts = False
save_data = True

if mimic_alerts:
    sub_directory = "partial"
else: 
    sub_directory = "complete"

In [7]:
if use_already_trained_features:
    train_features_df = pd.read_csv(os.path.join(data_dir, sub_directory, "train_features.csv"),index_col=0)

else:
    train_features_df = extract_features_all_lightcurves(df_phot, 
                                                         "SNID", 
                                                         pcs, 
                                                         [b'g', b'r'],
                                                         mimic_alerts=mimic_alerts)
    train_features_df = append_y_true_col(features_df=train_features_df,
                                          prediction_type_nos=[149, 150, 151],
                                          meta_df = df_header, 
                                          meta_key_col_name = "SNID",
                                          meta_type_col_name = "SNTYPE")

100%|██████████| 22280/22280 [30:55<00:00, 12.01it/s]


In [8]:
train_features_df

Unnamed: 0,coeff1_g,coeff2_g,coeff3_g,residuo_g,maxflux_g,coeff1_r,coeff2_r,coeff3_r,residuo_r,maxflux_r,key,type,type_names,y_true
0,-4.319741e-11,2.607949e-03,8.730999e-01,0.885027,230.670456,-5.799074e-09,-1.078548e-10,8.311306e-01,1.486240,373.748047,1757,150,150: KN GW170817,True
1,3.548435e-01,-6.470021e-09,7.598155e-01,1.038731,427.853790,7.659615e-01,-5.040123e-09,2.726789e-01,2.054284,788.683228,6415,141,141: 91BG,False
2,9.648955e-01,2.575183e-01,-5.976476e-11,0.941676,270.629425,8.970144e-01,1.025804e-11,9.012200e-02,1.150067,518.246521,7707,103,103: Core collapse Type Ibc,False
3,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,3.095666e-03,-9.974697e-10,1.097973e+00,0.865141,534.019775,8267,151,151: KN Karsen 2017,True
4,4.132830e-01,-5.778576e-09,6.755667e-01,1.064655,263.286224,3.814125e-01,-2.046229e-01,6.647399e-01,1.000773,291.099426,12578,102,102: MOSFIT-Ibc,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22275,1.018561e+00,-7.883219e-09,-1.950012e-10,1.756306,888.318420,1.048414e+00,8.001247e-01,-2.323796e-09,2.026613,666.974243,137062467,112,112: Core collapse Type II,False
22276,8.376594e-01,-1.224038e-09,1.249077e-01,1.517219,1141.677979,9.648918e-01,3.121808e-03,-1.027595e-10,1.321579,1213.934937,137071784,170,170: AGN,False
22277,7.701428e-01,3.264617e-14,2.541590e-01,1.195978,295.870941,8.503447e-01,-5.887991e-10,1.231181e-01,0.745460,408.148193,137071978,143,143: Iax,False
22278,-7.452845e-09,6.499812e-02,7.853952e-01,1.440366,525.301697,-2.146433e-09,1.233759e-02,1.524061e+00,0.950912,445.314880,137079473,151,151: KN Karsen 2017,True


# Now Train the classifier

In [9]:
from kndetect.training import train_classifier

In [10]:
if not use_already_trained_models:
    clf, features_df = train_classifier(train_features_df)
    if save_data:
        with open(os.path.join(data_dir, "models", sub_directory + ".pkl"), 'wb') as files:
            pickle.dump(clf, files)
else:
    from kndetect.predict import load_classifier
    clf = load_classifier(sub_directory + ".pkl")

In [11]:
# Generate Test features and Predict probabilities

In [12]:
df_header_test = Table.read(test_data_head_path, format='fits').to_pandas()
df_phot_test = Table.read(test_data_phot_path, format='fits').to_pandas()

In [13]:
if use_already_trained_features:
    test_features_df = pd.read_csv(os.path.join(data_dir, sub_directory, "test_features.csv"),index_col=0)

else:
    test_features_df = extract_features_all_lightcurves(df_phot_test, 
                                                         "SNID", 
                                                         pcs, 
                                                         [b'g', b'r'],
                                                         mimic_alerts=mimic_alerts)
    test_features_df = append_y_true_col(features_df=test_features_df,
                                          prediction_type_nos=[149, 150, 151],
                                          meta_df = df_header_test, 
                                          meta_key_col_name = "SNID",
                                          meta_type_col_name = "SNTYPE")

100%|██████████| 21288/21288 [29:08<00:00, 12.17it/s]


In [14]:
from kndetect.predict import load_classifier, predict_kn_score

In [15]:
probabilities, filtered_indices = predict_kn_score(clf, test_features_df)

In [16]:
test_features_df['y_pred_score'] = probabilities.T[1]

In [17]:
test_features_df

Unnamed: 0,coeff1_g,coeff2_g,coeff3_g,residuo_g,maxflux_g,coeff1_r,coeff2_r,coeff3_r,residuo_r,maxflux_r,key,type,type_names,y_true,y_pred_score
0,-1.701267e-09,-1.502815e-03,9.905790e-01,2.488623,3813.011456,1.803165e-08,-3.237767e-09,1.250420e+00,4.353821,5518.907806,1612,149,149: KN GRANDMA,True,0.619572
1,1.593189e+00,1.239719e-01,-2.053517e-01,5.750576,1215.978638,1.102284e+00,-3.680246e-09,3.050131e-11,9.126114,1900.364746,10871,162,162: ILOT,False,0.000000
2,-8.353341e-09,-3.133266e-10,1.044624e+00,1.317318,764.863892,2.147154e-05,-2.774969e-09,1.089900e+00,0.920627,805.522644,10872,150,150: KN GW170817,True,0.806835
3,-4.666291e-01,-3.102359e-09,1.739795e+00,13.765211,2329.209717,-1.100664e-05,1.117389e-10,7.239675e-01,11.834136,1588.494263,11422,180,180: RRLyrae,False,0.000000
4,-2.676815e-01,-7.956247e-10,8.081538e-01,41.273416,20902.640625,-1.497711e-01,2.239404e-01,8.265010e-01,35.546320,11635.100586,13390,180,180: RRLyrae,False,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21283,-5.095310e-10,1.445970e+00,1.225910e+00,4.000762,1325.314697,1.071865e+00,1.785420e-01,3.227832e-10,1.808287,1154.043457,137049400,183,183: PHOEBE,False,0.042045
21284,1.925799e-01,3.392688e-01,6.042987e-01,1.606676,229.678802,-6.846556e-10,-3.315549e-10,8.413947e-01,1.900921,785.573853,137051059,181,M 181: dwarf_flares,False,0.600987
21285,7.285401e-01,-2.495702e-10,2.336911e-01,1.023906,374.421295,1.046595e+00,-3.272608e-10,7.844658e-10,1.273695,341.236908,137054088,143,143: Iax,False,0.000000
21286,-1.846900e-03,-4.616774e-09,9.970780e-01,1.367441,1000.960266,-6.074228e-06,-2.292031e-08,9.927710e-01,2.048569,3378.102295,137071432,181,M 181: dwarf_flares,False,0.785229


In [18]:
if save_data:
    train_features_df.to_csv(os.path.join(data_dir, sub_directory, "train_features.csv"))
    test_features_df.to_csv(os.path.join(data_dir, sub_directory, "test_features.csv"))