Skip to content

Commit 0ebd99c

Browse files
committed
Added fig_code from https://github.com/jakevdp/sklearn_pycon2015, which is used in the scikit-learn notebooks.
1 parent 815accf commit 0ebd99c

File tree

12 files changed

+1425
-0
lines changed

12 files changed

+1425
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
Tutorial Diagrams
3+
-----------------
4+
5+
This script plots the flow-charts used in the scikit-learn tutorials.
6+
"""
7+
8+
import numpy as np
9+
import pylab as pl
10+
from matplotlib.patches import Circle, Rectangle, Polygon, Arrow, FancyArrow
11+
12+
def create_base(box_bg = '#CCCCCC',
13+
arrow1 = '#88CCFF',
14+
arrow2 = '#88FF88',
15+
supervised=True):
16+
fig = pl.figure(figsize=(9, 6), facecolor='w')
17+
ax = pl.axes((0, 0, 1, 1),
18+
xticks=[], yticks=[], frameon=False)
19+
ax.set_xlim(0, 9)
20+
ax.set_ylim(0, 6)
21+
22+
patches = [Rectangle((0.3, 3.6), 1.5, 1.8, zorder=1, fc=box_bg),
23+
Rectangle((0.5, 3.8), 1.5, 1.8, zorder=2, fc=box_bg),
24+
Rectangle((0.7, 4.0), 1.5, 1.8, zorder=3, fc=box_bg),
25+
26+
Rectangle((2.9, 3.6), 0.2, 1.8, fc=box_bg),
27+
Rectangle((3.1, 3.8), 0.2, 1.8, fc=box_bg),
28+
Rectangle((3.3, 4.0), 0.2, 1.8, fc=box_bg),
29+
30+
Rectangle((0.3, 0.2), 1.5, 1.8, fc=box_bg),
31+
32+
Rectangle((2.9, 0.2), 0.2, 1.8, fc=box_bg),
33+
34+
Circle((5.5, 3.5), 1.0, fc=box_bg),
35+
36+
Polygon([[5.5, 1.7],
37+
[6.1, 1.1],
38+
[5.5, 0.5],
39+
[4.9, 1.1]], fc=box_bg),
40+
41+
FancyArrow(2.3, 4.6, 0.35, 0, fc=arrow1,
42+
width=0.25, head_width=0.5, head_length=0.2),
43+
44+
FancyArrow(3.75, 4.2, 0.5, -0.2, fc=arrow1,
45+
width=0.25, head_width=0.5, head_length=0.2),
46+
47+
FancyArrow(5.5, 2.4, 0, -0.4, fc=arrow1,
48+
width=0.25, head_width=0.5, head_length=0.2),
49+
50+
FancyArrow(2.0, 1.1, 0.5, 0, fc=arrow2,
51+
width=0.25, head_width=0.5, head_length=0.2),
52+
53+
FancyArrow(3.3, 1.1, 1.3, 0, fc=arrow2,
54+
width=0.25, head_width=0.5, head_length=0.2),
55+
56+
FancyArrow(6.2, 1.1, 0.8, 0, fc=arrow2,
57+
width=0.25, head_width=0.5, head_length=0.2)]
58+
59+
if supervised:
60+
patches += [Rectangle((0.3, 2.4), 1.5, 0.5, zorder=1, fc=box_bg),
61+
Rectangle((0.5, 2.6), 1.5, 0.5, zorder=2, fc=box_bg),
62+
Rectangle((0.7, 2.8), 1.5, 0.5, zorder=3, fc=box_bg),
63+
FancyArrow(2.3, 2.9, 2.0, 0, fc=arrow1,
64+
width=0.25, head_width=0.5, head_length=0.2),
65+
Rectangle((7.3, 0.85), 1.5, 0.5, fc=box_bg)]
66+
else:
67+
patches += [Rectangle((7.3, 0.2), 1.5, 1.8, fc=box_bg)]
68+
69+
for p in patches:
70+
ax.add_patch(p)
71+
72+
pl.text(1.45, 4.9, "Training\nText,\nDocuments,\nImages,\netc.",
73+
ha='center', va='center', fontsize=14)
74+
75+
pl.text(3.6, 4.9, "Feature\nVectors",
76+
ha='left', va='center', fontsize=14)
77+
78+
pl.text(5.5, 3.5, "Machine\nLearning\nAlgorithm",
79+
ha='center', va='center', fontsize=14)
80+
81+
pl.text(1.05, 1.1, "New Text,\nDocument,\nImage,\netc.",
82+
ha='center', va='center', fontsize=14)
83+
84+
pl.text(3.3, 1.7, "Feature\nVector",
85+
ha='left', va='center', fontsize=14)
86+
87+
pl.text(5.5, 1.1, "Predictive\nModel",
88+
ha='center', va='center', fontsize=12)
89+
90+
if supervised:
91+
pl.text(1.45, 3.05, "Labels",
92+
ha='center', va='center', fontsize=14)
93+
94+
pl.text(8.05, 1.1, "Expected\nLabel",
95+
ha='center', va='center', fontsize=14)
96+
pl.text(8.8, 5.8, "Supervised Learning Model",
97+
ha='right', va='top', fontsize=18)
98+
99+
else:
100+
pl.text(8.05, 1.1,
101+
"Likelihood\nor Cluster ID\nor Better\nRepresentation",
102+
ha='center', va='center', fontsize=12)
103+
pl.text(8.8, 5.8, "Unsupervised Learning Model",
104+
ha='right', va='top', fontsize=18)
105+
106+
107+
108+
def plot_supervised_chart(annotate=False):
109+
create_base(supervised=True)
110+
if annotate:
111+
fontdict = dict(color='r', weight='bold', size=14)
112+
pl.text(1.9, 4.55, 'X = vec.fit_transform(input)',
113+
fontdict=fontdict,
114+
rotation=20, ha='left', va='bottom')
115+
pl.text(3.7, 3.2, 'clf.fit(X, y)',
116+
fontdict=fontdict,
117+
rotation=20, ha='left', va='bottom')
118+
pl.text(1.7, 1.5, 'X_new = vec.transform(input)',
119+
fontdict=fontdict,
120+
rotation=20, ha='left', va='bottom')
121+
pl.text(6.1, 1.5, 'y_new = clf.predict(X_new)',
122+
fontdict=fontdict,
123+
rotation=20, ha='left', va='bottom')
124+
125+
def plot_unsupervised_chart():
126+
create_base(supervised=False)
127+
128+
129+
if __name__ == '__main__':
130+
plot_supervised_chart(False)
131+
plot_supervised_chart(True)
132+
plot_unsupervised_chart()
133+
pl.show()
134+
135+

scikit-learn/fig_code/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .data import *
2+
from .figures import *
3+
4+
from .sgd_separator import plot_sgd_separator
5+
from .linear_regression import plot_linear_regression
6+
from .helpers import plot_iris_knn

scikit-learn/fig_code/__init__.py~

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .sgd_separator import plot_sgd_separator
2+
from .linear_regression import plot_linear_regression
3+
from .ML_flow_chart import plot_supervised_chart, plot_unsupervised_chart
4+
from .helpers import plot_iris_knn

scikit-learn/fig_code/data.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import numpy as np
2+
3+
4+
def linear_data_sample(N=40, rseed=0, m=3, b=-2):
5+
rng = np.random.RandomState(rseed)
6+
7+
x = 10 * rng.rand(N)
8+
dy = m / 2 * (1 + rng.rand(N))
9+
y = m * x + b + dy * rng.randn(N)
10+
11+
return (x, y, dy)
12+
13+
14+
def linear_data_sample_big_errs(N=40, rseed=0, m=3, b=-2):
15+
rng = np.random.RandomState(rseed)
16+
17+
x = 10 * rng.rand(N)
18+
dy = m / 2 * (1 + rng.rand(N))
19+
dy[20:25] *= 10
20+
y = m * x + b + dy * rng.randn(N)
21+
22+
return (x, y, dy)
23+
24+
25+
def sample_light_curve(phased=True):
26+
from astroML.datasets import fetch_LINEAR_sample
27+
data = fetch_LINEAR_sample()
28+
t, y, dy = data[18525697].T
29+
30+
if phased:
31+
P_best = 0.580313015651
32+
t /= P_best
33+
34+
return (t, y, dy)
35+
36+
37+
def sample_light_curve_2(phased=True):
38+
from astroML.datasets import fetch_LINEAR_sample
39+
data = fetch_LINEAR_sample()
40+
t, y, dy = data[10022663].T
41+
42+
if phased:
43+
P_best = 0.61596079804
44+
t /= P_best
45+
46+
return (t, y, dy)
47+

0 commit comments

Comments
 (0)