Skip to content

Commit

Permalink
Properly refine or's
Browse files Browse the repository at this point in the history
Summary: When analyzing `e1 or e2`, we should only consider `e1`'s type in the case where the narrowing effects of `e1` hold. More concretely, when analyzing `x is not None or e2`, the result will only evaluate to `x is not None` if the condition holds, so we need to ensure that we're considering that effect when widening the type. The only expression that gets returned unconditionally is the last expression in an or, so we shouldn't apply any effects for those, but still need to build the effect as it gets consumed.

Reviewed By: DinoV

Differential Revision: D29272909

fbshipit-source-id: b11f221
  • Loading branch information
sinancepel authored and facebook-github-bot committed Jun 23, 2021
1 parent 3182eef commit 71e2a28
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
15 changes: 14 additions & 1 deletion Lib/compiler/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -6090,14 +6090,27 @@ def visitBoolOp(
# in a conditional context
effect.undo(self.local_types)
elif isinstance(node.op, ast.Or):
for value in node.values:
for value in node.values[:-1]:
new_effect = self.visit(value) or NO_EFFECT
effect = effect.or_(new_effect)

old_type = self.get_type(value)
# The or expression will only return the `value` we're visiting if it's
# effect holds, so we visit it assuming that the narrowing effects apply.
new_effect.apply(self.local_types)
self.visit(value)
new_effect.undo(self.local_types)

final_type = self.widen(final_type, self.get_type(value))
self.set_type(value, old_type, None)

new_effect.reverse(self.local_types)
# We know nothing about the last node of an or, so we simply widen with its type.
new_effect = self.visit(node.values[-1]) or NO_EFFECT
final_type = self.widen(final_type, self.get_type(node.values[-1]))

effect.undo(self.local_types)
effect = effect.or_(new_effect)
else:
for value in node.values:
self.visit(value)
Expand Down
37 changes: 37 additions & 0 deletions Lib/test/test_compiler/test_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -14419,6 +14419,43 @@ def f(s: Optional[str]) -> bytes:
self.assertEqual(f("A"), b"A")
self.assertEqual(f(None), b"")

def test_refine_or_expression(self):
codestr = """
from typing import Optional
def f(s: Optional[str]) -> str:
return s or "hi"
"""
with self.in_module(codestr) as mod:
f = mod["f"]
self.assertEqual(f("A"), "A")
self.assertEqual(f(None), "hi")

def test_refine_or_expression_with_multiple_optionals(self):
codestr = """
from typing import Optional
def f(s1: Optional[str], s2: Optional[str]) -> str:
return s1 or s2 or "hi"
"""
with self.in_module(codestr) as mod:
f = mod["f"]
self.assertEqual(f("A", None), "A")
self.assertEqual(f(None, "B"), "B")
self.assertEqual(f("A", "B"), "A")
self.assertEqual(f(None, None), "hi")

def test_or_expression_with_multiple_optionals_type_error(self):
codestr = """
from typing import Optional
def f(s1: Optional[str], s2: Optional[str]) -> str:
return s1 or s2
"""
self.type_error(
codestr, r"type mismatch: Optional\[str\] cannot be assigned to str"
)

def test_donotcompile_fn(self):
codestr = """
from __static__ import _donotcompile
Expand Down

0 comments on commit 71e2a28

Please sign in to comment.