Skip to content

Commit

Permalink
Fix Bessel order gathering.
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed May 13, 2012
1 parent a94843a commit ab1e425
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions sumpy/codegen.py
Expand Up @@ -102,6 +102,8 @@ def bessel_j(self, order, arg):
elif order < 0: elif order < 0:
return (-1)**order*self.bessel_j(-order, arg) return (-1)**order*self.bessel_j(-order, arg)
else: else:
assert abs(order) < top_order

# AS (9.1.27) # AS (9.1.27)
nu = order+1 nu = order+1
return prim.CommonSubexpression( return prim.CommonSubexpression(
Expand All @@ -114,6 +116,10 @@ def bessel_j(self, order, arg):




class BesselTopOrderGatherer(WalkMapper): class BesselTopOrderGatherer(WalkMapper):
"""This mapper walks the expression tree to find the highest-order
Bessel J being used, so that all other Js can be computed by the
(stable) downward recurrence.
"""
def __init__(self): def __init__(self):
self.bessel_j_arg_to_top_order = {} self.bessel_j_arg_to_top_order = {}


Expand Down Expand Up @@ -142,14 +148,20 @@ def map_substitution(self, expr):
order, _ = call.parameters order, _ = call.parameters
arg, = expr.values arg, = expr.values


# AS (9.1.31)
n_derivs = len(expr.child.variables) n_derivs = len(expr.child.variables)
import sympy as sp import sympy as sp

# AS (9.1.31)
if order >= 0:
order_str = str(order)
else:
order_str = "m"+str(-order)
k = n_derivs
return prim.CommonSubexpression( return prim.CommonSubexpression(
2**(-n_derivs)*sum( 2**(-k)*sum(
(-1)**idx*int(sp.binomial(n_derivs, idx)) * function(i, arg) (-1)**idx*int(sp.binomial(k, idx)) * function(i, arg)
for idx, i in enumerate(range(order-n_derivs, order+n_derivs+1, 2))), for idx, i in enumerate(range(order-k, order+k+1, 2))),
"d%d_%s_%d" % (n_derivs, function.name, order)) "d%d_%s_%s" % (n_derivs, function.name, order_str))
else: else:
return IdentityMapper.map_substitution(self, expr) return IdentityMapper.map_substitution(self, expr)


Expand Down Expand Up @@ -304,13 +316,14 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[]):
sympy_conv = SympyToPymbolicMapper() sympy_conv = SympyToPymbolicMapper()
assignments = [(name, sympy_conv(expr)) for name, expr in assignments] assignments = [(name, sympy_conv(expr)) for name, expr in assignments]


# gather information bdr = BesselDerivativeReplacer()
assignments = [(name, bdr(expr)) for name, expr in assignments]

btog = BesselTopOrderGatherer() btog = BesselTopOrderGatherer()
for name, expr in assignments: for name, expr in assignments:
btog(expr) btog(expr)


# do the rest of the conversion # do the rest of the conversion
bdr = BesselDerivativeReplacer()
bessel_sub = BesselSubstitutor(BesselGetter(btog.bessel_j_arg_to_top_order)) bessel_sub = BesselSubstitutor(BesselGetter(btog.bessel_j_arg_to_top_order))
vcr = VectorComponentRewriter(vector_names) vcr = VectorComponentRewriter(vector_names)
pwr = PowerRewriter() pwr = PowerRewriter()
Expand All @@ -328,6 +341,8 @@ def convert_expr(expr):
expr = m(expr) expr = m(expr)
return expr return expr




import loopy as lp import loopy as lp
return [ return [
lp.Instruction(id=None, lp.Instruction(id=None,
Expand Down

0 comments on commit ab1e425

Please sign in to comment.