In [1]:
import os 
import numpy as np
import pickle 
from astropy.table import Table

In [2]:
train_data_head_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/train_final_master_HEAD.FITS'
train_data_phot_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/train_final_master_PHOT.FITS'
test_data_head_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/test_final_master_HEAD.FITS'
test_data_phot_path = '/sps/lsst/users/bbiswas/data/kilonova_datasets/test_final_master_PHOT.FITS'

In [3]:
df_header = Table.read(train_data_head_path, format='fits').to_pandas()
df_phot = Table.read(train_data_phot_path, format='fits').to_pandas()

In [4]:
from kndetect.utils import load_pcs, get_event_type, get_data_dir_path
from kndetect.predict_features import extract_features_all_lightcurves
from kndetect.training import append_y_true_col

In [5]:
pcs = load_pcs()

In [6]:
train_features_df = extract_features_all_lightcurves(df_phot, 
                                                     "SNID", 
                                                     pcs, 
                                                     [b'g', b'r'])

100%|██████████| 22280/22280 [40:58<00:00,  9.06it/s] 


In [7]:
train_features_df = append_y_true_col(features_df=train_features_df,
                                      prediction_type_nos=[149, 150, 151],
                                      meta_df = df_header, 
                                      meta_key_col_name = "SNID",
                                      meta_type_col_name = "SNTYPE")

# Now Train the classifier

In [8]:
from kndetect.training import train_classifier

In [9]:
clf, features_df = train_classifier(train_features_df)



In [10]:
data_dir = get_data_dir_path()
with open(os.path.join(data_dir, 'complete.pkl'), 'wb') as files:
    pickle.dump(clf, files)

In [11]:
# Generate Test features and Predict probabilities

In [12]:
data_dir_path = get_data_dir_path()

In [13]:

df_header_test = Table.read(test_data_head_path, format='fits').to_pandas()
df_phot_test = Table.read(test_data_phot_path, format='fits').to_pandas()

In [14]:
test_features_df = extract_features_all_lightcurves(df_phot_test, 
                                                     "SNID", 
                                                     pcs, 
                                                     [b'g', b'r'])
test_features_df = append_y_true_col(features_df=test_features_df,
                                      prediction_type_nos=[149, 150, 151],
                                      meta_df = df_header_test, 
                                      meta_key_col_name = "SNID",
                                      meta_type_col_name = "SNTYPE")

100%|██████████| 21288/21288 [38:02<00:00,  9.33it/s] 


In [15]:
from kndetect.predict import load_classifier, predict_kn_score

In [16]:
clf1=load_classifier("complete.pkl")

In [17]:
probabilities, filtered_indices = predict_kn_score(clf1, test_features_df)

In [18]:
test_features_df['y_pred_score'] = probabilities.T[1]

In [19]:
test_features_df

Unnamed: 0,coeff1_g,coeff2_g,coeff3_g,residuo_g,maxflux_g,coeff1_r,coeff2_r,coeff3_r,residuo_r,maxflux_r,key,type,type_names,y_true,y_pred_score
0,1.732623e-09,-2.509101e-11,1.097553e+00,1.938472,3813.011456,4.975693e-09,-2.662738e-09,1.250420e+00,4.353821,5518.907806,1612,149,149: KN GRANDMA,True,0.751210
1,1.593189e+00,1.239720e-01,-2.053517e-01,5.750576,1215.978638,1.102308e+00,-3.386281e-09,-8.771612e-09,9.126114,1900.364746,10871,162,162: ILOT,False,0.000000
2,-1.025571e-09,-3.375859e-09,1.044571e+00,1.317318,764.863892,-1.401216e-09,-3.309492e-09,1.104302e+00,0.919006,805.522644,10872,150,150: KN GW170817,True,0.816085
3,-4.666286e-01,-1.631006e-09,1.739795e+00,13.765211,2329.209717,3.035286e-06,1.190303e-09,7.239731e-01,11.834150,1588.494263,11422,180,180: RRLyrae,False,0.033333
4,-2.676815e-01,-6.766489e-10,8.081538e-01,41.273416,20902.640625,-1.497712e-01,2.239404e-01,8.265010e-01,35.546320,11635.100586,13390,180,180: RRLyrae,False,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21283,-3.042166e-09,1.438937e+00,1.222298e+00,4.003843,1325.314697,1.071865e+00,1.785420e-01,1.248190e-09,1.808287,1154.043457,137049400,183,183: PHOEBE,False,0.066667
21284,1.925799e-01,3.392688e-01,6.042987e-01,1.606676,229.678802,-1.666384e-09,-1.315977e-09,8.413959e-01,1.900921,785.573853,137051059,181,M 181: dwarf_flares,False,0.626033
21285,7.285402e-01,-4.506676e-10,2.336910e-01,1.023906,374.421295,1.046572e+00,-1.830059e-09,-1.155196e-08,1.273695,341.236908,137054088,143,143: Iax,False,0.000000
21286,-1.846909e-03,-4.560111e-09,9.970787e-01,1.367441,1000.960266,-2.760333e-09,-3.718945e-09,9.958187e-01,2.048643,3378.102295,137071432,181,M 181: dwarf_flares,False,0.809916
