# Preprocess JWST JADES mocks (SPRITZ, Bisigello et al.)




In [1]:
// This is a Scala / Spark cell
import spark.implicits._
import scala.collection.mutable.ListBuffer
import org.apache.spark.sql.functions.lit

// JWST JADES mocks: contains photometry and redshift
val df_scala_0 = spark.read.parquet("/opt/datasets/Mock_cat_physparam_JADES_Deep_1_po-1_2p.pq")
// select required columns
val cols_0 = df_scala_0.columns.toSeq
val cols_0_filt = Seq("z","M","SFR_UV","SFR_IR","LBOL_AGN","12logOH_W") ++: cols_0.filter(_.contains("JWST")) :+ "ID"
val df_scala_0_filt = df_scala_0.select(cols_0_filt.head, cols_0_filt.tail: _*)

// JWST JAED mocks: extra filters
val df_scala_1 = spark.read.parquet("/opt/datasets/Mock_cat_extrafilters_JADES_Deep_1_po-1_2p.pq")
// select required columns
val cols_1 = df_scala_1.columns.toSeq
val cols_1_filt = cols_1.filter(_.contains("LSST")) :+ "ID"
val df_scala_1_filt = df_scala_1.select(cols_1_filt.head, cols_1_filt.tail: _*)

// join dataframes on ID
var df_scala = df_scala_0_filt.join(df_scala_1_filt,Seq("ID"))

df_scala.summary().show()
df_scala

+-------+-----------------+------------------+------------------+-------------------+-------------------+-------------------+------------------+-------------------+--------------------+-------------------+--------------------+------------------+--------------------+-------------------+--------------------+-------------------+--------------------+-------------------+--------------------+-------------------+--------------------+-------------------+--------------------+-------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|summary|               ID|                 z|                 M|             SFR_UV|             SFR_IR|           LBOL_AGN|         12logOH_W|  JWST_NIRCam_F090W|errJWST_NIRCam_F090W|  JWST_NIRCam_F115W|errJWST_NIRCam_F115W| JWST_NIRCam_F150W|errJWST_NIRCam_F150W|  JWST_NIRCam_F200W|errJWST_NIRCam_F200W|  JWST_NIRCam_F277W|errJWST_NIRCam_F277W|  JWST_NIRCam

[ID: double, z: double ... 29 more fields]

In [2]:
// log the fluxes and rescale
val cols_scale = df_scala.columns.toSeq.filter(_.contains("LSST")) ++ df_scala.columns.toSeq.filter(_.startsWith("JWST"))
val cols_scale_list = List(cols_scale)

val t1 = System.nanoTime
for (c <- cols_scale) {
    println(c)
    val logColName = c + "_log"
    df_scala = df_scala.withColumn(logColName, log10(col(c)))
    //println((System.nanoTime - t1) / 1e9d)

    val newColName = logColName + "_scaled"
    val c_avg = df_scala.select(avg(col(logColName))).head()(0)
    //println((System.nanoTime - t1) / 1e9d)
    
    val c_stddev = df_scala.select(stddev(col(logColName))).head()(0)
    //println((System.nanoTime - t1) / 1e9d)
    df_scala = df_scala.withColumn(newColName, ((col(logColName) - c_avg) / c_stddev))
    //println((System.nanoTime - t1) / 1e9d)
}
val duration = (System.nanoTime - t1) / 1e9d
println("time taken:",duration)

val cols_toPandas = Seq("ID","z","M","SFR_UV","SFR_IR","LBOL_AGN","12logOH_W") ++ df_scala.columns.toSeq.filter(_.contains("_scaled"))
val df_toPandas = df_scala.select(cols_toPandas.head, cols_toPandas.tail: _*)
df_toPandas.show()




LSST_g
LSST_i
LSST_r
LSST_u
LSST_y
LSST_z
JWST_NIRCam_F090W
JWST_NIRCam_F115W
JWST_NIRCam_F150W
JWST_NIRCam_F200W
JWST_NIRCam_F277W
JWST_NIRCam_F335M
JWST_NIRCam_F356W
JWST_NIRCam_F410M
JWST_NIRCam_F444W
(time taken:,13.267348013)
+----+--------+--------+---------+---------+--------+---------+-------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------------------------+----------------------------+----------------------------+----------------------------+----------------------------+----------------------------+----------------------------+----------------------------+----------------------------+
|  ID|       z|       M|   SFR_UV|   SFR_IR|LBOL_AGN|12logOH_W|  LSST_g_log_scaled|   LSST_i_log_scaled|   LSST_r_log_scaled|   LSST_u_log_scaled|   LSST_y_log_scaled|   LSST_z_log_scaled|JWST_NIRCam_F090W_log_scaled|JWST_NIRCam_F115W_log_scaled|JWST_NIRCam_F150W_log_scaled|JWST_NIRCam_F200W_log_scaled|JWST_NIRCam_F27

In [3]:
// replace special characters in 
//val cols_mod = new ListBuffer[String]()
//val cols = df_scala.columns.toSeq
//val len_cols = cols.length
//for (i <- 0 to len_cols-1){
//    val col_new = cols(i).replace("(","_").replace(")","")
//    cols_mod += col_new
//}    
//val cols_new = cols_mod.toList
//val df_scala_mod = df_scala.toDF(cols_new:_*)
//println(cols)

In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import sklearn
from sklearn import preprocessing
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.decomposition import PCA
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.metrics import make_scorer, accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, classification_report

In [6]:
# now switch to python and convert the scala spark dataframe into a pandas dataframe
df = df_toPandas.toPandas().fillna(-99)
df['z'] += 1.

In [7]:
# select
cols_ml = [x for x in df.columns.values.tolist() if "JWST" in x]
target = 'z'

# restrict dataset to sources with fewer than 7 missing photometry values
df['missing'] = df[cols_ml].isin([-99]).sum(1)
df = df[df.missing < 7]

for col in cols_ml:
    n_na = len(df[col][df[col] == -99])
    print(col,"has",n_na,"missing values.")

X_train, X_test, y_train, y_test = train_test_split(df[cols_ml], df[target],test_size=0.30)

clf = RandomForestRegressor()
print('Training the model ...')
clf.fit(X_train,y_train)
print('predicting the labels ...')
preds = clf.predict(X_test)
print('... all done.')

nmad = 1.48 * np.median(np.abs(y_test - preds) / y_test)
print("Photo-z NMAD:",np.round(nmad,6))

              ID          z  ...  JWST_NIRCam_F444W_log_scaled  missing
44          44.0   1.526928  ...                    -99.000000        5
428        428.0   1.357948  ...                    -99.000000        5
429        429.0   1.392398  ...                    -99.000000        5
653        653.0   2.738840  ...                     -0.932931        5
880        880.0   1.665224  ...                    -99.000000        5
...          ...        ...  ...                           ...      ...
142380  142380.0  10.168962  ...                     -1.352637        5
142381  142381.0  10.224633  ...                     -0.461080        5
142382  142382.0  10.318645  ...                     -0.413084        5
142384  142384.0  10.535000  ...                     -0.761490        5
142386  142386.0  10.807749  ...                     -0.346965        5

[3411 rows x 23 columns]
JWST_NIRCam_F090W_log_scaled 41686
JWST_NIRCam_F115W_log_scaled 16527
JWST_NIRCam_F150W_log_scaled 5516
JWST_N