In [1]:
import os 
import numpy as np
import pandas as pd
import pickle 
from astropy.table import Table

In [2]:
train_data_head_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/train_final_master_HEAD.FITS'
train_data_phot_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/train_final_master_PHOT.FITS'
test_data_head_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/test_final_master_HEAD.FITS'
test_data_phot_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/test_final_master_PHOT.FITS'

In [3]:
df_header = Table.read(train_data_head_path, format='fits').to_pandas()
df_phot = Table.read(train_data_phot_path, format='fits').to_pandas()

In [4]:
from kndetect.utils import load_pcs, get_event_type, get_data_dir_path
from kndetect.features import extract_features_all_lightcurves
from kndetect.training import append_y_true_col

In [5]:
pcs = load_pcs()
data_dir = get_data_dir_path()

In [6]:
use_already_trained_features = False
use_already_trained_models = False
mimic_alerts = True
save_data = True

if mimic_alerts:
    sub_directory = "partial"
else: 
    sub_directory = "complete"

In [7]:
if use_already_trained_features:
    train_features_df = pd.read_csv(os.path.join(data_dir, sub_directory, "train_features.csv"),index_col=0)

else:
    train_features_df = extract_features_all_lightcurves(df_phot, 
                                                         "SNID", 
                                                         pcs, 
                                                         [b'g', b'r'],
                                                         mimic_alerts=mimic_alerts)
    train_features_df = append_y_true_col(features_df=train_features_df,
                                          prediction_type_nos=[149, 150, 151],
                                          meta_df = df_header, 
                                          meta_key_col_name = "SNID",
                                          meta_type_col_name = "SNTYPE")

100%|██████████| 22280/22280 [35:33<00:00, 10.44it/s] 


In [8]:
train_features_df

Unnamed: 0,coeff1_g,coeff2_g,coeff3_g,residuo_g,maxflux_g,coeff1_r,coeff2_r,coeff3_r,residuo_r,maxflux_r,key,current_dates,type,type_names,y_true
0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,-1.295441e-04,3.733547e-11,8.520474e-01,1.725680,373.748047,1757,57661.2481,150,150: KN GW170817,True
1,4.928263e-01,-2.579259e-09,5.578413e-01,0.527078,427.853790,8.560393e-01,-4.833575e-09,1.434888e-01,2.521909,788.683228,6415,57465.4896,141,141: 91BG,False
2,3.197094e-01,3.588890e-12,7.492444e-01,0.614854,217.574081,1.090284e+00,-4.766958e-03,-3.617276e-11,2.237485,477.578766,7707,57528.2030,103,103: Core collapse Type Ibc,False
3,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,-1.654896e-09,4.840368e-08,1.140608e+00,0.720051,239.701599,8267,57668.3579,151,151: KN Karsen 2017,True
4,7.967472e-01,-9.334097e-10,2.444602e-01,1.155829,263.286224,1.121489e+00,-6.510165e-10,2.543369e-08,1.164721,291.099426,12578,57533.3567,102,102: MOSFIT-Ibc,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22275,1.165127e+00,-7.011757e-10,1.263837e-07,1.150831,543.045044,1.403667e+00,4.003274e-01,-5.266488e-09,1.693237,601.617798,137062467,57520.2377,112,112: Core collapse Type II,False
22276,1.188802e+00,2.902721e-12,4.557042e-08,3.301094,926.267822,1.161873e+00,1.096777e-01,3.074452e-08,2.181212,1133.099243,137071784,57687.5039,170,170: AGN,False
22277,8.299908e-01,-1.118831e-10,1.630860e-01,1.593124,295.870941,8.605131e-01,-7.895845e-12,1.088392e-01,0.695266,408.148193,137071978,57621.3001,143,143: Iax,False
22278,-1.390330e-09,7.625778e-02,7.951868e-01,1.967141,525.301697,-6.475022e-04,-1.179644e-10,1.503577e+00,1.196400,445.314880,137079473,57496.2954,151,151: KN Karsen 2017,True


# Now Train the classifier

In [9]:
from kndetect.training import train_classifier

In [10]:
if not use_already_trained_models:
    clf, features_df = train_classifier(train_features_df)
    if save_data:
        with open(os.path.join(data_dir, "models", sub_directory + ".pkl"), 'wb') as files:
            pickle.dump(clf, files)
else:
    from kndetect.predict import load_classifier
    clf = load_classifier(sub_directory + ".pkl")

In [11]:
# Generate Test features and Predict probabilities

In [12]:
df_header_test = Table.read(test_data_head_path, format='fits').to_pandas()
df_phot_test = Table.read(test_data_phot_path, format='fits').to_pandas()

In [13]:
if use_already_trained_features:
    test_features_df = pd.read_csv(os.path.join(data_dir, sub_directory, "test_features.csv"),index_col=0)

else:
    test_features_df = extract_features_all_lightcurves(df_phot_test, 
                                                         "SNID", 
                                                         pcs, 
                                                         [b'g', b'r'],
                                                         mimic_alerts=mimic_alerts)
    test_features_df = append_y_true_col(features_df=test_features_df,
                                          prediction_type_nos=[149, 150, 151],
                                          meta_df = df_header_test, 
                                          meta_key_col_name = "SNID",
                                          meta_type_col_name = "SNTYPE")

100%|██████████| 21288/21288 [34:09<00:00, 10.39it/s]


In [14]:
from kndetect.predict import load_classifier, predict_kn_score

In [15]:
probabilities, filtered_indices = predict_kn_score(clf, test_features_df)

In [16]:
test_features_df['y_pred_score'] = probabilities.T[1]

In [17]:
test_features_df

Unnamed: 0,coeff1_g,coeff2_g,coeff3_g,residuo_g,maxflux_g,coeff1_r,coeff2_r,coeff3_r,residuo_r,maxflux_r,key,current_dates,type,type_names,y_true,y_pred_score
0,2.042297e-01,-1.645565e-09,9.108830e-01,0.000209,3813.011456,9.230618e-01,-1.680576e-09,1.023044e-01,0.000738,5518.907806,1612,2.0000,149,149: KN GRANDMA,True,0.151754
1,1.236020e+00,-1.569641e-09,-4.680368e-09,6.036989,1215.978638,1.016291e+00,2.591295e-02,-1.204437e-09,10.778131,1606.578125,10871,57684.3980,162,162: ILOT,False,0.000000
2,-2.212779e-09,-1.463765e-09,1.044626e+00,2.186456,764.863892,1.456365e-07,-1.835298e-09,1.104328e+00,1.291791,805.522644,10872,57705.1047,150,150: KN GW170817,True,0.642428
3,2.982296e-01,-4.453453e-09,6.560412e-01,16.582328,2168.536621,7.428626e-06,-9.493877e-10,6.703501e-01,13.028280,1423.442017,11422,57687.2331,180,180: RRLyrae,False,0.173077
4,-6.659900e-01,1.784912e+00,2.000000e+00,26.798630,2189.897949,-5.592150e-01,-2.193268e-09,1.636869e+00,19.215868,6558.107910,13390,57532.2883,180,180: RRLyrae,False,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21283,1.022996e+00,5.234718e-12,-7.694322e-04,1.330253,1250.087036,3.856643e-01,2.292315e-01,6.077984e-01,3.664189,1089.336670,137049400,57561.2518,183,183: PHOEBE,False,0.000030
21284,1.574288e-01,3.273697e-01,6.462521e-01,1.714669,229.678802,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,137051059,57662.3631,181,M 181: dwarf_flares,False,0.000000
21285,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,8.124684e-01,-1.200614e-09,2.048080e-01,2.555934,321.823578,137054088,57528.2784,143,143: Iax,False,0.000000
21286,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,1.363102e-03,-1.398756e-09,1.133022e+00,0.962335,3378.102295,137071432,57618.3748,181,M 181: dwarf_flares,False,0.000000


In [18]:
if save_data:
    train_features_df.to_csv(os.path.join(data_dir, sub_directory, "train_features.csv"))
    test_features_df.to_csv(os.path.join(data_dir, sub_directory, "test_features.csv"))