|
| 1 | +import numpy as np |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +from sklearn import (datasets, linear_model, naive_bayes, neural_network, neighbors, svm, tree, ensemble, metrics) |
| 4 | +from matplotlib.colors import ListedColormap |
| 5 | + |
| 6 | +# Load a dataset partially |
| 7 | +iris = datasets.load_iris() |
| 8 | +iris.data = iris.data[:,0:2] |
| 9 | +iris.feature_names = iris.feature_names[0:2] |
| 10 | +iris.color = np.array([(1, 0, 0), (0, 1, 0), (0, 0, 1)]) |
| 11 | + |
| 12 | +# Instantiate training models |
| 13 | +models = [ |
| 14 | + {'name': 'linear_model.SGD', 'obj': linear_model.SGDClassifier()}, |
| 15 | + {'name': 'naive_bayes.Gaussian', 'obj': naive_bayes.GaussianNB()}, |
| 16 | + {'name': 'neural_network.MLP', 'obj': neural_network.MLPClassifier()}, |
| 17 | + {'name': 'neighbors.KNN', 'obj': neighbors.KNeighborsClassifier()}, |
| 18 | + |
| 19 | + {'name': 'svm.LinearSVC', 'obj': svm.LinearSVC()}, |
| 20 | + {'name': 'svm.SVC(linear)', 'obj': svm.SVC(kernel='linear')}, |
| 21 | + {'name': 'svm.SVC(poly,2)', 'obj': svm.SVC(kernel='poly', degree=2)}, |
| 22 | + {'name': 'svm.SVC(poly,3)', 'obj': svm.SVC(kernel='poly')}, |
| 23 | + {'name': 'svm.SVC(poly,4)', 'obj': svm.SVC(kernel='poly', degree=4)}, |
| 24 | + {'name': 'svm.SVC(rbf)', 'obj': svm.SVC(kernel='rbf')}, |
| 25 | + {'name': 'svm.SVC(rbf,$\gamma$=1)', 'obj': svm.SVC(kernel='rbf', gamma=1)}, |
| 26 | + {'name': 'svm.SVC(rbf,$\gamma$=4)', 'obj': svm.SVC(kernel='rbf', gamma=4)}, |
| 27 | + {'name': 'svm.SVC(rbf,$\gamma$=16)', 'obj': svm.SVC(kernel='rbf', gamma=16)}, |
| 28 | + {'name': 'svm.SVC(rbf,$\gamma$=64)', 'obj': svm.SVC(kernel='rbf', gamma=64)}, |
| 29 | + {'name': 'svm.SVC(sigmoid)', 'obj': svm.SVC(kernel='sigmoid')}, |
| 30 | + |
| 31 | + {'name': 'tree.DecisionTree(2)', 'obj': tree.DecisionTreeClassifier(max_depth=2)}, |
| 32 | + {'name': 'tree.DecisionTree(4)', 'obj': tree.DecisionTreeClassifier(max_depth=4)}, |
| 33 | + {'name': 'tree.DecisionTree(N)', 'obj': tree.DecisionTreeClassifier()}, |
| 34 | + {'name': 'tree.ExtraTree', 'obj': tree.ExtraTreeClassifier()}, |
| 35 | + |
| 36 | + {'name': 'ensemble.RandomForest(10)', 'obj': ensemble.RandomForestClassifier(n_estimators=10)}, |
| 37 | + {'name': 'ensemble.RandomForest(100)', 'obj': ensemble.RandomForestClassifier()}, |
| 38 | + {'name': 'ensemble.ExtraTrees(10)', 'obj': ensemble.ExtraTreesClassifier(n_estimators=10)}, |
| 39 | + {'name': 'ensemble.ExtraTrees(100)', 'obj': ensemble.ExtraTreesClassifier()}, |
| 40 | + {'name': 'ensemble.AdaBoost(DTree)', 'obj': ensemble.AdaBoostClassifier(tree.DecisionTreeClassifier())}, |
| 41 | +] |
| 42 | + |
| 43 | +x_min, x_max = iris.data[:, 0].min() - 1, iris.data[:, 0].max() + 1 |
| 44 | +y_min, y_max = iris.data[:, 1].min() - 1, iris.data[:, 1].max() + 1 |
| 45 | +xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01), np.arange(y_min, y_max, 0.01)) |
| 46 | +xy = np.vstack((xx.flatten(), yy.flatten())).T |
| 47 | + |
| 48 | +for model in models: |
| 49 | + # Train a model |
| 50 | + model['obj'].fit(iris.data, iris.target) |
| 51 | + |
| 52 | + # Test the model |
| 53 | + predict = model['obj'].predict(iris.data) |
| 54 | + model['acc'] = metrics.balanced_accuracy_score(iris.target, predict) |
| 55 | + |
| 56 | + # Visualize training results (decision boundaries) |
| 57 | + zz = model['obj'].predict(xy) |
| 58 | + plt.figure() |
| 59 | + plt.contourf(xx, yy, zz.reshape(xx.shape), cmap=ListedColormap(iris.color), alpha=0.2) |
| 60 | + |
| 61 | + # Visualize testing results |
| 62 | + plt.title(model['name'] + f' ({model["acc"]:.3f})') |
| 63 | + plt.scatter(iris.data[:,0], iris.data[:,1], c=iris.color[iris.target], edgecolors=iris.color[predict]) |
| 64 | + plt.xlabel(iris.feature_names[0]) |
| 65 | + plt.ylabel(iris.feature_names[1]) |
| 66 | + |
| 67 | +plt.show() |
0 commit comments