Skip to content

Commit 643697b

Browse files
committed
ENH: add a spline fitting demo
1 parent 8e8154a commit 643697b

File tree

1 file changed

+274
-0
lines changed

1 file changed

+274
-0
lines changed

04-spline_demo.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
import scipy.interpolate as si
2+
import numpy as np
3+
from functools import reduce
4+
5+
# uncomment this to set the backend
6+
# import matplotlib
7+
# matplotlib.use('Qt4Agg')
8+
import matplotlib.pyplot as plt
9+
10+
11+
class TooFewPointsException(Exception):
12+
...
13+
14+
15+
class SplineFitter:
16+
def __init__(self, ax, pix_err=1):
17+
self.canvas = ax.get_figure().canvas
18+
self.cid = None
19+
self.pt_lst = []
20+
self.pt_plot = ax.plot([], [], marker='o',
21+
linestyle='none', zorder=5)[0]
22+
self.sp_plot = ax.plot([], [], lw=3, color='r')[0]
23+
self.pix_err = pix_err
24+
self.connect_sf()
25+
26+
def clear(self):
27+
'''Clears the points'''
28+
self.pt_lst = []
29+
self.redraw()
30+
31+
def connect_sf(self):
32+
if self.cid is None:
33+
self.cid = self.canvas.mpl_connect('button_press_event',
34+
self.click_event)
35+
36+
def disconnect_sf(self):
37+
if self.cid is not None:
38+
self.canvas.mpl_disconnect(self.cid)
39+
self.cid = None
40+
41+
def click_event(self, event):
42+
''' Extracts locations from the user'''
43+
if event.key == 'shift':
44+
self.clear()
45+
return
46+
if event.xdata is None or event.ydata is None:
47+
return
48+
if event.button == 1:
49+
self.pt_lst.append((event.xdata, event.ydata))
50+
elif event.button == 3:
51+
self.remove_pt((event.xdata, event.ydata))
52+
self.ev = event
53+
self.redraw()
54+
55+
def remove_pt(self, loc):
56+
if len(self.pt_lst) > 0:
57+
self.pt_lst.pop(np.argmin(list(map(lambda x:
58+
np.sqrt((x[0] - loc[0]) ** 2 +
59+
(x[1] - loc[1]) ** 2),
60+
self.pt_lst))))
61+
62+
def redraw(self):
63+
if len(self.pt_lst) > 5:
64+
SC = SplineCurve.from_pts(self.pt_lst, pix_err=self.pix_err)
65+
new_pts = SC.q_phi_to_xy(0, np.linspace(0, 2 * np.pi, 1000))
66+
center = SC.cntr
67+
self.sp_plot.set_xdata(new_pts[0])
68+
self.sp_plot.set_ydata(new_pts[1])
69+
self.pt_lst.sort(key=lambda x:
70+
np.arctan2(x[1] - center[1], x[0] - center[0]))
71+
else:
72+
self.sp_plot.set_xdata([])
73+
self.sp_plot.set_ydata([])
74+
if len(self.pt_lst) > 0:
75+
x, y = zip(*self.pt_lst)
76+
else:
77+
x, y = [], []
78+
self.pt_plot.set_xdata(x)
79+
self.pt_plot.set_ydata(y)
80+
81+
self.canvas.draw_idle()
82+
83+
@property
84+
def points(self):
85+
'''Returns the clicked points in the format the rest of the
86+
code expects'''
87+
return np.vstack(self.pt_lst).T
88+
89+
@property
90+
def SplineCurve(self):
91+
curve = SplineCurve.from_pts(self.pt_lst, pix_err=self.pix_err)
92+
return curve
93+
94+
95+
class SplineCurve:
96+
'''
97+
A class that wraps the scipy.interpolation objects
98+
'''
99+
@classmethod
100+
def _get_spline(cls, points, pix_err=2, need_sort=True, **kwargs):
101+
'''
102+
Returns a closed spline for the points handed in.
103+
Input is assumed to be a (2xN) array
104+
105+
=====
106+
input
107+
=====
108+
109+
:param points: the points to fit the spline to
110+
:type points: a 2xN ndarray or a list of len =2 tuples
111+
112+
:param pix_err: the error is finding the spline in pixels
113+
:param need_sort: if the points need to be sorted
114+
or should be processed as-is
115+
116+
=====
117+
output
118+
=====
119+
tck
120+
The return data from the spline fitting
121+
'''
122+
if type(points) is np.ndarray:
123+
# make into a list
124+
pt_lst = zip(*points)
125+
# get center
126+
center = np.mean(points, axis=1).reshape(2, 1)
127+
else:
128+
# make a copy of the list
129+
pt_lst = list(points)
130+
131+
# compute center
132+
def tmp_fun(x, y): (x[0] + y[0], x[1] + y[1])
133+
134+
center = np.array(reduce(tmp_fun, pt_lst)).reshape(2, 1)
135+
center /= len(pt_lst)
136+
if len(pt_lst) < 5:
137+
raise TooFewPointsException("not enough points")
138+
139+
if need_sort:
140+
# sort the list by angle around center
141+
pt_lst.sort(key=lambda x: np.arctan2(x[1] - center[1],
142+
x[0] - center[0]))
143+
# add first point to end because it is periodic (makes the
144+
# interpolation code happy)
145+
pt_lst.append(pt_lst[0])
146+
# make array for handing in to spline fitting
147+
pt_array = np.vstack(pt_lst).T
148+
# do spline fitting
149+
150+
tck, u = si.splprep(pt_array, s=len(pt_lst) * (pix_err ** 2), per=True)
151+
return tck
152+
153+
@classmethod
154+
def from_pts(cls, new_pts, **kwargs):
155+
tck = cls._get_spline(new_pts, **kwargs)
156+
this = cls(tck)
157+
this.raw_pts = new_pts
158+
return this
159+
160+
def __init__(self, tck):
161+
'''Use `from_pts` class method to construct instance
162+
'''
163+
self.tck = tck
164+
self._cntr = None
165+
self._circ = None
166+
self._th_offset = None
167+
168+
def write_to_hdf(self, parent_group, name=None):
169+
'''
170+
Writes out the essential data (spline of central curve) to hdf file.
171+
'''
172+
if name is not None:
173+
curve_group = parent_group.create_group(name)
174+
else:
175+
curve_group = parent_group
176+
curve_group.attrs['tck0'] = self.tck[0]
177+
curve_group.attrs['tck1'] = np.vstack(self.tck[1])
178+
curve_group.attrs['tck2'] = self.tck[2]
179+
180+
@property
181+
def circ(self):
182+
'''returns a rough estimate of the circumference'''
183+
if self._circ is None:
184+
new_pts = si.splev(np.linspace(0, 1, 1000), self.tck, ext=2)
185+
self._circ = np.sum(np.sqrt(np.sum(np.diff(new_pts, axis=1) ** 2,
186+
axis=0)))
187+
return self._circ
188+
189+
@property
190+
def cntr(self):
191+
'''returns a rough estimate of the circumference'''
192+
if self._cntr is None:
193+
new_pts = si.splev(np.linspace(0, 1, 1000), self.tck, ext=2)
194+
self._cntr = np.mean(new_pts, 1)
195+
return self._cntr
196+
197+
@property
198+
def th_offset(self):
199+
"""
200+
The angle from the y-axis for (x, y) at `phi=0`
201+
"""
202+
if self._th_offset is None:
203+
x, y = self.q_phi_to_xy(0, 0) - self.cntr.reshape(2, 1)
204+
self._th_offset = np.arctan2(y, x)
205+
return self._th_offset
206+
207+
@property
208+
def tck0(self):
209+
return self.tck[0]
210+
211+
@property
212+
def tck1(self):
213+
return self.tck[1]
214+
215+
@property
216+
def tck2(self):
217+
return self.tck[2]
218+
219+
def q_phi_to_xy(self, q, phi, cross=None):
220+
'''Converts q, phi pairs -> x, y pairs. All other code that
221+
does this should move to using this so that there is minimal
222+
breakage when we change over to using additive q instead of
223+
multiplicative'''
224+
# make sure data is arrays
225+
q = np.asarray(q)
226+
# convert real units -> interpolation units
227+
phi = np.mod(np.asarray(phi), 2 * np.pi) / (2 * np.pi)
228+
# get the shapes
229+
q_shape, phi_shape = [_.shape if (_.shape != () and
230+
len(_) > 1) else None for
231+
_ in (q, phi)]
232+
233+
# flatten everything
234+
q = q.ravel()
235+
phi = phi.ravel()
236+
# sanity checks on shapes
237+
if cross is False:
238+
if phi_shape != q_shape:
239+
raise ValueError("q and phi must have same" +
240+
" dimensions to broadcast")
241+
if cross is None:
242+
if ((phi_shape is not None) and (q_shape is not None)
243+
and (phi_shape == q_shape)):
244+
cross = False
245+
elif q_shape is None:
246+
cross = False
247+
q = q[0]
248+
else:
249+
cross = True
250+
251+
x, y = si.splev(phi, self.tck, ext=2)
252+
dx, dy = si.splev(phi, self.tck, der=1, ext=2)
253+
norm = np.sqrt(dx ** 2 + dy ** 2)
254+
nx, ny = dy / norm, -dx / norm
255+
256+
# if cross, then
257+
if cross:
258+
data_out = zip(
259+
*map(lambda q_: ((x + q_ * nx).reshape(phi_shape),
260+
(y + q_ * ny).reshape(phi_shape)),
261+
q)
262+
)
263+
else:
264+
265+
data_out = np.vstack([(x + q * nx).reshape(phi_shape),
266+
(y + q * ny).reshape(phi_shape)])
267+
268+
return data_out
269+
270+
271+
fig, ax = plt.subplots()
272+
sp = SplineFitter(ax, .001)
273+
plt.ion()
274+
plt.show()

0 commit comments

Comments
 (0)