In [1]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建高斯朴素贝叶斯模型
gnb = GaussianNB()

# 训练模型
gnb.fit(X_train, y_train)

# 预测
predictions = gnb.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, predictions)
print("Accuracy:", accuracy)

Accuracy: 1.0


##### 贝叶斯参数优化

In [5]:
'''
bayesian-optimization是一个基于贝叶斯推理和高斯过程的约束全局优化包，它试图在尽可能少的迭代中找到未知函数的最值。
贝叶斯最优化能够在不需要大量计算资源的情况下，有效探索参数空间
'''
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from bayes_opt import BayesianOptimization
import numpy as np

# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target

# 定义优化目标函数
#先要定义一个目标函数。比如此时，函数输入为随机森林的所有参数，
#输出为模型交叉验证5次的AUC均值，作为我们的目标函数。
#因为bayes_opt库只支持最大值，所以最后的输出如果是越小越好，那么需要在前面加上负号，以转为最大值。
def decision_tree_cv(max_depth, min_samples_split):
    dt = DecisionTreeClassifier(max_depth=int(max_depth), min_samples_split=int(min_samples_split))
    scores = cross_val_score(dt, X, y, cv=5)
    return np.mean(scores)

# 定义Bayesian Optimization对象
optimizer = BayesianOptimization(
    f=decision_tree_cv,
    pbounds={"max_depth": (1, 10), "min_samples_split": (2, 20)},
    random_state=42,
)

# 执行优化过程
optimizer.maximize(init_points=5, n_iter=25)

# 输出最优参数
print("Best parameters found:")
print(optimizer.max["params"])

# 使用最优参数重新训练模型并评估性能
best_max_depth = int(optimizer.max["params"]["max_depth"])
best_min_samples_split = int(optimizer.max["params"]["min_samples_split"])
best_dt = DecisionTreeClassifier(max_depth=best_max_depth, min_samples_split=best_min_samples_split)
best_dt.fit(X, y)
accuracy = best_dt.score(X, y)
print("Test accuracy with best parameters:", accuracy)

|   iter    |  target   | max_depth | min_sa... |
-------------------------------------------------
| [0m1        [0m | [0m0.9667   [0m | [0m4.371    [0m | [0m19.11    [0m |
| [0m2        [0m | [0m0.9667   [0m | [0m7.588    [0m | [0m12.78    [0m |
| [0m3        [0m | [0m0.9333   [0m | [0m2.404    [0m | [0m4.808    [0m |
| [0m4        [0m | [0m0.6667   [0m | [0m1.523    [0m | [0m17.59    [0m |
| [0m5        [0m | [0m0.9667   [0m | [0m6.41     [0m | [0m14.75    [0m |
| [0m6        [0m | [0m0.9667   [0m | [0m8.013    [0m | [0m14.37    [0m |
| [0m7        [0m | [0m0.9667   [0m | [0m7.157    [0m | [0m18.65    [0m |
| [0m8        [0m | [0m0.9667   [0m | [0m5.115    [0m | [0m9.347    [0m |
| [0m9        [0m | [0m0.9667   [0m | [0m8.834    [0m | [0m8.648    [0m |
| [0m10       [0m | [0m0.9667   [0m | [0m7.026    [0m | [0m5.48     [0m |
| [0m11       [0m | [0m0.9667   [0m | [0m10.0     [0m | [0m3.155    [0m 