In [1]:
import importlib
import tree
import random_forest

importlib.reload(tree)
importlib.reload(random_forest)

<module 'random_forest' from '/home/julia/uma-random-forest/src/random_forest.py'>

In [2]:
from datasets.mushrooms import MushroomDataset
from random_forest import RandomForestClassifier, TournamentRandomForestClassifier
import math, time
from utils.experiments import grid_search


# Load dataset

In [3]:
path = "../data/mushroom/agaricus-lepiota.data"
dataset = MushroomDataset(path=path)
dataset.clean()
X_train, X_val, y_train, y_val = dataset.split(test_size=0.2, random_state=42)

In [4]:
n_features = round(math.sqrt(X_train.shape[1]))

# Random forest classifier

In [5]:
params_matrix = {
    "n_trees": [100, 200, 300],
    "max_depth": [
        3, 5, 7
    ],
    "max_split_values" : [1000],
    "max_features": [n_features],
}

In [6]:

save_path = "../out/mushrooms/random_forest_classifier.csv"
n_calls = 1
time_start = time.time()
best_params, score, all_results = grid_search(
    params_matrix, RandomForestClassifier, X_train, X_val, y_train, y_val, n_calls, path = save_path
)
print(f"Execution time: {time.time() - time_start}")
print(f"Best params: {best_params}")
print(f"Best score: {score}")
print(f"All rEsults: {all_results}")

  0%|          | 0/25 [00:00<?, ?it/s]INFO:root:RandomForestClassifier: n_trees=10, max_depth=3
INFO:root:DecisionTreeClassifier(max_depth=3) created
INFO:root:Node(split_feature=None, split_val=None, depth=3) created
INFO:root:Node(split_feature=None, split_val=None, depth=3) created
INFO:root:Node(split_feature=21, split_val=0.5, depth=2) created
INFO:root:Node(split_feature=None, split_val=None, depth=3) created
INFO:root:Node(split_feature=None, split_val=None, depth=3) created
INFO:root:Node(split_feature=67, split_val=0.5, depth=2) created
INFO:root:Node(split_feature=27, split_val=0.5, depth=1) created
INFO:root:Node(split_feature=None, split_val=None, depth=1) created
INFO:root:Node(split_feature=37, split_val=0.5, depth=0) created
INFO:root:DecisionTreeClassifier(max_depth=3) created
INFO:root:Node(split_feature=None, split_val=None, depth=3) created
INFO:root:Node(split_feature=None, split_val=None, depth=3) created
INFO:root:Node(split_feature=36, split_val=0.5, depth=2) cre

Execution time: 71.89852547645569
Best params: {'n_trees': 10, 'max_depth': 7, 'max_split_values': 1000, 'max_features': 11}
Best score: 0.9624615384615385
All rEsults: [{'n_trees': 10, 'max_depth': 3, 'max_split_values': 1000, 'max_features': 11, 'accuracy': 0.7341538461538462, 'precision': 0.6611764705882353, 'recall': 1.0, 'f1': 0.7960339943342776}, {'n_trees': 10, 'max_depth': 4, 'max_split_values': 1000, 'max_features': 11, 'accuracy': 0.8763076923076923, 'precision': 0.8074712643678161, 'recall': 1.0, 'f1': 0.8934817170111288}, {'n_trees': 10, 'max_depth': 5, 'max_split_values': 1000, 'max_features': 11, 'accuracy': 0.827076923076923, 'precision': 0.75, 'recall': 1.0, 'f1': 0.8571428571428571}, {'n_trees': 10, 'max_depth': 6, 'max_split_values': 1000, 'max_features': 11, 'accuracy': 0.8578461538461538, 'precision': 0.7849162011173184, 'recall': 1.0, 'f1': 0.8794992175273866}, {'n_trees': 10, 'max_depth': 7, 'max_split_values': 1000, 'max_features': 11, 'accuracy': 0.9624615384615




# Tournament Random forest

In [12]:
params_matrix = {
    "n_trees": [100, 200, 300],
    "max_depth": [3, 5, 7],
    "tournament_size": [3, 5, 7],
    "max_features": [n_features],
}

In [16]:
import importlib
from utils import experiments

importlib.reload(experiments)
from utils.experiments import grid_search

n_calls = 10

best_params, score, all_results = grid_search(params_matrix, TournamentRandomForestClassifier, X_train, X_val, y_train, y_val, n_calls)
print(f"Best params: {best_params}")
print(f"Best score: {score}")
print(f"All rEsults: {all_results}")

100%|██████████| 1/1 [01:26<00:00, 86.28s/it]

Best params: {'n_trees': 10, 'max_depth': 5, 'tournament_size': 7, 'max_features': 5}
Best score: 0.6236307692307693
All rEsults: [{'n_trees': 10, 'max_depth': 5, 'tournament_size': 7, 'max_features': 5, 'accuracy': 0.6236307692307693, 'precision': 0.584507137486008, 'recall': 1.0, 'f1': 0.7362310462544407}]



