forked from sympy/sympy
/
function.py
699 lines (589 loc) · 21.2 KB
/
function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
"""
There are two types of functions:
1) defined function like exp or sin that has a name and body
(in the sense that function can be evaluated).
e = exp
2) undefined function with a name but no body. Undefined
functions can be defined using a Function class as follows:
f = Function('f')
(the result will be Function instance)
3) this isn't implemented yet: anonymous function or lambda function that has
no name but has body with dummy variables. An anonymous function
object creation examples:
f = Lambda(x, exp(x)*x)
f = Lambda(exp(x)*x) # free symbols in the expression define the number of arguments
f = exp * Lambda(x,x)
4) isn't implemented yet: composition of functions, like (sin+cos)(x), this
works in sympycore, but needs to be ported back to SymPy.
Example:
>>> from sympy import *
>>> f = Function("f")
>>> x = Symbol("x")
>>> f(x)
f(x)
>>> print srepr(f(x).func)
Function('f')
>>> f(x).args
(x,)
"""
from basic import Basic, Atom, S, C, sympify
from basic import BasicType, BasicMeta
from operations import AssocOp
from cache import cacheit
from itertools import repeat
from numbers import Rational,Integer
from symbol import Symbol
from add import Add
from multidimensional import vectorize
from sympy import mpmath
class FunctionClass(BasicMeta):
"""
Base class for function classes. FunctionClass is a subclass of type.
Use Function('<function name>' [ , signature ]) to create
undefined function classes.
"""
_new = type.__new__
def __new__(cls, arg1, arg2, arg3=None, **options):
assert not options,`options`
if isinstance(arg1, type):
ftype, name, signature = arg1, arg2, arg3
#XXX this probably needs some fixing:
assert ftype.__name__.endswith('Function'),`ftype`
attrdict = ftype.__dict__.copy()
attrdict['undefined_Function'] = True
if signature is not None:
attrdict['signature'] = signature
bases = (ftype,)
return type.__new__(cls, name, bases, attrdict)
else:
name, bases, attrdict = arg1, arg2, arg3
return type.__new__(cls, name, bases, attrdict)
def __repr__(cls):
return cls.__name__
class Function(Basic):
"""
Base class for applied functions.
Constructor of undefined classes.
"""
__metaclass__ = FunctionClass
is_Function = True
nargs = None
@vectorize(1)
@cacheit
def __new__(cls, *args, **options):
# NOTE: this __new__ is twofold:
#
# 1 -- it can create another *class*, which can then be instantiated by
# itself e.g. Function('f') creates a new class f(Function)
#
# 2 -- on the other hand, we instantiate -- that is we create an
# *instance* of a class created earlier in 1.
#
# So please keep, both (1) and (2) in mind.
# (1) create new function class
# UC: Function('f')
if cls is Function:
#when user writes Function("f"), do an equivalent of:
#taking the whole class Function(...):
#and rename the Function to "f" and return f, thus:
#In [13]: isinstance(f, Function)
#Out[13]: False
#In [14]: isinstance(f, FunctionClass)
#Out[14]: True
if len(args) == 1 and isinstance(args[0], str):
#always create Function
return FunctionClass(Function, *args)
return FunctionClass(Function, *args, **options)
else:
print args
print type(args[0])
raise Exception("You need to specify exactly one string")
# (2) create new instance of a class created in (1)
# UC: Function('f')(x)
# UC: sin(x)
args = map(sympify, args)
# these lines should be refactored
for opt in ["nargs", "dummy", "comparable", "noncommutative", "commutative"]:
if opt in options:
del options[opt]
# up to here.
if options.get('evaluate') is False:
return Basic.__new__(cls, *args, **options)
r = cls.canonize(*args)
if isinstance(r, Basic):
return r
elif r is None:
pass
elif not isinstance(r, tuple):
args = (r,)
return Basic.__new__(cls, *args, **options)
@property
def is_commutative(self):
return True
@classmethod
def canonize(cls, *args):
"""
Returns a canonical form of cls applied to arguments args.
The canonize() method is called when the class cls is about to be
instantiated and it should return either some simplified instance
(possible of some other class), or if the class cls should be
unmodified, return None.
Example of canonize() for the function "sign"
---------------------------------------------
@classmethod
def canonize(cls, arg):
if arg is S.NaN:
return S.NaN
if arg is S.Zero: return S.Zero
if arg.is_positive: return S.One
if arg.is_negative: return S.NegativeOne
if isinstance(arg, C.Mul):
coeff, terms = arg.as_coeff_terms()
if coeff is not S.One:
return cls(coeff) * cls(C.Mul(*terms))
"""
return
@property
def func(self):
return self.__class__
def _eval_subs(self, old, new):
if self == old:
return new
elif isinstance(old, FunctionClass) and isinstance(new, FunctionClass):
if old == self.func and old.nargs == new.nargs:
return new(*self.args[:])
obj = self.func._eval_apply_subs(*(self.args[:] + (old,) + (new,)))
if obj is not None:
return obj
return Basic._seq_subs(self, old, new)
def _eval_expand_basic(self, *args):
return None
def _eval_evalf(self, prec):
# Lookup mpmath function based on name
fname = self.func.__name__
try:
if not hasattr(mpmath, fname):
from sympy.utilities.lambdify import MPMATH_TRANSLATIONS
fname = MPMATH_TRANSLATIONS[fname]
func = getattr(mpmath, fname)
except (AttributeError, KeyError):
return
# Convert all args to mpf or mpc
try:
args = [arg._to_mpmath(prec) for arg in self.args]
except ValueError:
return
# Set mpmath precision and apply. Make sure precision is restored
# afterwards
orig = mpmath.mp.prec
try:
mpmath.mp.prec = prec
v = func(*args)
finally:
mpmath.mp.prec = orig
return Basic._from_mpmath(v, prec)
def _eval_is_comparable(self):
if self.is_Function:
r = True
for s in self.args:
c = s.is_comparable
if c is None: return
if not c: r = False
return r
return
def _eval_derivative(self, s):
# f(x).diff(s) -> x.diff(s) * f.fdiff(1)(s)
i = 0
l = []
r = S.Zero
for a in self.args:
i += 1
da = a.diff(s)
if da is S.Zero:
continue
if isinstance(self.func, FunctionClass):
df = self.fdiff(i)
l.append(df * da)
return Add(*l)
def _eval_is_commutative(self):
r = True
for a in self._args:
c = a.is_commutative
if c is None: return None
if not c: r = False
return r
def _eval_eq_nonzero(self, other):
if isinstance(other.func, self.func.__class__) and len(self)==len(other):
for a1,a2 in zip(self,other):
if not (a1==a2):
return False
return True
def as_base_exp(self):
return self, S.One
def count_ops(self, symbolic=True):
# f() args
return 1 + Add(*[t.count_ops(symbolic) for t in self.args])
def _eval_oseries(self, order):
assert self.func.nargs==1,`self.func`
arg = self.args[0]
x = order.symbols[0]
if not C.Order(1,x).contains(arg):
return self.func(arg)
arg0 = arg.limit(x, 0)
if arg0 is not S.Zero:
e = self.func(arg)
e1 = e.expand()
if e==e1:
#for example when e = sin(x+1) or e = sin(cos(x))
#let's try the general algorithm
term = e.subs(x, S.Zero)
series = S.Zero
fact = S.One
i = 0
while not order.contains(term) or term == 0:
series += term
i += 1
fact *= Rational(i)
e = e.diff(x)
term = e.subs(x, S.Zero)*(x**i)/fact
return series
#print '%s(%s).oseries(%s) is unevaluated' % (self.func,arg,order)
return e1.oseries(order)
return self._compute_oseries(arg, order, self.func.taylor_term, self.func)
def nseries(self, x, x0, n):
return self._series(x, x0, n)
def _eval_is_polynomial(self, syms):
for arg in self.args:
if arg.has(*syms):
return False
return True
def _eval_expand_complex(self, *args):
func = self.func(*[ a._eval_expand_complex(*args) for a in self.args ])
return C.re(func) + S.ImaginaryUnit * C.im(func)
def _eval_rewrite(self, pattern, rule, **hints):
if hints.get('deep', False):
args = [ a._eval_rewrite(pattern, rule, **hints) for a in self ]
else:
args = self.args[:]
if pattern is None or isinstance(self.func, pattern):
if hasattr(self, rule):
rewritten = getattr(self, rule)(*args)
if rewritten is not None:
return rewritten
return self.func(*args, **self._assumptions)
def fdiff(self, argindex=1):
if self.nargs is not None:
if isinstance(self.nargs, tuple):
nargs = self.nargs[-1]
else:
nargs = self.nargs
if not (1<=argindex<=nargs):
raise TypeError("argument index %r is out of range [1,%s]" % (argindex,nargs))
return Derivative(self,self.args[argindex-1],evaluate=False)
@classmethod
def _eval_apply_evalf(cls, arg):
arg = arg.evalf(prec)
#if cls.nargs == 1:
# common case for functions with 1 argument
#if arg.is_Number:
if arg.is_number:
func_evalf = getattr(arg, cls.__name__)
return func_evalf()
def _eval_as_leading_term(self, x):
"""General method for the leading term"""
arg = self.args[0].as_leading_term(x)
if C.Order(1,x).contains(arg):
return arg
else:
return self.func(arg)
@classmethod
def taylor_term(cls, n, x, *previous_terms):
"""General method for the taylor term.
This method is slow, because it differentiates n-times. Subclasses can
redefine it to make it faster by using the "previous_terms".
"""
x = sympify(x)
return cls(x).diff(x, n).subs(x, 0) * x**n / C.Factorial(n)
class WildFunction(Function, Atom):
"""
WildFunction() matches any expression but another WildFunction()
XXX is this as intended, does it work ?
"""
nargs = 1
def __new__(cls, name=None, **assumptions):
if name is None:
name = 'Wf%s' % (Symbol.dummycount + 1) # XXX refactor dummy counting
Symbol.dummycount += 1
obj = Function.__new__(cls, name, **assumptions)
obj.name = name
return obj
def matches(pattern, expr, repl_dict={}, evaluate=False):
for p,v in repl_dict.items():
if p==pattern:
if v==expr: return repl_dict
return None
if pattern.nargs is not None:
if pattern.nargs != expr.nargs:
return None
repl_dict = repl_dict.copy()
repl_dict[pattern] = expr
return repl_dict
@classmethod
def _eval_apply_evalf(cls, arg):
return
@property
def is_number(self):
return False
class Derivative(Basic):
"""
Carries out differentation of the given expression with respect to symbols.
expr must define ._eval_derivative(symbol) method that returns
the differentation result or None.
Examples:
Derivative(Derivative(expr, x), y) -> Derivative(expr, x, y)
Derivative(expr, x, 3) -> Derivative(expr, x, x, x)
"""
is_Derivative = True
@staticmethod
def _symbolgen(*symbols):
"""
Generator of all symbols in the argument of the Derivative.
Example:
>> ._symbolgen(x, 3, y)
(x, x, x, y)
>> ._symbolgen(x, 10**6)
(x, x, x, x, x, x, x, ...)
The second example shows why we don't return a list, but a generator,
so that the code that calls _symbolgen can return earlier for special
cases, like x.diff(x, 10**6).
"""
last_s = sympify(symbols[len(symbols)-1])
for i in xrange(len(symbols)):
s = sympify(symbols[i])
next_s = None
if s != last_s:
next_s = sympify(symbols[i+1])
if isinstance(s, Integer):
continue
elif isinstance(s, Symbol):
# handle cases like (x, 3)
if isinstance(next_s, Integer):
# yield (x, x, x)
for copy_s in repeat(s,int(next_s)):
yield copy_s
else:
yield s
else:
yield s
def __new__(cls, expr, *symbols, **assumptions):
expr = sympify(expr)
if not symbols: return expr
symbols = Derivative._symbolgen(*symbols)
if not assumptions.get("evaluate", False):
obj = Basic.__new__(cls, expr, *symbols)
return obj
unevaluated_symbols = []
for s in symbols:
s = sympify(s)
assert isinstance(s, Symbol),`s`
if not expr.has(s):
return S.Zero
obj = expr._eval_derivative(s)
if obj is None:
unevaluated_symbols.append(s)
elif obj is S.Zero:
return expr
else:
expr = obj
if not unevaluated_symbols:
return expr
return Basic.__new__(cls, expr, *unevaluated_symbols)
def _eval_derivative(self, s):
#print
#print self
#print s
#stop
if s not in self.symbols:
obj = self.expr.diff(s)
if isinstance(obj, Derivative):
return Derivative(obj.expr, *(obj.symbols+self.symbols))
return Derivative(obj, *self.symbols)
return Derivative(self.expr, *((s,)+self.symbols), **{'evaluate': False})
def doit(self):
return Derivative(self.expr, *self.symbols,**{'evaluate': True})
@property
def expr(self):
return self._args[0]
@property
def symbols(self):
return self._args[1:]
def _eval_subs(self, old, new):
if self==old:
return new
return Derivative(self.args[0].subs(old, new), *self.args[1:], **{'evaluate': True})
def matches(pattern, expr, repl_dict={}, evaluate=False):
# this method needs a cleanup.
#print "? :",pattern, expr, repl_dict, evaluate
#if repl_dict:
# return repl_dict
for p,v in repl_dict.items():
if p==pattern:
if v==expr: return repl_dict
return None
assert isinstance(pattern, Derivative)
if isinstance(expr, Derivative):
if len(expr.symbols) == len(pattern.symbols):
#print "MAYBE:",pattern, expr, repl_dict, evaluate
return Basic.matches(pattern, expr, repl_dict, evaluate)
#print "NONE:",pattern, expr, repl_dict, evaluate
return None
#print pattern, expr, repl_dict, evaluate
stop
if pattern.nargs is not None:
if pattern.nargs != expr.nargs:
return None
repl_dict = repl_dict.copy()
repl_dict[pattern] = expr
return repl_dict
class Lambda(Function):
"""
Lambda(x, expr) represents a lambda function similar to Python's
'lambda x: expr'. A function of several variables is written as
Lambda((x, y, ...), expr).
A simple example:
>>> from sympy import Symbol
>>> x = Symbol('x')
>>> f = Lambda(x, x**2)
>>> f(4)
16
For multivariate functions, use:
>>> x = Symbol('x')
>>> y = Symbol('y')
>>> z = Symbol('z')
>>> t = Symbol('t')
>>> f2 = Lambda(x,y,z,t,x+y**z+t**z)
>>> f2(1,2,3,4)
73
Multivariate functions can be curries for partial applications:
>>> sum2numbers = Lambda(x,y,x+y)
>>> sum2numbers(1,2)
3
>>> plus1 = sum2numbers(1)
>>> plus1(3)
4
A handy shortcut for lots of arguments:
>>> from sympy import *
>>> var('x y z')
(x, y, z)
>>> p = x, y, z
>>> f = Lambda(p, x + y*z)
>>> f(*p)
x + y*z
"""
# a minimum of 2 arguments (parameter, expression) are needed
nargs = 2
def __new__(cls,*args):
# nargs = len(args)
assert len(args) >= 2,"Must have at least one parameter and an expression"
if len(args) == 2 and isinstance(args[0], (list, tuple)):
args = tuple(args[0])+(args[1],)
obj = Function.__new__(cls,*args)
obj.nargs = len(args)
return obj
@classmethod
def canonize(cls,*args):
obj = Basic.__new__(cls, *args)
#use dummy variables internally, just to be sure
nargs = len(args)
expression = args[nargs-1]
funargs = [Symbol(arg.name, dummy=True) for arg in args[:nargs-1]]
#probably could use something like foldl here
for arg,funarg in zip(args[:nargs-1],funargs):
expression = expression.subs(arg,funarg)
funargs.append(expression)
obj._args = tuple(funargs)
return obj
def apply(self, *args):
"""Applies the Lambda function "self" to the arguments given.
This supports partial application.
Example:
>>> from sympy import Symbol
>>> x = Symbol('x')
>>> y = Symbol('y')
>>> f = Lambda(x, x**2)
>>> f.apply(4)
16
>>> sum2numbers = Lambda(x,y,x+y)
>>> sum2numbers(1,2)
3
>>> plus1 = sum2numbers(1)
>>> plus1(3)
4
"""
nparams = self.nargs - 1
assert nparams >= len(args),"Cannot call function with more parameters than function variables: %s (%d variables) called with %d arguments" % (str(self),nparams,len(args))
#replace arguments
expression = self.args[self.nargs-1]
for arg,funarg in zip(args,self.args[:nparams]):
expression = expression.subs(funarg,arg)
#curry the rest
if nparams != len(args):
unused_args = list(self.args[len(args):nparams])
unused_args.append(expression)
return Lambda(*tuple(unused_args))
return expression
def __call__(self, *args):
return self.apply(*args)
def __eq__(self, other):
if isinstance(other, Lambda):
if not len(self.args) == len(other.args):
return False
selfexpr = self.args[self.nargs-1]
otherexpr = other.args[other.nargs-1]
for selfarg,otherarg in zip(self.args[:self.nargs-1],other.args[:other.nargs-1]):
otherexpr = otherexpr.subs(otherarg,selfarg)
if selfexpr == otherexpr:
return True
# if self.args[1] == other.args[1].subs(other.args[0], self.args[0]):
# return True
return False
@vectorize(0,1,2)
def diff(f, x, times = 1, evaluate=True):
"""Differentiate f with respect to x
It's just a wrapper to unify .diff() and the Derivative class,
it's interface is similar to that of integrate()
see http://documents.wolfram.com/v5/Built-inFunctions/AlgebraicComputation/Calculus/D.html
"""
return Derivative(f,x,times, **{'evaluate':evaluate})
@vectorize(0)
def expand(e, **hints):
"""
Expand an expression using hints.
This is just a wrapper around Basic.expand(), see it's docstring of for a
thourough docstring for this function. In isympy you can just type
Basic.expand? and enter.
"""
return sympify(e).expand(**hints)
# /cyclic/
import basic as _
_.Derivative = Derivative
_.FunctionClass = FunctionClass
del _
import add as _
_.FunctionClass = FunctionClass
del _
import mul as _
_.FunctionClass = FunctionClass
_.WildFunction = WildFunction
del _
import operations as _
_.Lambda = Lambda
_.WildFunction = WildFunction
del _
import symbol as _
_.Function = Function
_.WildFunction = WildFunction
del _
import numbers as _
_.FunctionClass = FunctionClass
del _