In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from scipy.stats import randint

import pickle
import joblib

# Load data
df = pd.read_csv("mushrooms.csv")

# Separate target first
y = df["class"]
X = df.drop("class", axis=1)

# One-hot encode features only
X = pd.get_dummies(X, drop_first=False)

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)


# Random Forest model
rf = RandomForestClassifier(random_state=42)

# Hyperparameter space
param_dist = {
    "n_estimators": randint(50, 300),  # number of trees
    "max_depth": randint(3, 30),  # max depth of each tree
    "min_samples_split": randint(2, 20),  # min samples to split a node
    "min_samples_leaf": randint(1, 20),  # min samples in each leaf
}

# Randomized Search CV
random_search = RandomizedSearchCV(
    rf, param_distributions=param_dist, n_iter=50, cv=5, random_state=42, n_jobs=-1
)

# Fit model
random_search.fit(X_train, y_train)

# Best model
best_rf = random_search.best_estimator_

# Evaluate
y_pred = best_rf.predict(X_test)
print(classification_report(y_test, y_pred))


# Keep track of final model + features
print("Best hyperparameters:", random_search.best_params_)

with open("rf_model.pkl", "wb") as f:
    pickle.dump(best_rf, f)

joblib.dump(X.columns, "X_columns.pkl")


              precision    recall  f1-score   support

           e       1.00      1.00      1.00       843
           p       1.00      1.00      1.00       782

    accuracy                           1.00      1625
   macro avg       1.00      1.00      1.00      1625
weighted avg       1.00      1.00      1.00      1625

Best hyperparameters: {'max_depth': 26, 'min_samples_leaf': 4, 'min_samples_split': 9, 'n_estimators': 201}


['X_columns.pkl']

In [3]:
from sklearn.metrics import accuracy_score

final_acc = accuracy_score(y_test, y_pred)
print("Final Test Accuracy:", final_acc)


Final Test Accuracy: 1.0
