Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Making Beanstalk work for Python 3.9+ (#1302)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1302

During the OSS release of Bean Machine Neeraj had the foresight to expand testing to include Python 3.9. This revealed (in good time) that the Beanstalk compiler does not yet work with this version of Python. Upon review of the changes from Python 3.8 to 3.9 (https://docs.python.org/3/whatsnew/3.9.html) it was noted that there were changes to the AST, and in particular, the removal of the ast.Index constructor (https://bugs.python.org/issue34822). This diff adds basic versioning support to Beanstalk as well as version-specific patterns and constructors for slices so that the existing rewrite rules can continue to work as intended.

Notes about this can be found at https://fburl.com/beanstalk-py39-notes

Reviewed By: neerajprad

Differential Revision: D33357242

fbshipit-source-id: 563cc348c97db80e21ba93c9a6d4489f575f8d3d
  • Loading branch information
wtaha authored and facebook-github-bot committed Jan 4, 2022
1 parent 4d80c84 commit cc289d7
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 9 deletions.
49 changes: 48 additions & 1 deletion src/beanmachine/ppl/compiler/ast_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""Pattern matching for ASTs"""
import ast
import math
from platform import python_version
from typing import Any, Dict

import torch
Expand All @@ -25,6 +26,21 @@
)
from beanmachine.ppl.compiler.rules import RuleDomain

# To support different Python versions correctly, in particular changes from 3.8 to 3.9,
# some functionality defined in this module needs to be version dependent.

_python_version = [int(i) for i in python_version().split(".")[:2]]
_python_3_9_or_later = _python_version >= [3, 9]

# Assertions about changes across versions that we address in this module

if _python_3_9_or_later:
dummy_value = ast.Constant(1)
assert ast.Index(dummy_value) == dummy_value
else:
dummy_value = ast.Constant(1)
assert ast.Index(dummy_value) != dummy_value


def _get_children(node: Any) -> Dict[str, Any]:
if isinstance(node, ast.AST):
Expand Down Expand Up @@ -309,10 +325,41 @@ def if_statement(
return type_and_attributes(ast.If, {"test": test, "body": body, "orelse": orelse})


def index(value: Pattern = _any) -> Pattern:
# Note: The following pattern definition is valid only for Python
# versions less than 3.9. As a result, it is followed by a
# version-dependent redefinition


def _index(value: Pattern = _any) -> Pattern:
return type_and_attributes(ast.Index, {"value": value})


def index(value: Pattern = _any):
if _python_3_9_or_later:
return match_every(value, negate(slice_pattern()))
else:
return _index(value=value)


# The following definition should not be necessary in 3.9
# since ast.Index should be identity in this version. It is
# nevertheless included for clarity.


def ast_index(value, **other):
if _python_3_9_or_later:
return value
else:
return ast.Index(value=value, **other)


def get_value(slice_field):
if _python_3_9_or_later:
return slice_field
else:
return slice_field.value


def slice_pattern(
lower: Pattern = _any, upper: Pattern = _any, step: Pattern = _any
) -> Pattern:
Expand Down
6 changes: 4 additions & 2 deletions src/beanmachine/ppl/compiler/bm_to_bmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
binary_compare,
binop,
call,
get_value,
index,
keyword,
load,
Expand Down Expand Up @@ -196,7 +197,8 @@ def _handle_comparison(p: Pattern, s: str) -> PatternRule:
_handle_index = PatternRule(
assign(value=subscript(slice=index())),
lambda a: ast.Assign(
a.targets, _make_bmg_call("handle_index", [a.value.value, a.value.slice.value])
a.targets,
_make_bmg_call("handle_index", [a.value.value, get_value(a.value.slice)]),
),
)

Expand Down Expand Up @@ -247,7 +249,7 @@ def _or_none(a):
"handle_subscript_assign",
[
a.targets[0].value,
a.targets[0].slice.value,
get_value(a.targets[0].slice),
_ast_none,
_ast_none,
a.value,
Expand Down
14 changes: 8 additions & 6 deletions src/beanmachine/ppl/compiler/single_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
ast_compare,
ast_dict,
ast_dictComp,
ast_index,
ast_generator,
ast_domain,
ast_for,
Expand All @@ -95,6 +96,7 @@
attribute,
binop,
call,
get_value,
expr,
index,
keyword,
Expand Down Expand Up @@ -1007,10 +1009,10 @@ def _handle_assign_subscript_slice_index_2(self) -> Rule:
#
return self._make_right_assignment_rule(
subscript(value=name(), slice=index(value=_not_identifier)),
lambda original_right: original_right.slice.value,
lambda original_right: get_value(original_right.slice),
lambda original_right, new_name: ast.Subscript(
value=original_right.value,
slice=ast.Index(value=new_name),
slice=ast_index(value=new_name),
ctx=original_right.ctx,
),
"handle_assign_subscript_slice_index_2",
Expand Down Expand Up @@ -2202,10 +2204,10 @@ def _handle_left_value_subscript_slice_index(self) -> Rule:
"""Rewrites like a[b.c] = z → x = b.c; a[x] = z"""
return self._make_left_any_assignment_rule(
subscript(value=name(), slice=index(value=_not_identifier)),
lambda original_left: original_left.slice.value,
lambda original_left: get_value(original_left.slice),
lambda original_left, new_name: ast.Subscript(
value=original_left.value,
slice=ast.Index(
slice=ast_index(
value=new_name,
ctx=ast.Load(),
),
Expand Down Expand Up @@ -2346,7 +2348,7 @@ def _handle_left_value_list_not_starred(self) -> Rule:
targets=[source_term.targets[0].elts[0]],
value=ast.Subscript(
value=source_term.value,
slice=ast.Index(value=ast.Num(n=0)),
slice=ast_index(value=ast.Num(n=0)),
ctx=ast.Load(),
),
),
Expand Down Expand Up @@ -2397,7 +2399,7 @@ def _handle_left_value_list_starred(self) -> Rule:
targets=[source_term.targets[0].elts[-1]],
value=ast.Subscript(
value=source_term.value,
slice=ast.Index(value=ast.Num(n=-1)),
slice=ast_index(value=ast.Num(n=-1)),
ctx=ast.Load(),
),
),
Expand Down

0 comments on commit cc289d7

Please sign in to comment.