## Let's read in Kelsey's file and play with some classifiers in scikit learn.
%matplotlib widget seems to work in vscode for zoomable plots

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
# Modelling
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, ConfusionMatrixDisplay
from sklearn.model_selection import RandomizedSearchCV, train_test_split
from scipy.stats import randint
from sklearn.model_selection import cross_val_score

# Tree Visualisation
from sklearn.tree import export_graphviz


In [2]:
ddir = "data/"
tcefile = "TESS_ML_20240412.txt"
columns = ["pass","win","type","injnum","period","tzero","power","dur","depth","snr","deptest1","deptest2","stmass","stradius","chisq",
           "bicft","bicplus","bicminus","snrshape","noise","depth_shape2","depth_shape3","mindBIC","23","24","25","26","27","28","29","30","31",
           "Tmag","Teff","logg","injrec","trprob"]
tces = pd.read_csv(ddir+tcefile, names=columns)

## Define some metrics and a test set.

In [3]:
metrics=["period","power","dur","deptest1","deptest2","chisq",
           "bicft","bicplus","bicminus","snrshape","noise","depth_shape2","depth_shape3","mindBIC","23","24","25","26","27","28","29","30","31",
           "Tmag", "logg", "Teff"]
X = np.array(tces[metrics])
X[~np.isfinite(X)] = -9999
X_train, X_test, y_train, y_test = train_test_split(X, tces['injrec'], test_size=0.2)

len(X_train)


161273

In [4]:
print(np.sum(y_test), len(y_test))

5104.0 69117


In [5]:
rf = RandomForestClassifier(n_estimators=10,max_depth=4, min_samples_split=2)
rf.fit(X_train, y_train)
scores= cross_val_score(rf, X_train, y_train, cv=6)
scores.mean()

0.9804430983540323

In [6]:
y_pred = rf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

Accuracy: 0.9820594065135929


In [7]:
from sklearn import tree
print(tree.export_text(rf.estimators_[0]))
#plt.figure(figsize=10,10)
#_ = tree.plot

|--- feature_8 <= 130.23
|   |--- feature_12 <= 6.05
|   |   |--- feature_13 <= 3.95
|   |   |   |--- feature_5 <= 0.91
|   |   |   |   |--- class: 0.0
|   |   |   |--- feature_5 >  0.91
|   |   |   |   |--- class: 0.0
|   |   |--- feature_13 >  3.95
|   |   |   |--- feature_6 <= 10.26
|   |   |   |   |--- class: 0.0
|   |   |   |--- feature_6 >  10.26
|   |   |   |   |--- class: 0.0
|   |--- feature_12 >  6.05
|   |   |--- feature_15 <= 5.50
|   |   |   |--- feature_23 <= 1595.18
|   |   |   |   |--- class: 0.0
|   |   |   |--- feature_23 >  1595.18
|   |   |   |   |--- class: 0.0
|   |   |--- feature_15 >  5.50
|   |   |   |--- feature_5 <= 1.58
|   |   |   |   |--- class: 1.0
|   |   |   |--- feature_5 >  1.58
|   |   |   |   |--- class: 0.0
|--- feature_8 >  130.23
|   |--- feature_13 <= 10.65
|   |   |--- feature_10 <= 19.45
|   |   |   |--- feature_20 <= 26.98
|   |   |   |   |--- class: 0.0
|   |   |   |--- feature_20 >  26.98
|   |   |   |   |--- class: 0.0
|   |   |--- feature

In [8]:
isort = np.argsort(rf.feature_importances_)
for i,m in enumerate(metrics):
    print(m, 100*rf.feature_importances_[i])

period 3.0389166160048777
power 0.0
dur 0.5051852209011063
deptest1 1.086289277373661
deptest2 0.6832601148800549
chisq 2.063590208022201
snr 36.120605353825944
bicft 8.875701273719136
bicplus 9.145564286028709
bicminus 0.07874688108370986
snrshape 8.693675353599021
noise 0.0
depth_shape2 19.436176543524333
depth_shape3 6.256121107960415
mindBIC 0.5403978879835392
23 0.2473401458693971
24 0.22586953722893857
25 1.058043752496841
26 1.0809347077908624
27 0.0
28 0.5013079691100828
29 0.0022547488735519617
30 0.21870344465984665
31 0.14131556906378298
Tmag 0.0


In [19]:
import dtreeviz

viz = dtreeviz.model(rf.estimators_[0], X_train, y_train,
                target_name="target",
                feature_names=metrics, class_names=["FP","PC"])

In [20]:
viz.view(depth_range_to_display=(0, 2))

ExecutableNotFound: failed to execute 'dot', make sure the Graphviz executables are on your systems' PATH

<dtreeviz.utils.DTreeVizRender at 0x1675cb8f0>