Skip to content

Commit

Permalink
Avoid exponential recursion when coercing nested conditional expressi…
Browse files Browse the repository at this point in the history
…ons.

This used to coerce the nesting tree twice at each condition, once for `coerce_to()` and once for `analyse_result_type()`, both calling each other for the entire subtree.

Closes #5197
  • Loading branch information
scoder committed Jan 5, 2023
1 parent 2cff1bf commit cf19b86
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 9 deletions.
28 changes: 19 additions & 9 deletions Cython/Compiler/ExprNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12118,30 +12118,40 @@ def analyse_types(self, env):
return self.analyse_result_type(env)

def analyse_result_type(self, env):
self.type = PyrexTypes.independent_spanning_type(
self.true_val.type, self.false_val.type)
true_val_type = self.true_val.type
false_val_type = self.false_val.type
self.type = PyrexTypes.independent_spanning_type(true_val_type, false_val_type)

if self.type.is_reference:
self.type = PyrexTypes.CFakeReferenceType(self.type.ref_base_type)
if self.type.is_pyobject:
self.result_ctype = py_object_type
elif self.true_val.is_ephemeral() or self.false_val.is_ephemeral():
error(self.pos, "Unsafe C derivative of temporary Python reference used in conditional expression")
if self.true_val.type.is_pyobject or self.false_val.type.is_pyobject:
self.true_val = self.true_val.coerce_to(self.type, env)
self.false_val = self.false_val.coerce_to(self.type, env)

if true_val_type.is_pyobject or false_val_type.is_pyobject:
if true_val_type != self.type:
self.true_val = self.true_val.coerce_to(self.type, env)
if false_val_type != self.type:
self.false_val = self.false_val.coerce_to(self.type, env)

if self.type.is_error:
self.type_error()
return self

def coerce_to_integer(self, env):
self.true_val = self.true_val.coerce_to_integer(env)
self.false_val = self.false_val.coerce_to_integer(env)
if not self.true_val.type.is_int:
self.true_val = self.true_val.coerce_to_integer(env)
if not self.false_val.type.is_int:
self.false_val = self.false_val.coerce_to_integer(env)
self.result_ctype = None
return self.analyse_result_type(env)

def coerce_to(self, dst_type, env):
self.true_val = self.true_val.coerce_to(dst_type, env)
self.false_val = self.false_val.coerce_to(dst_type, env)
if self.true_val.type != dst_type:
self.true_val = self.true_val.coerce_to(dst_type, env)
if self.false_val.type != dst_type:
self.false_val = self.false_val.coerce_to(dst_type, env)
self.result_ctype = None
return self.analyse_result_type(env)

Expand Down
46 changes: 46 additions & 0 deletions tests/run/if_else_expr.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mode: run
# tag: condexpr
# ticket: 5197

cimport cython

Expand Down Expand Up @@ -67,3 +68,48 @@ def test_cfunc_ptrs(double x, bint round_down):
3.0
"""
return (math.floor if round_down else math.ceil)(x)


def performance_gh5197(patternsList):
"""
>>> performance_gh5197([]) # do not actually run anything, just see that things work at all
"""
# Coercing the types in nested conditional expressions used to slow down exponentially.
# See https://github.com/cython/cython/issues/5197
import re
matched=[]
for _ in range(len(patternsList)):
try:
matched.append(patternsList[_].split('|')[-1].split('/')[-1] + 'pattr1' if re.search('^SomeString.*EndIng$')\
else patternsList[_].split('|a')[-1].split('/a')[-1] + 'pattr2' if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|a')[-1].split('/a')[-1] + 'pattr2' if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
# else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
# else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
# else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
# else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
# else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
# else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
# else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] if re.search('^SomeOtherString.?Number.*EndIng$')\
else patternsList[_].split('|b')[-1].split('/b')[-1] + 'pattr2' + patternsList[_].split('/')[-1].split('//')[-1] )
except Exception as e:
matched.append('Error at Indx:%s-%s' %(_, patternsList[_]))

0 comments on commit cf19b86

Please sign in to comment.