Permalink
Browse files

Change _cg_simp_add logic to use pattern matching

  • Loading branch information...
1 parent 4d44a0e commit ca33290cd27c1168a7b5ef1319b09a1bd9ed8972 @flacjacket committed Jul 5, 2011
Showing with 171 additions and 168 deletions.
  1. +167 −164 sympy/physics/quantum/cg.py
  2. +4 −4 sympy/physics/quantum/tests/test_cg.py
View
@@ -4,7 +4,7 @@
# -Implement new simpifications
"""Clebsch-Gordon Coefficients."""
-from sympy import Expr, Add, Function, Mul, Pow, sqrt, Sum, symbols, sympify, Wild
+from sympy import Expr, Add, Function, Mul, Pow, sqrt, Sum, symbols, sympify, Wild, expand
from sympy.printing.pretty.stringpict import prettyForm, stringPict
from sympy.physics.quantum.kronecker import KroneckerDelta
@@ -265,135 +265,168 @@ def _cg_simp_add(e):
"""Takes a sum of terms involving Clebsch-Gordan coefficients and
simplifies the terms.
- This method creates lists of the terms in each of the arguments of the
- summation. Each term is then parsed for the Clebsch-Gordan coefficients in
- it and, where applicable, apply symmetries to the simplify the terms in
- the sum.
-
- References
- ==========
-
- [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988.
+ First, we create two lists, cg_part, which is all the terms involving CG
+ coefficients, and other_part, which is all other terms. The cg_part list
+ is then passed to the simplification methods, which return the new cg_part
+ and any additional terms that are added to other_part
"""
cg_part = []
other_part = []
+
+ e = expand(e)
for arg in e.args:
if arg.has(CG):
- if isinstance(arg, CG):
- terms = [arg]
- elif isinstance(arg, Sum):
+ if isinstance(arg, Sum):
other_part.append(_cg_simp_sum(arg))
- continue
elif isinstance(arg, Mul):
- terms = []
+ terms = 1
for term in arg.args:
- if isinstance(term, Pow) and sympify(term.exp).is_number:
- [ terms.append(term.base) for _ in range(term.exp) ]
- elif isinstance(term, Sum):
- terms.append(_cg_simp_sum(term))
+ if isinstance(term, Sum):
+ terms *= _cg_simp_sum(term)
else:
- terms.append(term)
- elif isinstance(arg, Pow):
- terms = []
- if sympify(arg.exp).is_number:
- [ terms.append(arg.base) for i in range(arg.exp) ]
+ terms *= term
+ if terms.has(CG):
+ cg_part.append(terms)
else:
- other_part.append(arg)
- continue
+ other_part.append(terms)
else:
- other_part.append(arg)
- continue
- cg_part.append(terms)
+ cg_part.append(arg)
else:
other_part.append(arg)
+
+ cg_part, other = _check_varsh_871_1(cg_part)
+ other_part.append(other)
+ cg_part, other = _check_varsh_871_2(cg_part)
+ other_part.append(other)
+ cg_part, other = _check_varsh_872_9(cg_part)
+ other_part.append(other)
+ return Add(*cg_part)+Add(*other_part)
+
+def _check_varsh_871_1(term_list):
+ # Sum( CG(a,alpha,b,0,a,alpha), (alpha, -a, a)) == KroneckerDelta(b,0)
+ a,alpha,b,lt = map(Wild,('a','alpha','b','lt'))
+ expr = lt*CG(a,alpha,b,0,a,alpha)
+ simp = (2*a+1)*KroneckerDelta(b,0)
+ sign = lt/abs(lt)
+ build_expr = 2*a+1
+ index_expr = a+alpha
+ return _check_cg_simp(expr, simp, sign, lt, term_list, (a,alpha,b,lt), (a,b), build_expr, index_expr)
+
+
+def _check_varsh_871_2(term_list):
+ # Sum((-1)**(a-alpha)*CG(a,alpha,a,-alpha,c,0),(alpha,-a,a))
+ a,alpha,c,lt = map(Wild,('a','alpha','c','lt'))
+ expr = lt*CG(a,alpha,a,-alpha,c,0)
+ simp = sqrt(2*a+1)*KroneckerDelta(c,0)
+ sign = (-1)**(a-alpha)*lt/abs(lt)
+ build_expr = 2*a+1
+ index_expr = a+alpha
+ return _check_cg_simp(expr, simp, sign, lt, term_list, (a,alpha,c,lt), (a,c), build_expr, index_expr)
+
+def _check_varsh_872_9(term_list):
+ # Sum( CG(a,alpha,b,beta,c,gamma)*CG(a,alpha',b,beta',c,gamma), (gamma, -c, c), (c, abs(a-b), a+b))
+ a,alpha,b,beta,c,gamma,lt = map(Wild, ('a','alpha','b','beta','c','gamma','lt'))
+ expr = lt*CG(a,alpha,b,beta,c,gamma)**2
+ simp = 1
+ sign = lt/abs(lt)
+ # Note: there are some weird evaluation issues with max and pattern matching
+ # The following expression is equivalent to:
+ #build_expr = a+b+1-max(abs(a-b),abs(alpha+beta))
+ # but this expression can have some strange evaluation problems, as well
+ x = abs(a-b)
+ y = abs(alpha+beta)
+ build_expr = a+b+1-(x*(1+cmp(x,y)) + y*(1+cmp(y,x)))/2
+ index_expr = a+b-c
+ return _check_cg_simp(expr, simp, sign, lt, term_list, (a,alpha,b,beta,c,gamma,lt), (a,alpha,b,beta), build_expr, index_expr)
+
+def _check_cg_simp(expr, simp, sign, lt, term_list, variables, dep_variables, build_index_expr, index_expr):
+ """ Checks for simplifications that can be made, returning a tuple of the
+ simplified list of terms and any terms generated by simplification.
+
+ Parameters
+ ==========
+
+ expr: expression
+ The expression with Wild terms that will be matched to the terms in
+ the sum
+
+ simp: expression
+ The expression with Wild terms that is substituted in place of the CG
+ terms in the case of simplification
+
+ sign: expression
+ The expression with Wild terms denoting the sign that is on expr that
+ must match
+
+ lt: expression
+ The expression with Wild terms that gives the leading term of the
+ matched expr
+
+ term_list: list
+ A list of all of the terms is the sum to be simplified
+
+ variables: list
+ A list of all the variables that appears in expr
+
+ dep_variables: list
+ A list of the variables that must match for all the terms in the sum,
+ i.e. the dependant variables
+
+ build_index_expr: expression
+ Expression with Wild terms giving the number of elements in cg_index
+
+ index_expr: expression
+ Expression with Wild terms giving the index terms have when storing
+ them to cg_index
+
+ """
+ other_part = 0
i = 0
- while i < len(cg_part):
- cg, coeff, sign = _cg_list(cg_part[i])
- cg_count = len(cg)
- if cg_count == 1:
- if not cg[0].j1.is_number:
- i += 1
- continue
- # Simplifies Varshalovich 8.7.1 Eq 1
- # Sum( CG(a,alpha,b,0,a,alpha), (alpha,-a,a) ) = (2*a+1) * delta_b,0
- cg_index = [None]*(2*cg[0].j1+1)
- for term in cg_part:
- cg2, coeff2, sign2 = _cg_list(term)
- if not len(cg2) == 1:
- continue
- cg2, coeff2, sign2 = _has_cg(_check_871_1, cg, sign, cg2, coeff2, sign2)
- if not len(cg2) == 1:
- continue
- if not cg2[0].m1.is_number or not cg2[0].j1.is_number:
- continue
- cg_index[cg2[0].m1+cg2[0].j1] = term,cg2[0],coeff2
- if cg_index.count(None) == 0:
- min_coeff = min([abs(term[2]) for term in cg_index])
- for term in cg_index:
- cg_part.pop(cg_part.index(term[0]))
- if not term[2] == min_coeff*sign:
- cg_part.append((term[1],term[2]-min_coeff*sign))
- other_part.append((2*cg[0].j1+1)*min_coeff*sign*KroneckerDelta(cg[0].j2,0))
- continue
- # Simplifies Varshalovich 8.7.1 Eq 2
- # Sum( (-1)**(a-alpha) * CG(a,alpha,a,-alpha,c,0), (alpha,-a,a) ) = sqrt(2*a+1) * delta_c,0
- cg_index = [None]*(2*cg[0].j1+1)
- for term in cg_part:
- cg2, coeff2, sign2 = _cg_list(term)
- if not len(cg2) == 1:
- continue
- cg2, coeff2, sign2 = _has_cg(_check_871_2, cg, sign, cg2, coeff2, sign2)
- if not len(cg2) == 1:
- continue
- if not cg2[0].m1.is_number or not cg2[0].j1.is_number:
- continue
- cg_index[cg2[0].m1+cg2[0].j1] = term,cg2[0],coeff2
- if cg_index.count(None) == 0:
- min_coeff = min([abs(term[2]) for term in cg_index])
- for term in cg_index:
- sign2 = (-1)**((term[1].j1-term[1].m1)+(cg[0].j1-cg[0].m1))
- cg_part.pop(cg_part.index(term[0]))
- if not term[2] == min_coeff*sign*sign2:
- cg_part.append((term[1],term[2]-min_coeff*sign*sign2))
- other_part.append(sqrt(2*cg[0].j1+1)*min_coeff*sign*sign2*KroneckerDelta(cg[0].j3,0))
- continue
- if cg_count == 2:
- if not (cg[0].j1.is_number and cg[0].j2.is_number):
- i += 1
- continue
- # Simplifies Varshalovich 8.7.2 Eq 9
- # Sum( CG(a,alpha,b,beta,c,gamma) * CG(a,alpha',b,beta',c,gamma), (c, abs(a-b), a+b), (gamma, -c, c) ) =
- # delta_alpha,alpha' * delta_beta,beta'
- if cg[0].m1.is_number and cg[0].m2.is_number:
- cg_index = [None]*(cg[0].j1+cg[0].j2-max(abs(cg[0].j1-cg[0].j2),abs(cg[0].m1+cg[0].m2))+1)
- else:
- # TODO: Symbolic simplification for this case
- #cg_index = [None]*(cg[0].j1+cg[0].j2-abs(cg[0].j1-cg[0].j2)+1)
+ while i < len(term_list):
+ sub_1 = _check_cg(term_list[i], expr, len(variables))
+ if sub_1 is None:
+ i += 1
+ continue
+ if not sympify(build_index_expr.subs(sub_1)).is_number:
+ i += 1
+ continue
+ sub_dep = [(x,sub_1[x]) for x in dep_variables]
+ cg_index = [None] * build_index_expr.subs(sub_1)
+ for j in range(i,len(term_list)):
+ sub_2 = _check_cg(term_list[j], expr.subs(sub_dep), len(variables)-len(dep_variables), sign=(sign.subs(sub_1),sign.subs(sub_dep)))
+ if sub_2 is None:
continue
- for term in cg_part:
- cg2, coeff2, sign2 = _cg_list(term)
- if not len(cg2) == 2:
- continue
- cg2, coeff2, sign2 = _has_cg(_check_872_9, cg, sign, cg2, coeff2, sign2)
- if not len(cg2) == 2:
- continue
- if cg[0].m1.is_number and cg[0].m2.is_number:
- cg_index[cg2[0].j3-max(abs(cg[0].j1-cg[0].j2),abs(cg[0].m1+cg[0].m2))] = term,cg2[0]*cg2[1],coeff2
- else:
- #cg_index[cg2[0].j3-abs(cg[0].j1-cg[0].j2)] = term,cg2[0]*cg2[1],coeff2
- continue
- if cg_index.count(None) == 0:
- min_coeff = min([abs(term[2]) for term in cg_index])
- for term in cg_index:
- cg_part.pop(cg_part.index(term[0]))
- if not term[2] == min_coeff*sign:
- cg_part.append((term[1],term[2]-min_coeff*sign))
- other_part.append(min_coeff*sign*KroneckerDelta(cg[0].m1,cg[1].m1)*KroneckerDelta(cg[0].m2,cg[1].m2))
+ if not sympify(index_expr.subs(sub_dep).subs(sub_2)).is_number:
continue
- i += 1
- return Add(*[Mul(*i) for i in cg_part])+Add(*other_part)
+ cg_index[index_expr.subs(sub_dep).subs(sub_2)] = j, expr.subs(lt,1).subs(sub_dep).subs(sub_2), lt.subs(sub_2), sign.subs(sub_dep).subs(sub_2)
+ if cg_index.count(None) == 0:
+ min_lt = min(*[ abs(term[2]) for term in cg_index ])
+ indicies = [ term[0] for term in cg_index]
+ indicies.sort()
+ indicies.reverse()
+ [ term_list.pop(i) for i in indicies ]
+ for term in cg_index:
+ if abs(term[2]) > min_lt:
+ term_list.append( (term[2]-min_lt*term[3]) * term[1] )
+ other_part += min_lt * (sign*simp).subs(sub_1)
+ else:
+ i += 1
+ return term_list, other_part
+
+def _check_cg(cg_term, expr, length, sign=None):
+ """Checks whether a term matches the given expression"""
+ # TODO: Check for symmetries
+ matches = cg_term.match(expr)
+ if matches is None:
+ return
+ if not sign is None:
+ if not isinstance(sign, tuple):
+ raise TypeError('sign must be a tuple')
+ if not sign[0] == (sign[1]).subs(matches):
+ return
+ if len(matches) == length:# and sign1 == sign2.subs(matches):
+ return matches
-#TODO: Implement symbolic simplification of Sum objects
def _cg_simp_sum(e):
e = _check_varsh_sum_871_1(e)
e = _check_varsh_sum_871_2(e)
@@ -406,7 +439,7 @@ def _check_varsh_sum_871_1(e):
b = Wild('b')
match = e.match(Sum(CG(a,alpha,b,0,a,alpha),(alpha,-a,a)))
if not match is None and len(match) == 2:
- return (2*match.get(a)+1)*KroneckerDelta(match.get(b),0)
+ return ((2*a+1)*KroneckerDelta(b,0)).subs(match)
return e
def _check_varsh_sum_871_2(e):
@@ -415,7 +448,7 @@ def _check_varsh_sum_871_2(e):
c = Wild('c')
match = e.match(Sum((-1)**(a-alpha)*CG(a,alpha,a,-alpha,c,0),(alpha,-a,a)))
if not match is None and len(match) == 2:
- return sqrt(2*match.get(a)+1)*KroneckerDelta(match.get(c),0)
+ return (sqrt(2*a+1)*KroneckerDelta(c,0)).subs(match)
return e
def _check_varsh_sum_872_4(e):
@@ -429,59 +462,29 @@ def _check_varsh_sum_872_4(e):
gammap = Wild('gammap')
match1 = e.match(Sum(CG(a,alpha,b,beta,c,gamma)*CG(a,alpha,b,beta,cp,gammap),(alpha,-a,a),(beta,-b,b)))
if not match1 is None and len(match1) == 8:
- return KroneckerDelta(match1.get(c),match1.get(cp))*KroneckerDelta(match1.get(gamma),match1.get(gammap))
+ return (KroneckerDelta(c,cp)*KroneckerDelta(gamma,gammap)).subs(match1)
match2 = e.match(Sum(CG(a,alpha,b,beta,c,gamma)**2,(alpha,-a,a),(beta,-b,b)))
if not match2 is None and len(match2) == 6:
return 1
return e
-def _check_871_1(cg1, cg2, sign1, sign2):
- cg1 = cg1[0]
- cg2 = cg2[0]
- if cg1.j1 == cg2.j1 and cg1.j2 == cg2.j2 and cg1.j3 == cg2.j3 and \
- cg2.m2 == 0 and cg2.m1 == cg2.m3 and cg2.j1 == cg2.j3 and \
- sign1 == sign2:
- return True
- return False
-
-def _check_871_2(cg1, cg2, sign1, sign2):
- cg1 = cg1[0]
- cg2 = cg2[0]
- if cg1.j1 == cg2.j1 and cg1.j2 == cg2.j2 and cg1.j3 == cg2.j3 and \
- cg2.j1 == cg2.j2 and cg2.m3 == 0 and cg2.m1 == -cg2.m2 and \
- sign1*(-1)**(cg1.j1-cg1.m1) == sign2*(-1)**(cg2.j1-cg2.m1):
- return True
- return False
-
-def _check_872_9(cg1, cg2, sign1, sign2):
- cg1_1 = cg1[0]
- cg1_2 = cg1[1]
- cg2_1 = cg2[0]
- cg2_2 = cg2[1]
- if cg1_1.j1 == cg2_1.j1 and cg1_1.j2 == cg2_1.j2 and cg1_1.m1 == cg2_1.m1 and cg1_1.m2 == cg2_1.m2 and \
- cg1_2.j1 == cg2_2.j1 and cg1_2.j2 == cg2_2.j2 and cg1_2.m1 == cg2_2.m1 and cg1_2.m2 == cg2_2.m2 and \
- cg2_1.j1 == cg2_2.j1 and cg2_1.j2 == cg2_2.j2 and cg2_1.j3 == cg2_2.j3 and cg2_1.m3 == cg2_2.m3 and \
- sign1 == sign2:
- return True
- return False
-
-def _cg_list(arg_list):
- coeff = 1
+
+def _cg_list(term):
+ if isinstance(term, CG):
+ return (term,), 1, 1
cg = []
- for term in arg_list:
- if isinstance(term, CG):
- cg.append(term)
- elif isinstance(term, Pow):
- terms = []
- if isinstance(term.base, CG) and sympify(term.exp).is_number:
- [ cg.append(term.base) for i in range(term.exp) ]
+ coeff = 1
+ if not (isinstance(term, Mul) or isinstance(term, Pow)):
+ raise NotImplementedError('term must be CG, Add, Mul or Pow')
+ if isinstance(term, Pow) and sympify(term.exp).is_number:
+ if sympify(term.exp).is_number:
+ [ cg.append(term.base) for _ in range(term.exp) ]
else:
- coeff *= term
- return cg, coeff, coeff/abs(coeff)
-
-
-def _has_cg(f, cg1, sign1, cg2, coeff2, sign2):
- # This should be extended to check if symmetries can be applied to give the necessary cg terms
- if f(cg1, cg2, sign1, sign2):
- return cg2, coeff2, sign2
- return [], [], []
+ return (term,), 1, 1
+ if isinstance(term, Mul):
+ for arg in term.args:
+ if isinstance(arg, CG):
+ cg.append(arg)
+ else:
+ coeff *= arg
+ return cg, coeff, coeff/abs(coeff)
Oops, something went wrong.

0 comments on commit ca33290

Please sign in to comment.