Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Fix Bessel order gathering.

  • Loading branch information...
commit ab1e42584a97be2056ac6024dc568d06e197df99 1 parent a94843a
@inducer authored
Showing with 22 additions and 7 deletions.
  1. +22 −7 sumpy/codegen.py
View
29 sumpy/codegen.py
@@ -102,6 +102,8 @@ def bessel_j(self, order, arg):
elif order < 0:
return (-1)**order*self.bessel_j(-order, arg)
else:
+ assert abs(order) < top_order
+
# AS (9.1.27)
nu = order+1
return prim.CommonSubexpression(
@@ -114,6 +116,10 @@ def bessel_j(self, order, arg):
class BesselTopOrderGatherer(WalkMapper):
+ """This mapper walks the expression tree to find the highest-order
+ Bessel J being used, so that all other Js can be computed by the
+ (stable) downward recurrence.
+ """
def __init__(self):
self.bessel_j_arg_to_top_order = {}
@@ -142,14 +148,20 @@ def map_substitution(self, expr):
order, _ = call.parameters
arg, = expr.values
- # AS (9.1.31)
n_derivs = len(expr.child.variables)
import sympy as sp
+
+ # AS (9.1.31)
+ if order >= 0:
+ order_str = str(order)
+ else:
+ order_str = "m"+str(-order)
+ k = n_derivs
return prim.CommonSubexpression(
- 2**(-n_derivs)*sum(
- (-1)**idx*int(sp.binomial(n_derivs, idx)) * function(i, arg)
- for idx, i in enumerate(range(order-n_derivs, order+n_derivs+1, 2))),
- "d%d_%s_%d" % (n_derivs, function.name, order))
+ 2**(-k)*sum(
+ (-1)**idx*int(sp.binomial(k, idx)) * function(i, arg)
+ for idx, i in enumerate(range(order-k, order+k+1, 2))),
+ "d%d_%s_%s" % (n_derivs, function.name, order_str))
else:
return IdentityMapper.map_substitution(self, expr)
@@ -304,13 +316,14 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[]):
sympy_conv = SympyToPymbolicMapper()
assignments = [(name, sympy_conv(expr)) for name, expr in assignments]
- # gather information
+ bdr = BesselDerivativeReplacer()
+ assignments = [(name, bdr(expr)) for name, expr in assignments]
+
btog = BesselTopOrderGatherer()
for name, expr in assignments:
btog(expr)
# do the rest of the conversion
- bdr = BesselDerivativeReplacer()
bessel_sub = BesselSubstitutor(BesselGetter(btog.bessel_j_arg_to_top_order))
vcr = VectorComponentRewriter(vector_names)
pwr = PowerRewriter()
@@ -328,6 +341,8 @@ def convert_expr(expr):
expr = m(expr)
return expr
+
+
import loopy as lp
return [
lp.Instruction(id=None,
Please sign in to comment.
Something went wrong with that request. Please try again.