Skip to content

Commit

Permalink
adds new matplotlib module for abstracting matplotlib routines
Browse files Browse the repository at this point in the history
  • Loading branch information
Eduardo Blancas Reyes committed Apr 26, 2019
1 parent 2ff760f commit 97d341e
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 9 deletions.
26 changes: 17 additions & 9 deletions examples/feature_importances.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
"""
Feature importances plot
%load_ext autoreload
%autoreload 2
"""
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split
from sklearn.model_selection import train_test_split

from sklearn_evaluation import plot

data = datasets.make_classification(200, 10, 5, class_sep=0.65)
X = data[0]
y = data[1]

X, y = datasets.make_classification(200, 20, 5, class_sep=0.65)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

est = RandomForestClassifier()
est.fit(X_train, y_train)
model = RandomForestClassifier(n_estimators=1)
model.fit(X_train, y_train)

# plot all features
plot.feature_importances(model)
plt.show()

plot.feature_importances(est, top_n=5)
plt.show()
# only top 5
plot.feature_importances(model, top_n=5)
plt.show()
3 changes: 3 additions & 0 deletions sklearn_evaluation/plot/matplotlib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Matplotlib plotting code
"""
31 changes: 31 additions & 0 deletions sklearn_evaluation/plot/matplotlib/bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import matplotlib.pyplot as plt


def horizontal(values, labels=None, error=None):
"""Horizontal bar plot
Examples
--------
>>> from sklearn_evaluation.plot.matplotlib import bar
>>> values = np.random.rand(10)
>>> bar.horizontal(values)
>>> plt.show()
Notes
-----
https://matplotlib.org/gallery/lines_bars_and_markers/barh.html
"""
y_pos = np.arange(len(values))
ax = plt.gca()

if error is None:
ax.barh(y_pos, values)
else:
ax.barh(y_pos, values, xerr=error)

ax.set_yticks(y_pos if labels is None else labels)
ax.set_yticklabels(y_pos)
ax.invert_yaxis()
return ax

0 comments on commit 97d341e

Please sign in to comment.