# Aaron's example notebook 

Uses `create_model.py` functions to generate and use a Random Forest model on BASE-9 output to identify good vs. bad sampling.

Below are commands to create a `conda` environment containing libraries needed for this project:

```
conda create --name base9 python=3.12 numpy jupyter matplotlib astropy pandas arviz scipy scikit-learn
conda activate base9
pip install diptest
```

In [1]:
from create_model import *
#importing functions from .py file

%load_ext autoreload
%autoreload 2

In [2]:
# directory the res files we will use to train the model 
directory = "data/NGC2682/jw_output"

# creating the data table with the features, using column 0 = age
# Note, this may take a couple minutes to run
features_df = create_features(directory, column=0)
 
features_df

Unnamed: 0,source_id,Width,Upper_bound,Lower_bound,Stdev,SnR,Dip_p,Dip_value,KS_value,KS_p,ESS
0,608154449852709120,0.542966,1.183637,1.103694,0.237557,38.492771,0.000000,0.076963,0.456867,0.000000e+00,10134.731571
1,608303231815505920,0.755195,1.183637,1.103694,0.464201,20.268446,0.003637,0.006149,0.144000,2.699721e-111,10100.933421
2,608141294367793024,1.705887,1.183637,1.103694,0.725193,12.193566,0.000000,0.024109,0.127287,1.124591e-87,9992.886462
3,608068623521152384,1.344996,1.183637,1.103694,0.657152,14.096636,0.000085,0.007507,0.190204,1.606217e-194,9682.699257
4,608038764908563968,1.255172,1.183637,1.103694,0.639640,14.644034,0.000000,0.014813,0.197085,1.205499e-210,9694.481703
...,...,...,...,...,...,...,...,...,...,...,...
1423,604694561637360640,1.555578,1.183637,1.103694,0.709386,12.669910,0.000000,0.021891,0.152263,1.939749e-123,9982.136760
1424,604711505283863808,2.474428,1.183637,1.103694,1.159964,7.203835,0.000000,0.118215,0.215391,7.182757e-248,8964.186543
1425,604703465105196416,1.633171,1.183637,1.103694,0.750779,11.732707,0.000000,0.052372,0.172874,7.621408e-162,9955.299659
1426,604712531781276928,1.240942,1.183637,1.103694,0.596285,15.420267,0.000000,0.012573,0.167347,2.121818e-150,9657.308944


# Creating ML model

In [3]:
# reading in dataframe containing known sampling quality to train model on
df1 = pd.read_csv('data/NGC2682/NGC2682_Age_Stats.csv',sep=',')

# this is the column we will use to train the model (keep only rows that have this column)
label_column_name="Single Sampling"
label_df = df1.loc[~pd.isna(df1[label_column_name])]

In [4]:
# create the model
pipe, X, y, X_train, y_train, X_test, y_test = create_model(features_df, label_df, label_column_name=label_column_name)

# test the model
y_pred = make_preds(pipe, X_test, y_test = y_test)

There are 161 training elements with classification = Bad
There are 161 training elements with classification = Good
Accuracy: 0.955607476635514
              precision    recall  f1-score   support

         Bad       1.00      0.95      0.97       346
        Good       0.82      0.99      0.90        82

    accuracy                           0.96       428
   macro avg       0.91      0.97      0.93       428
weighted avg       0.96      0.96      0.96       428

Feature Importance Ranking:
Width          0.313331
SnR            0.261396
Stdev          0.212242
KS_value       0.138108
Dip_value      0.044622
ESS            0.023788
Dip_p          0.003994
KS_p           0.002519
Upper_bound    0.000000
Lower_bound    0.000000
dtype: float64


In [5]:
save_model(pipe, filename = 'amg_model.pkl')

# Applying new dataset to saved model

In [6]:
directory = 'data/NGC6819/ngc6819_single_resfiles' 

# creating the data table using 'create_features' function with the directory files
# this many take minutes to run
ngc6819_features_df = create_features(directory)
 
ngc6819_features_df


Unnamed: 0,source_id,Width,Upper_bound,Lower_bound,Stdev,SnR,Dip_p,Dip_value,KS_value,KS_p,ESS
0,gaia_2076377646922516096_sin2,1.310742,0.371344,0.17444,0.705594,12.625670,0.0,0.055795,0.194899,5.741272e-206,10270.940503
1,gaia_2076269108813963776_sin2,1.465552,0.371344,0.17444,0.688275,12.954116,0.0,0.026590,0.147045,5.342355e-117,9896.861579
2,gaia_2076395170383622016_sin2,1.609451,0.371344,0.17444,0.793903,11.146351,0.0,0.013034,0.146612,2.256215e-115,9651.344083
3,gaia_2076390192531116544_sin2,1.370839,0.371344,0.17444,0.657100,13.391178,0.0,0.009949,0.157108,1.142071e-134,10104.686208
4,gaia_2076479596566965376_sin2,1.188421,0.371344,0.17444,0.610059,14.706439,0.0,0.068436,0.190924,3.387986e-199,9728.625263
...,...,...,...,...,...,...,...,...,...,...,...
1693,gaia_2076286593616203520_sin2,0.948054,0.371344,0.17444,0.563526,16.162504,0.0,0.012569,0.210435,3.263369e-246,9587.069179
1694,gaia_2076394109541237248_sin2,1.592687,0.371344,0.17444,0.727018,12.184470,0.0,0.016283,0.207952,1.347615e-236,9843.404012
1695,gaia_2076490213726294528_sin2,1.546252,0.371344,0.17444,0.728562,12.026299,0.0,0.017071,0.173333,5.065443e-164,9338.604991
1696,gaia_2076490041927622016_sin2,1.496974,0.371344,0.17444,0.702897,12.574082,0.0,0.016518,0.171621,2.331086e-163,9360.880574


In [7]:
# reading in the saved model
pipe = load_model('amg_model.pkl')

In [8]:
# using the model to predict the labels on the new data
y_pred_6819 = make_preds(pipe, ngc6819_features_df)
y_pred_6819

array(['Bad', 'Bad', 'Bad', ..., 'Bad', 'Bad', 'Good'],
      shape=(1698,), dtype=object)

In [12]:
# check
print(f'{len(np.where(y_pred_6819 == "Good")[0])} are labeled Good and {len(np.where(y_pred_6819 == "Bad")[0])} are labeled Bad')


409 are labeled Good and 1289 are labeled Bad


# Comparing ML labels with Dr. Jeffrey's labels

In [13]:
column_names = ['file', 'star', 'rank']
txt_6819 = pd.read_csv( 'data/NGC6819/ngc6819_checkRes_all.txt',sep='\s+',names=column_names, skiprows=1)
txt_6819 = txt_6819.drop('file', axis=1)

txt_6819

# rank 1 - good sampling (chain converges reasonably)
# rank 2 - poor sampling (no sampling or highly correlated)
# rank 3 - parts of chain seem convergent
# rank 4 - flat distribution 

  txt_6819 = pd.read_csv( 'data/NGC6819/ngc6819_checkRes_all.txt',sep='\s+',names=column_names, skiprows=1)


Unnamed: 0,star,rank
0,2076220554197701504,4
1,2076220829075672960,4
2,2076220966514688128,3
3,2076224784750223232,4
4,2076227116907912960,4
...,...,...
1683,2076616859415037824,3
1684,2076617276038017664,4
1685,2076617447828109440,4
1686,2076617477890443008,2
