In [2]:
from pathlib import Path
import os
import sys
if str(Path.cwd().parent) not in sys.path:
    sys.path.append(str(Path.cwd().parent))
    
import warnings
from cycler import cycler
import pandas as pd
import matplotlib.pyplot as plt
from settings.paths import  validation_path, rf_path, bmdn_path, flex_path, match_path
from utils.metrics import print_metrics_xval, print_metrics_test
from utils.preprocessing import rename_aper, prep_wise, missing_input, mag_redshift_selection, flag_observation


plt.rcParams["font.size"] = 22
blue = (0, 0.48, 0.70)
orange = (230/255,159/255, 0)
yellow = (0.94, 0.89, 0.26)
pink = (0.8, 0.47, 0.65)
CB_color_cycle = ['#377eb8', '#ff7f00', '#4daf4a',
                  '#f781bf', '#a65628', '#984ea3',
                  '#999999', '#e41a1c', '#dede00']
plt.rcParams['axes.prop_cycle'] = cycler('color', CB_color_cycle)

warnings.filterwarnings("ignore")
%load_ext autoreload
%autoreload 2

In [4]:
import pickle
pickle.load(open(os.path.join(rf_path,'GridSearch_broad+GALEX+WISE+narrow.sav'), 'rb')).best_params_


{'bootstrap': True,
 'max_depth': None,
 'min_samples_leaf': 2,
 'min_samples_split': 2,
 'n_estimators': 400}

In [5]:
data = pd.read_table(os.path.join(match_path,"STRIPE82_DR4_DR16Q1a_unWISE2a_GALEXDR672a.csv"), sep=",")
data = mag_redshift_selection(data, rmax=22, zmax=5)
data = prep_wise(data)
data = flag_observation(data)
# data = correction(data)
data = missing_input(data)

# Test set
test = pd.read_csv(os.path.join(validation_path,"test.csv"), index_col="index")

# Train set
train = pd.read_csv(os.path.join(validation_path, "train.csv"), index_col="index")

In [6]:
# def fraction_of_non_observation(df, survey="WISE"):
#     # if survey=="WISE":
#     #     print(len(df[df["objID_x"].isna()])/len(df))
#     #     return 
    
#     if survey=="GALEX":
#         print(len(df[df["name"].isna()])/len(df))
#         return 
    
def fraction_of_detection(df, survey="WISE"):
    if survey=="WISE":
        print("W1:", len(df[(df["W1"]!=99)])/len(df))
        print("W2:", len(df[(df["W2"]!=99)])/len(df))
        print("W1 and W2:", len(df[(df["W1"]!=99) & (df["W2"]!=99)])/len(df))
    
    if survey=="GALEX":
        print("FUV:", len(df[df["FUVmag"]!=99])/len(df))
        print("NUV:", len(df[df["NUVmag"]!=99])/len(df))
        print("FUV & NUV:", len(df[(df["FUVmag"]!=99) & (df["NUVmag"]!=99)])/len(df))

    if survey == "both":
        print("W1, W2, FUV & NUV:", len(df[(df["W1"]!=99) & (df["W2"]!=99) & (df["FUVmag"]!=99) & (df["NUVmag"]!=99)])/len(df))
    return

In [7]:
fraction_of_detection(data, survey="WISE")
# fraction_of_detection(train, survey="WISE")
# fraction_of_detection(test, survey="WISE")

fraction_of_detection(data, survey="GALEX")
# fraction_of_detection(train, survey="GALEX")
# fraction_of_detection(test, survey="GALEX")

fraction_of_detection(data, survey="both")

print("Total # of spec sample:", len(data))

W1: 0.9307713191155621
W2: 0.8586769629875418
W1 and W2: 0.8499291122439745
FUV: 0.11055473439715242
NUV: 0.3080148411812615
FUV & NUV: 0.11055473439715242
W1, W2, FUV & NUV: 0.10910681427407921


In [6]:
val.W1

index
15907    14.280789
4418     17.278070
7135     16.536206
9067     16.919188
272      15.735855
           ...    
34673    17.464103
37753    16.813281
4682     16.408280
27694    15.543058
16841    16.887201
Name: W1, Length: 4972, dtype: float64

In [7]:
import numpy as np
from settings.columns import list_feat
id_only_splus = {}
id_wise_galex = {}
id_wise = {}
id_galex = {} 
id_not_wise = {}
id_not_galex = {} 
id_complete_case = {}
feat_mag = list_feat(broad = True, narrow = True, galex = True, wise = True)

for file in np.sort(os.listdir(validation_path)):
    if file.endswith(".csv") and file.startswith("val"):
        print(file)
        val = pd.read_csv(os.path.join(validation_path, file), index_col="index")
        print(val.W1)
        idx =  val[val["objID_x"].isna() & val["name"].isna()].index #without WISE AND GALEX
        print("Without WISE and GALEX (only S-PLUS):", len(idx))

        id_not_wise[file.split(".")[0]] = val.query("W1==99 and W2==99").index #without WISE
        id_wise[file.split(".")[0]] = val.drop(id_not_wise[file.split(".")[0]]).index #with WISE
        print("Without WISE in W1 and W2 (with WISE in at least one band):", len(id_not_wise[file.split(".")[0]]), "(", len(id_wise[file.split(".")[0]]),")")

        id_not_galex[file.split(".")[0]] = val.query("FUVmag == 99 and NUVmag ==99 ").index #without GALEX
        id_galex[file.split(".")[0]] = val.drop(id_not_galex[file.split(".")[0]]).index #with GALEX
        print("Without GALEX in NUV and FUV (with GALEX in at least one band):", len(id_not_galex[file.split(".")[0]]), "(", len(id_galex[file.split(".")[0]]),")")
        
        id_only_splus[file.split(".")[0]] = idx
        id_wise_galex[file.split(".")[0]] = val.drop(idx).index

        id_complete_case[file.split(".")[0]] = val[(val[feat_mag]<50).all(axis=1)].index #all bands have valid measurements
        print("Have valid measurements in ALL bands:", len(id_complete_case[file.split(".")[0]]))
        # print(len(idx), len(id_wise[file.split(".")[0]] ), len(id_galex[file.split(".")[0]]), len(id_complete_case[file.split(".")[0]]))
        # print(len(val.drop(idx)))
        print("-----------")

valf0.csv
index
33650    17.924442
26140    16.393015
20290    16.098439
9177     17.374940
13119    17.795592
           ...    
7007     15.844301
10239    15.933164
15507    14.824965
9157     17.098956
32165    18.129544
Name: W1, Length: 4973, dtype: float64
Without WISE and GALEX (only S-PLUS): 277
Without WISE in W1 and W2 (with WISE in at least one band): 301 ( 4672 )
Without GALEX in NUV and FUV (with GALEX in at least one band): 3427 ( 1546 )
Have valid measurements in ALL bands: 543
-----------
valf0_error_replaced.csv
index
33650    17.924442
26140    16.393015
20290    16.098439
9177     17.374940
13119    17.795592
           ...    
7007     15.844301
10239    15.933164
15507    14.824965
9157     17.098956
32165    18.129544
Name: W1, Length: 4973, dtype: float64
Without WISE and GALEX (only S-PLUS): 277
Without WISE in W1 and W2 (with WISE in at least one band): 301 ( 4672 )
Without GALEX in NUV and FUV (with GALEX in at least one band): 3427 ( 1546 )
Have valid measur

In [10]:
import numpy as np
from settings.columns import list_feat
id_only_splus = {}
id_wise_galex = {}
id_wise = {}
id_galex = {} 
id_not_wise = {}
id_not_galex = {} 
id_complete_case = {}
feat_mag = list_feat(broad = True, narrow = True, galex = True, wise = True)

for file in np.sort(os.listdir(validation_path)):
    if file.endswith(".csv") and file.startswith("val"):
        print(file)
        
        val = pd.read_csv(os.path.join(validation_path, file), index_col="index")
        print(len(val))
        idx =  val[val["objID_x"].isna() & val["name"].isna()].index #without WISE AND GALEX
        print("Without WISE and GALEX (only S-PLUS):", len(idx))

        id_not_wise[file.split(".")[0]] = val[val["objID_x"].isna()].index #without WISE
        id_wise[file.split(".")[0]] = val.drop(id_not_wise[file.split(".")[0]]).index #with WISE
        print("Without WISE in W1 and W2 (with WISE in at least one band):", len(id_not_wise[file.split(".")[0]]), "(", len(id_wise[file.split(".")[0]]),")")

        id_not_galex[file.split(".")[0]] = val[val["name"].isna()].index #without GALEX
        id_galex[file.split(".")[0]] = val.drop(id_not_galex[file.split(".")[0]]).index #with GALEX
        print("Without GALEX in NUV and FUV (with GALEX in at least one band):", len(id_not_galex[file.split(".")[0]]), "(", len(id_galex[file.split(".")[0]]),")")
        
        id_only_splus[file.split(".")[0]] = idx
        id_wise_galex[file.split(".")[0]] = val.drop(idx).index

        id_complete_case[file.split(".")[0]] = val[(val[feat_mag]<50).all(axis=1)].index #all bands have valid measurements
        print("Have valid measurements in ALL bands:", len(id_complete_case[file.split(".")[0]]))
        # print(len(idx), len(id_wise[file.split(".")[0]] ), len(id_galex[file.split(".")[0]]), len(id_complete_case[file.split(".")[0]]))
        # print(len(val.drop(idx)))
        print("-----------")

valf0.csv
4973
Without WISE and GALEX (only S-PLUS): 277
Without WISE in W1 and W2 (with WISE in at least one band): 301 ( 4672 )
Without GALEX in NUV and FUV (with GALEX in at least one band): 3427 ( 1546 )
Have valid measurements in ALL bands: 543
-----------
valf0_error_replaced.csv
4973
Without WISE and GALEX (only S-PLUS): 277
Without WISE in W1 and W2 (with WISE in at least one band): 301 ( 4672 )
Without GALEX in NUV and FUV (with GALEX in at least one band): 3427 ( 1546 )
Have valid measurements in ALL bands: 543
-----------
valf1.csv
4973
Without WISE and GALEX (only S-PLUS): 274
Without WISE in W1 and W2 (with WISE in at least one band): 297 ( 4676 )
Without GALEX in NUV and FUV (with GALEX in at least one band): 3414 ( 1559 )
Have valid measurements in ALL bands: 532
-----------
valf1_error_replaced.csv
4973
Without WISE and GALEX (only S-PLUS): 274
Without WISE in W1 and W2 (with WISE in at least one band): 297 ( 4676 )
Without GALEX in NUV and FUV (with GALEX in at least o

In [11]:

# Metrics from crossvalidation
print("---RF---")
for file in os.listdir(rf_path):
    if file.startswith("val"):
        print("-----")
        print(file.split("z_")[-1][:-4])
        print("-----")
        results = pd.read_csv(os.path.join(rf_path, file), index_col="index")

        # UNCOMMENT THE LINES OF INTEREST
        print("Complete sample")
        print_metrics_xval(results)
        # print("Complete-case scenario (no missing values)")
        # print_metrics_xval(results, id_complete_case)
        # print("only S-PLUS sample")
        # print_metrics_xval(results, id_only_splus)
        # print("S-PLUS+WISE+GALEX sample")
        # print_metrics_xval(results, id_wise_galex)
        # print("with WISE")
        # print_metrics_xval(results, id_wise)
        # print("without WISE")
        # print_metrics_xval(results, id_not_wise)
        # print("with GALEX")
        # print_metrics_xval(results, id_galex)
        # print("without GALEX")
        # print_metrics_xval(results, id_not_galex)


            

---RF---
-----
broad+GALEX+WISE
-----
Complete sample
RMSE 0.4248 0.0103
NMAD 0.102 0.0023
bias -0.001 0.0062
n15 0.2283 0.0047
n30 0.0704 0.0016
-----
broad+WISE+narrow
-----
Complete sample
RMSE 0.5338 0.0049
NMAD 0.1414 0.0031
bias 0.002 0.0052
n15 0.3563 0.003
n30 0.1451 0.0039
-----
broad+narrow
-----
Complete sample
RMSE 0.5755 0.0031
NMAD 0.1795 0.002
bias 0.0012 0.0029
n15 0.4283 0.004
n30 0.1848 0.0027
-----
broad+GALEX+WISE+narrow
-----
Complete sample
RMSE 0.4097 0.0095
NMAD 0.0931 0.0019
bias 0.0008 0.0066
n15 0.2167 0.0054
n30 0.0662 0.0011
-----
broad
-----
Complete sample
RMSE 0.6468 0.006
NMAD 0.2173 0.0031
bias 0.0011 0.0049
n15 0.4921 0.0046
n30 0.2248 0.0027


In [18]:
# Random Forest
rf_all = pd.read_csv(os.path.join(rf_path,"test_z_broad+GALEX+WISE+narrow.csv"), index_col=0)
rf_broad = pd.read_csv(os.path.join(rf_path,"test_z_broad+GALEX+WISE.csv"), index_col=0)

# BMDN
bmdn_all = pd.read_csv(os.path.join(bmdn_path,"crossval_model_dr4_BNWG", "Results_DF.csv"))
bmdn_broad = pd.read_csv(os.path.join(bmdn_path,"crossval_model_dr4_BWG", "Results_DF.csv"))
bmdn_all.index = rf_all.index
bmdn_broad.index = rf_broad.index
# FlexCoDE
flex_all = pd.read_csv(os.path.join(flex_path,"test_z_broad+GALEX+WISE+narrow.csv"))
flex_broad = pd.read_csv(os.path.join(flex_path,"test_z_broad+GALEX+WISE.csv"))
flex_all.index = rf_all.index
flex_broad.index = rf_broad.index


In [12]:
# Metrics from testing set
print("---RF---")

id_no_wise = test[test["objID_x"].isna()].index
print("broad+GALEX+WISE - no WISE (W1 and W2)")
print_metrics_test(test.loc[id_no_wise].Z.to_numpy(), rf_broad.loc[id_no_wise].z_pred.to_numpy())
print("broad+GALEX+WISE+narrow - no WISE (W1 and W2)")
print_metrics_test(test.loc[id_no_wise].Z.to_numpy(), rf_all.loc[id_no_wise].z_pred.to_numpy())


id_no_galex = test[test["name"].isna()].index
print("broad+GALEX+WISE - no GALEX (FUV and NUV)")
print_metrics_test(test.loc[id_no_galex].Z.to_numpy(), rf_broad.loc[id_no_galex].z_pred.to_numpy())
print("broad+GALEX+WISE+narrow - no GALEX (FUV and NUV)")
print_metrics_test(test.loc[id_no_galex].Z.to_numpy(), rf_all.loc[id_no_galex].z_pred.to_numpy())

---RF---
broad+GALEX+WISE - no WISE (W1 and W2)
RMSE 0.6115
NMAD 0.1535
bias -0.0115
n15 0.3286
n30 0.127
broad+GALEX+WISE+narrow - no WISE (W1 and W2)
RMSE 0.5951
NMAD 0.1352
bias -0.0123
n15 0.3327
n30 0.1109
broad+GALEX+WISE - no GALEX (FUV and NUV)
RMSE 0.472
NMAD 0.1145
bias -0.0035
n15 0.2623
n30 0.082
broad+GALEX+WISE+narrow - no GALEX (FUV and NUV)
RMSE 0.4556
NMAD 0.1024
bias 0.0027
n15 0.2555
n30 0.0789


In [28]:
# Metrics from testing set
print("---RF---")
print("Without narrow bands")
print_metrics_test(test.Z.to_numpy(), rf_broad.z_pred.to_numpy())
print("With narrow bands")
print_metrics_test(test.Z.to_numpy(), rf_all.z_pred.to_numpy())

print("---FlexCoDE---")
print("Without narrow bands")
print_metrics_test(test.Z.to_numpy(), flex_broad.z_flex_peak.to_numpy())
print("With narrow bands")
print_metrics_test(test.Z.to_numpy(), flex_all.z_flex_peak.to_numpy())

print("---BMDN---")
print("Without narrow bands")
print_metrics_test(test.Z.to_numpy(), bmdn_broad.zphot.to_numpy())
print("With narrow bands")
print_metrics_test(test.Z.to_numpy(), bmdn_all.zphot.to_numpy())


print("---Average---")
print("Without narrow bands")
aver_broad = (rf_broad.z_pred+flex_broad.z_flex_peak+bmdn_broad.zphot)/3
print_metrics_test(test.Z.to_numpy(), aver_broad.to_numpy())
print("With narrow bands")
aver_all = (rf_all.z_pred+flex_all.z_flex_peak+bmdn_all.zphot)/3
print_metrics_test(test.Z.to_numpy(), aver_all.to_numpy())


---RF---
Without narrow bands
RMSE 0.4229
NMAD 0.1003
bias -0.0019
n15 0.2245
n30 0.0681
With narrow bands
RMSE 0.4084
NMAD 0.0903
bias 0.0028
n15 0.2198
n30 0.0656
---FlexCoDE---
Without narrow bands
RMSE 0.4774
NMAD 0.0845
bias 0.0443
n15 0.2216
n30 0.0841
With narrow bands
RMSE 0.4551
NMAD 0.0392
bias 0.0163
n15 0.2092
n30 0.0779
---BMDN---
Without narrow bands
RMSE 0.4479
NMAD 0.0829
bias 0.0154
n15 0.2049
n30 0.0724
With narrow bands
RMSE 0.4211
NMAD 0.0468
bias 0.0033
n15 0.1889
n30 0.0661
---Average---
Without narrow bands
RMSE 0.4169
NMAD 0.0863
bias 0.0193
n15 0.2037
n30 0.0638
With narrow bands
RMSE 0.3894
NMAD 0.0587
bias 0.0075
n15 0.1876
n30 0.0576
