Skip to content

Commit

Permalink
properly propagate string comparison optimisation into cascaded compa…
Browse files Browse the repository at this point in the history
…risons
  • Loading branch information
scoder committed Aug 9, 2012
1 parent 4b858e4 commit 4d1a524
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 3 deletions.
12 changes: 9 additions & 3 deletions Cython/Compiler/ExprNodes.py
Expand Up @@ -8696,9 +8696,9 @@ def is_ptr_contains(self):
return (container_type.is_ptr or container_type.is_array) \
and not container_type.is_string

def find_special_bool_compare_function(self, env):
def find_special_bool_compare_function(self, env, operand1):
if self.operator in ('==', '!='):
type1, type2 = self.operand1.type, self.operand2.type
type1, type2 = operand1.type, self.operand2.type
if type1.is_pyobject and type2.is_pyobject:
if type1 is Builtin.unicode_type or type2 is Builtin.unicode_type:
env.use_utility_code(UtilityCode.load_cached("UnicodeEquals", "StringTools.c"))
Expand Down Expand Up @@ -8901,7 +8901,7 @@ def analyse_types(self, env):
self.operand2 = self.operand2.as_none_safe_node("'NoneType' object is not iterable")
common_type = py_object_type
self.is_pycmp = True
elif self.find_special_bool_compare_function(env):
elif self.find_special_bool_compare_function(env, self.operand1):
common_type = None # if coercion needed, the method call above has already done it
self.is_pycmp = False # result is bint
self.is_temp = True # must check for error return
Expand All @@ -8916,6 +8916,7 @@ def analyse_types(self, env):

if self.cascade:
self.operand2 = self.operand2.coerce_to_simple(env)
self.cascade.optimise_comparison(env, self.operand2)
self.cascade.coerce_cascaded_operands_to_temp(env)
if self.is_python_result():
self.type = PyrexTypes.py_object_type
Expand Down Expand Up @@ -9079,6 +9080,11 @@ def analyse_types(self, env):
def has_python_operands(self):
return self.operand2.type.is_pyobject

def optimise_comparison(self, env, operand1):
self.find_special_bool_compare_function(env, operand1)
if self.cascade:
self.cascade.optimise_comparison(env, self.operand2)

def coerce_operands_to_pyobjects(self, env):
self.operand2 = self.operand2.coerce_to_pyobject(env)
if self.operand2.type is dict_type and self.operator in ('in', 'not_in'):
Expand Down
193 changes: 193 additions & 0 deletions tests/run/string_comparison.pyx
@@ -0,0 +1,193 @@

bstring1 = b"abcdefg"
bstring2 = b"1234567"

string1 = "abcdefg"
string2 = "1234567"

ustring1 = u"abcdefg"
ustring2 = u"1234567"

# unicode

def unicode_eq(unicode s1, unicode s2):
"""
>>> unicode_eq(ustring1, ustring1)
True
>>> unicode_eq(ustring1+ustring2, ustring1+ustring2)
True
>>> unicode_eq(ustring1, ustring2)
False
"""
return s1 == s2

def unicode_neq(unicode s1, unicode s2):
"""
>>> unicode_neq(ustring1, ustring1)
False
>>> unicode_neq(ustring1+ustring2, ustring1+ustring2)
False
>>> unicode_neq(ustring1, ustring2)
True
"""
return s1 != s2

def unicode_literal_eq(unicode s):
"""
>>> unicode_literal_eq(ustring1)
True
>>> unicode_literal_eq((ustring1+ustring2)[:len(ustring1)])
True
>>> unicode_literal_eq(ustring2)
False
"""
return s == u"abcdefg"

def unicode_literal_neq(unicode s):
"""
>>> unicode_literal_neq(ustring1)
False
>>> unicode_literal_neq((ustring1+ustring2)[:len(ustring1)])
False
>>> unicode_literal_neq(ustring2)
True
"""
return s != u"abcdefg"

def unicode_cascade(unicode s1, unicode s2):
"""
>>> unicode_cascade(ustring1, ustring1)
True
>>> unicode_cascade(ustring1, (ustring1+ustring2)[:len(ustring1)])
True
>>> unicode_cascade(ustring1, ustring2)
False
"""
return s1 == s2 == u"abcdefg"

''' # NOTE: currently crashes
def unicode_cascade_untyped_end(unicode s1, unicode s2):
"""
>>> unicode_cascade_untyped_end(ustring1, ustring1)
True
>>> unicode_cascade_untyped_end(ustring1, (ustring1+ustring2)[:len(ustring1)])
True
>>> unicode_cascade_untyped_end(ustring1, ustring2)
False
"""
return s1 == s2 == u"abcdefg" == (<object>ustring1) == ustring1
'''

# str

def str_eq(str s1, str s2):
"""
>>> str_eq(string1, string1)
True
>>> str_eq(string1+string2, string1+string2)
True
>>> str_eq(string1, string2)
False
"""
return s1 == s2

def str_neq(str s1, str s2):
"""
>>> str_neq(string1, string1)
False
>>> str_neq(string1+string2, string1+string2)
False
>>> str_neq(string1, string2)
True
"""
return s1 != s2

def str_literal_eq(str s):
"""
>>> str_literal_eq(string1)
True
>>> str_literal_eq((string1+string2)[:len(string1)])
True
>>> str_literal_eq(string2)
False
"""
return s == "abcdefg"

def str_literal_neq(str s):
"""
>>> str_literal_neq(string1)
False
>>> str_literal_neq((string1+string2)[:len(string1)])
False
>>> str_literal_neq(string2)
True
"""
return s != "abcdefg"

def str_cascade(str s1, str s2):
"""
>>> str_cascade(string1, string1)
True
>>> str_cascade(string1, (string1+string2)[:len(string1)])
True
>>> str_cascade(string1, string2)
False
"""
return s1 == s2 == "abcdefg"

# bytes

def bytes_eq(bytes s1, bytes s2):
"""
>>> bytes_eq(bstring1, bstring1)
True
>>> bytes_eq(bstring1+bstring2, bstring1+bstring2)
True
>>> bytes_eq(bstring1, bstring2)
False
"""
return s1 == s2

def bytes_neq(bytes s1, bytes s2):
"""
>>> bytes_neq(bstring1, bstring1)
False
>>> bytes_neq(bstring1+bstring2, bstring1+bstring2)
False
>>> bytes_neq(bstring1, bstring2)
True
"""
return s1 != s2

def bytes_literal_eq(bytes s):
"""
>>> bytes_literal_eq(bstring1)
True
>>> bytes_literal_eq((bstring1+bstring2)[:len(bstring1)])
True
>>> bytes_literal_eq(bstring2)
False
"""
return s == b"abcdefg"

def bytes_literal_neq(bytes s):
"""
>>> bytes_literal_neq(bstring1)
False
>>> bytes_literal_neq((bstring1+bstring2)[:len(bstring1)])
False
>>> bytes_literal_neq(bstring2)
True
"""
return s != b"abcdefg"

def bytes_cascade(bytes s1, bytes s2):
"""
>>> bytes_cascade(bstring1, bstring1)
True
>>> bytes_cascade(bstring1, (bstring1+bstring2)[:len(bstring1)])
True
>>> bytes_cascade(bstring1, bstring2)
False
"""
return s1 == s2 == b"abcdefg"

0 comments on commit 4d1a524

Please sign in to comment.