In [1]:
from pathlib import Path
import os
import sys
if str(Path.cwd().parent) not in sys.path:
    sys.path.append(str(Path.cwd().parent))
    
import warnings
from cycler import cycler
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, StratifiedKFold

from settings.paths import match_path, validation_path
from settings.columns import specz, calculate_colors, create_colors, list_feat
from utils.metrics import nmad, bias ,out_frac, rmse
from utils.preprocessing import create_bins, rename_aper, mag_redshift_selection, prep_wise, missing_input, flag_observation
from utils.crossvalidation import xval_results, save_folds


plt.rcParams["font.size"] = 22
blue = (0, 0.48, 0.70)
orange = (230/255,159/255, 0)
yellow = (0.94, 0.89, 0.26)
pink = (0.8, 0.47, 0.65)
CB_color_cycle = ['#377eb8', '#ff7f00', '#4daf4a',
                  '#f781bf', '#a65628', '#984ea3',
                  '#999999', '#e41a1c', '#dede00']
plt.rcParams['axes.prop_cycle'] = cycler('color', CB_color_cycle)

warnings.filterwarnings("ignore")
%load_ext autoreload
%autoreload 2

In [2]:
from utils.correct_extinction import correction
aper = "PStotal"

In [3]:
# table_dr2 = pd.read_table(os.path.join(match_path,"STRIPE82_DR2_DR16Q1a_unWISE2a_GALEXDR672a.csv"), sep=",")
# table_dr2 = rename_aper(table_dr2)

In [3]:
table_dr4 = pd.read_table(os.path.join(match_path,"STRIPE82_DR4_DR16Q1a_unWISE2a_GALEXDR672a.csv"), sep=",")
table_dr4 = rename_aper(table_dr4)

In [5]:
feat_mag = list_feat(aper=aper, broad = True, narrow = True, galex = True, wise = True)
feat = create_colors(broad = True, narrow=True, wise=True, galex=True, aper=aper)
# data = table_corrected.copy(deep=True)

In [10]:
data = mag_redshift_selection(table_dr4, rmax=22, zmax=5)
data = prep_wise(data)

In [14]:
from settings.columns import wise, galex, splus, aper, calculate_colors, specz, error_splus


In [19]:
data[["u_PStotal", "e_u_PStotal"]]

Unnamed: 0,u_PStotal,e_u_PStotal
0,22.462389,0.637063
1,25.834753,13.570198
2,99.000000,99.000000
3,22.812483,0.915610
4,23.874598,2.208114
...,...,...
37973,20.698160,0.159381
37974,19.455515,0.052979
37975,20.565550,0.129557
37976,20.754107,0.156790


In [28]:
data = missing_input(data, 0.5)

In [38]:

data = flag_observation(data)
data = correction(data)
data = missing_input(data)
data = calculate_colors(data, broad = True, narrow= True, wise = True, galex= True, aper=aper)
data, bins, itvs = create_bins(data = data, bin_size=0.5, return_data = True, var = specz)

In [72]:
# Did two steps of train_test_split just to ensure that some specific targets (9) are in the test set
# These 9 targets were originally used to plot some PDFs in thr first draft of the paper

# sample = pd.read_csv(os.path.join(data_path, "sample_plot_paper_9.txt"), delim_whitespace=True)
# list_ids = sample["ID_1"].apply(lambda x: x.split(" ")[0].split("-")[-1])
# check=False
# k=0
# while check==False:
#     train, test = train_test_split(data, test_size=0.5, random_state=823, stratify=data['Zclass'])
#     train2, test2 = train_test_split(test, test_size=0.5, random_state=124, stratify=test['Zclass'])
#     list_ids_test = test2["ID"].apply(lambda x: x.split(" ")[0].split("-")[-1])
#     compare = [True if i in list_ids_test.values else False for i in list_ids.values]
#     print(compare)
#     check = all(compare)
#     k+=1

[True, True, True, True, True, True, True, True, True]


In [39]:
from utils.preprocessing import split_data
train, test = split_data(data)

In [81]:
# list_ids_test = test["ID"].apply(lambda x: x.split(" ")[0].split("-")[-1])
# compare = [True if i in list_ids_test.values else False for i in list_ids.values]

In [44]:
# train, test = train_test_split(data, random_state=22, stratify=data['Zclass'])
zclass_train = train["Zclass"]

In [17]:
#number of objects with non-observation in GALEX
len(data[data["name"].isna()])

#number of objects with non-observation in WISE
len(data[data["objID_x"].isna()])

2005

In [18]:
flags = ["flag_WISE", "flag_GALEX"]

In [41]:
feat = {}
feat["broad"] =  create_colors(broad = True, narrow=False, wise=False, galex=False, aper=aper)
feat["broad"] = feat["broad"]
feat["broad+narrow"]=  create_colors(broad = True, narrow=True, wise=False, galex=False, aper=aper)
feat["broad+narrow"]=   feat["broad+narrow"]

feat["broad+GALEX+WISE"]=  create_colors(broad = True, narrow=False, wise=True, galex=True, aper=aper)

feat["broad+WISE+narrow"]=  create_colors(broad = True, narrow=True, wise=False, galex=True, aper=aper)

feat["broad+GALEX+WISE+narrow"]=  create_colors(broad = True, narrow=True, wise=True, galex=True, aper=aper)


In [42]:
color_feat = {}
color_feat["broad"] = "#377eb8"
color_feat["broad+narrow"] =  "#ff7f00"
color_feat["broad+GALEX+WISE"]= "#4daf4a"
color_feat["broad+GALEX+WISE+narrow"]= "#f781bf"


In [45]:
# skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=47)
# i=0
# for train_index, val_index in skf.split(train, zclass_train):
#     mag_train_cv, mag_val_cv = train.iloc[train_index], train.iloc[val_index]
#     mag_train_cv.index.name = "index"
#     mag_val_cv.index.name = "index"
#     mag_train_cv.to_csv(os.path.join(validation_path,"trainf"+str(i)+"_latest.csv"), sep=",")
#     mag_val_cv.to_csv(os.path.join(validation_path,"valf"+str(i)+"_latest.csv"), sep=",")
#     i = i+1


save_folds(train, zclass_train)

test.index.name = 'index'
test.to_csv(os.path.join(validation_path, "test_error_replaced.csv"))

In [21]:
z={}

In [31]:
for key, value in feat.items():
    print(key)
    z[key] = xval_results(feat[key], key, save_model=False, save_result=False)


broad
broad+narrow
broad+GALEX+WISE
broad+GALEX+WISE+flags
broad+WISE+narrow
broad+WISE+narrow+flags
broad+GALEX+WISE+narrow
broad+GALEX+WISE+narrow+flags


In [43]:
from utils.metrics import print_metrics
for key, value in feat.items():
    print(key)
    print_metrics(z[key], xval=True)

broad
RMSE 0.6451 0.0056
NMAD 0.2154 0.0045
bias 0.0013 0.0052
n30 0.2259 0.0035
n15 0.4886 0.0067
broad+narrow
RMSE 0.576 0.0032
NMAD 0.1809 0.0029
bias 0.0023 0.0029
n30 0.1874 0.0032
n15 0.4304 0.0047
broad+GALEX+WISE
RMSE 0.4245 0.0102
NMAD 0.1027 0.0025
bias -0.0014 0.0063
n30 0.0704 0.0023
n15 0.2272 0.0037
broad+GALEX+WISE+flags
RMSE 0.4245 0.0102
NMAD 0.1028 0.0026
bias -0.0015 0.0063
n30 0.0702 0.0017
n15 0.2272 0.0039
broad+WISE+narrow
RMSE 0.5334 0.0045
NMAD 0.1424 0.0039
bias 0.0026 0.0053
n30 0.1449 0.0054
n15 0.355 0.0024
broad+WISE+narrow+flags
RMSE 0.5267 0.0047
NMAD 0.1406 0.0023
bias 0.0027 0.0048
n30 0.1419 0.0039
n15 0.3477 0.0034
broad+GALEX+WISE+narrow
RMSE 0.4102 0.0093
NMAD 0.0931 0.0019
bias 0.0013 0.0069
n30 0.067 0.0016
n15 0.2155 0.0042
broad+GALEX+WISE+narrow+flags
RMSE 0.4101 0.0092
NMAD 0.0932 0.0018
bias 0.0012 0.0069
n30 0.0671 0.0015
n15 0.2164 0.0044


In [46]:
# Results when errors > 0.5 are replaced by 99

In [49]:
import pickle
from settings.paths import rf_path
dict_gridsearch= pickle.load(open(os.path.join(rf_path,'GridSearch_broad+GALEX+WISE+narrow+flags.sav'), 'rb')).best_params_
    

In [51]:
z={}
for key, value in feat.items():
    print(key)
    z[key] = xval_results(feat[key], key, dict_gridsearch, save_model=False, save_result=False)


broad
broad+narrow
broad+GALEX+WISE
broad+WISE+narrow
broad+GALEX+WISE+narrow


In [53]:
from utils.metrics import print_metrics_xval
for key, value in feat.items():
    print(key)
    print_metrics_xval(z[key])

broad
RMSE 0.6464 0.0057
NMAD 0.2174 0.0035
bias 0.001 0.0053
n15 0.4921 0.0073
n30 0.2249 0.003
broad+narrow
RMSE 0.5752 0.003
NMAD 0.18 0.0026
bias 0.0007 0.0029
n15 0.4281 0.0034
n30 0.1857 0.0034
broad+GALEX+WISE
RMSE 0.4244 0.0103
NMAD 0.1021 0.0025
bias -0.001 0.0059
n15 0.228 0.0057
n30 0.0703 0.0016
broad+WISE+narrow
RMSE 0.5337 0.0045
NMAD 0.1426 0.0028
bias 0.0021 0.0051
n15 0.3548 0.0031
n30 0.145 0.0036
broad+GALEX+WISE+narrow
RMSE 0.4099 0.0094
NMAD 0.0927 0.0021
bias 0.0009 0.0071
n15 0.2169 0.0038
n30 0.0665 0.0014
