In [1]:
import os 
import numpy as np
import pandas as pd
import pickle 
from astropy.table import Table

In [2]:
train_data_head_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/train_final_master_HEAD.FITS'
train_data_phot_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/train_final_master_PHOT.FITS'
test_data_head_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/test_final_master_HEAD.FITS'
test_data_phot_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/test_final_master_PHOT.FITS'

In [3]:
df_header = Table.read(train_data_head_path, format='fits').to_pandas()
df_phot = Table.read(train_data_phot_path, format='fits').to_pandas()

In [4]:
from kndetect.utils import load_pcs, get_event_type, get_data_dir_path
from kndetect.features import extract_features_all_lightcurves
from kndetect.training import append_y_true_col

In [5]:
pcs = load_pcs()
data_dir = get_data_dir_path()

In [6]:
use_already_trained_features = False
use_already_trained_models = False
mimic_alerts = True
save_data = True

if mimic_alerts:
    sub_directory = "partial"
else: 
    sub_directory = "complete"

In [7]:
if use_already_trained_features:
    train_features_df = pd.read_csv(os.path.join(data_dir, sub_directory, "train_features.csv"),index_col=0)

else:
    train_features_df = extract_features_all_lightcurves(df_phot, 
                                                         "SNID", 
                                                         pcs, 
                                                         [b'g', b'r'],
                                                         mimic_alerts=mimic_alerts)
    train_features_df = append_y_true_col(features_df=train_features_df,
                                          prediction_type_nos=[149, 150, 151],
                                          meta_df = df_header, 
                                          meta_key_col_name = "SNID",
                                          meta_type_col_name = "SNTYPE")

100%|██████████| 22280/22280 [35:12<00:00, 10.54it/s]


In [8]:
train_features_df

Unnamed: 0,coeff1_g,coeff2_g,coeff3_g,residuo_g,maxflux_g,coeff1_r,coeff2_r,coeff3_r,residuo_r,maxflux_r,key,current_dates,type,type_names,y_true
0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,-1.295441e-04,3.733547e-11,8.520474e-01,1.725680,373.748047,1757,57690.4057,150,150: KN GW170817,True
1,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,-2.162157e-09,6.537960e-09,8.402953e-01,2.586591,355.378204,6415,57690.4057,141,141: 91BG,False
2,9.134064e-01,2.179555e-01,3.234515e-02,1.014444,270.629425,8.225509e-01,-1.590835e-09,1.937334e-01,0.967757,518.246521,7707,57690.4057,103,103: Core collapse Type Ibc,False
3,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,-1.654896e-09,4.840368e-08,1.140608e+00,0.720051,239.701599,8267,57690.4057,151,151: KN Karsen 2017,True
4,8.998934e-01,-2.876286e-10,1.293946e-01,1.148403,263.286224,1.121489e+00,-6.510165e-10,2.543369e-08,1.164721,291.099426,12578,57690.4057,102,102: MOSFIT-Ibc,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22275,1.165127e+00,-7.011757e-10,1.263837e-07,1.150831,543.045044,1.403667e+00,4.003274e-01,-5.266488e-09,1.693237,601.617798,137062467,57690.4057,112,112: Core collapse Type II,False
22276,7.207955e-01,-8.300510e-10,2.975757e-01,0.830262,1141.677979,7.687465e-01,-3.069789e-09,2.758771e-01,0.000073,1213.934937,137071784,57690.4057,170,170: AGN,False
22277,5.826164e-01,-4.391905e-09,4.366269e-01,1.181515,295.870941,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,137071978,57690.4057,143,143: Iax,False
22278,-5.137166e-09,8.289812e-02,8.067113e-01,2.122228,525.301697,-1.472420e-09,1.411990e-01,7.406151e-01,1.771357,401.749268,137079473,57690.4057,151,151: KN Karsen 2017,True


# Now Train the classifier

In [9]:
from kndetect.training import train_classifier

In [10]:
if not use_already_trained_models:
    clf, features_df = train_classifier(train_features_df)
    if save_data:
        with open(os.path.join(data_dir, "models", sub_directory + ".pkl"), 'wb') as files:
            pickle.dump(clf, files)
else:
    from kndetect.predict import load_classifier
    clf = load_classifier(sub_directory + ".pkl")

In [11]:
# Generate Test features and Predict probabilities

In [12]:
df_header_test = Table.read(test_data_head_path, format='fits').to_pandas()
df_phot_test = Table.read(test_data_phot_path, format='fits').to_pandas()

In [13]:
if use_already_trained_features:
    test_features_df = pd.read_csv(os.path.join(data_dir, sub_directory, "test_features.csv"),index_col=0)

else:
    test_features_df = extract_features_all_lightcurves(df_phot_test, 
                                                         "SNID", 
                                                         pcs, 
                                                         [b'g', b'r'],
                                                         mimic_alerts=mimic_alerts)
    test_features_df = append_y_true_col(features_df=test_features_df,
                                          prediction_type_nos=[149, 150, 151],
                                          meta_df = df_header_test, 
                                          meta_key_col_name = "SNID",
                                          meta_type_col_name = "SNTYPE")

100%|██████████| 21288/21288 [33:35<00:00, 10.56it/s]


In [14]:
from kndetect.predict import load_classifier, predict_kn_score

In [15]:
probabilities, filtered_indices = predict_kn_score(clf, test_features_df)

In [16]:
test_features_df['y_pred_score'] = probabilities.T[1]

In [17]:
test_features_df

Unnamed: 0,coeff1_g,coeff2_g,coeff3_g,residuo_g,maxflux_g,coeff1_r,coeff2_r,coeff3_r,residuo_r,maxflux_r,key,current_dates,type,type_names,y_true,y_pred_score
0,-1.701267e-09,-1.502815e-03,9.905790e-01,2.488623,3813.011456,1.803165e-08,-3.237767e-09,1.250420e+00,4.353821,5518.907806,1612,57697.1431,149,149: KN GRANDMA,True,0.506574
1,1.163185e+00,-9.205347e-10,3.442886e-09,3.757906,1002.060974,9.461943e-01,-1.553868e-03,6.545767e-11,3.371526,1348.136475,10871,57697.1431,162,162: ILOT,False,0.000014
2,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,1.014029e+00,-3.255174e-09,3.318235e-07,0.338975,805.522644,10872,57697.1431,150,150: KN GW170817,True,0.000000
3,6.208387e-01,-3.508960e-10,3.876444e-01,3.998751,2329.209717,1.123977e+00,-2.933828e-09,2.325028e-07,1.756737,1575.078979,11422,57697.1431,180,180: RRLyrae,False,0.000000
4,-4.352636e-01,-9.784811e-11,9.542391e-01,34.598697,9882.000000,-2.949151e-01,-1.788313e-10,8.172060e-01,35.521548,11146.453125,13390,57697.1431,180,180: RRLyrae,False,0.033333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21283,1.178081e+00,5.443015e-07,-3.088650e-10,1.466874,892.556213,2.207520e-01,3.998960e-12,7.410286e-01,4.632906,1089.336670,137049400,57697.1431,183,183: PHOEBE,False,0.038095
21284,1.574288e-01,3.273697e-01,6.462521e-01,1.714669,229.678802,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000,137051059,57697.1431,181,M 181: dwarf_flares,False,0.000000
21285,5.445467e-01,-3.293925e-13,4.796387e-01,0.995758,374.421295,7.017638e-01,4.872597e-01,3.879600e-01,1.208195,200.605377,137054088,57697.1431,143,143: Iax,False,0.000014
21286,-2.347246e-02,-2.999360e-09,1.034537e+00,1.327719,1000.960266,-6.598237e-03,-5.268800e-10,9.953161e-01,3.317436,3378.102295,137071432,57697.1431,181,M 181: dwarf_flares,False,0.603388


In [18]:
if save_data:
    train_features_df.to_csv(os.path.join(data_dir, sub_directory, "train_features.csv"))
    test_features_df.to_csv(os.path.join(data_dir, sub_directory, "test_features.csv"))