In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, log_loss
from matplotlib.colors import ListedColormap
import pandas as pd
import time

# 加载数据
iris = load_iris()
X = iris.data[:, 2:4]  # 使用花瓣长度和宽度
y = iris.target

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)


In [None]:
# 训练Softmax回归模型
model = LogisticRegression(multi_class='multinomial', solver='lbfgs', C=100)
model.fit(X_train, y_train)

# 测试集预测与准确率
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print("测试集准确率：", acc)

# 可视化分类边界
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
                     np.arange(y_min, y_max, 0.02))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

plt.figure(figsize=(8,6))
plt.contourf(xx, yy, Z, alpha=0.3, cmap=ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']))
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', cmap=ListedColormap(['#FF0000', '#00FF00', '#0000FF']))
plt.xlabel("Petal length")
plt.ylabel("Petal width")
plt.title("Softmax分类边界")
plt.show()


In [None]:
# 超参数对比
learning_rates = [0.01, 0.1, 1]
max_iters = [50, 100, 200]
results = []

for lr in learning_rates:
    for iters in max_iters:
        model = LogisticRegression(multi_class='multinomial', solver='lbfgs',
                                   C=1/lr, max_iter=iters)
        start = time.time()
        model.fit(X_train, y_train)
        end = time.time()
        train_loss = log_loss(y_train, model.predict_proba(X_train))
        test_loss = log_loss(y_test, model.predict_proba(X_test))
        results.append([lr, iters, train_loss, test_loss, end - start])

# 表格输出
df = pd.DataFrame(results, columns=["Learning Rate", "Max Iter", "Train Loss", "Test Loss", "Train Time (s)"])
print(df)
