-
Notifications
You must be signed in to change notification settings - Fork 6
/
piecewise.py
232 lines (194 loc) · 7.29 KB
/
piecewise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
#!/usr/bin/env python
# coding: utf8
'''
piecewise-defined functions
'''
__author__ = "Philippe Guglielmetti"
__cfyright__ = "Cfyright 2013, Philippe Guglielmetti"
__license__ = "LGPL"
import bisect, math
from . import expr, math2, itertools2
class Piecewise(expr.Expr):
'''
piecewise function defined by a sorted list of (startx, Expr)
'''
def __init__(self, init=[], default=0, period=(-math2.inf, +math2.inf)):
# Note : started by deriving a list of (point,value), but this leads to a problem:
# the value is taken into account in sort order by bisect
# so instead of defining one more class with a __cmp__ method, I split both lists
if math2.is_number(period):
period = (0, period)
try: # copy constructor ?
self.x = list(init.x)
self.y = list(init.y)
self.period = period or init.period # allow to force periodicity
except AttributeError:
self.x = []
self.y = []
self.period = period
self.append(period[0], default)
self.extend(init)
super(Piecewise, self).__init__(0) # to initialize context and such stuff
self.body = '?' # should not happen
def __len__(self):
return len(self.x)
def __getitem__(self, i):
return (self.x[i], self.y[i])
def is_periodic(self):
if math.isinf(self.period[1]):
return False
return self.period[1] - self.period[0]
def _str_period(self):
p = self.is_periodic()
return ", period=%s" % p if p else ""
def __str__(self):
return str(list(self)) + self._str_period()
def __repr__(self):
return repr(list(self)) + self._str_period()
def latex(self):
''':return: string LaTex formula'''
def condition(i):
min = self[i][0]
try:
max = self[i + 1][0]
except IndexError:
max = math2.inf
if i == 0:
return r'{x}<{' + str(max) + '}'
elif i == len(self) - 1:
return r'{x}\geq{' + str(min) + '}'
else:
return r'{' + str(min) + r'}\leq{x}<{' + str(max) + '}'
l = [f[1].latex() + '&' + condition(i) for i, f in enumerate(self)]
return r'\begin{cases}' + r'\\'.join(l) + r'\end{cases}'
def _x(self, x):
'''handle periodicity'''
p = self.is_periodic()
return x % p if p else x
def index(self, x):
'''return index of piece'''
return bisect.bisect_right(self.x, self._x(x)) - 1
def __call__(self, x):
'''returns value of Expr at point x '''
if itertools2.isiterable(x):
return [self(x) for x in x]
i = self.index(x)
xx = self._x(x)
return self.y[i](xx)
def insort(self, x, v=None):
'''insert a point (or returns it if it already exists)
note : method name follows bisect.insort convention
'''
x = self._x(x)
i = bisect.bisect_left(self.x, x) # do not use self.index here !
if i < len(self) and x == self.x[i]:
return i
# insert either the v value, or copy the current value at x
# note : we might have consecutive tuples with the same y value
if v is not None:
self.y.insert(i, expr.Expr(v))
else: # split the piece at x
self.y.insert(i, self.y[i - 1])
self.x.insert(i, x)
return i
def __iter__(self):
'''iterators through discontinuities. take the opportunity to delete redundant tuples'''
prev = None
i = 0
while i < len(self):
x, y = self.x[i], self.y[i]
if y == prev: # simplify
self.y.pop(i)
self.x.pop(i)
else:
yield x, y
prev = y
i += 1
def append(self, x, y=None):
'''appends a (x,y) piece. In fact inserts it at correct position'''
if y is None:
(x, y) = x
x = self._x(x)
i = self.insort(x, y)
return self # to allow chained calls
def extend(self, iterable):
'''appends an iterable of (x,y) values'''
for p in iterable:
self.append(p)
def iapply(self, f, right):
'''apply function to self'''
if not right: # monadic . apply to each expr
self.y = [v.apply(f) for v in self.y]
elif isinstance(right, Piecewise): # combine each piece of right with self
for i, p in enumerate(right):
try:
self.iapply(f, (p[0], p[1], right[i + 1][0]))
except:
self.iapply(f, (p[0], p[1]))
else: # assume a triplet (start,value,end) as called above
i = self.insort(right[0])
try:
j = self.insort(right[2])
if j < i:
i, j = j, i
except:
j = len(self)
for k in range(i, j):
self.y[k] = self.y[k].apply(f, right[1]) # calls Expr.apply
return self
def apply(self, f, right=None):
'''apply function to copy of self'''
return Piecewise(self).iapply(f, right)
def applx(self, f):
''' apply a function to each x value '''
self.x = [f(x) for x in self.x]
self.y = [y.applx(f) for y in self.y]
return self
def __lshift__(self, dx):
return Piecewise(self).applx(lambda x: x - dx)
def __rshift__(self, dx):
return Piecewise(self).applx(lambda x: x + dx)
def _switch_points(self, xmin, xmax):
prevy = None
firstpoint, lastpoint = False, False
for x, y in self:
y = y(x)
if x < xmin:
if firstpoint: continue
firstpoint = True
x = xmin
if x > xmax:
if lastpoint: break
lastpoint = True
x = xmax
if prevy is not None and not math2.isclose(y, prevy): # step
yield x, prevy
yield x, y
prevy = y
def points(self, xmin=None, xmax=None):
''':return: x,y lists of float : points for a line plot'''
resx = []
resy = []
dx = self.x[-1] - self.x[1]
p = self.is_periodic()
if xmin is None:
# by default we extend the range by 10%
xmin = min(0, self.x[1] - dx * .1)
if xmax is None:
if p:
# by default we show 2.5 periods
xmax = xmin + p * 2.5
else:
# by default we extend the range by 10%
xmax = self.x[-1] + dx * .1
for x, y in self._switch_points(xmin, xmax):
resx.append(x)
resy.append(y)
if xmax > x:
resx.append(xmax)
resy.append(self(xmax))
return resx, resy
def _plot(self, ax, xmax=None, **kwargs):
'''plots function'''
(x, y) = self.points(xmax=xmax)
return super(Piecewise, self)._plot(ax, x, y, **kwargs)