File Name: train_samples_classifier 

Description: This file trains a decision tree classifier to classify the original recipe data. This classifier will then be used to assess the generated samples.

In [3]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split 
from sklearn.metrics import accuracy_score, precision_score, recall_score
import pandas as pd 
import pickle

In [4]:
df = pd.read_csv('../../data/unlabelled_data.csv')
df.drop(columns=['Juniper'], inplace=True)

botanicals = df.iloc[:,25:277]
labels = df["KMeans_labels"]

x_train, x_test, y_train, y_test = train_test_split(botanicals, labels, test_size=0.3, stratify=labels)

In [6]:
params = {
    #compare both split evaluation functions
    'criterion' : ['gini', 'entropy'], 
    
    #after some research it's was found a tree in this depth range is approate for a database of this size 
    'max_depth' : [40, 60, 80, 100, 120, 150, 180, 200, 250],  
    
    #this describes what % of each class should be at each leaf node, a higher value here can be used to limit overfitting
    'min_weight_fraction_leaf' : [0.0, 0.1, 0.2, 0.3], 
    
    #this decides if a split is allowed based on how many samples are at that node
    'min_samples_split' : [2, 3, 4, 5, 6, 7, 8, 9, 10]
}


gs_decision_tree = GridSearchCV(
    estimator=DecisionTreeClassifier(),
    param_grid=params,
    scoring='accuracy',
    cv=5
    )


gs_decision_tree.fit(x_train,y_train)

print(gs_decision_tree.best_params_)

pred = gs_decision_tree.predict(x_test)

print(f'accuracy score {accuracy_score(y_test, pred)}')
print(f'precision {precision_score(y_test, pred,average=None)}')
print(f'recall {recall_score(y_test, pred, average=None)}')



{'criterion': 'gini', 'max_depth': 120, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0}
accuracy score 0.9096385542168675
precision [0.66666667 0.88       0.95890411 1.         0.8125     1.
 0.5        0.92857143 0.8        1.        ]
recall [0.66666667 0.95652174 1.         1.         1.         0.85714286
 0.33333333 0.86666667 0.44444444 0.9375    ]


In [10]:
with open("../../models/samples_classifier.pkl", "wb") as f:
    pickle.dump(gs_decision_tree, f)