-
-
Notifications
You must be signed in to change notification settings - Fork 17
/
mul.py
1640 lines (1402 loc) · 53.2 KB
/
mul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import print_function, division
from collections import defaultdict
import operator
from functools import cmp_to_key
from .sympify import sympify
from .basic import Basic
from .singleton import S
from .operations import AssocOp
from .cache import cacheit
from .logic import fuzzy_not, _fuzzy_group
from .compatibility import reduce, range
from .expr import Expr
# internal marker to indicate:
# "there are still non-commutative objects -- don't forget to process them"
class NC_Marker:
is_Order = False
is_Mul = False
is_Number = False
is_Poly = False
is_commutative = False
# Key for sorting commutative args in canonical order
_args_sortkey = cmp_to_key(Basic.compare)
def _mulsort(args):
# in-place sorting of args
args.sort(key=_args_sortkey)
def _unevaluated_Mul(*args):
"""Return a well-formed unevaluated Mul: Numbers are collected and
put in slot 0, any arguments that are Muls will be flattened, and args
are sorted. Use this when args have changed but you still want to return
an unevaluated Mul.
Examples
========
>>> from sympy.core.mul import _unevaluated_Mul as uMul
>>> from sympy import S, sqrt, Mul
>>> from sympy.abc import x
>>> a = uMul(*[S(3.0), x, S(2)])
>>> a.args[0]
6.00000000000000
>>> a.args[1]
x
Two unevaluated Muls with the same arguments will
always compare as equal during testing:
>>> m = uMul(sqrt(2), sqrt(3))
>>> m == uMul(sqrt(3), sqrt(2))
True
>>> u = Mul(sqrt(3), sqrt(2), evaluate=False)
>>> m == uMul(u)
True
>>> m == Mul(*m.args)
False
"""
args = list(args)
newargs = []
ncargs = []
co = S.One
while args:
a = args.pop()
if a.is_Mul:
c, nc = a.args_cnc()
args.extend(c)
if nc:
ncargs.append(Mul._from_args(nc))
elif a.is_Number:
co *= a
else:
newargs.append(a)
_mulsort(newargs)
if co is not S.One:
newargs.insert(0, co)
if ncargs:
newargs.append(Mul._from_args(ncargs))
return Mul._from_args(newargs)
class Mul(Expr, AssocOp):
__slots__ = []
is_Mul = True
@classmethod
def flatten(cls, seq):
"""Return commutative, noncommutative and order arguments by
combining related terms.
Notes
=====
* In an expression like ``a*b*c``, python process this through sympy
as ``Mul(Mul(a, b), c)``. This can have undesirable consequences.
- Sometimes terms are not combined as one would like:
{c.f. https://github.com/sympy/sympy/issues/4596}
>>> from sympy import Mul, sqrt
>>> from sympy.abc import x, y, z
>>> 2*(x + 1) # this is the 2-arg Mul behavior
2*x + 2
>>> y*(x + 1)*2
2*y*(x + 1)
>>> 2*(x + 1)*y # 2-arg result will be obtained first
y*(2*x + 2)
>>> Mul(2, x + 1, y) # all 3 args simultaneously processed
2*y*(x + 1)
>>> 2*((x + 1)*y) # parentheses can control this behavior
2*y*(x + 1)
Powers with compound bases may not find a single base to
combine with unless all arguments are processed at once.
Post-processing may be necessary in such cases.
{c.f. https://github.com/sympy/sympy/issues/5728}
>>> a = sqrt(x*sqrt(y))
>>> a**3
(x*sqrt(y))**(3/2)
>>> Mul(a,a,a)
(x*sqrt(y))**(3/2)
>>> a*a*a
x*sqrt(y)*sqrt(x*sqrt(y))
>>> _.subs(a.base, z).subs(z, a.base)
(x*sqrt(y))**(3/2)
- If more than two terms are being multiplied then all the
previous terms will be re-processed for each new argument.
So if each of ``a``, ``b`` and ``c`` were :class:`Mul`
expression, then ``a*b*c`` (or building up the product
with ``*=``) will process all the arguments of ``a`` and
``b`` twice: once when ``a*b`` is computed and again when
``c`` is multiplied.
Using ``Mul(a, b, c)`` will process all arguments once.
* The results of Mul are cached according to arguments, so flatten
will only be called once for ``Mul(a, b, c)``. If you can
structure a calculation so the arguments are most likely to be
repeats then this can save time in computing the answer. For
example, say you had a Mul, M, that you wished to divide by ``d[i]``
and multiply by ``n[i]`` and you suspect there are many repeats
in ``n``. It would be better to compute ``M*n[i]/d[i]`` rather
than ``M/d[i]*n[i]`` since every time n[i] is a repeat, the
product, ``M*n[i]`` will be returned without flattening -- the
cached value will be returned. If you divide by the ``d[i]``
first (and those are more unique than the ``n[i]``) then that will
create a new Mul, ``M/d[i]`` the args of which will be traversed
again when it is multiplied by ``n[i]``.
{c.f. https://github.com/sympy/sympy/issues/5706}
This consideration is moot if the cache is turned off.
NB
--
The validity of the above notes depends on the implementation
details of Mul and flatten which may change at any time. Therefore,
you should only consider them when your code is highly performance
sensitive.
"""
rv = None
if len(seq) == 2:
a, b = seq
if b.is_Rational:
a, b = b, a
assert a is not S.One
if not a.is_zero and a.is_Rational:
r, b = b.as_coeff_Mul()
if b.is_Add:
if r is not S.One: # 2-arg hack
if a*r is S.One:
rv = [b], [], None
else:
# leave the Mul as a Mul
rv = [cls(a*r, b, evaluate=False)], [], None
elif b.is_commutative:
if a is S.One:
rv = [b], [], None
else:
r, b = b.as_coeff_Add()
bargs = [_keep_coeff(a, bi) for bi in Add.make_args(b)]
_addsort(bargs)
ar = a*r
if ar:
bargs.insert(0, ar)
bargs = [Add._from_args(bargs)]
rv = bargs, [], None
if rv:
return rv
# apply associativity, separate commutative part of seq
c_part = [] # out: commutative factors
nc_part = [] # out: non-commutative factors
nc_seq = []
coeff = S.One # standalone term
# e.g. 3 * ...
c_powers = [] # (base,exp) n
# e.g. (x,n) for x
num_exp = [] # (num-base, exp) y
# e.g. (3, y) for ... * 3 * ...
neg1e = S.Zero # exponent on -1 extracted from Number-based Pow and I
pnum_rat = {} # (num-base, Rat-exp) 1/2
# e.g. (3, 1/2) for ... * 3 * ...
order_symbols = None
# --- PART 1 ---
#
# "collect powers and coeff":
#
# o coeff
# o c_powers
# o num_exp
# o neg1e
# o pnum_rat
#
# NOTE: this is optimized for all-objects-are-commutative case
for o in seq:
# O(x)
if o.is_Order:
o, order_symbols = o.as_expr_variables(order_symbols)
# Mul([...])
if o.is_Mul:
if o.is_commutative:
seq.extend(o.args) # XXX zerocopy?
else:
# NCMul can have commutative parts as well
for q in o.args:
if q.is_commutative:
seq.append(q)
else:
nc_seq.append(q)
# append non-commutative marker, so we don't forget to
# process scheduled non-commutative objects
seq.append(NC_Marker)
continue
# 3
elif o.is_Number:
if o is S.NaN or coeff is S.ComplexInfinity and o is S.Zero:
# we know for sure the result will be nan
return [S.NaN], [], None
elif coeff.is_Number: # it could be zoo
coeff *= o
if coeff is S.NaN:
# we know for sure the result will be nan
return [S.NaN], [], None
continue
elif o is S.ComplexInfinity:
if not coeff:
# 0 * zoo = NaN
return [S.NaN], [], None
if coeff is S.ComplexInfinity:
# zoo * zoo = zoo
return [S.ComplexInfinity], [], None
coeff = S.ComplexInfinity
continue
elif o is S.ImaginaryUnit:
neg1e += S.Half
continue
elif o.is_commutative:
# e
# o = b
b, e = o.as_base_exp()
# y
# 3
if o.is_Pow:
if b.is_Number:
# get all the factors with numeric base so they can be
# combined below, but don't combine negatives unless
# the exponent is an integer
if e.is_Rational:
if e.is_Integer:
coeff *= Pow(b, e) # it is an unevaluated power
continue
elif e.is_negative: # also a sign of an unevaluated power
seq.append(Pow(b, e))
continue
elif b.is_negative:
neg1e += e
b = -b
if b is not S.One:
pnum_rat.setdefault(b, []).append(e)
continue
elif b.is_positive or e.is_integer:
num_exp.append((b, e))
continue
elif b is S.ImaginaryUnit and e.is_Rational:
neg1e += e/2
continue
c_powers.append((b, e))
# NON-COMMUTATIVE
# TODO: Make non-commutative exponents not combine automatically
else:
if o is not NC_Marker:
nc_seq.append(o)
# process nc_seq (if any)
while nc_seq:
o = nc_seq.pop(0)
if not nc_part:
nc_part.append(o)
continue
# b c b+c
# try to combine last terms: a * a -> a
o1 = nc_part.pop()
b1, e1 = o1.as_base_exp()
b2, e2 = o.as_base_exp()
new_exp = e1 + e2
# Only allow powers to combine if the new exponent is
# not an Add. This allow things like a**2*b**3 == a**5
# if a.is_commutative == False, but prohibits
# a**x*a**y and x**a*x**b from combining (x,y commute).
if b1 == b2 and (not new_exp.is_Add):
o12 = b1 ** new_exp
# now o12 could be a commutative object
if o12.is_commutative:
seq.append(o12)
continue
else:
nc_seq.insert(0, o12)
else:
nc_part.append(o1)
nc_part.append(o)
# We do want a combined exponent if it would not be an Add, such as
# y 2y 3y
# x * x -> x
# We determine if two exponents have the same term by using
# as_coeff_Mul.
#
# Unfortunately, this isn't smart enough to consider combining into
# exponents that might already be adds, so things like:
# z - y y
# x * x will be left alone. This is because checking every possible
# combination can slow things down.
# gather exponents of common bases...
def _gather(c_powers):
new_c_powers = []
common_b = {} # b:e
for b, e in c_powers:
co = e.as_coeff_Mul()
common_b.setdefault(b, {}).setdefault(co[1], []).append(co[0])
for b, d in common_b.items():
for di, li in d.items():
d[di] = Add(*li)
for b, e in common_b.items():
for t, c in e.items():
new_c_powers.append((b, c*t))
return new_c_powers
# in c_powers
c_powers = _gather(c_powers)
# and in num_exp
num_exp = _gather(num_exp)
# --- PART 2 ---
#
# o process collected powers (x**0 -> 1; x**1 -> x; otherwise Pow)
# o combine collected powers (2**x * 3**x -> 6**x)
# with numeric base
# ................................
# now we have:
# - coeff:
# - c_powers: (b, e)
# - num_exp: (2, e)
# - pnum_rat: {(1/3, [1/3, 2/3, 1/4])}
# 0 1
# x -> 1 x -> x
for b, e in c_powers:
if e is S.One:
if b.is_Number:
coeff *= b
else:
c_part.append(b)
elif e is not S.Zero:
c_part.append(Pow(b, e))
# x x x
# 2 * 3 -> 6
inv_exp_dict = {} # exp:Mul(num-bases) x x
# e.g. x:6 for ... * 2 * 3 * ...
for b, e in num_exp:
inv_exp_dict.setdefault(e, []).append(b)
for e, b in inv_exp_dict.items():
inv_exp_dict[e] = cls(*b)
c_part.extend([Pow(b, e) for e, b in inv_exp_dict.items() if e])
# b, e -> e' = sum(e), b
# {(1/5, [1/3]), (1/2, [1/12, 1/4]} -> {(1/3, [1/5, 1/2])}
comb_e = {}
for b, e in pnum_rat.items():
comb_e.setdefault(Add(*e), []).append(b)
del pnum_rat
# process them, reducing exponents to values less than 1
# and updating coeff if necessary else adding them to
# num_rat for further processing
num_rat = []
for e, b in comb_e.items():
b = cls(*b)
if e.q == 1:
coeff *= Pow(b, e)
continue
if e.p > e.q:
e_i, ep = divmod(e.p, e.q)
coeff *= Pow(b, e_i)
e = Rational(ep, e.q)
num_rat.append((b, e))
del comb_e
# extract gcd of bases in num_rat
# 2**(1/3)*6**(1/4) -> 2**(1/3+1/4)*3**(1/4)
pnew = defaultdict(list)
i = 0 # steps through num_rat which may grow
while i < len(num_rat):
bi, ei = num_rat[i]
grow = []
for j in range(i + 1, len(num_rat)):
bj, ej = num_rat[j]
g = bi.gcd(bj)
if g is not S.One:
# 4**r1*6**r2 -> 2**(r1+r2) * 2**r1 * 3**r2
# this might have a gcd with something else
e = ei + ej
if e.q == 1:
coeff *= Pow(g, e)
else:
if e.p > e.q:
e_i, ep = divmod(e.p, e.q) # change e in place
coeff *= Pow(g, e_i)
e = Rational(ep, e.q)
grow.append((g, e))
# update the jth item
num_rat[j] = (bj/g, ej)
# update bi that we are checking with
bi = bi/g
if bi is S.One:
break
if bi is not S.One:
obj = Pow(bi, ei)
if obj.is_Number:
coeff *= obj
else:
# changes like sqrt(12) -> 2*sqrt(3)
for obj in Mul.make_args(obj):
if obj.is_Number:
coeff *= obj
else:
assert obj.is_Pow
bi, ei = obj.args
pnew[ei].append(bi)
num_rat.extend(grow)
i += 1
# combine bases of the new powers
for e, b in pnew.items():
pnew[e] = cls(*b)
# handle -1 and I
if neg1e:
# treat I as (-1)**(1/2) and compute -1's total exponent
p, q = neg1e.as_numer_denom()
# if the integer part is odd, extract -1
n, p = divmod(p, q)
if n % 2:
coeff = -coeff
# if it's a multiple of 1/2 extract I
if q == 2:
c_part.append(S.ImaginaryUnit)
elif p:
# see if there is any positive base this power of
# -1 can join
neg1e = Rational(p, q)
for e, b in pnew.items():
if e == neg1e and b.is_positive:
pnew[e] = -b
break
else:
# keep it separate; we've already evaluated it as
# much as possible so evaluate=False
c_part.append(Pow(S.NegativeOne, neg1e, evaluate=False))
# add all the pnew powers
c_part.extend([Pow(b, e) for e, b in pnew.items()])
# oo, -oo
if (coeff is S.Infinity) or (coeff is S.NegativeInfinity):
def _handle_for_oo(c_part, coeff_sign):
new_c_part = []
for t in c_part:
if t.is_positive:
continue
if t.is_negative:
coeff_sign *= -1
continue
new_c_part.append(t)
return new_c_part, coeff_sign
c_part, coeff_sign = _handle_for_oo(c_part, 1)
nc_part, coeff_sign = _handle_for_oo(nc_part, coeff_sign)
coeff *= coeff_sign
# zoo
if coeff is S.ComplexInfinity:
# zoo might be
# infinite_real + bounded_im
# bounded_real + infinite_im
# infinite_real + infinite_im
# and non-zero real or imaginary will not change that status.
c_part = [c for c in c_part if not (c.is_nonzero and
c.is_extended_real is not None)]
nc_part = [c for c in nc_part if not (c.is_nonzero and
c.is_extended_real is not None)]
# 0
elif coeff is S.Zero:
# we know for sure the result will be 0 except the multiplicand
# is infinity
if any(c.is_finite is False for c in c_part):
return [S.NaN], [], order_symbols
return [coeff], [], order_symbols
# check for straggling Numbers that were produced
_new = []
for i in c_part:
if i.is_Number:
coeff *= i
else:
_new.append(i)
c_part = _new
# order commutative part canonically
_mulsort(c_part)
# current code expects coeff to be always in slot-0
if coeff is not S.One:
c_part.insert(0, coeff)
# we are done
if (not nc_part and len(c_part) == 2 and c_part[0].is_Number and
c_part[1].is_Add):
# 2*(1+a) -> 2 + 2 * a
coeff = c_part[0]
c_part = [Add(*[coeff*f for f in c_part[1].args])]
return c_part, nc_part, order_symbols
def _eval_power(b, e):
# don't break up NC terms: (A*B)**3 != A**3*B**3, it is A*B*A*B*A*B
cargs, nc = b.args_cnc(split_1=False)
if e.is_Integer:
return Mul(*[Pow(b, e, evaluate=False) for b in cargs]) * \
Pow(Mul._from_args(nc), e, evaluate=False)
p = Pow(b, e, evaluate=False)
if e.is_Rational or e.is_Float:
return p._eval_expand_power_base()
return p
@classmethod
def class_key(cls):
return 3, 0, cls.__name__
def _eval_evalf(self, prec):
c, m = self.as_coeff_Mul()
if c is S.NegativeOne:
if m.is_Mul:
rv = -AssocOp._eval_evalf(m, prec)
else:
mnew = m._eval_evalf(prec)
if mnew is not None:
m = mnew
rv = -m
else:
rv = AssocOp._eval_evalf(self, prec)
if rv.is_number:
return rv.expand()
return rv
@cacheit
def as_two_terms(self):
"""Return head and tail of self.
This is the most efficient way to get the head and tail of an
expression.
- if you want only the head, use self.args[0];
- if you want to process the arguments of the tail then use
self.as_coef_mul() which gives the head and a tuple containing
the arguments of the tail when treated as a Mul.
- if you want the coefficient when self is treated as an Add
then use self.as_coeff_add()[0]
>>> from sympy.abc import x, y
>>> (3*x*y).as_two_terms()
(3, x*y)
"""
args = self.args
if len(args) == 1:
return S.One, self
elif len(args) == 2:
return args
else:
return args[0], self._new_rawargs(*args[1:])
@cacheit
def as_coeff_mul(self, *deps, **kwargs):
rational = kwargs.pop('rational', True)
if deps:
l1 = []
l2 = []
for f in self.args:
if f.has(*deps):
l2.append(f)
else:
l1.append(f)
return self._new_rawargs(*l1), tuple(l2)
args = self.args
if args[0].is_Number:
if not rational or args[0].is_Rational:
return args[0], args[1:]
elif args[0].is_negative:
return S.NegativeOne, (-args[0],) + args[1:]
return S.One, args
def as_coeff_Mul(self, rational=False):
"""Efficiently extract the coefficient of a product. """
coeff, args = self.args[0], self.args[1:]
if coeff.is_Number:
if not rational or coeff.is_Rational:
if len(args) == 1:
return coeff, args[0]
else:
return coeff, self._new_rawargs(*args)
elif coeff.is_negative:
return S.NegativeOne, self._new_rawargs(*((-coeff,) + args))
return S.One, self
def as_real_imag(self, deep=True, **hints):
from sympy import Abs, expand_mul, im, re
other = []
coeffr = []
coeffi = []
addterms = S.One
for a in self.args:
if a.is_extended_real:
coeffr.append(a)
elif a.is_imaginary:
coeffi.append(a)
elif a.is_commutative:
# search for complex conjugate pairs:
for i, x in enumerate(other):
if x == a.conjugate():
coeffr.append(Abs(x)**2)
del other[i]
break
else:
if a.is_Add:
addterms *= a
else:
other.append(a)
else:
other.append(a)
m = self.func(*other)
if hints.get('ignore') == m:
return
if len(coeffi) % 2:
imco = im(coeffi.pop(0))
# all other pairs make a real factor; they will be
# put into reco below
else:
imco = S.Zero
reco = self.func(*(coeffr + coeffi))
r, i = (reco*re(m), reco*im(m))
if addterms == 1:
if m == 1:
if imco is S.Zero:
return (reco, S.Zero)
else:
return (S.Zero, reco*imco)
if imco is S.Zero:
return (r, i)
return (-imco*i, imco*r)
addre, addim = expand_mul(addterms, deep=False).as_real_imag()
if imco is S.Zero:
return (r*addre - i*addim, i*addre + r*addim)
else:
r, i = -imco*i, imco*r
return (r*addre - i*addim, r*addim + i*addre)
@staticmethod
def _expandsums(sums):
"""
Helper function for _eval_expand_mul.
sums must be a list of instances of Basic.
"""
L = len(sums)
if L == 1:
return sums[0].args
terms = []
left = Mul._expandsums(sums[:L//2])
right = Mul._expandsums(sums[L//2:])
terms = [Mul(a, b) for a in left for b in right]
added = Add(*terms)
return Add.make_args(added) # it may have collapsed down to one term
def _eval_expand_mul(self, **hints):
from sympy import fraction
# Handle things like 1/(x*(x + 1)), which are automatically converted
# to 1/x*1/(x + 1)
expr = self
n, d = fraction(expr)
if d.is_Mul:
n, d = [i._eval_expand_mul(**hints) if i.is_Mul else i
for i in (n, d)]
expr = n/d
if not expr.is_Mul:
return expr
plain, sums, rewrite = [], [], False
for factor in expr.args:
if factor.is_Add:
sums.append(factor)
rewrite = True
else:
if factor.is_commutative:
plain.append(factor)
else:
sums.append(Basic(factor)) # Wrapper
if not rewrite:
return expr
else:
plain = self.func(*plain)
if sums:
terms = self.func._expandsums(sums)
args = []
for term in terms:
t = self.func(plain, term)
if t.is_Mul and any(a.is_Add for a in t.args):
t = t._eval_expand_mul()
args.append(t)
return Add(*args)
else:
return plain
@cacheit
def _eval_derivative(self, s):
args = list(self.args)
terms = []
for i in range(len(args)):
d = args[i].diff(s)
if d:
terms.append(self.func(*(args[:i] + [d] + args[i + 1:])))
return Add(*terms)
def _matches_simple(self, expr, repl_dict):
# handle (w*3).matches('x*5') -> {w: x*5/3}
coeff, terms = self.as_coeff_Mul()
terms = Mul.make_args(terms)
if len(terms) == 1:
newexpr = self.__class__._combine_inverse(expr, coeff)
return terms[0].matches(newexpr, repl_dict)
return
def matches(self, expr, repl_dict={}):
expr = sympify(expr)
if self.is_commutative and expr.is_commutative:
return AssocOp._matches_commutative(self, expr, repl_dict)
elif self.is_commutative is not expr.is_commutative:
return
c1, nc1 = self.args_cnc()
c2, nc2 = expr.args_cnc()
repl_dict = repl_dict.copy()
if c1:
if not c2:
c2 = [1]
a = self.func(*c1)
if isinstance(a, AssocOp):
repl_dict = a._matches_commutative(self.func(*c2), repl_dict)
else:
repl_dict = a.matches(self.func(*c2), repl_dict)
if repl_dict:
a = self.func(*nc1)
if isinstance(a, self.func):
repl_dict = a._matches(self.func(*nc2), repl_dict)
else:
repl_dict = a.matches(self.func(*nc2), repl_dict)
return repl_dict or None
def _matches(self, expr, repl_dict={}):
# weed out negative one prefixes#
from sympy import Wild
sign = 1
a, b = self.as_two_terms()
if a is S.NegativeOne:
if b.is_Mul:
sign = -sign
else:
# the remainder, b, is not a Mul anymore
return b.matches(-expr, repl_dict)
expr = sympify(expr)
if expr.is_Mul and expr.args[0] is S.NegativeOne:
expr = -expr
sign = -sign
if not expr.is_Mul:
# expr can only match if it matches b and a matches +/- 1
if len(self.args) == 2:
# quickly test for equality
if b == expr:
return a.matches(Rational(sign), repl_dict)
# do more expensive match
dd = b.matches(expr, repl_dict)
if dd is None:
return
dd = a.matches(Rational(sign), dd)
return dd
return
d = repl_dict.copy()
# weed out identical terms
pp = list(self.args)
ee = list(expr.args)
for p in self.args:
if p in expr.args:
ee.remove(p)
pp.remove(p)
# only one symbol left in pattern -> match the remaining expression
if len(pp) == 1 and isinstance(pp[0], Wild):
if len(ee) == 1:
d[pp[0]] = sign * ee[0]
else:
d[pp[0]] = sign * expr.func(*ee)
return d
if len(ee) != len(pp):
return
for p, e in zip(pp, ee):
d = p.xreplace(d).matches(e, d)
if d is None:
return
return d
@staticmethod
def _combine_inverse(lhs, rhs):
"""
Returns lhs/rhs, but treats arguments like symbols, so things like
oo/oo return 1, instead of a nan.
"""
if lhs == rhs:
return S.One
def check(l, r):
if l.is_Float and r.is_comparable:
# if both objects are added to 0 they will share the same "normalization"
# and are more likely to compare the same. Since Add(foo, 0) will not allow
# the 0 to pass, we use __add__ directly.
return l.__add__(0) == r.evalf().__add__(0)
return False
if check(lhs, rhs) or check(rhs, lhs):
return S.One
if lhs.is_Mul and rhs.is_Mul:
a = list(lhs.args)
b = [1]
for x in rhs.args:
if x in a:
a.remove(x)
elif -x in a:
a.remove(-x)
b.append(-1)
else:
b.append(x)
return lhs.func(*a)/rhs.func(*b)
return lhs/rhs
def as_powers_dict(self):
d = defaultdict(int)
for term in self.args:
b, e = term.as_base_exp()
d[b] += e
return d
def as_numer_denom(self):
# don't use _from_args to rebuild the numerators and denominators
# as the order is not guaranteed to be the same once they have
# been separated from each other
numers, denoms = list(zip(*[f.as_numer_denom() for f in self.args]))
return self.func(*numers), self.func(*denoms)
def as_base_exp(self):
e1 = None
bases = []
nc = 0
for m in self.args:
b, e = m.as_base_exp()
if not b.is_commutative:
nc += 1
if e1 is None:
e1 = e
elif e != e1 or nc > 1:
return self, S.One
bases.append(b)
return self.func(*bases), e1
def _eval_is_polynomial(self, syms):
return all(term._eval_is_polynomial(syms) for term in self.args)
def _eval_is_rational_function(self, syms):
return all(term._eval_is_rational_function(syms) for term in self.args)
def _eval_is_algebraic_expr(self, syms):
return all(term._eval_is_algebraic_expr(syms) for term in self.args)
def _eval_is_finite(self):
return _fuzzy_group(a.is_finite for a in self.args)
def _eval_is_commutative(self):
return _fuzzy_group(a.is_commutative for a in self.args)
def _eval_is_complex(self):
return _fuzzy_group((a.is_complex for a in self.args), quick_exit=True)
def _eval_is_infinite(self):
if any(a.is_infinite for a in self.args):
if any(a.is_zero for a in self.args):
return S.NaN.is_infinite
if any(a.is_zero is None for a in self.args):
return
return True
def _eval_is_rational(self):
r = _fuzzy_group((a.is_rational for a in self.args), quick_exit=True)
if r:
return r
elif r is False:
return self.is_zero
def _eval_is_algebraic(self):
r = _fuzzy_group((a.is_algebraic for a in self.args), quick_exit=True)
if r:
return r
elif r is False:
return self.is_zero
def _eval_is_zero(self):
zero = infinite = False
for a in self.args: