In [1]:
import pandas as pd
import numpy as np
import os
from copy import deepcopy

from actsnfink import *
from sklearn.ensemble import RandomForestClassifier

from sklearn.model_selection import cross_validate
from sklearn.metrics import precision_score
import pickle

In [2]:
input_dir = '/media/ELAsTICC/Fink/first_year/early_SNIa/all_features/'

flist = os.listdir(input_dir)
flist.remove('.ipynb_checkpoints')

data_list = []
for fname in flist:
    data_temp = pd.read_csv(input_dir + fname, index_col=False)
    
    col_remove = []
    for colname in data_temp.keys():
        if 'Unnamed' in colname:
            col_remove.append(colname)
            
    if len(col_remove) > 0:
        data_temp.drop(columns=col_remove, inplace=True)

    data_list.append(data_temp)
    
    print(fname, len(data_temp.keys()), data_temp.shape[0], np.unique(data_temp['classId'].values))
    
data_pd = pd.concat(data_list, ignore_index=True)

print('Total: ', data_pd.shape[0])

class_111.csv 48 2507323 [111]
class_131.csv 48 170026 [131]
class_212.csv 48 2230557 [212]
class_214.csv 48 1301841 [214]
class_124.csv 48 12923 [124]
class_123.csv 48 4023 [123]
class_135.csv 48 2761 [135]
class_213.csv 48 429117 [213]
class_221.csv 48 453542 [221]
class_113.csv 48 2451398 [113]
class_115.csv 48 49795 [115]
class_133.csv 48 4043 [133]
class_132.csv 48 30705 [132]
class_121.csv 48 5 [121]
class_134.csv 48 12833 [134]
class_211.csv 48 47713 [211]
class_112.csv 48 583007 [112]
class_114.csv 48 60497 [114]
Total:  10352109


In [3]:
features_names_rep = list(data_pd.keys())
for name in ['diaObjectId', 'alertId', 'mwebv_err', 'mwebv_err.1']:
    features_names_rep.remove(name)

In [4]:
data_unique = data_pd.drop_duplicates(subset=features_names_rep, keep='first')

# separate train and test per object
objects = np.unique(data_unique['diaObjectId'].values)

objects_train = np.random.choice(objects, size=int(len(objects)/2), replace=False)
train_flag = np.isin(data_unique['diaObjectId'].values, objects_train)

data_train_all = data_unique[train_flag]
data_test_all = data_unique[~train_flag]

# separate ias and nonias
train_ia_flag = data_train_all['classId'].values == 111
data_train_ia = data_train_all[train_ia_flag]
data_train_others = data_train_all[~train_ia_flag]

test_ia_flag = data_test_all['classId'].values == 111
data_test_ia = data_test_all[test_ia_flag]
data_test_others = data_test_all[~test_ia_flag]

In [16]:
ntrain = 500000
ntest = min(1000000, data_test_ia.shape[0])

# construct a sample 50/50  for Ia/others
data_train_use = pd.concat([data_train_ia.sample(n=ntrain, replace=False),
                           data_train_others.sample(n=ntrain, replace=False)], 
                           ignore_index=True)
data_train_use = data_train_use.sample(frac=1, replace=False)

data_test_use = pd.concat([data_test_ia.sample(n=ntest, replace=False),
                           data_test_others.sample(n=ntest, replace=False)],
                         ignore_index=True)
data_test_use = data_test_use.sample(frac=1, replace=False)

In [17]:
data_train_use.drop(columns=['mwebv_err', 'mwebv_err.1'], inplace=True)
data_test_use.drop(columns=['mwebv_err', 'mwebv_err.1'], inplace=True)

In [18]:
data_train_features = deepcopy(data_train_use[features_names_rep[1:]])
data_train_labels = data_train_use['classId'].values == 111

data_test_features = deepcopy(data_test_use[features_names_rep[1:]])
#data_test_features = deepcopy(data_test_use[features_use])

data_test_labels = data_test_use['classId'].values == 111

In [8]:
data_train_use.to_csv('/media/ELAsTICC/Fink/first_year/early_SNIa/final_model/train.csv', 
                      index=False)

In [9]:
data_test_use.to_csv('/media/ELAsTICC/Fink/first_year/early_SNIa/final_model/test.csv',
                       index=False)

In [19]:
data_test_use.shape

(1932828, 46)

In [42]:
nest = 50
seed = 42
max_depth = 50
n_jobs = 20
min_samples_leaf=0.00001

clf = RandomForestClassifier(n_estimators=nest, random_state=seed,
                             max_depth=max_depth, n_jobs=n_jobs, 
                             min_samples_leaf=min_samples_leaf)
clf.fit(data_train_features, data_train_labels)

RandomForestClassifier(max_depth=50, min_samples_leaf=1e-05, n_estimators=50,
                       n_jobs=20, random_state=42)

In [49]:
pred = clf.predict(data_test_features)

In [50]:
data_test_features.shape

(1932828, 43)

In [51]:
sum(data_test_labels)

966414

In [52]:
clf.score(data_train_features, data_train_labels)

0.905351

In [53]:
clf.score(data_test_features, data_test_labels)

0.8212381029248335

In [54]:
sum(data_test_labels[pred == 1])/sum(pred)

0.783404812374364

In [55]:
filename = '../data/earlysnia_elasticc_small.pkl'
pickle.dump(clf, open(filename, 'wb'))

In [16]:
np.array(features_names_rep[1:])[clf.feature_importances_.argsort()]

array(['nrise_u', 'nrise_g', 'mse_u', 'a_u', 'b_u', 'mse_g', 'nrise_Y',
       'a_g', 'b_g', 'nrise_z', 'c_u', 'mse_Y', 'nrise_r', 'c_g',
       'snratio_u', 'a_Y', 'b_Y', 'mse_z', 'snratio_g', 'nrise_i', 'c_Y',
       'a_z', 'c_z', 'b_z', 'mse_r', 'decl', 'ra', 'snratio_z', 'mse_i',
       'snratio_i', 'snratio_r', 'b_r', 'snratio_Y', 'a_i', 'c_i', 'a_r',
       'c_r', 'b_i', 'hostgal_dec', 'hostgal_snsep', 'hostgal_ra',
       'hostgal_zphot_err', 'hostgal_zphot'], dtype='<U17')

### Crossvalidate

In [18]:
scoring = ['precision_macro', 'recall_macro']

In [20]:
nest = 30
max_depth = 30
n_jobs = 20
cv=10

clf = RandomForestClassifier(n_estimators=nest, random_state=seed,
                             max_depth=max_depth)
scores = cross_validate(clf, data_train_use[list(data_train_use.keys())[3:]], 
                        data_train_use['classId'].values == 111, scoring=scoring,
                       return_train_score=True, cv=cv)

In [21]:
scores['test_precision_macro']

array([0.85566007, 0.85508283, 0.85397706, 0.85309592, 0.85130072,
       0.85452727, 0.85328974, 0.85552044, 0.85523075, 0.8555537 ])

In [22]:
scores['train_precision_macro']

array([0.97005914, 0.96994843, 0.97106427, 0.97168394, 0.97099679,
       0.9689634 , 0.97081157, 0.97067437, 0.96986394, 0.9707479 ])