forked from ploomber/sklearn-evaluation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adds new matplotlib module for abstracting matplotlib routines
- Loading branch information
Eduardo Blancas Reyes
committed
Apr 26, 2019
1 parent
2ff760f
commit 97d341e
Showing
3 changed files
with
51 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,26 @@ | ||
""" | ||
Feature importances plot | ||
%load_ext autoreload | ||
%autoreload 2 | ||
""" | ||
import matplotlib.pyplot as plt | ||
from sklearn import datasets | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.cross_validation import train_test_split | ||
from sklearn.model_selection import train_test_split | ||
|
||
from sklearn_evaluation import plot | ||
|
||
data = datasets.make_classification(200, 10, 5, class_sep=0.65) | ||
X = data[0] | ||
y = data[1] | ||
|
||
X, y = datasets.make_classification(200, 20, 5, class_sep=0.65) | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) | ||
|
||
est = RandomForestClassifier() | ||
est.fit(X_train, y_train) | ||
model = RandomForestClassifier(n_estimators=1) | ||
model.fit(X_train, y_train) | ||
|
||
# plot all features | ||
plot.feature_importances(model) | ||
plt.show() | ||
|
||
plot.feature_importances(est, top_n=5) | ||
plt.show() | ||
# only top 5 | ||
plot.feature_importances(model, top_n=5) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
""" | ||
Matplotlib plotting code | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
def horizontal(values, labels=None, error=None): | ||
"""Horizontal bar plot | ||
Examples | ||
-------- | ||
>>> from sklearn_evaluation.plot.matplotlib import bar | ||
>>> values = np.random.rand(10) | ||
>>> bar.horizontal(values) | ||
>>> plt.show() | ||
Notes | ||
----- | ||
https://matplotlib.org/gallery/lines_bars_and_markers/barh.html | ||
""" | ||
y_pos = np.arange(len(values)) | ||
ax = plt.gca() | ||
|
||
if error is None: | ||
ax.barh(y_pos, values) | ||
else: | ||
ax.barh(y_pos, values, xerr=error) | ||
|
||
ax.set_yticks(y_pos if labels is None else labels) | ||
ax.set_yticklabels(y_pos) | ||
ax.invert_yaxis() | ||
return ax |