diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..cae7050
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,29 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to this project. There are
+just a few small guidelines you need to follow.
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement (CLA). You (or your employer) retain the copyright to your
+contribution; this simply gives us permission to use and redistribute your
+contributions as part of the project. Head over to
+ to see your current agreements on file or
+to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code Reviews
+
+All submissions, including submissions by project members, require review. For
+external contributions, we use GitHub pull requests for this purpose. Consult
+[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
+
+## Community Guidelines
+
+This project follows
+[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index e69de29..d5ac5ea 100644
--- a/README.md
+++ b/README.md
@@ -0,0 +1,47 @@
+# python_graphs
+
+This package is for computing graph representations of Python programs for
+machine learning applications. It includes the following modules:
+
+* `control_flow` For computing control flow graphs statically from Python
+ programs.
+* `data_flow` For computing data flow analyses of Python programs.
+* `program_graph` For computing graphs statically to represent arbitrary
+ Python programs or functions.
+* `cyclomatic_complexity` For computing the cyclomatic complexity of a Python function.
+
+
+## Installation
+
+To install python_graphs with pip, run: `pip install python_graphs`.
+
+To install python_graphs from source, run: `python setup.py develop`.
+
+## Common Tasks
+
+**Generate a control flow graph from a function `fn`:**
+
+```python
+from python_graphs import control_flow
+graph = control_flow.get_control_flow_graph(fn)
+```
+
+**Generate a program graph from a function `fn`:**
+
+```python
+from python_graphs import program_graph
+graph = program_graph.get_program_graph(fn)
+```
+
+**Compute the cyclomatic complexity of a function `fn`:**
+
+```python
+from python_graphs import control_flow
+from python_graphs import cyclomatic_complexity
+graph = control_flow.get_control_flow_graph(fn)
+value = cyclomatic_complexity.cyclomatic_complexity(graph)
+```
+
+---
+
+This is not an officially supported Google product.
diff --git a/python_graphs/analysis/program_graph_analysis.py b/python_graphs/analysis/program_graph_analysis.py
new file mode 100644
index 0000000..c21699a
--- /dev/null
+++ b/python_graphs/analysis/program_graph_analysis.py
@@ -0,0 +1,150 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functions to analyze program graphs.
+
+Computes properties such as the height of a program graph's AST.
+"""
+
+import gast as ast
+import networkx as nx
+
+
+def num_nodes(graph):
+ """Returns the number of nodes in a ProgramGraph."""
+ return len(graph.all_nodes())
+
+
+def num_edges(graph):
+ """Returns the number of edges in a ProgramGraph."""
+ return len(graph.edges)
+
+
+def ast_height(ast_node):
+ """Computes the height of an AST from the given node.
+
+ Args:
+ ast_node: An AST node.
+
+ Returns:
+ The height of the AST starting at ast_node. A leaf node or single-node AST
+ has a height of 1.
+ """
+ max_child_height = 0
+ for child_node in ast.iter_child_nodes(ast_node):
+ max_child_height = max(max_child_height, ast_height(child_node))
+ return 1 + max_child_height
+
+
+def graph_ast_height(graph):
+ """Computes the height of the AST of a ProgramGraph.
+
+ Args:
+ graph: A ProgramGraph.
+
+ Returns:
+ The height of the graph's AST. A single-node AST has a height of 1.
+ """
+ return ast_height(graph.to_ast())
+
+
+def degrees(graph):
+ """Returns a list of node degrees in a ProgramGraph.
+
+ Args:
+ graph: A ProgramGraph.
+
+ Returns:
+ An (unsorted) list of node degrees (in-degree plus out-degree).
+ """
+ return [len(graph.neighbors(node)) for node in graph.all_nodes()]
+
+
+def in_degrees(graph):
+ """Returns a list of node in-degrees in a ProgramGraph.
+
+ Args:
+ graph: A ProgramGraph.
+
+ Returns:
+ An (unsorted) list of node in-degrees.
+ """
+ return [len(graph.incoming_neighbors(node)) for node in graph.all_nodes()]
+
+
+def out_degrees(graph):
+ """Returns a list of node out-degrees in a ProgramGraph.
+
+ Args:
+ graph: A ProgramGraph.
+
+ Returns:
+ An (unsorted) list of node out-degrees.
+ """
+ return [len(graph.outgoing_neighbors(node)) for node in graph.all_nodes()]
+
+
+def _program_graph_to_nx(program_graph, directed=False):
+ """Converts a ProgramGraph to a NetworkX graph.
+
+ Args:
+ program_graph: A ProgramGraph.
+ directed: Whether the graph should be treated as a directed graph.
+
+ Returns:
+ A NetworkX graph that can be analyzed by the networkx module.
+ """
+ # Create a dict-of-lists representation, where {0: [1]} represents a directed
+ # edge from node 0 to node 1.
+ dict_of_lists = {}
+ for node in program_graph.all_nodes():
+ neighbor_ids = [neighbor.id
+ for neighbor in program_graph.outgoing_neighbors(node)]
+ dict_of_lists[node.id] = neighbor_ids
+ return nx.DiGraph(dict_of_lists) if directed else nx.Graph(dict_of_lists)
+
+
+def diameter(graph):
+ """Returns the diameter of a ProgramGraph.
+
+ Note: this is very slow for large graphs.
+
+ Args:
+ graph: A ProgramGraph.
+
+ Returns:
+ The diameter of the graph. A single-node graph has diameter 0. The graph is
+ treated as an undirected graph.
+
+ Raises:
+ networkx.exception.NetworkXError: Raised if the graph is not connected.
+ """
+ nx_graph = _program_graph_to_nx(graph, directed=False)
+ return nx.algorithms.distance_measures.diameter(nx_graph)
+
+
+def max_betweenness(graph):
+ """Returns the maximum node betweenness centrality in a ProgramGraph.
+
+ Note: this is very slow for large graphs.
+
+ Args:
+ graph: A ProgramGraph.
+
+ Returns:
+ The maximum betweenness centrality value among all nodes in the graph. The
+ graph is treated as an undirected graph.
+ """
+ nx_graph = _program_graph_to_nx(graph, directed=False)
+ return max(nx.algorithms.centrality.betweenness_centrality(nx_graph).values())
diff --git a/python_graphs/analysis/program_graph_analysis_test.py b/python_graphs/analysis/program_graph_analysis_test.py
new file mode 100644
index 0000000..d8fb589
--- /dev/null
+++ b/python_graphs/analysis/program_graph_analysis_test.py
@@ -0,0 +1,366 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for program_graph_analysis.py."""
+
+from absl.testing import absltest
+import gast as ast
+import networkx as nx
+
+from python_graphs import program_graph
+from python_graphs.analysis import program_graph_analysis as pga
+
+
+class ProgramGraphAnalysisTest(absltest.TestCase):
+
+ def setUp(self):
+ super(ProgramGraphAnalysisTest, self).setUp()
+ self.singleton = self.create_singleton_graph()
+ self.disconnected = self.create_disconnected_graph()
+ self.cycle_3 = self.create_cycle_3()
+ self.chain_4 = self.create_chain_4()
+ self.wide_tree = self.create_wide_tree()
+
+ def create_singleton_graph(self):
+ """Returns a graph with one node and zero edges."""
+ graph = program_graph.ProgramGraph()
+ node = program_graph.make_node_from_syntax('singleton_node')
+ graph.add_node(node)
+ graph.root_id = node.id
+ return graph
+
+ def create_disconnected_graph(self):
+ """Returns a disconnected graph with two nodes and zero edges."""
+ graph = program_graph.ProgramGraph()
+ a = program_graph.make_node_from_syntax('a')
+ b = program_graph.make_node_from_syntax('b')
+ graph.add_node(a)
+ graph.add_node(b)
+ graph.root_id = a.id
+ return graph
+
+ def create_cycle_3(self):
+ """Returns a 3-cycle graph, A -> B -> C -> A."""
+ graph = program_graph.ProgramGraph()
+ a = program_graph.make_node_from_syntax('A')
+ b = program_graph.make_node_from_ast_value('B')
+ c = program_graph.make_node_from_syntax('C')
+ graph.add_node(a)
+ graph.add_node(b)
+ graph.add_node(c)
+ graph.add_new_edge(a, b)
+ graph.add_new_edge(b, c)
+ graph.add_new_edge(c, a)
+ graph.root_id = a.id
+ return graph
+
+ def create_chain_4(self):
+ """Returns a chain of 4 nodes, A -> B -> C -> D."""
+ graph = program_graph.ProgramGraph()
+ a = program_graph.make_node_from_syntax('A')
+ b = program_graph.make_node_from_ast_value('B')
+ c = program_graph.make_node_from_syntax('C')
+ d = program_graph.make_node_from_ast_value('D')
+ graph.add_node(a)
+ graph.add_node(b)
+ graph.add_node(c)
+ graph.add_node(d)
+ graph.add_new_edge(a, b)
+ graph.add_new_edge(b, c)
+ graph.add_new_edge(c, d)
+ graph.root_id = a.id
+ return graph
+
+ def create_wide_tree(self):
+ """Returns a tree where the root has 4 children that are all leaves."""
+ graph = program_graph.ProgramGraph()
+ root = program_graph.make_node_from_syntax('root')
+ graph.add_node(root)
+ graph.root_id = root.id
+ for i in range(4):
+ leaf = program_graph.make_node_from_ast_value(i)
+ graph.add_node(leaf)
+ graph.add_new_edge(root, leaf)
+ return graph
+
+ def ids_from_cycle_3(self):
+ """Returns a triplet of IDs from the 3-cycle graph in cycle order."""
+ root = self.cycle_3.root
+ id_a = root.id
+ id_b = self.cycle_3.outgoing_neighbors(root)[0].id
+ id_c = self.cycle_3.incoming_neighbors(root)[0].id
+ return id_a, id_b, id_c
+
+ def test_num_nodes_returns_expected(self):
+ self.assertEqual(pga.num_nodes(self.singleton), 1)
+ self.assertEqual(pga.num_nodes(self.disconnected), 2)
+ self.assertEqual(pga.num_nodes(self.cycle_3), 3)
+ self.assertEqual(pga.num_nodes(self.chain_4), 4)
+ self.assertEqual(pga.num_nodes(self.wide_tree), 5)
+
+ def test_num_edges_returns_expected(self):
+ self.assertEqual(pga.num_edges(self.singleton), 0)
+ self.assertEqual(pga.num_edges(self.disconnected), 0)
+ self.assertEqual(pga.num_edges(self.cycle_3), 3)
+ self.assertEqual(pga.num_edges(self.chain_4), 3)
+ self.assertEqual(pga.num_edges(self.wide_tree), 4)
+
+ def test_ast_height_returns_expected_for_constructed_expression_ast(self):
+ # Testing the expression "1".
+ # Height 3: Module -> Expr -> Num.
+ ast_node = ast.Module(
+ body=[ast.Expr(value=ast.Constant(value=1, kind=None))],
+ type_ignores=[])
+ self.assertEqual(pga.ast_height(ast_node), 3)
+
+ # Testing the expression "1 + 1".
+ # Height 4: Module -> Expr -> BinOp -> Num.
+ ast_node = ast.Module(
+ body=[
+ ast.Expr(
+ value=ast.BinOp(
+ left=ast.Constant(value=1, kind=None),
+ op=ast.Add(),
+ right=ast.Constant(value=1, kind=None)))
+ ],
+ type_ignores=[])
+ self.assertEqual(pga.ast_height(ast_node), 4)
+
+ # Testing the expression "a + 1".
+ # Height 5: Module -> Expr -> BinOp -> Name -> Load.
+ ast_node = ast.Module(
+ body=[
+ ast.Expr(
+ value=ast.BinOp(
+ left=ast.Name(
+ id='a',
+ ctx=ast.Load(),
+ annotation=None,
+ type_comment=None),
+ op=ast.Add(),
+ right=ast.Constant(value=1, kind=None)))
+ ],
+ type_ignores=[])
+ self.assertEqual(pga.ast_height(ast_node), 5)
+
+ # Testing the expression "a.b + 1".
+ # Height 6: Module -> Expr -> BinOp -> Attribute -> Name -> Load.
+ ast_node = ast.Module(
+ body=[
+ ast.Expr(
+ value=ast.BinOp(
+ left=ast.Attribute(
+ value=ast.Name(
+ id='a',
+ ctx=ast.Load(),
+ annotation=None,
+ type_comment=None),
+ attr='b',
+ ctx=ast.Load()),
+ op=ast.Add(),
+ right=ast.Constant(value=1, kind=None)))
+ ],
+ type_ignores=[])
+ self.assertEqual(pga.ast_height(ast_node), 6)
+
+ def test_ast_height_returns_expected_for_constructed_function_ast(self):
+ # Testing the function declaration "def foo(n): return".
+ # Height 5: Module -> FunctionDef -> arguments -> Name -> Param.
+ ast_node = ast.Module(
+ body=[
+ ast.FunctionDef(
+ name='foo',
+ args=ast.arguments(
+ args=[
+ ast.Name(
+ id='n',
+ ctx=ast.Param(),
+ annotation=None,
+ type_comment=None)
+ ],
+ posonlyargs=[],
+ vararg=None,
+ kwonlyargs=[],
+ kw_defaults=[],
+ kwarg=None,
+ defaults=[]),
+ body=[ast.Return(value=None)],
+ decorator_list=[],
+ returns=None,
+ type_comment=None)
+ ],
+ type_ignores=[])
+ self.assertEqual(pga.ast_height(ast_node), 5)
+
+ # Testing the function declaration "def foo(n): return n + 1".
+ # Height 6: Module -> FunctionDef -> Return -> BinOp -> Name -> Load.
+ ast_node = ast.Module(
+ body=[
+ ast.FunctionDef(
+ name='foo',
+ args=ast.arguments(
+ args=[
+ ast.Name(
+ id='n',
+ ctx=ast.Param(),
+ annotation=None,
+ type_comment=None)
+ ],
+ posonlyargs=[],
+ vararg=None,
+ kwonlyargs=[],
+ kw_defaults=[],
+ kwarg=None,
+ defaults=[]),
+ body=[
+ ast.Return(
+ value=ast.BinOp(
+ left=ast.Name(
+ id='n',
+ ctx=ast.Load(),
+ annotation=None,
+ type_comment=None),
+ op=ast.Add(),
+ right=ast.Constant(value=1, kind=None)))
+ ],
+ decorator_list=[],
+ returns=None,
+ type_comment=None)
+ ],
+ type_ignores=[],
+ )
+ self.assertEqual(pga.ast_height(ast_node), 6)
+
+ def test_ast_height_returns_expected_for_parsed_ast(self):
+ # Height 3: Module -> Expr -> Num.
+ self.assertEqual(pga.ast_height(ast.parse('1')), 3)
+
+ # Height 6: Module -> Expr -> BinOp -> Attribute -> Name -> Load.
+ self.assertEqual(pga.ast_height(ast.parse('a.b + 1')), 6)
+
+ # Height 6: Module -> FunctionDef -> Return -> BinOp -> Name -> Load.
+ self.assertEqual(pga.ast_height(ast.parse('def foo(n): return n + 1')), 6)
+
+ # Height 9: Module -> FunctionDef -> If -> Return -> BinOp -> Call
+ # -> BinOp -> Name -> Load.
+ # Adding whitespace before "def foo" causes an IndentationError in parse().
+ ast_node = ast.parse("""def foo(n):
+ if n <= 0:
+ return 0
+ else:
+ return 1 + foo(n - 1)
+ """)
+ self.assertEqual(pga.ast_height(ast_node), 9)
+
+ def test_graph_ast_height_returns_expected(self):
+ # Height 6: Module -> FunctionDef -> Return -> BinOp -> Name -> Load.
+ def foo1(n):
+ return n + 1
+
+ graph = program_graph.get_program_graph(foo1)
+ self.assertEqual(pga.graph_ast_height(graph), 6)
+
+ # Height 9: Module -> FunctionDef -> If -> Return -> BinOp -> Call
+ # -> BinOp -> Name -> Load.
+ def foo2(n):
+ if n <= 0:
+ return 0
+ else:
+ return 1 + foo2(n - 1)
+
+ graph = program_graph.get_program_graph(foo2)
+ self.assertEqual(pga.graph_ast_height(graph), 9)
+
+ def test_degrees_returns_expected(self):
+ self.assertCountEqual(pga.degrees(self.singleton), [0])
+ self.assertCountEqual(pga.degrees(self.disconnected), [0, 0])
+ self.assertCountEqual(pga.degrees(self.cycle_3), [2, 2, 2])
+ self.assertCountEqual(pga.degrees(self.chain_4), [1, 2, 2, 1])
+ self.assertCountEqual(pga.degrees(self.wide_tree), [4, 1, 1, 1, 1])
+
+ def test_in_degrees_returns_expected(self):
+ self.assertCountEqual(pga.in_degrees(self.singleton), [0])
+ self.assertCountEqual(pga.in_degrees(self.disconnected), [0, 0])
+ self.assertCountEqual(pga.in_degrees(self.cycle_3), [1, 1, 1])
+ self.assertCountEqual(pga.in_degrees(self.chain_4), [0, 1, 1, 1])
+ self.assertCountEqual(pga.in_degrees(self.wide_tree), [0, 1, 1, 1, 1])
+
+ def test_out_degrees_returns_expected(self):
+ self.assertCountEqual(pga.out_degrees(self.singleton), [0])
+ self.assertCountEqual(pga.out_degrees(self.disconnected), [0, 0])
+ self.assertCountEqual(pga.out_degrees(self.cycle_3), [1, 1, 1])
+ self.assertCountEqual(pga.out_degrees(self.chain_4), [1, 1, 1, 0])
+ self.assertCountEqual(pga.out_degrees(self.wide_tree), [4, 0, 0, 0, 0])
+
+ def test_diameter_returns_expected_if_connected(self):
+ self.assertEqual(pga.diameter(self.singleton), 0)
+ self.assertEqual(pga.diameter(self.cycle_3), 1)
+ self.assertEqual(pga.diameter(self.chain_4), 3)
+ self.assertEqual(pga.diameter(self.wide_tree), 2)
+
+ def test_diameter_throws_exception_if_disconnected(self):
+ with self.assertRaises(nx.exception.NetworkXError):
+ pga.diameter(self.disconnected)
+
+ def test_program_graph_to_nx_undirected_has_correct_edges(self):
+ id_a, id_b, id_c = self.ids_from_cycle_3()
+ nx_graph = pga._program_graph_to_nx(self.cycle_3, directed=False)
+ self.assertCountEqual(nx_graph.nodes(), [id_a, id_b, id_c])
+ expected_adj = {
+ id_a: {
+ id_b: {},
+ id_c: {}
+ },
+ id_b: {
+ id_a: {},
+ id_c: {}
+ },
+ id_c: {
+ id_a: {},
+ id_b: {}
+ },
+ }
+ self.assertEqual(nx_graph.adj, expected_adj)
+
+ def test_program_graph_to_nx_directed_has_correct_edges(self):
+ id_a, id_b, id_c = self.ids_from_cycle_3()
+ nx_digraph = pga._program_graph_to_nx(self.cycle_3, directed=True)
+ self.assertCountEqual(nx_digraph.nodes(), [id_a, id_b, id_c])
+ expected_adj = {
+ id_a: {
+ id_b: {}
+ },
+ id_b: {
+ id_c: {}
+ },
+ id_c: {
+ id_a: {}
+ },
+ }
+ self.assertEqual(nx_digraph.adj, expected_adj)
+
+ def test_max_betweenness_returns_expected(self):
+ self.assertAlmostEqual(pga.max_betweenness(self.singleton), 0)
+ self.assertAlmostEqual(pga.max_betweenness(self.disconnected), 0)
+ self.assertAlmostEqual(pga.max_betweenness(self.cycle_3), 0)
+
+ # Middle nodes are in 2 shortest paths, normalizer = (4-1)*(4-2)/2 = 3
+ self.assertAlmostEqual(pga.max_betweenness(self.chain_4), 2 / 3)
+
+ # Root is in 6 shortest paths, normalizer = (5-1)*(5-2)/2 = 6
+ self.assertAlmostEqual(pga.max_betweenness(self.wide_tree), 6 / 6)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/python_graphs/analysis/run_program_graph_analysis.py b/python_graphs/analysis/run_program_graph_analysis.py
new file mode 100644
index 0000000..6f214dc
--- /dev/null
+++ b/python_graphs/analysis/run_program_graph_analysis.py
@@ -0,0 +1,277 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Runs the program graph analysis for datasets of programs.
+
+Analyzes each dataset of programs, producing plots for properties such as the
+AST height.
+"""
+
+import inspect
+import math
+
+from absl import app
+from absl import logging
+import matplotlib.pyplot as plt
+import numpy as np
+from python_graphs import control_flow_test_components as cftc
+from python_graphs import program_graph
+from python_graphs import program_graph_test_components as pgtc
+from python_graphs.analysis import program_graph_analysis
+import six
+from six.moves import range
+
+
+
+TARGET_NUM_BINS = 15 # A reasonable constant number of histogram bins.
+MAX_NUM_BINS = 20 # The maximum number of bins reasonable on a histogram.
+
+
+def test_components():
+ """Generates functions from two sets of test components.
+
+ Yields:
+ All functions in the program graph and control flow test components files.
+ """
+ for unused_name, fn in inspect.getmembers(pgtc, predicate=inspect.isfunction):
+ yield fn
+
+ for unused_name, fn in inspect.getmembers(cftc, predicate=inspect.isfunction):
+ yield fn
+
+
+
+
+def get_graph_generator(function_generator):
+ """Generates ProgramGraph objects from functions.
+
+ Args:
+ function_generator: A function generator.
+
+ Yields:
+ ProgramGraph objects for the functions.
+ """
+ for index, function in enumerate(function_generator):
+ try:
+ graph = program_graph.get_program_graph(function)
+ yield graph
+ except SyntaxError:
+ # get_program_graph can fail for programs with different string encodings.
+ logging.info('SyntaxError in get_program_graph for function index %d. '
+ 'First 100 chars of function source:\n%s',
+ index, function[:100])
+ except RuntimeError:
+ # get_program_graph can fail for programs that are only return statements.
+ logging.info('RuntimeError in get_program_graph for function index %d. '
+ 'First 100 chars of function source:\n%s',
+ index, function[:100])
+
+
+def get_percentiles(data, percentiles, integer_valued=True):
+ """Returns a dict of percentiles of the data.
+
+ Args:
+ data: An unsorted list of datapoints.
+ percentiles: A list of ints or floats in the range [0, 100] representing the
+ percentiles to compute.
+ integer_valued: Whether or not the values are all integers. If so,
+ interpolate to the nearest datapoint (instead of computing a fractional
+ value between the two nearest datapoints).
+
+ Returns:
+ A dict mapping each element of percentiles to the computed result.
+ """
+ # Ensure integer datapoints for cleaner binning if necessary.
+ interpolation = 'nearest' if integer_valued else 'linear'
+ results = np.percentile(data, percentiles, interpolation=interpolation)
+ return {percentiles[i]: results[i] for i in range(len(percentiles))}
+
+
+def analyze_graph(graph, identifier):
+ """Performs various analyses on a graph.
+
+ Args:
+ graph: A ProgramGraph to analyze.
+ identifier: A unique identifier for this graph (for later aggregation).
+
+ Returns:
+ A pair (identifier, result_dict), where result_dict contains the results of
+ analyses run on the graph.
+ """
+ num_nodes = program_graph_analysis.num_nodes(graph)
+ num_edges = program_graph_analysis.num_edges(graph)
+ ast_height = program_graph_analysis.graph_ast_height(graph)
+
+ degree_percentiles = [25, 50, 90]
+ degrees = get_percentiles(program_graph_analysis.degrees(graph),
+ degree_percentiles)
+ in_degrees = get_percentiles(program_graph_analysis.in_degrees(graph),
+ degree_percentiles)
+ out_degrees = get_percentiles(program_graph_analysis.out_degrees(graph),
+ degree_percentiles)
+
+ diameter = program_graph_analysis.diameter(graph)
+ max_betweenness = program_graph_analysis.max_betweenness(graph)
+
+ # TODO(kshi): Turn this into a protobuf and fix everywhere else in this file.
+ # Eventually this should be parallelized (currently takes ~6 hours to run).
+ result_dict = {
+ 'num_nodes': num_nodes,
+ 'num_edges': num_edges,
+ 'ast_height': ast_height,
+ 'degrees': degrees,
+ 'in_degrees': in_degrees,
+ 'out_degrees': out_degrees,
+ 'diameter': diameter,
+ 'max_betweenness': max_betweenness,
+ }
+
+ return (identifier, result_dict)
+
+
+def create_bins(values, integer_valued=True, log_x=False):
+ """Creates appropriate histogram bins.
+
+ Args:
+ values: The values to be plotted in a histogram.
+ integer_valued: Whether the values are all integers.
+ log_x: Whether to plot the x-axis using a log scale.
+
+ Returns:
+ An object (sequence, integer, or 'auto') that can be used as the 'bins'
+ keyword argument to plt.hist(). If there are no values to plot, or all of
+ the values are identical, then 'auto' is returned.
+ """
+ if not values:
+ return 'auto' # No data to plot; let pyplot handle this case.
+ min_value = min(values)
+ max_value = max(values)
+ if min_value == max_value:
+ return 'auto' # All values are identical; let pyplot handle this case.
+
+ if log_x:
+ return np.logspace(np.log10(min_value), np.log10(max_value + 1),
+ num=(TARGET_NUM_BINS + 1))
+ elif integer_valued:
+ # The minimum integer width resulting in at most MAX_NUM_BINS bins.
+ bin_width = math.ceil((max_value - min_value + 1) / MAX_NUM_BINS)
+ # Place bin boundaries between integers.
+ return np.arange(min_value - 0.5, max_value + bin_width + 0.5, bin_width)
+ else:
+ return TARGET_NUM_BINS
+
+
+def create_histogram(values, title, percentiles=False, integer_valued=True,
+ log_x=False, log_y=False):
+ """Returns a histogram of integer values computed from a dataset.
+
+ Args:
+ values: A list of integer values to plot, or if percentiles is True, then
+ each value is a dict mapping some chosen percentiles in [0, 100] to the
+ corresponding data value.
+ title: The figure title.
+ percentiles: Whether to plot multiple histograms for percentiles.
+ integer_valued: Whether the values are all integers, which affects how the
+ data is partitioned into bins.
+ log_x: Whether to plot the x-axis using a log scale.
+ log_y: Whether to plot the y-axis using a log scale.
+
+ Returns:
+ A histogram figure.
+ """
+ figure = plt.figure()
+
+ if percentiles:
+ for percentile in sorted(values[0].keys()):
+ new_values = [percentile_dict[percentile]
+ for percentile_dict in values]
+ bins = create_bins(new_values, integer_valued=integer_valued, log_x=log_x)
+ plt.hist(new_values, bins=bins, alpha=0.5, label='{}%'.format(percentile))
+ plt.legend(loc='upper right')
+ else:
+ bins = create_bins(values, integer_valued=integer_valued, log_x=log_x)
+ plt.hist(values, bins=bins)
+
+ if log_x:
+ plt.xscale('log', nonposx='clip')
+ if log_y:
+ plt.yscale('log', nonposy='clip')
+ plt.title(title)
+ return figure
+
+
+def save_histogram(all_results, result_key, dataset_name, path_root,
+ percentiles=False, integer_valued=True,
+ log_x=False, log_y=False):
+ """Saves a histogram image to disk.
+
+ Args:
+ all_results: A list of dicts containing all analysis results for each graph.
+ result_key: The key in the result dicts specifying what data to plot.
+ dataset_name: The name of the dataset, which appears in the figure title and
+ the image filename.
+ path_root: The directory to save the histogram image in.
+ percentiles: Whether the data has multiple percentiles to plot.
+ integer_valued: Whether the values are all integers, which affects how the
+ data is partitioned into bins.
+ log_x: Whether to plot the x-axis using a log scale.
+ log_y: Whether to plot the y-axis using a log scale.
+ """
+ values = [result[result_key] for result in all_results]
+ title = '{} distribution for {}'.format(result_key, dataset_name)
+ figure = create_histogram(values, title, percentiles=percentiles,
+ integer_valued=integer_valued,
+ log_x=log_x, log_y=log_y)
+ path = '{}/{}-{}.png'.format(path_root, result_key, dataset_name)
+ figure.savefig(path)
+ logging.info('Saved image %s', path)
+
+
+def main(argv):
+ del argv # Unused.
+
+ dataset_pairs = [
+ (test_components(), 'test_components'),
+ ]
+ path_root = '/tmp/program_graph_analysis'
+
+ for function_generator, dataset_name in dataset_pairs:
+ logging.info('Analyzing graphs in dataset %s...', dataset_name)
+ graph_generator = get_graph_generator(function_generator)
+ all_results = []
+ for index, graph in enumerate(graph_generator):
+ identifier = '{}-{}'.format(dataset_name, index)
+ # Discard the identifiers (not needed until this is parallelized).
+ all_results.append(analyze_graph(graph, identifier)[1])
+
+ if all_results:
+ logging.info('Creating plots for dataset %s...', dataset_name)
+ for result_key in ['num_nodes', 'num_edges']:
+ save_histogram(all_results, result_key, dataset_name, path_root,
+ percentiles=False, integer_valued=True, log_x=True)
+ for result_key in ['ast_height', 'diameter']:
+ save_histogram(all_results, result_key, dataset_name, path_root,
+ percentiles=False, integer_valued=True)
+ for result_key in ['max_betweenness']:
+ save_histogram(all_results, result_key, dataset_name, path_root,
+ percentiles=False, integer_valued=False)
+ for result_key in ['degrees', 'in_degrees', 'out_degrees']:
+ save_histogram(all_results, result_key, dataset_name, path_root,
+ percentiles=True, integer_valued=True)
+ else:
+ logging.warn('Dataset %s is empty.', dataset_name)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/python_graphs/control_flow.py b/python_graphs/control_flow.py
new file mode 100644
index 0000000..a5d8f8c
--- /dev/null
+++ b/python_graphs/control_flow.py
@@ -0,0 +1,1291 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Computes the control flow graph for a Python program from its AST.
+"""
+
+import uuid
+
+from absl import logging # pylint: disable=unused-import
+import gast as ast
+from python_graphs import instruction as instruction_module
+from python_graphs import program_utils
+import six
+
+
+def get_control_flow_graph(program):
+ """Get a ControlFlowGraph for the provided AST node.
+
+ Args:
+ program: Either an AST node, source string, or a function.
+ Returns:
+ A ControlFlowGraph.
+ """
+ control_flow_visitor = ControlFlowVisitor()
+ node = program_utils.program_to_ast(program)
+ control_flow_visitor.run(node)
+ return control_flow_visitor.graph
+
+
+class ControlFlowGraph(object):
+ """A control flow graph for a Python program.
+
+ Attributes:
+ blocks: All blocks contained in the control flow graph.
+ nodes: All control flow nodes in the control flow graph.
+ start_block: The entry point to the program.
+ """
+
+ def __init__(self):
+ self.blocks = []
+ self.nodes = []
+
+ self.start_block = self.new_block(prunable=False)
+ self.start_block.label = ''
+
+ def add_node(self, control_flow_node):
+ self.nodes.append(control_flow_node)
+
+ def new_block(self, node=None, label=None, prunable=True):
+ block = BasicBlock(node=node, label=label, prunable=prunable)
+ block.graph = self
+ self.blocks.append(block)
+ return block
+
+ def get_control_flow_nodes(self):
+ return self.nodes
+
+ def get_enter_blocks(self):
+ """Returns entry blocks for all functions."""
+ return six.moves.filter(
+ lambda block: block.label.startswith(''.format(name=name),
+ self.blocks)
+
+ def get_block_by_function_name(self, name):
+ return next(self.get_blocks_by_function_name(name))
+
+ def get_control_flow_nodes_by_source(self, source):
+ module = ast.parse(source, mode='exec') # TODO(dbieber): Factor out 4 lines
+ node = module.body[0]
+ if isinstance(node, ast.Expr):
+ node = node.value
+
+ return six.moves.filter(
+ lambda cfn: cfn.instruction.contains_subprogram(node),
+ self.get_control_flow_nodes())
+
+ def get_control_flow_node_by_source(self, source):
+ return next(self.get_control_flow_nodes_by_source(source))
+
+ def get_control_flow_nodes_by_source_and_identifier(self, source, name):
+ for control_flow_node in self.get_control_flow_nodes_by_source(source):
+ for node in ast.walk(control_flow_node.instruction.node):
+ if isinstance(node, ast.Name) and node.id == name:
+ for i2 in self.get_control_flow_nodes_by_ast_node(node):
+ yield i2
+
+ def get_control_flow_node_by_source_and_identifier(self, source, name):
+ return next(
+ self.get_control_flow_nodes_by_source_and_identifier(source, name))
+
+ def get_blocks_by_source(self, source):
+ """Yields blocks that contain instructions matching the query source."""
+ module = ast.parse(source, mode='exec')
+ node = module.body[0]
+ if isinstance(node, ast.Expr):
+ node = node.value
+
+ for block in self.blocks:
+ for control_flow_node in block.control_flow_nodes:
+ if control_flow_node.instruction.contains_subprogram(node):
+ yield block
+ break
+
+ def get_block_by_source(self, source):
+ return next(self.get_blocks_by_source(source))
+
+ def get_blocks_by_source_and_ast_node_type(self, source, node_type):
+ """Blocks with an Instruction matching node_type and containing source."""
+ module = ast.parse(source, mode='exec')
+ node = module.body[0]
+ if isinstance(node, ast.Expr):
+ node = node.value
+
+ for block in self.blocks:
+ for instruction in block.instructions:
+ if (isinstance(instruction.node, node_type)
+ and instruction.contains_subprogram(node)):
+ yield block
+ break
+
+ def get_block_by_source_and_ast_node_type(self, source, node_type):
+ """A block with an Instruction matching node_type and containing source."""
+ return next(self.get_blocks_by_source_and_ast_node_type(source, node_type))
+
+ def get_block_by_ast_node_and_label(self, node, label):
+ """Gets the block corresponding to `node` having label `label`."""
+ for block in self.blocks:
+ if block.node is node and block.label == label:
+ return block
+
+ def get_blocks_by_ast_node_type_and_label(self, node_type, label):
+ """Gets the blocks with node type `node_type` having label `label`."""
+ for block in self.blocks:
+ if isinstance(block.node, node_type) and block.label == label:
+ yield block
+
+ def get_block_by_ast_node_type_and_label(self, node_type, label):
+ """Gets a block with node type `node_type` having label `label`."""
+ return next(self.get_blocks_by_ast_node_type_and_label(node_type, label))
+
+ def prune(self):
+ """Prunes all prunable blocks from the graph."""
+ progress = True
+ while progress:
+ progress = False
+ for block in iter(self.blocks):
+ if block.can_prune():
+ to_remove = block.prune()
+ self.blocks.remove(to_remove)
+ progress = True
+
+ def compact(self):
+ """Prunes unused blocks and merges blocks when possible."""
+ self.prune()
+ for block in iter(self.blocks):
+ while block.can_merge():
+ to_remove = block.merge()
+ self.blocks.remove(to_remove)
+ for block in self.blocks:
+ block.compact()
+
+
+class Frame(object):
+ """A Frame indicates how statements affect control flow in parts of a program.
+
+ Frames are introduced when the program enters a new loop, function definition,
+ or try/except/finally block.
+
+ A Frame indicates how an exit such as a continue, break, exception, or return
+ affects control flow. For example, a continue statement inside of a loop sends
+ control back to the loop's condition. In nested loops, a continue statement
+ sends control back to the condition of the innermost loop containing the
+ continue statement.
+
+ Attributes:
+ kind: One of LOOP, FUNCTION, TRY_EXCEPT, or TRY_FINALLY.
+ blocks: A dictionary with the blocks relevant to the frame.
+ """
+
+ # Kinds:
+ LOOP = 'loop'
+ FUNCTION = 'function'
+ TRY_EXCEPT = 'try-except'
+ TRY_FINALLY = 'try-finally'
+
+ def __init__(self, kind, **blocks):
+ self.kind = kind
+ self.blocks = blocks
+
+
+class BasicBlock(object):
+ """A basic block in a control flow graph.
+
+ All instructions (generally, AST nodes) in a basic block are either executed
+ or none are (with the exception of blocks interrupted by exceptions). These
+ instructions are executed in a straight-line manner.
+
+ Attributes:
+ graph: The control flow graph which this basic block is a part of.
+
+ next: Indicates which basic blocks may be executed after this basic block.
+ prev: Indicates which basic blocks may lead to the execution of this basic
+ block in a Python program.
+ control_flow_nodes: A list of the ControlFlowNodes contained in this basic
+ block. Each ControlFlowNode corresponds to a single Instruction.
+ control_flow_node_indexes: Maps from id(control_flow_node) to the
+ ControlFlowNode's index in self.control_flow_nodes. Only available once
+ the block is compacted.
+
+ branches: A map from booleans to the basic block reachable by making the
+ branch decision indicated by that boolean.
+ exits_from_middle: These basic blocks may be exited to at any point during
+ the execution of this basic block.
+ exits_from_end: These basic blocks may only be exited to at the end of
+ the execution of this basic block.
+ node: The AST node this basic block is associated with.
+ prunable: Whether this basic block may be pruned from the control flow graph
+ if empty. Set to False for special blocks, such as enter and exit blocks.
+ label: A label for the basic block.
+ identities: A list of (node, label) pairs that refer to this basic block.
+ This starts as (self.node, self.label), but old identities are preserved
+ during merging and pruning. Allows lookup of blocks by node and label,
+ e.g. for finding the after block of a particular if statement.
+ labels: Labels, used for example by data flow analyses. Maps from label name
+ to value.
+ """
+
+ def __init__(self, node=None, label=None, prunable=True):
+ self.graph = None
+ self.next = set()
+ self.prev = set()
+ self.control_flow_nodes = []
+ self.control_flow_node_indexes = None
+
+ self.branches = {}
+ self.exits_from_middle = set()
+ self.exits_from_end = set()
+ self.node = node
+ self.prunable = prunable
+ self.label = label
+ self.identities = [(node, label)]
+ self.labels = {}
+
+ def has_label(self, label):
+ """Returns whether this BasicBlock has the specified label."""
+ return label in self.labels
+
+ def set_label(self, label, value):
+ """Sets the value of a label on the BasicBlock."""
+ self.labels[label] = value
+
+ def get_label(self, label):
+ """Gets the value of a label on the BasicBlock."""
+ return self.labels[label]
+
+ def is_empty(self):
+ """Whether this block is empty."""
+ return not self.control_flow_nodes
+
+ def exits_to(self, block):
+ """Whether this block exits to `block`."""
+ return block in self.next
+
+ def raises_to(self, block):
+ """Whether this block exits to `block` in the case of an exception."""
+ return block in self.next and block in self.exits_from_middle
+
+ def add_exit(self, block, interrupting=False, branch=None):
+ """Adds an exit from this block to `block`."""
+ self.next.add(block)
+ block.prev.add(self)
+
+ if branch is not None:
+ self.branches[branch] = block
+
+ if interrupting:
+ self.exits_from_middle.add(block)
+ else:
+ self.exits_from_end.add(block)
+
+ def remove_exit(self, block):
+ """Removes the exit from this block to `block`."""
+ self.next.remove(block)
+ block.prev.remove(self)
+
+ if block in self.exits_from_middle:
+ self.exits_from_middle.remove(block)
+ if block in self.exits_from_end:
+ self.exits_from_end.remove(block)
+ for branch_decision, branch_exit in self.branches.copy().items():
+ if branch_exit is block:
+ del self.branches[branch_decision]
+
+ def can_prune(self):
+ return self.is_empty() and self.prunable
+
+ def prune(self):
+ """Prunes the empty block from its control flow graph.
+
+ A block is prunable if it has no control flow nodes and has not been marked
+ as unprunable (e.g. because it's the exit block, or a return block, etc).
+
+ Returns:
+ The block removed by the prune operation. That is, self.
+ """
+ assert self.can_prune()
+ prevs = self.prev.copy()
+ nexts = self.next.copy()
+ for prev_block in prevs:
+ exits_from_middle = prev_block.exits_from_middle.copy()
+ exits_from_end = prev_block.exits_from_end.copy()
+ branches = prev_block.branches.copy()
+ for next_block in nexts:
+ if self in exits_from_middle:
+ prev_block.add_exit(next_block, interrupting=True)
+ if self in exits_from_end:
+ prev_block.add_exit(next_block, interrupting=False)
+
+ for branch_decision, branch_exit in branches.items():
+ if branch_exit is self:
+ prev_block.branches[branch_decision] = next_block
+
+ for prev_block in prevs:
+ prev_block.remove_exit(self)
+ for next_block in nexts:
+ self.remove_exit(next_block)
+ next_block.identities = next_block.identities + self.identities
+ return self
+
+ def can_merge(self):
+ if len(self.exits_from_end) != 1:
+ return False
+ next_block = next(iter(self.exits_from_end))
+ if not next_block.prunable:
+ return False
+ if self.exits_from_middle != next_block.exits_from_middle:
+ return False
+ if len(next_block.prev) == 1:
+ return True
+
+ def merge(self):
+ """Merge this block with its one successor.
+
+ Returns:
+ The successor block removed by the merge operation.
+ """
+ assert self.can_merge()
+ next_block = next(iter(self.exits_from_end))
+
+ exits_from_middle = next_block.exits_from_middle.copy()
+ exits_from_end = next_block.exits_from_end.copy()
+
+ for branch_decision, branch_exit in next_block.branches.items():
+ self.branches[branch_decision] = branch_exit
+
+ self.remove_exit(next_block)
+ for block in next_block.next.copy():
+ next_block.remove_exit(block)
+ if block in exits_from_middle:
+ self.add_exit(block, interrupting=True)
+ if block in exits_from_end:
+ self.add_exit(block, interrupting=False)
+ for control_flow_node in next_block.control_flow_nodes:
+ control_flow_node.block = self
+ self.control_flow_nodes.append(control_flow_node)
+ self.prunable = self.prunable and next_block.prunable
+ self.label = self.label or next_block.label
+ self.identities = self.identities + next_block.identities
+ # Note: self.exits_from_middle is unchanged.
+ return next_block
+
+ def add_instruction(self, instruction):
+ assert isinstance(instruction, instruction_module.Instruction)
+ control_flow_node = ControlFlowNode(graph=self.graph,
+ block=self,
+ instruction=instruction)
+ self.graph.add_node(control_flow_node)
+ self.control_flow_nodes.append(control_flow_node)
+
+ def compact(self):
+ self.control_flow_node_indexes = {}
+ for index, control_flow_node in enumerate(self.control_flow_nodes):
+ self.control_flow_node_indexes[control_flow_node.uuid] = index
+
+ def index_of(self, control_flow_node):
+ """Returns the index of the Instruction in this BasicBlock."""
+ return self.control_flow_node_indexes[control_flow_node.uuid]
+
+
+class ControlFlowNode(object):
+ """A node in a control flow graph.
+
+ Corresponds to a single Instruction contained in a single BasicBlock.
+
+ Attributes:
+ graph: The ControlFlowGraph which this node is a part of.
+ block: The BasicBlock in which this node's instruction resides.
+ instruction: The Instruction corresponding to this node.
+ labels: Metadata attached to this node, for example for use by data flow
+ analyses.
+ uuid: A unique identifier for the ControlFlowNode.
+ """
+
+ def __init__(self, graph, block, instruction):
+ self.graph = graph
+ self.block = block
+ self.instruction = instruction
+ self.labels = {}
+ self.uuid = uuid.uuid4()
+
+ @property
+ def next(self):
+ """Returns the set of possible next instructions."""
+ if self.block is None:
+ return None
+ index_in_block = self.block.index_of(self)
+ if len(self.block.control_flow_nodes) > index_in_block + 1:
+ return {self.block.control_flow_nodes[index_in_block + 1]}
+ control_flow_nodes = set()
+ for next_block in self.block.next:
+ if next_block.control_flow_nodes:
+ control_flow_nodes.add(next_block.control_flow_nodes[0])
+ else:
+ # If next_block is empty, it isn't the case that some downstream block
+ # is nonempty. This is guaranteed by the pruning phase of control flow
+ # graph construction.
+ assert not next_block.next
+ return control_flow_nodes
+
+ @property
+ def prev(self):
+ """Returns the set of possible previous instructions."""
+ if self.block is None:
+ return None
+ index_in_block = self.block.index_of(self)
+ if index_in_block - 1 >= 0:
+ return {self.block.control_flow_nodes[index_in_block - 1]}
+ control_flow_nodes = set()
+ for prev_block in self.block.prev:
+ if prev_block.control_flow_nodes:
+ control_flow_nodes.add(prev_block.control_flow_nodes[-1])
+ else:
+ # If prev_block is empty, it isn't the case that some upstream block
+ # is nonempty. This is guaranteed by the pruning phase of control flow
+ # graph construction.
+ assert not prev_block.prev
+ return control_flow_nodes
+
+ @property
+ def branches(self):
+ """Returns the branch options available at the end of this node.
+
+ Returns:
+ A dictionary with possible keys True and False, and values given by the
+ node that is reached by taking the True/False branch. An empty dictionary
+ indicates that there are no branches to take, and so self.next gives the
+ next node (in a set of size 1). A value of None indicates that taking that
+ branch leads to the exit, since there are no exit ControlFlowNodes in a
+ ControlFlowGraph.
+ """
+ if self.block is None:
+ return {} # We're not in a block. No branch decision.
+ index_in_block = self.block.index_of(self)
+ if len(self.block.control_flow_nodes) > index_in_block + 1:
+ return {} # We're not yet at the end of the block. No branch decision.
+
+ branches = {} # We're at the end of the block.
+ for key, next_block in self.block.branches.items():
+ if next_block.control_flow_nodes:
+ branches[key] = next_block.control_flow_nodes[0]
+ else:
+ # If next_block is empty, it isn't the case that some downstream block
+ # is nonempty. This is guaranteed by the pruning phase of control flow
+ # graph construction.
+ assert not next_block.next
+ branches[key] = None # Indicates exit; there is no node to return.
+ return branches
+
+ def has_label(self, label):
+ """Returns whether this Instruction has the specified label."""
+ return label in self.labels
+
+ def set_label(self, label, value):
+ """Sets the value of a label on the Instruction."""
+ self.labels[label] = value
+
+ def get_label(self, label):
+ """Gets the value of a label on the Instruction."""
+ return self.labels[label]
+
+
+# pylint: disable=invalid-name,g-doc-return-or-yield,g-doc-args
+class ControlFlowVisitor(object):
+ """A visitor for determining the control flow of a Python program from an AST.
+
+ The main function of interest here is `visit`, which causes the visitor to
+ construct the control flow graph for the node passed to visit.
+
+ Basic control flow:
+ The state of the Visitor consists of a sequence of frames, and a current
+ basic block. When an AST node is visited by `visit`, it is added to the
+ current basic block. When a node can indicate a possible change in control,
+ new basic blocks are created and exits between the basic blocks are added
+ as appropriate.
+
+ For example, an If statement introduces two possibilities for control flow.
+ Consider the program:
+
+ if a > b:
+ c = 1
+ else:
+ c = 2
+ return c
+
+ There are four basic blocks in this program: let's call them `compare`,
+ `c = 1`, `c = 2`, and `return`. The exits between the blocks are:
+ `compare` -> `c = 1`, `compare` -> `c = 2`, `c = 1` -> `return`, and
+ `c = 2` -> `return`.
+
+ Frames:
+ There are four kinds of frames: function frames, loop frames, try-except, and
+ try-finally frames. All AST nodes in a function definition are in that
+ function's function frame. All AST nodes in the body of a loop are in that
+ loop's loop frame. And all AST nodes in the try and except blocks of a
+ try/except/finally are in that try's try-finally frame.
+
+ A function frame contains information about where control should flow to in
+ the case of a return statement or an uncaught exception.
+
+ A loop frame contains information about where control should pass to in the
+ case of a continue or break statement.
+
+ A try-except frame contains information about where control should flow to in
+ the case of an exception.
+
+ A try-finally frame contains information about where control should flow to
+ in the case of an exit (such as a finally block that must run before a return,
+ continue, or break statement can be executed).
+
+ Attributes:
+ graph: The control flow graph being generated by the visitor.
+ frames: The current frames. Each frame in this list contains all frames that
+ come after it in the list.
+ """
+
+ def __init__(self):
+ self.graph = ControlFlowGraph()
+ self.frames = []
+
+ def run(self, node):
+ start_block = self.graph.start_block
+ end_block = self.visit(node, start_block)
+ exit_block = self.new_block(node=node, label='', prunable=False)
+ end_block.add_exit(exit_block)
+ self.graph.compact()
+
+ def visit(self, node, current_block):
+ """Visit node, either an AST node or a list.
+
+ Args:
+ node: The AST node being visited. Not necessarily an instance of ast.AST;
+ node may also be a list, primitive, or Instruction.
+ current_block: The basic block whose execution necessarily precedes the
+ execution of `node`.
+ Returns:
+ The final basic block for the node.
+ """
+ assert isinstance(node, ast.AST)
+
+ if isinstance(node, instruction_module.INSTRUCTION_AST_NODES):
+ self.add_new_instruction(current_block, node)
+
+ method_name = 'visit_' + node.__class__.__name__
+ method = getattr(self, method_name, None)
+ if method is not None:
+ current_block = method(node, current_block)
+ return current_block
+
+ def visit_list(self, items, current_block):
+ """Visit each of the items in a list from the AST."""
+ for item in items:
+ current_block = self.visit(item, current_block)
+ return current_block
+
+ def add_new_instruction(self, block, node, accesses=None, source=None):
+ assert isinstance(node, ast.AST)
+ instruction = instruction_module.Instruction(
+ node, accesses=accesses, source=source)
+ self.add_instruction(block, instruction)
+
+ def add_instruction(self, block, instruction):
+ assert isinstance(instruction, instruction_module.Instruction)
+ block.add_instruction(instruction)
+
+ # Any instruction may raise an exception.
+ if not block.exits_from_middle:
+ self.raise_through_frames(block, interrupting=True)
+
+ def raise_through_frames(self, block, interrupting=True):
+ """Adds exits for the control flow of a raised exception.
+
+ `interrupting` means the exit can occur at any point (exit_from_middle).
+ `not interrupting` means the exit can only occur at the end of the block.
+
+ The reason to raise_through_frames with interrupting=False is for an
+ exception that already has been partially raised, but has passed control to
+ a finally block, and is now being raised at the end of that finally block.
+
+ Args:
+ block: The block where the exception's control flow begins.
+ interrupting: Whether the exception can be raised from any point in block.
+ If False, the exception is only raised from the end of block.
+ """
+ frames = self.get_current_exception_handling_frames()
+
+ if frames is None:
+ return
+
+ for frame in frames:
+ if frame.kind == Frame.TRY_FINALLY:
+ # Exit to finally and have finally exit to whatever's next...
+ final_block = frame.blocks['final_block']
+ block.add_exit(final_block, interrupting=interrupting)
+ block = frame.blocks['final_block_end']
+ interrupting = False
+ elif frame.kind == Frame.TRY_EXCEPT:
+ handler_block = frame.blocks['handler_block']
+ block.add_exit(handler_block, interrupting=interrupting)
+ interrupting = False # return...
+ elif frame.kind == Frame.FUNCTION:
+ raise_block = frame.blocks['raise_block']
+ block.add_exit(raise_block, interrupting=interrupting)
+
+ def new_block(self, node=None, label=None, prunable=True):
+ """Create a new block."""
+ return self.graph.new_block(node=node, label=label, prunable=prunable)
+
+ def enter_loop_frame(self, continue_block, break_block):
+ # The loop body is the interior of the frame.
+ # The continue block (loop condition) and break block (loop's after block)
+ # are the exits from the frame.
+ self.frames.append(Frame(Frame.LOOP,
+ continue_block=continue_block,
+ break_block=break_block))
+
+ def enter_function_frame(self, return_block, raise_block):
+ # The function body is the interior of the frame.
+ # The return block and raise block are the exits from the frame.
+ self.frames.append(Frame(Frame.FUNCTION,
+ return_block=return_block,
+ raise_block=raise_block))
+
+ def enter_try_except_frame(self, handler_block):
+ # The try block is the interior of the frame.
+ # handler_block is where the frame exits to on an exception.
+ self.frames.append(Frame(Frame.TRY_EXCEPT,
+ handler_block=handler_block))
+
+ def enter_try_finally_frame(self, final_block, final_block_end):
+ # The try block and handler blocks are the interior of the frame.
+ # The finally block is the exit from the frame.
+ self.frames.append(Frame(Frame.TRY_FINALLY,
+ final_block=final_block,
+ final_block_end=final_block_end))
+
+ def exit_frame(self):
+ """Exits the innermost current frame.
+
+ Note: Each enter_* function must be matched to exactly one exit_frame call
+ in reverse order.
+
+ Returns:
+ The frame being exited.
+ """
+ return self.frames.pop()
+
+ def get_current_loop_frame(self):
+ """Gets the current loop frame and contained current try-finally frames.
+
+ In order to exit the current loop frame, we must first enter the finally
+ blocks of all current contained try-finally frames.
+
+ Returns:
+ A list of frames, all of which are try-finally frames except for the last,
+ which is the current loop frame. Each of the returned try-finally
+ frames is contained within the current loop frame.
+ """
+ frames = []
+ for frame in reversed(self.frames):
+ if frame.kind == Frame.TRY_FINALLY:
+ frames.append(frame)
+ if frame.kind == Frame.LOOP:
+ frames.append(frame)
+ return frames
+ # There are no loop frames.
+ return None
+
+ def get_current_function_frame(self):
+ """Gets the current function frame and contained current try-finally frames.
+
+ In order to exit the current function frame, we must first enter the finally
+ blocks of all current contained try-finally frames.
+
+ Returns:
+ A list of frames, all of which are try-finally frames except for the last,
+ which is the current function frame. Each of the returned try-finally
+ frames is contained within the current function frame.
+ """
+ frames = []
+ for frame in reversed(self.frames):
+ if frame.kind == Frame.TRY_FINALLY:
+ frames.append(frame)
+ if frame.kind == Frame.FUNCTION:
+ frames.append(frame)
+ return frames
+ # There are no function frames.
+ return None
+
+ def get_current_exception_handling_frames(self):
+ """Get all exception handling frames containing the current block.
+
+ Returns:
+ A list of frames, all of which are exception handling frames containing
+ the current block. Any instruction contained in a try-except frame may
+ exit to the frame's exception handling block, with the caveat that an
+ instruction cannot exit through a TRY_FINALLY frame without passing first
+ through the frame's finally block. (The instruction will exit to the
+ finally block, and the finally block in turn will exit to the exception
+ handler.) A function frame's raise block serves to catch exceptions as
+ well.
+ """
+ frames = []
+ # Traverse frames from innermost to outermost until a frame that fully
+ # catches the exception is found.
+ for frame in reversed(self.frames):
+ if frame.kind == Frame.TRY_FINALLY:
+ frames.append(frame)
+ if frame.kind == Frame.TRY_EXCEPT:
+ # A try-except frame catches any exception, even if the frame's except
+ # statements do not match the exception. In this case, the final except
+ # will reraise the exception to higher frames.
+ frames.append(frame)
+ return frames
+ if frame.kind == Frame.FUNCTION:
+ # A function frame's raise_block catches any exception that reaches it.
+ frames.append(frame)
+ return frames
+ # There is no frame to fully catch the exception.
+ return None
+
+ def visit_Module(self, node, current_block):
+ return self.visit_list(node.body, current_block)
+
+ def visit_ClassDef(self, node, current_block):
+ """Visit a ClassDef node of the AST.
+
+ Blocks:
+ current_block: The block in which the class is defined.
+ """
+ # TODO(dbieber): Make sure all statements are handled, such as base classes.
+ # http://greentreesnakes.readthedocs.io/en/latest/nodes.html#ClassDef
+ # The body is exceuted before the decorators.
+ current_block = self.visit_list(node.body, current_block)
+ for decorator in node.decorator_list:
+ self.add_new_instruction(current_block, decorator)
+ assert isinstance(node.name, six.string_types)
+ self.add_new_instruction(
+ current_block,
+ node,
+ accesses=instruction_module.create_writes(node.name, node),
+ source=instruction_module.CLASS)
+ return current_block
+
+ def visit_FunctionDef(self, node, current_block):
+ """Visit a FunctionDef node of the AST.
+
+ Blocks:
+ current_block: The block in which the function is defined.
+ """
+ # First defaults are computed, then decorators are run, then the functiondef
+ # is assigned to the function name.
+ current_block = self.handle_argument_defaults(node.args, current_block)
+ for decorator in node.decorator_list:
+ self.add_new_instruction(current_block, decorator)
+ assert isinstance(node.name, six.string_types)
+ self.add_new_instruction(
+ current_block,
+ node,
+ accesses=instruction_module.create_writes(node.name, node),
+ source=instruction_module.FUNCTION)
+ self.handle_function_definition(node, node.name, node.args, node.body)
+ return current_block
+
+ def visit_Lambda(self, node, current_block):
+ """Visit a Lambda node of the AST.
+
+ Blocks:
+ current_block: The block in which the lambda is defined.
+ """
+ current_block = self.handle_argument_defaults(node.args, current_block)
+ self.handle_function_definition(node, 'lambda', node.args, node.body)
+ return current_block
+
+ def handle_function_definition(self, node, name, args, body):
+ """A helper fn for Lambda and FunctionDef.
+
+ Note that this function doesn't require a block as input, since it doesn't
+ modify the blocks where the function definition resides.
+
+ Blocks:
+ entry_block: The block where control flow starts when the function is
+ called.
+ return_block: The block the function returns to.
+ raise_block: The block the function raises uncaught exceptions to.
+ fn_block: The first used block of the FunctionDef.
+
+ Args:
+ node: The AST node of the function definition, either a FunctionDef or
+ Lambda node.
+ name: The function's name, a string.
+ args: The function's args, an ast.arguments node.
+ body: The function's body, a list of AST nodes.
+ """
+ return_block = self.new_block(node=node, label='', prunable=False)
+ raise_block = self.new_block(node=node, label='', prunable=False)
+ self.enter_function_frame(return_block, raise_block)
+
+ entry_block = self.new_block(node=node, label=''.format(name),
+ prunable=False)
+ fn_block = self.new_block(node=node, label='fn_block')
+ entry_block.add_exit(fn_block)
+ fn_block = self.handle_argument_writes(args, fn_block)
+ fn_block = self.visit_list(body, fn_block)
+ fn_block.add_exit(return_block)
+ self.exit_frame()
+
+ def handle_argument_defaults(self, node, current_block):
+ """Add Instructions for all of a FunctionDef's default values.
+
+ Note that these instructions are in the block containing the function, not
+ in the function definition itself.
+ """
+ for default in node.defaults:
+ self.add_new_instruction(current_block, default)
+ for default in node.kw_defaults:
+ self.add_new_instruction(current_block, default)
+ return current_block
+
+ def handle_argument_writes(self, node, current_block):
+ """Add Instructions for all of a FunctionDef's arguments.
+
+ These instructions are part of a function's body.
+ """
+ accesses = []
+ if node.args:
+ for arg in node.args:
+ accesses.extend(instruction_module.create_writes(arg, node))
+ if node.vararg:
+ accesses.extend(instruction_module.create_writes(node.vararg, node))
+ if node.kwonlyargs:
+ for arg in node.kwonlyargs:
+ accesses.extend(instruction_module.create_writes(arg, node))
+ if node.kwarg:
+ accesses.extend(instruction_module.create_writes(node.kwarg, node))
+
+ if accesses:
+ self.add_new_instruction(
+ current_block,
+ node,
+ accesses=accesses,
+ source=instruction_module.ARGS)
+ return current_block
+
+ def visit_If(self, node, current_block):
+ """Visit an If node of the AST.
+
+ Blocks:
+ current_block: This is where the if statement resides. The if statement's
+ test is added here.
+ after_block: The block to which control is passed after the if statement
+ is completed.
+ true_block: The true branch of the if statements.
+ false_block: The false branch of the if statements.
+ """
+ self.add_new_instruction(current_block, node.test)
+ after_block = self.new_block(node=node, label='after_block')
+ true_block = self.new_block(node=node, label='true_block')
+ current_block.add_exit(true_block, branch=True)
+ true_block = self.visit_list(node.body, true_block)
+ true_block.add_exit(after_block)
+ if node.orelse:
+ false_block = self.new_block(node=node, label='false_block')
+ current_block.add_exit(false_block, branch=False)
+ false_block = self.visit_list(node.orelse, false_block)
+ false_block.add_exit(after_block)
+ else:
+ current_block.add_exit(after_block, branch=False)
+ return after_block
+
+ def visit_While(self, node, current_block):
+ """Visit a While node of the AST.
+
+ Blocks:
+ current_block: This is where the while statement resides.
+ """
+ test_instruction = instruction_module.Instruction(node.test)
+ return self.handle_Loop(node, test_instruction, current_block)
+
+ def visit_For(self, node, current_block):
+ """Visit a For node of the AST.
+
+ Blocks:
+ current_block: This is where the for statement resides.
+ """
+ self.add_new_instruction(current_block, node.iter)
+ # node.target is a Name, Tuple, or List node.
+ # We wrap it in an Instruction so it knows where its write is coming from.
+ target = instruction_module.Instruction(
+ node.target,
+ accesses=instruction_module.create_writes(node.target, node),
+ source=instruction_module.ITERATOR)
+ return self.handle_Loop(node, target, current_block)
+
+ def handle_Loop(self, node, loop_instruction, current_block):
+ """A helper fn for For and While.
+
+ Args:
+ node: The AST node representing the loop.
+ loop_instruction: The Instruction in the loop header, such as a test or an
+ assignment from an iterator.
+ current_block: The BasicBlock containing the loop.
+
+ Blocks:
+ current_block: This is where the loop resides.
+ test_block: Contains the part of the loop header that is repeated. For a
+ While, this is the loop condition. For a For, this is assignment to the
+ target variable.
+ test_block_end: The last block in the test (often the same as test_block.)
+ body_block: The body of the loop.
+ else_block: Executed if the loop terminates naturally.
+ after_block: Follows the completion of the loop.
+ """
+ # We do not add an instruction for the ast.For or ast.While node.
+ test_block = self.new_block(node=node, label='test_block')
+ current_block.add_exit(test_block)
+ self.add_instruction(test_block, loop_instruction)
+ body_block = self.new_block(node=node, label='body_block')
+ after_block = self.new_block(node=node, label='after_block')
+
+ test_block.add_exit(body_block, branch=True)
+ # In the loop, continue goes to test_block and break goes to after_block.
+ self.enter_loop_frame(test_block, after_block)
+ body_block = self.visit_list(node.body, body_block)
+ body_block.add_exit(test_block)
+ self.exit_frame()
+
+ # If a loop exits via its test (rather than via a break) and it has
+ # an orelse, then it enters the orelse.
+ if node.orelse:
+ else_block = self.new_block(node=node, label='else_block')
+ test_block.add_exit(else_block, branch=False)
+ else_block = self.visit_list(node.orelse, else_block)
+ else_block.add_exit(after_block)
+ else:
+ test_block.add_exit(after_block, branch=False)
+
+ return after_block
+
+ def visit_Try(self, node, current_block):
+ """Visit a Try node of the AST.
+
+ Blocks:
+ current_block: This is where the try statement resides.
+ after_block: The block to which control flows after the conclusion of the
+ full try statement (including e.g. the else and finally sections, if
+ present).
+ handler_blocks: A list of blocks corresponding to the except statements.
+ bare_handler_block: The handler block corresponding to a bare except
+ statement. One of handler_blocks or None.
+ handler_body_blocks: A list of blocks corresponding to the bodies of the
+ except sections.
+ final_block: The block corresponding to the finally section.
+ final_block_end: The last block corresponding to the finally section.
+ try_block: The block corresponding to the try section.
+ try_block_end: The last block corresponding to the try section.
+ else_block: The block corresponding to the else section.
+ """
+ # We do not add an instruction for the ast.Try node.
+ after_block = self.new_block()
+ handler_blocks = [self.new_block() for _ in node.handlers]
+ handler_body_blocks = [self.new_block() for _ in node.handlers]
+
+ # If there is a bare except clause, determine its handler block.
+ # Only the last except is permitted to be a bare except.
+ if node.handlers and node.handlers[-1].type is None:
+ bare_handler_block = handler_blocks[-1]
+ else:
+ bare_handler_block = None
+
+ if node.finalbody:
+ final_block = self.new_block(node=node, label='final_block')
+ final_block_end = self.visit_list(node.finalbody, final_block)
+ final_block_end.add_exit(after_block)
+ self.enter_try_finally_frame(final_block, final_block_end)
+ else:
+ final_block = after_block
+
+ if node.handlers:
+ self.enter_try_except_frame(handler_blocks[0])
+
+ # The exits from the try_block may happen at any point since any instruction
+ # can throw an exception.
+ try_block = self.new_block(node=node, label='try_block')
+ current_block.add_exit(try_block)
+ try_block_end = self.visit_list(node.body, try_block)
+ if node.orelse: # The try body can exit to the else block.
+ else_block = self.new_block(node=node, label='else_block')
+ try_block_end.add_exit(else_block)
+ else: # If there is no else block, the try body can exit to final/after.
+ try_block_end.add_exit(final_block)
+
+ if node.handlers:
+ self.exit_frame() # Exit the try-except frame.
+
+ previous_handler_block_end = None
+ for handler, handler_block, handler_body_block in zip(node.handlers,
+ handler_blocks,
+ handler_body_blocks):
+ previous_handler_block_end = self.handle_ExceptHandler(
+ handler, handler_block, handler_body_block, final_block,
+ previous_handler_block_end=previous_handler_block_end)
+
+ if bare_handler_block is None and previous_handler_block_end is not None:
+ # If no exceptions match, then raise up through the frames.
+ # (A bare-except will always match.)
+ self.raise_through_frames(previous_handler_block_end, interrupting=False)
+
+ if node.orelse:
+ else_block = self.visit_list(node.orelse, else_block)
+ else_block.add_exit(final_block) # orelse exits to final/after
+
+ if node.finalbody:
+ self.exit_frame() # Exit the try-finally frame.
+
+ return after_block
+
+ def handle_ExceptHandler(self, handler, handler_block, handler_body_block,
+ final_block, previous_handler_block_end=None):
+ """Create the blocks appropriate for an exception handler.
+
+ Args:
+ handler: The AST ExceptHandler node.
+ handler_block: The block corresponding the ExceptHandler header.
+ handler_body_block: The block corresponding to the ExceptHandler body.
+ final_block: Where the handler body should exit to when it executes
+ successfully.
+ previous_handler_block_end: The last block corresponding to the previous
+ ExceptHandler header, if there is one, or None otherwise. The previous
+ handler's header should exit to this handler's header if the exception
+ doesn't match the previous handler's header.
+
+ Returns:
+ The last (usually the only) block in the handler's header.
+
+ Note that rather than having a visit_ExceptHandler function, we instead
+ use the following logic. This is because except statements don't follow
+ the visitor pattern exactly. Specifically, a handler may exit to either
+ its body or to the next handler, but under the visitor pattern the
+ handler would not know the block belonging to the next handler.
+ """
+ if handler.type is not None:
+ self.add_new_instruction(handler_block, handler.type)
+ # An ExceptHandler header can only have a single Instruction, so there is
+ # only one handler_block BasicBlock.
+ handler_block.add_exit(handler_body_block)
+
+ if previous_handler_block_end is not None:
+ previous_handler_block_end.add_exit(handler_block)
+ previous_handler_block_end = handler_block
+
+ if handler.name is not None:
+ # handler.name is a Name, Tuple, or List AST node.
+ self.add_new_instruction(
+ handler_body_block,
+ handler.name,
+ accesses=instruction_module.create_writes(handler.name, handler),
+ source=instruction_module.EXCEPTION)
+ handler_body_block = self.visit_list(handler.body, handler_body_block)
+ handler_body_block.add_exit(final_block) # handler exits to final/after
+ return previous_handler_block_end
+
+ def visit_Return(self, node, current_block):
+ """Visit a Return node of the AST.
+
+ Blocks:
+ current_block: This is where the return statement resides.
+ return_block: The containing function's return block. All successful exits
+ from the function lead here.
+
+ Raises:
+ RuntimeError: If a return AST node is visited while not in a function
+ frame.
+ """
+ # The Return statement is an Instruction. Don't visit the node's children.
+ frames = self.get_current_function_frame()
+ if frames is None:
+ raise RuntimeError('return occurs outside of a function frame.')
+ try_finally_frames = frames[:-1]
+ function_frame = frames[-1]
+
+ return_block = function_frame.blocks['return_block']
+ return self.handle_ExitStatement(node,
+ return_block,
+ try_finally_frames,
+ current_block)
+
+ def visit_Yield(self, node, current_block):
+ """Visit a Yield node of the AST.
+
+ The current implementation of yields allows control to flow directly through
+ a yield statement. TODO(dbieber): Introduce a node in between
+ yielding and resuming execution.
+ TODO(dbieber): Yield nodes aren't even visited since they are contained in
+ Expr nodes. Determine if Yield can occur outside of an Expr. Check for
+ Yield when visiting Expr.
+ """
+ logging.warn('yield visited: %s', ast.dump(node))
+ # The Yield statement is an Instruction. Don't visit children.
+ return current_block
+
+ def visit_Continue(self, node, current_block):
+ """Visit a Continue node of the AST.
+
+ Blocks:
+ current_block: This is where the continue statement resides.
+ continue_block: The block of the containing loop's header. For a For,
+ this is the target variable assignment. For a While, this is the loop
+ condition.
+
+ Raises:
+ RuntimeError: If a continue AST node is visited while not in a loop frame.
+ """
+ frames = self.get_current_loop_frame()
+ if frames is None:
+ raise RuntimeError('continue occurs outside of a loop frame.')
+
+ try_finally_frames = frames[:-1]
+ loop_frame = frames[-1]
+
+ continue_block = loop_frame.blocks['continue_block']
+ return self.handle_ExitStatement(node,
+ continue_block,
+ try_finally_frames,
+ current_block)
+
+ def visit_Break(self, node, current_block):
+ """Visit a Break node of the AST.
+
+ Blocks:
+ current_block: This is where the break statement resides.
+ break_block: The block that the containing loop exits to.
+
+ Raises:
+ RuntimeError: If a break AST node is visited while not in a loop frame.
+ """
+ frames = self.get_current_loop_frame()
+ if frames is None:
+ raise RuntimeError('break occurs outside of a loop frame.')
+
+ try_finally_frames = frames[:-1]
+ loop_frame = frames[-1]
+
+ break_block = loop_frame.blocks['break_block']
+ return self.handle_ExitStatement(node,
+ break_block,
+ try_finally_frames,
+ current_block)
+
+ def visit_Raise(self, node, current_block):
+ """Visit a Raise node of the AST.
+
+ Blocks:
+ current_block: This is where the raise statement resides.
+ after_block: An unreachable block for code that follows the raise
+ statement.
+ """
+ del current_block
+ # The Raise statement is an Instruction. Don't visit children.
+
+ # Note there is no exit to the after_block. It is unreachable.
+ after_block = self.new_block(node=node, label='after_block')
+ return after_block
+
+ def handle_ExitStatement(self, node, next_block, try_finally_frames,
+ current_block):
+ """A helper fn for Return, Continue, and Break.
+
+ An exit statement is a statement such as return, continue, break, or raise.
+ Such a statement causes control to leave through a frame's exit. Any
+ instructions immediately following an exit statement will be unreachable.
+
+ Args:
+ node: The AST node of the exit statement.
+ next_block: The block the exit statement exits to.
+ try_finally_frames: A possibly empty list of try-finally frames whose
+ finally blocks must be executed before control can pass to next_block.
+ current_block: The block the exit statement resides in.
+
+ Blocks:
+ current_block: This is where the exit statement resides.
+ next_block: The block the exit statement exits to (after first passing
+ through all the finally blocks.)
+ final_block: The start of a finally section that control must pass
+ through on the way to next_block.
+ final_block_end: The end of a finally section that control must pass
+ through on the way to next_block.
+ after_block: An unreachable block for code that follows the raise
+ statement.
+ """
+ for try_finally_frame in try_finally_frames:
+ final_block = try_finally_frame.blocks['final_block']
+ current_block.add_exit(final_block)
+ current_block = try_finally_frame.blocks['final_block_end']
+
+ current_block.add_exit(next_block)
+
+ # Note there is no exit to the after_block. It is unreachable.
+ after_block = self.new_block(node=node, label='after_block')
+ return after_block
diff --git a/python_graphs/control_flow_graphviz.py b/python_graphs/control_flow_graphviz.py
new file mode 100644
index 0000000..fff84bb
--- /dev/null
+++ b/python_graphs/control_flow_graphviz.py
@@ -0,0 +1,113 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Graphviz render for control flow graphs."""
+
+from absl import logging # pylint: disable=unused-import
+import astunparse
+import gast as ast
+import pygraphviz
+
+LEFT_ALIGN = '\l' # pylint: disable=anomalous-backslash-in-string
+
+
+def render(graph, include_src=None, path='/tmp/graph.png'):
+ g = to_graphviz(graph, include_src=include_src)
+ g.draw(path, prog='dot')
+
+
+def trim(line, max_length=30):
+ if len(line) <= max_length:
+ return line
+ return line[:max_length - 3] + '...'
+
+
+def unparse(node):
+ source = astunparse.unparse(node)
+ trimmed_source = '\n'.join(trim(line) for line in source.split('\n'))
+ return (
+ trimmed_source.strip()
+ .rstrip(' \n')
+ .lstrip(' \n')
+ .replace('\n', LEFT_ALIGN)
+ )
+
+
+def write_as_str(write):
+ if isinstance(write, ast.AST):
+ return unparse(write)
+ else:
+ return write
+
+
+def get_label_for_instruction(instruction):
+ if instruction.source is not None:
+ line = ', '.join(write for write in instruction.get_write_names())
+ line += ' <- ' + instruction.source
+ return line
+ else:
+ return unparse(instruction.node)
+
+
+def get_label(block):
+ """Gets the source code for a control flow basic block."""
+ lines = []
+ for control_flow_node in block.control_flow_nodes:
+ instruction = control_flow_node.instruction
+ line = get_label_for_instruction(instruction)
+ if line.strip():
+ lines.append(line)
+
+ return LEFT_ALIGN.join(lines) + LEFT_ALIGN
+
+
+def to_graphviz(graph, include_src=None):
+ """To graphviz."""
+ g = pygraphviz.AGraph(strict=False, directed=True)
+ for block in graph.blocks:
+ node_attrs = {}
+ label = get_label(block)
+ # We only show the , , , , block labels.
+ if block.label is not None and block.label.startswith('<'):
+ node_attrs['style'] = 'bold'
+ if not label.rstrip(LEFT_ALIGN):
+ label = block.label + LEFT_ALIGN
+ else:
+ label = block.label + LEFT_ALIGN + label
+ node_attrs['label'] = label
+ node_attrs['fontname'] = 'Courier New'
+ node_attrs['fontsize'] = 10.0
+
+ node_id = id(block)
+ g.add_node(node_id, **node_attrs)
+ for next_node in block.next:
+ next_node_id = id(next_node)
+ if next_node in block.exits_from_middle:
+ edge_attrs = {}
+ edge_attrs['style'] = 'dashed'
+ g.add_edge(node_id, next_node_id, **edge_attrs)
+ if next_node in block.exits_from_end:
+ edge_attrs = {}
+ edge_attrs['style'] = 'solid'
+ g.add_edge(node_id, next_node_id, **edge_attrs)
+
+ if include_src is not None:
+ node_id = id(include_src)
+ node_attrs['label'] = include_src.replace('\n', LEFT_ALIGN)
+ node_attrs['fontname'] = 'Courier New'
+ node_attrs['fontsize'] = 10.0
+ node_attrs['shape'] = 'box'
+ g.add_node(node_id, **node_attrs)
+
+ return g
diff --git a/python_graphs/control_flow_graphviz_test.py b/python_graphs/control_flow_graphviz_test.py
new file mode 100644
index 0000000..6741c19
--- /dev/null
+++ b/python_graphs/control_flow_graphviz_test.py
@@ -0,0 +1,41 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for control_flow_graphviz.py."""
+
+import inspect
+
+from absl.testing import absltest
+from python_graphs import control_flow
+from python_graphs import control_flow_graphviz
+from python_graphs import control_flow_test_components as tc
+
+
+class ControlFlowGraphvizTest(absltest.TestCase):
+
+ def test_to_graphviz_for_all_test_components(self):
+ for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction):
+ graph = control_flow.get_control_flow_graph(fn)
+ control_flow_graphviz.to_graphviz(graph)
+
+ def test_get_label_multi_op_expression(self):
+ graph = control_flow.get_control_flow_graph(tc.multi_op_expression)
+ block = graph.get_block_by_source('1 + 2 * 3')
+ self.assertEqual(
+ control_flow_graphviz.get_label(block).strip(),
+ 'return (1 + (2 * 3))\\l')
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/python_graphs/control_flow_test.py b/python_graphs/control_flow_test.py
new file mode 100644
index 0000000..5156d01
--- /dev/null
+++ b/python_graphs/control_flow_test.py
@@ -0,0 +1,308 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for control_flow.py."""
+
+import inspect
+
+from absl import logging # pylint: disable=unused-import
+from absl.testing import absltest
+import gast as ast
+from python_graphs import control_flow
+from python_graphs import control_flow_test_components as tc
+from python_graphs import instruction as instruction_module
+from python_graphs import program_utils
+import six
+
+
+class ControlFlowTest(absltest.TestCase):
+
+ def get_block(self, graph, selector):
+ if isinstance(selector, control_flow.BasicBlock):
+ return selector
+ elif isinstance(selector, six.string_types):
+ return graph.get_block_by_source(selector)
+
+ def assertSameBlock(self, graph, selector1, selector2):
+ block1 = self.get_block(graph, selector1)
+ block2 = self.get_block(graph, selector2)
+ self.assertEqual(block1, block2)
+
+ def assertExitsTo(self, graph, selector1, selector2):
+ block1 = self.get_block(graph, selector1)
+ block2 = self.get_block(graph, selector2)
+ self.assertTrue(block1.exits_to(block2))
+
+ def assertNotExitsTo(self, graph, selector1, selector2):
+ block1 = self.get_block(graph, selector1)
+ block2 = self.get_block(graph, selector2)
+ self.assertFalse(block1.exits_to(block2))
+
+ def assertRaisesTo(self, graph, selector1, selector2):
+ block1 = self.get_block(graph, selector1)
+ block2 = self.get_block(graph, selector2)
+ self.assertTrue(block1.raises_to(block2))
+
+ def assertNotRaisesTo(self, graph, selector1, selector2):
+ block1 = self.get_block(graph, selector1)
+ block2 = self.get_block(graph, selector2)
+ self.assertFalse(block1.raises_to(block2))
+
+ def test_control_flow_straight_line_code(self):
+ graph = control_flow.get_control_flow_graph(tc.straight_line_code)
+ self.assertSameBlock(graph, 'x = 1', 'y = x + 2')
+ self.assertSameBlock(graph, 'x = 1', 'z = y * 3')
+ self.assertSameBlock(graph, 'x = 1', 'return z')
+
+ def test_control_flow_simple_if_statement(self):
+ graph = control_flow.get_control_flow_graph(tc.simple_if_statement)
+ x1_block = 'x = 1'
+ y2_block = 'y = 2'
+ xy_block = 'x > y'
+ y3_block = 'y = 3'
+ return_block = 'return y'
+ self.assertSameBlock(graph, x1_block, y2_block)
+ self.assertSameBlock(graph, x1_block, xy_block)
+ self.assertExitsTo(graph, xy_block, y3_block)
+ self.assertExitsTo(graph, xy_block, return_block)
+ self.assertExitsTo(graph, y3_block, return_block)
+ self.assertNotExitsTo(graph, y3_block, x1_block)
+ self.assertNotExitsTo(graph, return_block, x1_block)
+ self.assertNotExitsTo(graph, return_block, y3_block)
+
+ def test_control_flow_simple_for_loop(self):
+ graph = control_flow.get_control_flow_graph(tc.simple_for_loop)
+ x1_block = 'x = 1'
+ iter_block = 'range'
+ target_block = 'y'
+ body_block = 'y + 3'
+ return_block = 'return z'
+ self.assertSameBlock(graph, x1_block, iter_block)
+ self.assertExitsTo(graph, iter_block, target_block)
+ self.assertExitsTo(graph, target_block, body_block)
+ self.assertNotExitsTo(graph, body_block, return_block)
+ self.assertExitsTo(graph, target_block, return_block)
+
+ def test_control_flow_simple_while_loop(self):
+ graph = control_flow.get_control_flow_graph(tc.simple_while_loop)
+ x1_block = 'x = 1'
+ test_block = 'x < 2'
+ body_block = 'x += 3'
+ return_block = 'return x'
+
+ self.assertExitsTo(graph, x1_block, test_block)
+ self.assertExitsTo(graph, test_block, body_block)
+ self.assertExitsTo(graph, body_block, test_block)
+ self.assertNotExitsTo(graph, body_block, return_block)
+ self.assertExitsTo(graph, test_block, return_block)
+
+ def test_control_flow_break_in_while_loop(self):
+ graph = control_flow.get_control_flow_graph(tc.break_in_while_loop)
+ # This is just one block since there's no edge from the while loop end
+ # back to the while loop test, and so the 'x = 1' line can be merged with
+ # the test.
+ x1_and_test_block = 'x < 2'
+ body_block = 'x += 3'
+ return_block = 'return x'
+
+ self.assertExitsTo(graph, x1_and_test_block, body_block)
+ self.assertExitsTo(graph, body_block, return_block)
+ self.assertNotExitsTo(graph, body_block, x1_and_test_block)
+ self.assertExitsTo(graph, x1_and_test_block, return_block)
+
+ def test_control_flow_nested_while_loops(self):
+ graph = control_flow.get_control_flow_graph(tc.nested_while_loops)
+ x1_block = 'x = 1'
+ outer_test_block = 'x < 2'
+ y3_block = 'y = 3'
+ inner_test_block = 'y < 4'
+ y5_block = 'y += 5'
+ x6_block = 'x += 6'
+ return_block = 'return x'
+
+ self.assertExitsTo(graph, x1_block, outer_test_block)
+ self.assertExitsTo(graph, outer_test_block, y3_block)
+ self.assertExitsTo(graph, outer_test_block, return_block)
+ self.assertExitsTo(graph, y3_block, inner_test_block)
+ self.assertExitsTo(graph, inner_test_block, y5_block)
+ self.assertExitsTo(graph, inner_test_block, x6_block)
+ self.assertExitsTo(graph, y5_block, inner_test_block)
+ self.assertExitsTo(graph, x6_block, outer_test_block)
+
+ def test_control_flow_exception_handling(self):
+ graph = control_flow.get_control_flow_graph(tc.exception_handling)
+ self.assertSameBlock(graph, 'before_stmt0', 'before_stmt1')
+ self.assertExitsTo(graph, 'before_stmt1', 'try_block')
+ self.assertNotExitsTo(graph, 'before_stmt0', 'except_block1')
+ self.assertNotExitsTo(graph, 'before_stmt1', 'final_block_stmt0')
+ self.assertRaisesTo(graph, 'try_block', 'error_type')
+ self.assertRaisesTo(graph, 'error_type', 'except_block2_stmt0')
+ self.assertExitsTo(graph, 'except_block1', 'after_stmt0')
+
+ self.assertRaisesTo(graph, 'after_stmt0', 'except_block2_stmt0')
+ self.assertNotRaisesTo(graph, 'try_block', 'except_block2_stmt0')
+
+ def test_control_flow_try_with_loop(self):
+ graph = control_flow.get_control_flow_graph(tc.try_with_loop)
+ self.assertSameBlock(graph, 'for_body0', 'for_body1')
+ self.assertSameBlock(graph, 'except_body0', 'except_body1')
+
+ self.assertExitsTo(graph, 'before_stmt0', 'iterator')
+ self.assertExitsTo(graph, 'iterator', 'target')
+ self.assertExitsTo(graph, 'target', 'for_body0')
+ self.assertExitsTo(graph, 'for_body1', 'target')
+ self.assertExitsTo(graph, 'target', 'after_stmt0')
+
+ self.assertRaisesTo(graph, 'iterator', 'except_body0')
+ self.assertRaisesTo(graph, 'target', 'except_body0')
+ self.assertRaisesTo(graph, 'for_body1', 'except_body0')
+
+ def test_control_flow_break_in_finally(self):
+ graph = control_flow.get_control_flow_graph(tc.break_in_finally)
+
+ # The exception handlers are tried sequentially until one matches.
+ self.assertRaisesTo(graph, 'try0', 'Exception0')
+ self.assertExitsTo(graph, 'Exception0', 'Exception1')
+ self.assertExitsTo(graph, 'Exception1', 'finally_stmt0')
+ # If the finally block were to finish and the exception hadn't matched, then
+ # the exception would exit to the FunctionDef's raise_block. However, the
+ # break statement prevents the finally from finishing and so the exception
+ # is lost when the break statement is reached.
+ # TODO(dbieber): Add the following assert.
+ # raise_block = graph.get_raise_block('break_in_finally')
+ # self.assertNotExitsFromEndTo(graph, 'finally_stmt1', raise_block)
+ # The finally block can of course still raise an exception of its own, so
+ # the following is still true:
+ # TODO(dbieber): Add the following assert.
+ # self.assertRaisesTo(graph, 'finally_stmt1', raise_block)
+
+ # An exception in the except handlers could flow to the finally block.
+ self.assertRaisesTo(graph, 'Exception0', 'finally_stmt0')
+ self.assertRaisesTo(graph, 'exception0_stmt0', 'finally_stmt0')
+ self.assertRaisesTo(graph, 'Exception1', 'finally_stmt0')
+
+ # The break statement flows to after0, rather than to the loop header.
+ self.assertNotExitsTo(graph, 'finally_stmt1', 'target0')
+ self.assertExitsTo(graph, 'finally_stmt1', 'after0')
+
+ def test_control_flow_for_loop_with_else(self):
+ graph = control_flow.get_control_flow_graph(tc.for_with_else)
+ self.assertExitsTo(graph, 'target', 'for_stmt0')
+ self.assertSameBlock(graph, 'for_stmt0', 'condition')
+
+ # If break is encountered, then the else clause is skipped.
+ self.assertExitsTo(graph, 'condition', 'after_stmt0')
+
+ # The else clause executes if the loop completes without reaching the break.
+ self.assertExitsTo(graph, 'target', 'else_stmt0')
+ self.assertNotExitsTo(graph, 'target', 'after_stmt0')
+
+ def test_control_flow_lambda(self):
+ graph = control_flow.get_control_flow_graph(tc.create_lambda)
+ self.assertNotExitsTo(graph, 'before_stmt0', 'args')
+ self.assertNotExitsTo(graph, 'before_stmt0', 'output')
+
+ def test_control_flow_generator(self):
+ graph = control_flow.get_control_flow_graph(tc.generator)
+ self.assertExitsTo(graph, 'target', 'yield_statement')
+ self.assertSameBlock(graph, 'yield_statement', 'after_stmt0')
+
+ def test_control_flow_inner_fn_while_loop(self):
+ graph = control_flow.get_control_flow_graph(tc.fn_with_inner_fn)
+ self.assertExitsTo(graph, 'x = 10', 'True')
+ self.assertExitsTo(graph, 'True', 'True')
+ self.assertSameBlock(graph, 'True', 'True')
+
+ def test_control_flow_example_class(self):
+ graph = control_flow.get_control_flow_graph(tc.ExampleClass)
+ self.assertSameBlock(graph, 'method_stmt0', 'method_stmt1')
+
+ def test_control_flow_return_outside_function(self):
+ with self.assertRaises(RuntimeError) as error:
+ control_flow.get_control_flow_graph('return x')
+ self.assertContainsSubsequence(str(error.exception),
+ 'outside of a function frame')
+
+ def test_control_flow_continue_outside_loop(self):
+ control_flow.get_control_flow_graph('for i in j: continue')
+ with self.assertRaises(RuntimeError) as error:
+ control_flow.get_control_flow_graph('if x: continue')
+ self.assertContainsSubsequence(str(error.exception),
+ 'outside of a loop frame')
+
+ def test_control_flow_break_outside_loop(self):
+ control_flow.get_control_flow_graph('for i in j: break')
+ with self.assertRaises(RuntimeError) as error:
+ control_flow.get_control_flow_graph('if x: break')
+ self.assertContainsSubsequence(str(error.exception),
+ 'outside of a loop frame')
+
+ def test_control_flow_for_all_test_components(self):
+ for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction):
+ control_flow.get_control_flow_graph(fn)
+
+ def test_control_flow_for_all_test_components_ast_to_instruction(self):
+ """All INSTRUCTION_AST_NODES in an AST correspond to one Instruction.
+
+ This assumes that a simple statement can't contain another simple statement.
+ However, Yield nodes are the exception to this as they are contained within
+ Expr nodes.
+
+ We omit Yield nodes from INSTRUCTION_AST_NODES despite them being listed
+ as simple statements in the Python docs.
+ """
+ for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction):
+ node = program_utils.program_to_ast(fn)
+ graph = control_flow.get_control_flow_graph(node)
+ for n in ast.walk(node):
+ if not isinstance(n, instruction_module.INSTRUCTION_AST_NODES):
+ continue
+ control_flow_nodes = list(graph.get_control_flow_nodes_by_ast_node(n))
+ self.assertLen(control_flow_nodes, 1, ast.dump(n))
+
+ def test_control_flow_reads_and_writes_appear_once(self):
+ """Asserts each read and write in an Instruction is unique in the graph.
+
+ Note that in the case of AugAssign, the same Name AST node is used once as
+ a read and once as a write.
+ """
+ for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction):
+ reads = set()
+ writes = set()
+ node = program_utils.program_to_ast(fn)
+ graph = control_flow.get_control_flow_graph(node)
+ for instruction in graph.get_instructions():
+ # Check that all reads are unique.
+ for read in instruction.get_reads():
+ if isinstance(read, tuple):
+ read = read[1]
+ self.assertIsInstance(read, ast.Name, 'Unexpected read type.')
+ self.assertNotIn(read, reads,
+ instruction_module.access_name(read))
+ reads.add(read)
+
+ # Check that all writes are unique.
+ for write in instruction.get_writes():
+ if isinstance(write, tuple):
+ write = write[1]
+ if isinstance(write, six.string_types):
+ continue
+ self.assertIsInstance(write, ast.Name)
+ self.assertNotIn(write, writes,
+ instruction_module.access_name(write))
+ writes.add(write)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/python_graphs/control_flow_test_components.py b/python_graphs/control_flow_test_components.py
new file mode 100644
index 0000000..ba55b54
--- /dev/null
+++ b/python_graphs/control_flow_test_components.py
@@ -0,0 +1,322 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test components for testing control flow.
+
+Many of these components would produce RuntimeErrors if run. Their purpose is
+for the testing of the control_flow module.
+"""
+
+
+# pylint: disable=missing-docstring
+# pylint: disable=pointless-statement,undefined-variable
+# pylint: disable=unused-variable,unused-argument
+# pylint: disable=bare-except,lost-exception,unreachable
+# pylint: disable=keyword-arg-before-vararg
+def straight_line_code():
+ x = 1
+ y = x + 2
+ z = y * 3
+ return z
+
+
+def simple_if_statement():
+ x = 1
+ y = 2
+ if x > y:
+ y = 3
+ return y
+
+
+def simple_for_loop():
+ x = 1
+ for y in range(x + 2):
+ z = y + 3
+ return z
+
+
+def tuple_in_for_loop():
+ a, b = 0, 1
+ for a, b in [(1, 2), (2, 3)]:
+ if a > b:
+ break
+ return b - a
+
+
+def simple_while_loop():
+ x = 1
+ while x < 2:
+ x += 3
+ return x
+
+
+def break_in_while_loop():
+ x = 1
+ while x < 2:
+ x += 3
+ break
+ return x
+
+
+def nested_while_loops():
+ x = 1
+ while x < 2:
+ y = 3
+ while y < 4:
+ y += 5
+ x += 6
+ return x
+
+
+def multiple_excepts():
+ try:
+ x = 1
+ except ValueError:
+ x = 2
+ x = 3
+ except RuntimeError:
+ x = 4
+ except:
+ x = 5
+ return x
+
+
+def try_finally():
+ header0
+ try:
+ try0
+ try1
+ except Exception0 as value0:
+ exception0_stmt0
+ finally:
+ finally_stmt0
+ finally_stmt1
+ after0
+
+
+def exception_handling():
+ try:
+ before_stmt0
+ before_stmt1
+ try:
+ try_block
+ except error_type as value:
+ except_block1
+ after_stmt0
+ after_stmt1
+ except:
+ except_block2_stmt0
+ except_block2_stmt1
+ finally:
+ final_block_stmt0
+ final_block_stmt1
+ end_block_stmt0
+ end_block_stmt1
+
+
+def fn_with_args(a, b=10, *varargs, **kwargs):
+ body_stmt0
+ body_stmt1
+ return
+
+
+def fn1(a, b):
+ return a + b
+
+
+def fn2(a, b):
+ c = a
+ if a > b:
+ c -= b
+ return c
+
+
+def fn3(a, b):
+ c = a
+ if a > b:
+ c -= b
+ c += 1
+ c += 2
+ c += 3
+ else:
+ c += b
+ return c
+
+
+def fn4(i):
+ count = 0
+ for i in range(i):
+ count += 1
+ return count
+
+
+def fn5(i):
+ count = 0
+ for _ in range(i):
+ if count > 5:
+ break
+ count += 1
+ return count
+
+
+def fn6():
+ count = 0
+ while count < 10:
+ count += 1
+ return count
+
+
+def fn7():
+ try:
+ raise ValueError('This will be caught.')
+ except ValueError as e:
+ del e
+ return
+
+
+def try_with_else():
+ try:
+ raise ValueError('This will be caught.')
+ except ValueError as e:
+ del e
+ else:
+ return 1
+ return 2
+
+
+def for_with_else():
+ for target in iterator:
+ for_stmt0
+ if condition:
+ break
+ for_stmt1
+ else:
+ else_stmt0
+ else_stmt1
+ after_stmt0
+
+
+def fn8(a):
+ a += 1
+
+
+def nested_loops(a):
+ """A test function illustrating nested loops."""
+ for i in range(a):
+ while True:
+ break
+ unreachable = 10
+ for j in range(i):
+ for k in range(j):
+ if j * k > 10:
+ continue
+ unreachable = 5
+ if i + j == 10:
+ return True
+ return False
+
+
+def try_with_loop():
+ before_stmt0
+ try:
+ for target in iterator:
+ for_body0
+ for_body1
+ except:
+ except_body0
+ except_body1
+ after_stmt0
+
+
+def break_in_finally():
+ header0
+ for target0 in iter0:
+ try:
+ try0
+ try1
+ except Exception0 as value0:
+ exception0_stmt0
+ except Exception1 as value1:
+ exception1_stmt0
+ exception1_stmt1
+ finally:
+ finally_stmt0
+ finally_stmt1
+ # This breaks out of the for-loop.
+ break
+ after0
+
+
+def break_in_try():
+ count = 0
+ for _ in range(10):
+ try:
+ count += 1
+ # This breaks out of the for-loop through the finally block.
+ break
+ except ValueError:
+ pass
+ finally:
+ count += 2
+ return count
+
+
+def nested_try_excepts():
+ try:
+ try:
+ x = 0
+ x += 1
+ try:
+ x = 2 + 2
+ except ValueError(1+1) as e:
+ x = 3 - 3
+ finally:
+ x = 4
+ except RuntimeError:
+ x = 5 * 5
+ finally:
+ x = 6 ** 6
+ except:
+ x = 7 / 7
+ return x
+
+
+def multi_op_expression():
+ return 1 + 2 * 3
+
+
+def create_lambda():
+ before_stmt0
+ fn = lambda args: output
+ after_stmt0
+
+
+def generator():
+ for target in iterator:
+ yield yield_statement
+ after_stmt0
+
+
+def fn_with_inner_fn():
+ def inner_fn():
+ x = 10
+ while True:
+ pass
+
+
+class ExampleClass(object):
+
+ def method0(self, arg):
+ method_stmt0
+ method_stmt1
diff --git a/python_graphs/control_flow_visualizer.py b/python_graphs/control_flow_visualizer.py
new file mode 100644
index 0000000..185c23b
--- /dev/null
+++ b/python_graphs/control_flow_visualizer.py
@@ -0,0 +1,74 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""Create control flow graph visualizations for the test components.
+
+
+Usage:
+python -m python_graphs.control_flow_visualizer
+"""
+
+import inspect
+import os
+
+from absl import app
+from absl import flags
+from absl import logging # pylint: disable=unused-import
+
+from python_graphs import control_flow
+from python_graphs import control_flow_graphviz
+from python_graphs import control_flow_test_components as tc
+from python_graphs import program_utils
+
+FLAGS = flags.FLAGS
+
+
+def render_functions(functions):
+ for name, function in functions:
+ logging.info(name)
+ graph = control_flow.get_control_flow_graph(function)
+ path = '/tmp/control_flow_graphs/{}.png'.format(name)
+ source = program_utils.getsource(function) # pylint: disable=protected-access
+ control_flow_graphviz.render(graph, include_src=source, path=path)
+
+
+def render_filepaths(filepaths):
+ for filepath in filepaths:
+ filename = os.path.basename(filepath).split('.')[0]
+ logging.info(filename)
+ with open(filepath, 'r') as f:
+ source = f.read()
+ graph = control_flow.get_control_flow_graph(source)
+ path = '/tmp/control_flow_graphs/{}.png'.format(filename)
+ control_flow_graphviz.render(graph, include_src=source, path=path)
+
+
+def main(argv):
+ del argv # Unused.
+
+ functions = [
+ (name, fn)
+ for name, fn in inspect.getmembers(tc, predicate=inspect.isfunction)
+ ]
+ render_functions(functions)
+
+ # Add filepaths here to visualize their functions.
+ filepaths = [
+ __file__,
+ ]
+ render_filepaths(filepaths)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/python_graphs/cyclomatic_complexity.py b/python_graphs/cyclomatic_complexity.py
new file mode 100644
index 0000000..4ac9812
--- /dev/null
+++ b/python_graphs/cyclomatic_complexity.py
@@ -0,0 +1,49 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Computes the cyclomatic complexity of a program or control flow graph."""
+
+
+def cyclomatic_complexity(control_flow_graph):
+ """Computes the cyclomatic complexity of a function from its cfg."""
+ enter_block = next(control_flow_graph.get_enter_blocks())
+
+ new_blocks = []
+ seen_block_ids = set()
+ new_blocks.append(enter_block)
+ seen_block_ids.add(id(enter_block))
+ num_edges = 0
+
+ while new_blocks:
+ block = new_blocks.pop()
+ for next_block in block.exits_from_end:
+ num_edges += 1
+ if id(next_block) not in seen_block_ids:
+ new_blocks.append(next_block)
+ seen_block_ids.add(id(next_block))
+ num_nodes = len(seen_block_ids)
+
+ p = 1 # num_connected_components
+ e = num_edges
+ n = num_nodes
+ return e - n + 2 * p
+
+
+def cyclomatic_complexity2(control_flow_graph):
+ """Computes the cyclomatic complexity of a program from its cfg."""
+ # Assumes a single connected component.
+ p = 1 # num_connected_components
+ e = sum(len(block.exits_from_end) for block in control_flow_graph.blocks)
+ n = len(control_flow_graph.blocks)
+ return e - n + 2 * p
diff --git a/python_graphs/cyclomatic_complexity_test.py b/python_graphs/cyclomatic_complexity_test.py
new file mode 100644
index 0000000..9aee1bd
--- /dev/null
+++ b/python_graphs/cyclomatic_complexity_test.py
@@ -0,0 +1,38 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for cyclomatic_complexity.py."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from python_graphs import control_flow
+from python_graphs import control_flow_test_components as tc
+from python_graphs import cyclomatic_complexity
+
+
+class CyclomaticComplexityTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ (tc.straight_line_code, 1),
+ (tc.simple_if_statement, 2),
+ (tc.simple_for_loop, 2),
+ )
+ def test_cyclomatic_complexity(self, component, target_value):
+ graph = control_flow.get_control_flow_graph(component)
+ value = cyclomatic_complexity.cyclomatic_complexity(graph)
+ self.assertEqual(value, target_value)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/python_graphs/data_flow.py b/python_graphs/data_flow.py
new file mode 100644
index 0000000..a1afa4e
--- /dev/null
+++ b/python_graphs/data_flow.py
@@ -0,0 +1,233 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Data flow analysis of Python programs."""
+
+import collections
+
+from absl import logging # pylint: disable=unused-import
+import gast as ast
+
+from python_graphs import control_flow
+from python_graphs import instruction as instruction_module
+
+
+READ = instruction_module.READ
+WRITE = instruction_module.WRITE
+
+
+class Analysis(object):
+ """Base class for a data flow analysis.
+
+ Attributes:
+ label: The name of the analysis.
+ forward: (bool) True for forward analyses, False for backward analyses.
+ in_label: The name of the analysis, suffixed with _in.
+ out_label: The name of the analysis, suffixed with _out.
+ before_label: Either the in_label or out_label depending on the direction of
+ the analysis. Marks the before_value on a node during an analysis.
+ after_label: Either the in_label or out_label depending on the direction of
+ the analysis. Marks the after_value on a node during an analysis.
+ """
+
+ def __init__(self, label, forward):
+ self.label = label
+ self.forward = forward
+
+ self.in_label = label + '_in'
+ self.out_label = label + '_out'
+
+ self.before_label = self.in_label if forward else self.out_label
+ self.after_label = self.out_label if forward else self.in_label
+
+ def aggregate_previous_after_values(self, previous_after_values):
+ """Computes the before value for a node from the previous after values.
+
+ This is the 'meet' or 'join' function of the analysis.
+ TODO(dbieber): Update terminology to match standard textbook notation.
+
+ Args:
+ previous_after_values: The after values of all before nodes.
+ Returns:
+ The before value for the current node.
+ """
+ raise NotImplementedError
+
+ def compute_after_value(self, node, before_value):
+ """Computes the after value for a node from the node and the before value.
+
+ This is the 'transfer' function of the analysis.
+ TODO(dbieber): Update terminology to match standard textbook notation.
+
+ Args:
+ node: The node or block for which to compute the after value.
+ before_value: The before value of the node.
+ Returns:
+ The computed after value for the node.
+ """
+ raise NotImplementedError
+
+ def visit(self, node):
+ """Visit the nodes of the control flow graph, performing the analysis.
+
+ Terminology:
+ in_value: The value of the analysis at the start of a node.
+ out_value: The value of the analysis at the end of a node.
+ before_value: in_value in a forward analysis; out_value in a backward
+ analysis.
+ after_value: out_value in a forward analysis; in_value in a backward
+ analysis.
+
+ Args:
+ node: A graph element that supports the .next / .prev API, such as a
+ ControlFlowNode from a ControlFlowGraph or a BasicBlock from a
+ ControlFlowGraph.
+ """
+ to_visit = collections.deque([node])
+ while to_visit:
+ node = to_visit.popleft()
+
+ before_nodes = node.prev if self.forward else node.next
+ after_nodes = node.next if self.forward else node.prev
+ previous_after_values = [
+ before_node.get_label(self.after_label)
+ for before_node in before_nodes
+ if before_node.has_label(self.after_label)]
+
+ if node.has_label(self.after_label):
+ initial_after_value_hash = hash(node.get_label(self.after_label))
+ else:
+ initial_after_value_hash = None
+ before_value = self.aggregate_previous_after_values(previous_after_values)
+ node.set_label(self.before_label, before_value)
+ after_value = self.compute_after_value(node, before_value)
+ node.set_label(self.after_label, after_value)
+ if hash(after_value) != initial_after_value_hash:
+ for after_node in after_nodes:
+ to_visit.append(after_node)
+
+
+def get_while_loop_variables(node, graph=None):
+ """Gets the set of loop variables used for while loop rewriting.
+
+ This is the set of variables used for rewriting a while loop into its
+ functional form.
+
+ Args:
+ node: An ast.While AST node.
+ graph: (Optional) The ControlFlowGraph of the function or program containing
+ the while loop. If not present, the control flow graph for the while loop
+ will be computed.
+ Returns:
+ The set of variable identifiers that are live at the start of the loop's
+ test and at the start of the loop's body.
+ """
+ graph = graph or control_flow.get_control_flow_graph(node)
+ test_block = graph.get_block_by_ast_node(node.test)
+
+ for block in graph.get_exit_blocks():
+ analysis = LivenessAnalysis()
+ analysis.visit(block)
+ # TODO(dbieber): Move this logic into the Analysis class to avoid the use of
+ # magic strings.
+ live_variables = test_block.get_label('liveness_in')
+ written_variables = {
+ write.id
+ for write in instruction_module.get_writes_from_ast_node(node)
+ if isinstance(write, ast.Name)
+ }
+ return live_variables & written_variables
+
+
+class LivenessAnalysis(Analysis):
+ """Liveness analysis by basic block.
+
+ In the liveness analysis, the in_value of a block is the set of variables
+ that are live at the start of a block. "Live" means that the current value of
+ the variable may be used later in the execution. The out_value of a block is
+ the set of variable identifiers that are live at the end of the block.
+
+ Since this is a backward analysis, the "before_value" is the out_value and the
+ "after_value" is the in_value.
+ """
+
+ def __init__(self):
+ super(LivenessAnalysis, self).__init__(label='liveness', forward=False)
+
+ def aggregate_previous_after_values(self, previous_after_values):
+ """Computes the out_value (before_value) of a block.
+
+ Args:
+ previous_after_values: A list of the sets of live variables at the start
+ of each of the blocks following the current block.
+ Returns:
+ The set of live variables at the end of the current block. This is the
+ union of live variable sets at the start of each subsequent block.
+ """
+ result = set()
+ for before_value in previous_after_values:
+ result |= before_value
+ return frozenset(result)
+
+ def compute_after_value(self, block, before_value):
+ """Computes the liveness analysis gen and kill sets for a basic block.
+
+ The gen set is the set of variables read by the block before they are
+ written to.
+ The kill set is the set of variables written to by the basic block.
+
+ Args:
+ block: The BasicBlock to analyze.
+ before_value: The out_value for block (the set of variables live at the
+ end of the block.)
+ Returns:
+ The in_value for block (the set of variables live at the start of the
+ block).
+ """
+ gen = set()
+ kill = set()
+ for control_flow_node in block.control_flow_nodes:
+ instruction = control_flow_node.instruction
+ for read in instruction.get_read_names():
+ if read not in kill:
+ gen.add(read)
+ kill.update(instruction.get_write_names())
+ return frozenset((before_value - kill) | gen)
+
+
+class FrozenDict(dict):
+
+ def __hash__(self):
+ return hash(tuple(sorted(self.items())))
+
+
+class LastAccessAnalysis(Analysis):
+ """Computes for each variable its possible last reads and last writes."""
+
+ def __init__(self):
+ super(LastAccessAnalysis, self).__init__(label='last_access', forward=True)
+
+ def aggregate_previous_after_values(self, previous_after_values):
+ result = collections.defaultdict(frozenset)
+ for previous_after_value in previous_after_values:
+ for key, value in previous_after_value.items():
+ result[key] |= value
+ return FrozenDict(result)
+
+ def compute_after_value(self, node, before_value):
+ result = before_value.copy()
+ for access in node.instruction.accesses:
+ kind_and_name = instruction_module.access_kind_and_name(access)
+ result[kind_and_name] = frozenset([access])
+ return FrozenDict(result)
diff --git a/python_graphs/data_flow_test.py b/python_graphs/data_flow_test.py
new file mode 100644
index 0000000..a3849ee
--- /dev/null
+++ b/python_graphs/data_flow_test.py
@@ -0,0 +1,138 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for data_flow.py."""
+
+import inspect
+
+from absl import logging # pylint: disable=unused-import
+from absl.testing import absltest
+import gast as ast
+
+from python_graphs import control_flow
+from python_graphs import control_flow_test_components as tc
+from python_graphs import data_flow
+from python_graphs import program_utils
+
+
+class DataFlowTest(absltest.TestCase):
+
+ def test_get_while_loop_variables(self):
+ root = program_utils.program_to_ast(tc.nested_while_loops)
+ graph = control_flow.get_control_flow_graph(root)
+
+ # node = graph.get_ast_node_by_type(ast.While)
+ # TODO(dbieber): data_flow.get_while_loop_variables(node, graph)
+
+ analysis = data_flow.LivenessAnalysis()
+ for block in graph.get_exit_blocks():
+ analysis.visit(block)
+
+ for block in graph.get_blocks_by_ast_node_type_and_label(
+ ast.While, 'test_block'):
+ logging.info(block.get_label('liveness_out'))
+
+ def test_liveness_simple_while_loop(self):
+ def simple_while_loop():
+ a = 2
+ b = 10
+ x = 1
+ while x < b:
+ tmp = x + a
+ x = tmp + 1
+
+ program_node = program_utils.program_to_ast(simple_while_loop)
+ graph = control_flow.get_control_flow_graph(program_node)
+
+ # TODO(dbieber): Use unified query system.
+ while_node = [
+ node for node in ast.walk(program_node)
+ if isinstance(node, ast.While)][0]
+ loop_variables = data_flow.get_while_loop_variables(while_node, graph)
+ self.assertEqual(loop_variables, {'x'})
+
+ def test_data_flow_nested_loops(self):
+ def fn():
+ count = 0
+ for x in range(10):
+ for y in range(10):
+ if x == y:
+ count += 1
+ return count
+
+ program_node = program_utils.program_to_ast(fn)
+ graph = control_flow.get_control_flow_graph(program_node)
+
+ # Perform the analysis.
+ analysis = data_flow.LastAccessAnalysis()
+ analysis.visit(graph.start_block.control_flow_nodes[0])
+ for node in graph.get_enter_control_flow_nodes():
+ analysis.visit(node)
+
+ # Verify correctness.
+ node = graph.get_control_flow_node_by_source('count += 1')
+ last_accesses_in = node.get_label('last_access_in')
+ last_accesses_out = node.get_label('last_access_out')
+ self.assertLen(last_accesses_in['write-count'], 2) # += 1, = 0
+ self.assertLen(last_accesses_in['read-count'], 1) # += 1
+ self.assertLen(last_accesses_out['write-count'], 1) # += 1
+ self.assertLen(last_accesses_out['read-count'], 1) # += 1
+
+ def test_last_accesses_analysis(self):
+ root = program_utils.program_to_ast(tc.nested_while_loops)
+ graph = control_flow.get_control_flow_graph(root)
+
+ analysis = data_flow.LastAccessAnalysis()
+ analysis.visit(graph.start_block.control_flow_nodes[0])
+
+ for node in graph.get_enter_control_flow_nodes():
+ analysis.visit(node)
+
+ for block in graph.blocks:
+ for cfn in block.control_flow_nodes:
+ self.assertTrue(cfn.has_label('last_access_in'))
+ self.assertTrue(cfn.has_label('last_access_out'))
+
+ node = graph.get_control_flow_node_by_source('y += 5')
+ last_accesses = node.get_label('last_access_out')
+ # TODO(dbieber): Add asserts that these are the correct accesses.
+ self.assertLen(last_accesses['write-x'], 2) # x = 1, x += 6
+ self.assertLen(last_accesses['read-x'], 1) # x < 2
+
+ node = graph.get_control_flow_node_by_source('return x')
+ last_accesses = node.get_label('last_access_out')
+ self.assertLen(last_accesses['write-x'], 2) # x = 1, x += 6
+ self.assertLen(last_accesses['read-x'], 1) # x < 2
+
+ def test_liveness_analysis_all_test_components(self):
+ for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction):
+ root = program_utils.program_to_ast(fn)
+ graph = control_flow.get_control_flow_graph(root)
+
+ analysis = data_flow.LivenessAnalysis()
+ for block in graph.get_exit_blocks():
+ analysis.visit(block)
+
+ def test_last_access_analysis_all_test_components(self):
+ for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction):
+ root = program_utils.program_to_ast(fn)
+ graph = control_flow.get_control_flow_graph(root)
+
+ analysis = data_flow.LastAccessAnalysis()
+ for node in graph.get_enter_control_flow_nodes():
+ analysis.visit(node)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/python_graphs/examples/control_flow_example.py b/python_graphs/examples/control_flow_example.py
new file mode 100644
index 0000000..01dfee1
--- /dev/null
+++ b/python_graphs/examples/control_flow_example.py
@@ -0,0 +1,57 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Example generating a control flow graph from a Python function.
+
+Generates an image visualizing the control flow graph for each of the functions
+in control_flow_test_components.py. Saves the resulting images to the directory
+`out`.
+
+Usage:
+python -m python_graphs.examples.control_flow_example
+"""
+
+import inspect
+import os
+
+from absl import app
+
+from python_graphs import control_flow
+from python_graphs import control_flow_graphviz
+from python_graphs import control_flow_test_components as tc
+from python_graphs import program_utils
+
+
+def plot_control_flow_graph(fn, path):
+ graph = control_flow.get_control_flow_graph(fn)
+ source = program_utils.getsource(fn)
+ control_flow_graphviz.render(graph, include_src=source, path=path)
+
+
+def main(argv) -> None:
+ del argv # Unused
+
+ # Create the output directory.
+ os.makedirs('out', exist_ok=True)
+
+ # For each function in control_flow_test_components.py, visualize its
+ # control flow graph. Save the results in the output directory.
+ for name, fn in inspect.getmembers(tc, predicate=inspect.isfunction):
+ path = f'out/{name}_cfg.png'
+ plot_control_flow_graph(fn, path)
+ print('Done. See the `out` directory for the results.')
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/python_graphs/examples/cyclomatic_complexity_example.py b/python_graphs/examples/cyclomatic_complexity_example.py
new file mode 100644
index 0000000..3077a73
--- /dev/null
+++ b/python_graphs/examples/cyclomatic_complexity_example.py
@@ -0,0 +1,46 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Example computing the cyclomatic complexity of various Python functions.
+
+For each of the functions in control_flow_test_components.py, this computes and
+prints the function's cyclomatic complexity.
+
+Usage:
+python -m python_graphs.examples.cyclomatic_complexity_example
+"""
+
+import inspect
+
+from absl import app
+
+from python_graphs import control_flow
+from python_graphs import control_flow_test_components as tc
+from python_graphs import cyclomatic_complexity
+
+
+def main(argv) -> None:
+ del argv # Unused
+
+ # For each function in control_flow_test_components.py, compute its cyclomatic
+ # complexity and print the result.
+ for name, fn in inspect.getmembers(tc, predicate=inspect.isfunction):
+ print(f'{name}: ', end='')
+ graph = control_flow.get_control_flow_graph(fn)
+ value = cyclomatic_complexity.cyclomatic_complexity(graph)
+ print(value)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/python_graphs/examples/program_graph_example.py b/python_graphs/examples/program_graph_example.py
new file mode 100644
index 0000000..b54a720
--- /dev/null
+++ b/python_graphs/examples/program_graph_example.py
@@ -0,0 +1,50 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Example generating a complete program graph from a Python function.
+
+Generates an image visualizing the complete program graph for each function
+in program_graph_test_components.py. Saves the resulting images to the directory
+`out`.
+
+Usage:
+python -m python_graphs.examples.program_graph_example
+"""
+
+import inspect
+import os
+
+from absl import app
+from python_graphs import program_graph
+from python_graphs import program_graph_graphviz
+from python_graphs import program_graph_test_components as tc
+
+
+def main(argv) -> None:
+ del argv # Unused
+
+ # Create the output directory.
+ os.makedirs('out', exist_ok=True)
+
+ # For each function in program_graph_test_components.py, visualize its
+ # program graph. Save the results in the output directory.
+ for name, fn in inspect.getmembers(tc, predicate=inspect.isfunction):
+ path = f'out/{name}-program-graph.png'
+ graph = program_graph.get_program_graph(fn)
+ program_graph_graphviz.render(graph, path=path)
+ print('Done. See the `out` directory for the results.')
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/python_graphs/instruction.py b/python_graphs/instruction.py
new file mode 100644
index 0000000..18d5d73
--- /dev/null
+++ b/python_graphs/instruction.py
@@ -0,0 +1,400 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""An Instruction represents an executable unit of a Python program.
+
+Almost all simple statements correspond to Instructions, except for statements
+likes pass, continue, and break, whose effects are already represented in the
+structure of the control-flow graph.
+
+In addition to simple statements, assignments that take place outside of simple
+statements such as implicitly in a function or class definition also correspond
+to Instructions.
+
+The complete set of places where Instructions occur in source are listed here:
+
+1. (Any node in INSTRUCTION_AST_NODES used as a statement.)
+2. if : ... (elif is the same.)
+3+4. for in : ...
+5. while : ...
+6. try: ... except : ...
+7. TODO(dbieber): Test for "with :"...
+
+In the code:
+
+@decorator
+def fn(args=defaults):
+ body
+
+Outside of the function definition, we get the following instructions:
+8. Each decorator is an Instruction.
+9. Each default is an Instruction.
+10. The assignment of the function def to the function name is an Instruction.
+Inside the function definition, we get the following instructions:
+11. An Instruction for the assignment of values to the arguments.
+(1, again) And then the body can consist of multiple Instructions too.
+
+Likewise in the code:
+
+@decorator
+class C(object):
+ body
+
+The following are Instructions:
+(8, again) Each decorator is an Instruction
+12. The assignment of the class to the variable C is an Instruction.
+(1, again) And then the body can consist of multiple Instructions too.
+13. TODO(dbieber): The base class (object) is an Instruction too.
+"""
+
+import gast as ast
+import six
+
+# Types of accesses:
+READ = 'read'
+WRITE = 'write'
+
+# Context lists
+WRITE_CONTEXTS = (ast.Store, ast.Del, ast.Param, ast.AugStore)
+READ_CONTEXTS = (ast.Load, ast.AugLoad)
+
+# Sources of implicit writes:
+CLASS = 'class'
+FUNCTION = 'function'
+ARGS = 'args'
+KWARG = 'kwarg'
+KWONLYARGS = 'kwonlyargs'
+VARARG = 'vararg'
+ITERATOR = 'iter'
+EXCEPTION = 'exception'
+
+INSTRUCTION_AST_NODES = (
+ ast.Expr, # expression_stmt
+ ast.Assert, # assert_stmt
+ ast.Assign, # assignment_stmt
+ ast.AugAssign, # augmented_assignment_stmt
+ ast.Delete, # del_stmt
+ ast.Print, # print_stmt
+ ast.Return, # return_stmt
+ # ast.Yield, # yield_stmt. ast.Yield nodes are contained in ast.Expr nodes.
+ ast.Raise, # raise_stmt
+ ast.Import, # import_stmt
+ ast.ImportFrom,
+ ast.Global, # global_stmt
+ ast.Exec, # exec_stmt
+)
+
+# https://docs.python.org/2/reference/simple_stmts.html
+SIMPLE_STATEMENT_AST_NODES = INSTRUCTION_AST_NODES + (
+ ast.Pass, # pass_stmt
+ ast.Break, # break_stmt
+ ast.Continue, # continue_stmt
+)
+
+
+def _canonicalize(node):
+ if isinstance(node, list) and len(node) == 1:
+ return _canonicalize(node[0])
+ if isinstance(node, ast.Module):
+ return _canonicalize(node.body)
+ if isinstance(node, ast.Expr):
+ return _canonicalize(node.value)
+ return node
+
+
+def represent_same_program(node1, node2):
+ """Whether AST nodes node1 and node2 represent the same program syntactically.
+
+ Two programs are the same syntactically is they have equivalent ASTs, up to
+ some small changes. The context field of Name nodes can change without the
+ syntax represented by the AST changing. This allows for example for the short
+ program 'x' (a read) to match with a subprogram 'x' of 'x = 3' (in which x is
+ a write), since these two programs are the same syntactically ('x' and 'x').
+
+ Except for the context field of Name nodes, the two nodes are recursively
+ checked for exact equality.
+
+ Args:
+ node1: An AST node. This can be an ast.AST object, a primitive, or a list of
+ AST nodes (primitives or ast.AST objects).
+ node2: An AST node. This can be an ast.AST object, a primitive, or a list of
+ AST nodes (primitives or ast.AST objects).
+
+ Returns:
+ Whether the two nodes represent equivalent programs.
+ """
+ node1 = _canonicalize(node1)
+ node2 = _canonicalize(node2)
+
+ if type(node1) != type(node2): # pylint: disable=unidiomatic-typecheck
+ return False
+ if not isinstance(node1, ast.AST):
+ return node1 == node2
+
+ fields1 = list(ast.iter_fields(node1))
+ fields2 = list(ast.iter_fields(node2))
+ if len(fields1) != len(fields2):
+ return False
+
+ for (field1, value1), (field2, value2) in zip(fields1, fields2):
+ if field1 == 'ctx':
+ continue
+ if field1 != field2 or type(value1) is not type(value2):
+ return False
+ if isinstance(value1, list):
+ for item1, item2 in zip(value1, value2):
+ if not represent_same_program(item1, item2):
+ return False
+ elif not represent_same_program(value1, value2):
+ return False
+
+ return True
+
+
+class AccessVisitor(ast.NodeVisitor):
+ """Visitor that computes an ordered list of accesses.
+
+ Accesses are ordered based on a depth-first traversal of the AST, using the
+ order of fields defined in `gast`, except for Assign nodes, for which the RHS
+ is ordered before the LHS.
+
+ This may differ from Python execution semantics in two ways:
+
+ - Both branches sides of short-circuit `and`/`or` expressions or conditional
+ `X if Y else Z` expressions are considered to be evaluated, even if one of
+ them is actually skipped at runtime.
+ - For AST nodes whose field order doesn't match the Python interpreter's
+ evaluation order, the field order is used instead. Most AST nodes match
+ execution order, but some differ (e.g. for dictionary literals, the
+ interpreter alternates evaluating keys and values, but the field order has
+ all keys and then all values). Assignments are a special case; the
+ AccessVisitor evaluates the RHS first even though the LHS occurs first in
+ the expression.
+
+ Attributes:
+ accesses: List of accesses encountered by the visitor.
+ """
+
+ # TODO(dbieber): Include accesses of ast.Subscript and ast.Attribute targets.
+
+ def __init__(self):
+ self.accesses = []
+
+ def visit_Name(self, node):
+ """Visit a Name, adding it to the list of accesses."""
+ self.accesses.append(node)
+
+ def visit_Assign(self, node):
+ """Visit an Assign, ordering RHS accesses before LHS accesses."""
+ self.visit(node.value)
+ for target in node.targets:
+ self.visit(target)
+
+ def visit_AugAssign(self, node):
+ """Visit an AugAssign, which contains both a read and a write."""
+ # An AugAssign is a read as well as a write, even with the ctx of a write.
+ self.visit(node.value)
+ # Add a read access if we are assigning to a name.
+ if isinstance(node.target, ast.Name):
+ # TODO(dbieber): Use a proper type instead of a tuple for accesses.
+ self.accesses.append(('read', node.target, node))
+ # Add the write access as normal.
+ self.visit(node.target)
+
+
+def get_accesses_from_ast_node(node):
+ """Get all accesses for an AST node, in depth-first AST field order."""
+ visitor = AccessVisitor()
+ visitor.visit(node)
+ return visitor.accesses
+
+
+def get_reads_from_ast_node(ast_node):
+ """Get all reads for an AST node, in depth-first AST field order.
+
+ Args:
+ ast_node: The AST node of interest.
+
+ Returns:
+ A list of writes performed by that AST node.
+ """
+ return [
+ access for access in get_accesses_from_ast_node(ast_node)
+ if access_is_read(access)
+ ]
+
+
+def get_writes_from_ast_node(ast_node):
+ """Get all writes for an AST node, in depth-first AST field order.
+
+ Args:
+ ast_node: The AST node of interest.
+
+ Returns:
+ A list of writes performed by that AST node.
+ """
+ return [
+ access for access in get_accesses_from_ast_node(ast_node)
+ if access_is_write(access)
+ ]
+
+
+def create_writes(node, parent=None):
+ # TODO(dbieber): Use a proper type instead of a tuple for accesses.
+ if isinstance(node, ast.AST):
+ return [
+ ('write', n, parent) for n in ast.walk(node) if isinstance(n, ast.Name)
+ ]
+ else:
+ return [('write', node, parent)]
+
+
+def access_is_read(access):
+ if isinstance(access, ast.AST):
+ assert isinstance(access, ast.Name), access
+ return isinstance(access.ctx, READ_CONTEXTS)
+ else:
+ return access[0] == 'read'
+
+
+def access_is_write(access):
+ if isinstance(access, ast.AST):
+ assert isinstance(access, ast.Name), access
+ return isinstance(access.ctx, WRITE_CONTEXTS)
+ else:
+ return access[0] == 'write'
+
+
+def access_name(access):
+ if isinstance(access, ast.AST):
+ return access.id
+ elif isinstance(access, tuple):
+ if isinstance(access[1], six.string_types):
+ return access[1]
+ elif isinstance(access[1], ast.Name):
+ return access[1].id
+ raise ValueError('Unexpected access type.', access)
+
+
+def access_kind(access):
+ if access_is_read(access):
+ return 'read'
+ elif access_is_write(access):
+ return 'write'
+
+
+def access_kind_and_name(access):
+ return '{}-{}'.format(access_kind(access), access_name(access))
+
+
+def access_identifier(name, kind):
+ return '{}-{}'.format(kind, name)
+
+
+class Instruction(object):
+ # pyformat:disable
+ """Represents an executable unit of a Python program.
+
+ An Instruction is a part of an AST corresponding to a simple statement or
+ assignment, not corresponding to control flow. The part of the AST is not
+ necessarily an AST node. It may be an AST node, or it may instead be a string
+ (such as a variable name).
+
+ Instructions play an important part in control flow graphs. An Instruction
+ is the smallest unit of a control flow graph (wrapped in a ControlFlowNode).
+ A control flow graph consists of basic blocks which represent a sequence of
+ Instructions that are executed in a straight-line manner, or not at all.
+
+ Conceptually an Instruction is immutable. This means that while Python does
+ permit the mutation of an Instruction, in practice an Instruction object
+ should not be modified once it is created.
+
+ Note that an Instruction may be interrupted by an exception mid-execution.
+ This is captured in control flow graphs via interrupting exits from basic
+ blocks to either exception handlers or special 'raises' blocks.
+
+ In addition to pure simple statements, an Instruction can represent a number
+ of different parts of code. These are all listed explicitly in the module
+ docstring.
+
+ In the common case, the accesses made by an Instruction are given by the Name
+ AST nodes contained in the Instruction's AST node. In some cases, when the
+ instruction.source field is not None, the accesses made by an Instruction are
+ not simply the Name AST nodes of the Instruction's node. For example, in a
+ function definition, the only access is the assignment of the function def to
+ the variable with the function's name; the Name nodes contained in the
+ function definition are not part of the function definition Instruction, and
+ instead are part of other Instructions that make up the function. The set of
+ accesses made by an Instruction is computed when the Instruction is created
+ and available via the accesses attribute of the Instruction.
+
+ Attributes:
+ node: The AST node corresponding to the instruction.
+ accesses: (optional) An ordered list of all reads and writes made by this
+ instruction. Each item in `accesses` is one of either:
+ - A 3-tuple with fields (kind, node, parent). kind is either 'read' or
+ 'write'. node is either a string or Name AST node. parent is an AST
+ node where node occurs.
+ - A Name AST node
+ # TODO(dbieber): Use a single type for all accesses.
+ source: (optional) The source of the writes. For example in the for loop
+ `for x in items: pass` there is a instruction for the Name node "x". Its
+ source is ITERATOR, indicating that this instruction corresponds to x
+ being assigned a value from an iterator. When source is not None, the
+ Python code corresponding to the instruction does not coincide with the
+ Python code corresponding to the instruction's node.
+ """
+ # pyformat:enable
+
+ def __init__(self, node, accesses=None, source=None):
+ if not isinstance(node, ast.AST):
+ raise TypeError('node must be an instance of ast.AST.', node)
+ self.node = node
+ if accesses is None:
+ accesses = get_accesses_from_ast_node(node)
+ self.accesses = accesses
+ self.source = source
+
+ def contains_subprogram(self, node):
+ """Whether this Instruction contains the given AST as a subprogram.
+
+ Computes whether `node` is a subtree of this Instruction's AST.
+ If the Instruction represents an implied write, then the node must match
+ against the Instruction's writes.
+
+ Args:
+ node: The node to check the instruction against for a match.
+
+ Returns:
+ (bool) Whether or not this Instruction contains the node, syntactically.
+ """
+ if self.source is not None:
+ # Only exact matches are permissible if source is not None.
+ return represent_same_program(node, self.node)
+ for subtree in ast.walk(self.node):
+ if represent_same_program(node, subtree):
+ return True
+ return False
+
+ def get_reads(self):
+ return {access for access in self.accesses if access_is_read(access)}
+
+ def get_read_names(self):
+ return {access_name(access) for access in self.get_reads()}
+
+ def get_writes(self):
+ return {access for access in self.accesses if access_is_write(access)}
+
+ def get_write_names(self):
+ return {access_name(access) for access in self.get_writes()}
diff --git a/python_graphs/instruction_test.py b/python_graphs/instruction_test.py
new file mode 100644
index 0000000..1f26b6e
--- /dev/null
+++ b/python_graphs/instruction_test.py
@@ -0,0 +1,123 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for instruction module."""
+
+from absl.testing import absltest
+import gast as ast
+from python_graphs import instruction as instruction_module
+
+
+def create_instruction(source):
+ node = ast.parse(source)
+ node = instruction_module._canonicalize(node)
+ return instruction_module.Instruction(node)
+
+
+class InstructionTest(absltest.TestCase):
+
+ def test_instruction(self):
+ self.assertIsNotNone(instruction_module.Instruction)
+
+ def test_represent_same_program_basic_positive_case(self):
+ program1 = ast.parse('x + 1')
+ program2 = ast.parse('x + 1')
+ self.assertTrue(
+ instruction_module.represent_same_program(program1, program2))
+
+ def test_represent_same_program_basic_negative_case(self):
+ program1 = ast.parse('x + 1')
+ program2 = ast.parse('x + 2')
+ self.assertFalse(
+ instruction_module.represent_same_program(program1, program2))
+
+ def test_represent_same_program_different_contexts(self):
+ full_program1 = ast.parse('y = x + 1') # y is a write
+ program1 = full_program1.body[0].targets[0] # 'y'
+ program2 = ast.parse('y') # y is a read
+ self.assertTrue(
+ instruction_module.represent_same_program(program1, program2))
+
+ def test_get_accesses(self):
+ instruction = create_instruction('x + 1')
+ self.assertEqual(instruction.get_read_names(), {'x'})
+ self.assertEqual(instruction.get_write_names(), set())
+
+ instruction = create_instruction('return x + y + z')
+ self.assertEqual(instruction.get_read_names(), {'x', 'y', 'z'})
+ self.assertEqual(instruction.get_write_names(), set())
+
+ instruction = create_instruction('fn(a, b, c)')
+ self.assertEqual(instruction.get_read_names(), {'a', 'b', 'c', 'fn'})
+ self.assertEqual(instruction.get_write_names(), set())
+
+ instruction = create_instruction('c = fn(a, b, c)')
+ self.assertEqual(instruction.get_read_names(), {'a', 'b', 'c', 'fn'})
+ self.assertEqual(instruction.get_write_names(), {'c'})
+
+ def test_get_accesses_augassign(self):
+ instruction = create_instruction('x += 1')
+ self.assertEqual(instruction.get_read_names(), {'x'})
+ self.assertEqual(instruction.get_write_names(), {'x'})
+
+ instruction = create_instruction('x *= y')
+ self.assertEqual(instruction.get_read_names(), {'x', 'y'})
+ self.assertEqual(instruction.get_write_names(), {'x'})
+
+ def test_get_accesses_augassign_subscript(self):
+ instruction = create_instruction('x[0] *= y')
+ # This is not currently considered a write of x. It is a read of x.
+ self.assertEqual(instruction.get_read_names(), {'x', 'y'})
+ self.assertEqual(instruction.get_write_names(), set())
+
+ def test_get_accesses_augassign_attribute(self):
+ instruction = create_instruction('x.attribute *= y')
+ # This is not currently considered a write of x. It is a read of x.
+ self.assertEqual(instruction.get_read_names(), {'x', 'y'})
+ self.assertEqual(instruction.get_write_names(), set())
+
+ def test_get_accesses_subscript(self):
+ instruction = create_instruction('x[0] = y')
+ # This is not currently considered a write of x. It is a read of x.
+ self.assertEqual(instruction.get_read_names(), {'x', 'y'})
+ self.assertEqual(instruction.get_write_names(), set())
+
+ def test_get_accesses_attribute(self):
+ instruction = create_instruction('x.attribute = y')
+ # This is not currently considered a write of x. It is a read of x.
+ self.assertEqual(instruction.get_read_names(), {'x', 'y'})
+ self.assertEqual(instruction.get_write_names(), set())
+
+ def test_access_ordering(self):
+ instruction = create_instruction('c = fn(a, b + c, d / a)')
+ access_names_and_kinds = [(instruction_module.access_name(access),
+ instruction_module.access_kind(access))
+ for access in instruction.accesses]
+ self.assertEqual(access_names_and_kinds, [('fn', 'read'), ('a', 'read'),
+ ('b', 'read'), ('c', 'read'),
+ ('d', 'read'), ('a', 'read'),
+ ('c', 'write')])
+
+ instruction = create_instruction('c += fn(a, b + c, d / a)')
+ access_names_and_kinds = [(instruction_module.access_name(access),
+ instruction_module.access_kind(access))
+ for access in instruction.accesses]
+ self.assertEqual(access_names_and_kinds, [('fn', 'read'), ('a', 'read'),
+ ('b', 'read'), ('c', 'read'),
+ ('d', 'read'), ('a', 'read'),
+ ('c', 'read'), ('c', 'write')])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/python_graphs/program_graph.py b/python_graphs/program_graph.py
new file mode 100644
index 0000000..4da2abd
--- /dev/null
+++ b/python_graphs/program_graph.py
@@ -0,0 +1,963 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Creates ProgramGraphs from a program or function's AST.
+
+A ProgramGraph represents a Python program or function. The nodes in a
+ProgramGraph represent an Instruction (see instruction.py), an AST node, or a
+piece of syntax from the program. The edges in a ProgramGraph represent the
+relationships between these nodes.
+"""
+
+import codecs
+import collections
+import os
+
+from absl import logging
+import astunparse
+from astunparse import unparser
+import gast as ast
+from python_graphs import control_flow
+from python_graphs import data_flow
+from python_graphs import instruction as instruction_module
+from python_graphs import program_graph_dataclasses as pb
+from python_graphs import program_utils
+import six
+from six.moves import builtins
+from six.moves import filter
+
+NEWLINE_TOKEN = '#NEWLINE#'
+UNINDENT_TOKEN = '#UNINDENT#'
+INDENT_TOKEN = '#INDENT#'
+
+
+class ProgramGraph(object):
+ """A ProgramGraph represents a Python program or function.
+
+ Attributes:
+ root_id: The id of the root ProgramGraphNode.
+ nodes: Maps from node id to the ProgramGraphNode with that id.
+ edges: A list of the edges (from_node.id, to_node.id, edge type) in the
+ graph.
+ child_map: Maps from node id to a list of that node's AST children node ids.
+ parent_map: Maps from node id to that node's AST parent node id.
+ neighbors_map: Maps from node id to a list of that node's neighboring edges.
+ ast_id_to_program_graph_node: Maps from an AST node's object id to the
+ corresponding AST program graph node, if it exists.
+ root: The root ProgramGraphNode.
+ """
+
+ def __init__(self):
+ """Constructs an empty ProgramGraph with no root."""
+ self.root_id = None
+
+ self.nodes = {}
+ # TODO(charlessutton): Seems odd to have Edge proto objects as part of the
+ # program graph object if node protos aren't. Consider a more consistent
+ # treatment.
+ self.edges = []
+
+ self.ast_id_to_program_graph_node = {}
+ self.child_map = collections.defaultdict(list)
+ self.parent_map = collections.defaultdict(lambda: None)
+ self.neighbors_map = collections.defaultdict(list)
+
+ # Accessors
+ @property
+ def root(self):
+ if self.root_id not in self.nodes:
+ raise ValueError('Graph has no root node.')
+ return self.nodes[self.root_id]
+
+ def all_nodes(self):
+ return self.nodes.values()
+
+ def get_node(self, obj):
+ """Returns the node in the program graph corresponding to an object.
+
+ Arguments:
+ obj: Can be an integer, AST node, ProgramGraphNode, or program graph node
+ protobuf.
+
+ Raises:
+ ValueError: no node exists in the program graph matching obj.
+ """
+ if isinstance(obj, six.integer_types) and obj in self.nodes:
+ return self.get_node_by_id(obj)
+ elif isinstance(obj, ProgramGraphNode):
+ # assert obj in self.nodes.values()
+ return obj
+ elif isinstance(obj, pb.Node):
+ return self.get_node_by_id(obj.id)
+ elif isinstance(obj, (ast.AST, list)):
+ return self.get_node_by_ast_node(obj)
+ else:
+ raise ValueError('Unexpected value for obj.', obj)
+
+ def get_node_by_id(self, obj):
+ """Gets a ProgramGraph node for the given integer id."""
+ return self.nodes[obj]
+
+ def get_node_by_access(self, access):
+ """Gets a ProgramGraph node for the given read or write."""
+ if isinstance(access, ast.Name):
+ return self.get_node(access)
+ else:
+ assert isinstance(access, tuple)
+ if isinstance(access[1], ast.Name):
+ return self.get_node(access[1])
+ else:
+ return self.get_node(access[2])
+ raise ValueError('Could not find node for access.', access)
+
+ def get_nodes_by_source(self, source):
+ """Generates the nodes in the program graph containing the query source.
+
+ Args:
+ source: The query source.
+
+ Returns:
+ A generator of all nodes in the program graph with an Instruction with
+ source that includes the query source.
+ """
+ module = ast.parse(source, mode='exec') # TODO(dbieber): Factor out 4 lines
+ # TODO(dbieber): Use statements beyond the first statement from source.
+ node = module.body[0]
+ # If the query source is an Expression, and the matching instruction matches
+ # the value field of that Expression, then the matching instruction is
+ # considered a match. This allows us to match subexpressions which appear in
+ # ast.Expr nodes in the query but not in the parent.
+ if isinstance(node, ast.Expr):
+ node = node.value
+
+ def matches_source(pg_node):
+ if pg_node.has_instruction():
+ return pg_node.instruction.contains_subprogram(node)
+ else:
+ return instruction_module.represent_same_program(pg_node.ast_node, node)
+
+ return filter(matches_source, self.nodes.values())
+
+ def get_node_by_source(self, node):
+ # We use min since nodes can contain each other and we want the most
+ # specific one.
+ return min(
+ self.get_nodes_by_source(node), key=lambda x: len(ast.dump(x.node)))
+
+ def get_nodes_by_function_name(self, name):
+ return filter(
+ lambda n: n.has_instance_of(ast.FunctionDef) and n.node.name == name,
+ self.nodes.values())
+
+ def get_node_by_function_name(self, name):
+ return next(self.get_nodes_by_function_name(name))
+
+ def get_node_by_ast_node(self, ast_node):
+ return self.ast_id_to_program_graph_node[id(ast_node)]
+
+ def contains_ast_node(self, ast_node):
+ return id(ast_node) in self.ast_id_to_program_graph_node
+
+ def get_ast_nodes_of_type(self, ast_type):
+ for node in six.itervalues(self.nodes):
+ if node.node_type == pb.NodeType.AST_NODE and node.ast_type == ast_type:
+ yield node
+
+ # TODO(dbieber): Unify selectors across program_graph and control_flow.
+ def get_nodes_by_source_and_identifier(self, source, name):
+ for pg_node in self.get_nodes_by_source(source):
+ for node in ast.walk(pg_node.node):
+ if isinstance(node, ast.Name) and node.id == name:
+ if self.contains_ast_node(node):
+ yield self.get_node_by_ast_node(node)
+
+ def get_node_by_source_and_identifier(self, source, name):
+ return next(self.get_nodes_by_source_and_identifier(source, name))
+
+ # Graph Construction Methods
+ def add_node(self, node):
+ """Adds a ProgramGraphNode to this graph.
+
+ Args:
+ node: The ProgramGraphNode that should be added.
+
+ Returns:
+ The node that was added.
+
+ Raises:
+ ValueError: the node has already been added to this graph.
+ """
+ assert isinstance(node, ProgramGraphNode), 'Not a ProgramGraphNode'
+ if node.id in self.nodes:
+ raise ValueError('Already contains node', self.nodes[node.id], node.id)
+ if node.ast_node is not None:
+ if self.contains_ast_node(node.ast_node):
+ raise ValueError('Already contains ast node', node.ast_node)
+ self.ast_id_to_program_graph_node[id(node.ast_node)] = node
+ self.nodes[node.id] = node
+ return node
+
+ def add_node_from_instruction(self, instruction):
+ """Adds a node to the program graph."""
+ node = make_node_from_instruction(instruction)
+ return self.add_node(node)
+
+ def add_edge(self, edge):
+ """Adds an edge between two nodes in the graph.
+
+ Args:
+ edge: The edge, a pb.Edge proto.
+ """
+ assert isinstance(edge, pb.Edge), 'Not a pb.Edge'
+ self.edges.append(edge)
+
+ n1 = self.get_node_by_id(edge.id1)
+ n2 = self.get_node_by_id(edge.id2)
+ if edge.type == pb.EdgeType.FIELD: # An AST node.
+ self.child_map[edge.id1].append(edge.id2)
+ # TODO(charlessutton): Add the below sanity check back once Instruction
+ # updates are complete.
+ # pylint: disable=line-too-long
+ # other_parent_id = self.parent_map[edge.id2]
+ # if other_parent_id and other_parent_id != edge.id1:
+ # raise Exception('Node {} {} with two parents\n {} {}\n {} {}'
+ # .format(edge.id2, dump_node(self.get_node(edge.id2)),
+ # edge.id1, dump_node(self.get_node(edge.id1)),
+ # other_parent_id, dump_node(self.get_node(other_parent_id))))
+ # pylint: enable=line-too-long
+ self.parent_map[n2.id] = edge.id1
+ self.neighbors_map[n1.id].append((edge, edge.id2))
+ self.neighbors_map[n2.id].append((edge, edge.id1))
+
+ def remove_edge(self, edge):
+ """Removes an edge from the graph.
+
+ If there are multiple copies of the same edge, only one copy is removed.
+
+ Args:
+ edge: The edge, a pb.Edge proto.
+ """
+ self.edges.remove(edge)
+
+ n1 = self.get_node_by_id(edge.id1)
+ n2 = self.get_node_by_id(edge.id2)
+
+ if edge.type == pb.EdgeType.FIELD: # An AST node.
+ self.child_map[edge.id1].remove(edge.id2)
+ del self.parent_map[n2.id]
+
+ self.neighbors_map[n1.id].remove((edge, edge.id2))
+ self.neighbors_map[n2.id].remove((edge, edge.id1))
+
+ def add_new_edge(self, n1, n2, edge_type=None, field_name=None):
+ """Adds a new edge between two nodes in the graph.
+
+ Both nodes must already be part of the graph.
+
+ Args:
+ n1: Specifies the from node of the edge. Can be any object type accepted
+ by get_node.
+ n2: Specifies the to node of the edge. Can be any object type accepted by
+ get_node.
+ edge_type: The type of edge. Can be any integer in the pb.Edge enum.
+ field_name: For AST edges, a string describing the Python AST field
+
+ Returns:
+ The new edge.
+ """
+ n1 = self.get_node(n1)
+ n2 = self.get_node(n2)
+ new_edge = pb.Edge(
+ id1=n1.id, id2=n2.id, type=edge_type, field_name=field_name)
+ self.add_edge(new_edge)
+ return new_edge
+
+ # AST Methods
+ # TODO(charlessutton): Consider whether AST manipulation should be moved
+ # e.g., to a more general graph object.
+ def to_ast(self, node=None):
+ """Convert the program graph to a Python AST."""
+ if node is None:
+ node = self.root
+ return self._build_ast(node=node, update_references=False)
+
+ def reconstruct_ast(self):
+ """Reconstruct all internal ProgramGraphNode.ast_node references.
+
+ After calling this method, all nodes of type AST_NODE will have their
+ `ast_node` property refer to subtrees of a reconstructed AST object, and
+ self.ast_id_to_program_graph_node will contain only entries from this new
+ AST.
+
+ Note that only AST nodes reachable by fields from the root node will be
+ converted; this should be all of them but this is not checked.
+ """
+ self.ast_id_to_program_graph_node.clear()
+ self._build_ast(node=self.root, update_references=True)
+
+ def _build_ast(self, node, update_references):
+ """Helper method: builds an AST and optionally sets ast_node references.
+
+ Args:
+ node: Program graph node to build an AST for.
+ update_references: Whether to modify this node and all of its children so
+ that they point to the reconstructed AST node.
+
+ Returns:
+ AST node corresponding to the program graph node.
+ """
+ if node.node_type == pb.NodeType.AST_NODE:
+ ast_node = getattr(ast, node.ast_type)()
+ adjacent_edges = self.neighbors_map[node.id]
+ for edge, other_node_id in adjacent_edges:
+ if other_node_id == edge.id1: # it's an incoming edge
+ continue
+ if edge.type == pb.EdgeType.FIELD:
+ child_id = other_node_id
+ child = self.get_node_by_id(child_id)
+ setattr(
+ ast_node, edge.field_name,
+ self._build_ast(node=child, update_references=update_references))
+ if update_references:
+ node.ast_node = ast_node
+ self.ast_id_to_program_graph_node[id(ast_node)] = node
+ return ast_node
+ elif node.node_type == pb.NodeType.AST_LIST:
+ list_items = {}
+ adjacent_edges = self.neighbors_map[node.id]
+ for edge, other_node_id in adjacent_edges:
+ if other_node_id == edge.id1: # it's an incoming edge
+ continue
+ if edge.type == pb.EdgeType.FIELD:
+ child_id = other_node_id
+ child = self.get_node_by_id(child_id)
+ unused_field_name, index = parse_list_field_name(edge.field_name)
+ list_items[index] = self._build_ast(
+ node=child, update_references=update_references)
+
+ ast_list = []
+ for index in six.moves.range(len(list_items)):
+ ast_list.append(list_items[index])
+ return ast_list
+ elif node.node_type == pb.NodeType.AST_VALUE:
+ return node.ast_value
+ else:
+ raise ValueError('This ProgramGraphNode does not correspond to a node in'
+ ' an AST.')
+
+ def walk_ast_descendants(self, node=None):
+ """Yields the nodes that correspond to the descendants of node in the AST.
+
+ Args:
+ node: the node in the program graph corresponding to the root of the AST
+ subtree that should be walked. If None, defaults to the root of the
+ program graph.
+
+ Yields:
+ All nodes corresponding to descendants of node in the AST.
+ """
+ if node is None:
+ node = self.root
+ frontier = [node]
+ while frontier:
+ current = frontier.pop()
+ for child_id in reversed(self.child_map[current.id]):
+ frontier.append(self.get_node_by_id(child_id))
+ yield current
+
+ def parent(self, node):
+ """Returns the AST parent of an AST program graph node.
+
+ Args:
+ node: A ProgramGraphNode.
+
+ Returns:
+ The node's AST parent, which is also a ProgramGraphNode.
+ """
+ parent_id = self.parent_map[node.id]
+ if parent_id is None:
+ return None
+ else:
+ return self.get_node_by_id(parent_id)
+
+ def children(self, node):
+ """Yields the (direct) AST children of an AST program graph node.
+
+ Args:
+ node: A ProgramGraphNode.
+
+ Yields:
+ The AST children of node, which are ProgramGraphNode objects.
+ """
+ for child_id in self.child_map[node.id]:
+ yield self.get_node_by_id(child_id)
+
+ def neighbors(self, node, edge_type=None):
+ """Returns the incoming and outgoing neighbors of a program graph node.
+
+ Args:
+ node: A ProgramGraphNode.
+ edge_type: If provided, only edges of this type are considered.
+
+ Returns:
+ The incoming and outgoing neighbors of node, which are ProgramGraphNode
+ objects but not necessarily AST nodes.
+ """
+ adj_edges = self.neighbors_map[node.id]
+ if edge_type is None:
+ ids = list(tup[1] for tup in adj_edges)
+ else:
+ ids = list(tup[1] for tup in adj_edges if tup[0].type == edge_type)
+ return [self.get_node_by_id(id0) for id0 in ids]
+
+ def incoming_neighbors(self, node, edge_type=None):
+ """Returns the incoming neighbors of a program graph node.
+
+ Args:
+ node: A ProgramGraphNode.
+ edge_type: If provided, only edges of this type are considered.
+
+ Returns:
+ The incoming neighbors of node, which are ProgramGraphNode objects but not
+ necessarily AST nodes.
+ """
+ adj_edges = self.neighbors_map[node.id]
+ result = []
+ for edge, neighbor_id in adj_edges:
+ if edge.id2 == node.id:
+ if (edge_type is None) or (edge.type == edge_type):
+ result.append(self.get_node_by_id(neighbor_id))
+ return result
+
+ def outgoing_neighbors(self, node, edge_type=None):
+ """Returns the outgoing neighbors of a program graph node.
+
+ Args:
+ node: A ProgramGraphNode.
+ edge_type: If provided, only edges of this type are considered.
+
+ Returns:
+ The outgoing neighbors of node, which are ProgramGraphNode objects but not
+ necessarily AST nodes.
+ """
+ adj_edges = self.neighbors_map[node.id]
+ result = []
+ for edge, neighbor_id in adj_edges:
+ if edge.id1 == node.id:
+ if (edge_type is None) or (edge.type == edge_type):
+ result.append(self.get_node_by_id(neighbor_id))
+ return result
+
+ def dump_tree(self, start_node=None):
+ """Returns a string representation for debugging."""
+
+ def dump_tree_recurse(node, indent, all_lines):
+ """Create a string representation for a subtree."""
+ indent_str = ' ' + ('--' * indent)
+ node_str = dump_node(node)
+ line = ' '.join([indent_str, node_str, '\n'])
+ all_lines.append(line)
+ # output long distance edges
+ for edge, neighbor_id in self.neighbors_map[node.id]:
+ if (not is_ast_edge(edge) and not is_syntax_edge(edge) and
+ node.id == edge.id1):
+ type_str = edge.type.name
+ line = [indent_str, '--((', type_str, '))-->', str(neighbor_id), '\n']
+ all_lines.append(' '.join(line))
+ for child in self.children(node):
+ dump_tree_recurse(child, indent + 1, all_lines)
+ return all_lines
+
+ if start_node is None:
+ start_node = self.root
+ return ''.join(dump_tree_recurse(start_node, 0, []))
+
+ # TODO(charlessutton): Consider whether this belongs in ProgramGraph
+ # or in make_synthesis_problems.
+ def copy_with_placeholder(self, node):
+ """Returns a new program graph in which the subtree of NODE is removed.
+
+ In the new graph, the subtree headed by NODE is replaced by a single
+ node of type PLACEHOLDER, which is connected to the AST parent of NODE
+ by the same edge type as in the original graph.
+
+ The new program graph will share structure (i.e. the ProgramGraphNode
+ objects) with the original graph.
+
+ Args:
+ node: A node in this program graph
+
+ Returns:
+ A new ProgramGraph object with NODE replaced
+ """
+ descendant_ids = {n.id for n in self.walk_ast_descendants(node)}
+ new_graph = ProgramGraph()
+ new_graph.add_node(self.root)
+ new_graph.root_id = self.root_id
+ for edge in self.edges:
+ v1 = self.nodes[edge.id1]
+ v2 = self.nodes[edge.id2]
+ # Omit edges that are adjacent to the subtree rooted at `node` UNLESS this
+ # is the AST edge to the root of the subtree.
+ # In that case, create an edge to a new placeholder node
+ adj_bad_subtree = ((edge.id1 in descendant_ids) or
+ (edge.id2 in descendant_ids))
+ if adj_bad_subtree:
+ if edge.id2 == node.id and is_ast_edge(edge):
+ placeholder = ProgramGraphNode()
+ placeholder.node_type = pb.NodeType.PLACEHOLDER
+ placeholder.id = node.id
+ new_graph.add_node(placeholder)
+ new_graph.add_new_edge(v1, placeholder, edge_type=edge.type)
+ else:
+ # nodes on the edge have not been added yet
+ if edge.id1 not in new_graph.nodes:
+ new_graph.add_node(v1)
+ if edge.id2 not in new_graph.nodes:
+ new_graph.add_node(v2)
+ new_graph.add_new_edge(v1, v2, edge_type=edge.type)
+ return new_graph
+
+ def copy_subgraph(self, node):
+ """Returns a new program graph containing only the subtree rooted at NODE.
+
+ All edges that connect nodes in the subtree are included, both AST edges
+ and other types of edges.
+
+ Args:
+ node: A node in this program graph
+
+ Returns:
+ A new ProgramGraph object whose root is NODE
+ """
+ descendant_ids = {n.id for n in self.walk_ast_descendants(node)}
+ new_graph = ProgramGraph()
+ new_graph.add_node(node)
+ new_graph.root_id = node.id
+ for edge in self.edges:
+ v1 = self.nodes[edge.id1]
+ v2 = self.nodes[edge.id2]
+ # Omit edges that are adjacent to the subtree rooted at NODE
+ # UNLESS this is the AST edge to the root of the subtree.
+ # In that case, create an edge to a new placeholder node
+ good_edge = ((edge.id1 in descendant_ids) and
+ (edge.id2 in descendant_ids))
+ if good_edge:
+ if edge.id1 not in new_graph.nodes:
+ new_graph.add_node(v1)
+ if edge.id2 not in new_graph.nodes:
+ new_graph.add_node(v2)
+ new_graph.add_new_edge(v1, v2, edge_type=edge.type)
+ return new_graph
+
+
+def is_ast_node(node):
+ return node.node_type == pb.NodeType.AST_NODE
+
+
+def is_ast_edge(edge):
+ # TODO(charlessutton): Expand to enumerate edge types in gast.
+ return edge.type == pb.EdgeType.FIELD
+
+
+def is_syntax_edge(edge):
+ return edge.type == pb.EdgeType.SYNTAX
+
+
+def dump_node(node):
+ type_str = '[' + node.node_type.name + ']'
+ elements = [type_str, str(node.id), node.ast_type]
+ if node.ast_value:
+ elements.append(str(node.ast_value))
+ if node.syntax:
+ elements.append(str(node.syntax))
+ return ' '.join(elements)
+
+
+def get_program_graph(program):
+ """Constructs a program graph to represent the given program."""
+ program_node = program_utils.program_to_ast(program) # An AST node.
+
+ # TODO(dbieber): Refactor sections of graph building into separate functions.
+ program_graph = ProgramGraph()
+
+ # Perform control flow analysis.
+ control_flow_graph = control_flow.get_control_flow_graph(program_node)
+
+ # Add AST_NODE program graph nodes corresponding to Instructions in the
+ # control flow graph.
+ for control_flow_node in control_flow_graph.get_control_flow_nodes():
+ program_graph.add_node_from_instruction(control_flow_node.instruction)
+
+ # Add AST_NODE program graph nodes corresponding to AST nodes.
+ for ast_node in ast.walk(program_node):
+ if not program_graph.contains_ast_node(ast_node):
+ pg_node = make_node_from_ast_node(ast_node)
+ program_graph.add_node(pg_node)
+
+ root = program_graph.get_node_by_ast_node(program_node)
+ program_graph.root_id = root.id
+
+ # Add AST edges (FIELD). Also add AST_LIST and AST_VALUE program graph nodes.
+ for ast_node in ast.walk(program_node):
+ for field_name, value in ast.iter_fields(ast_node):
+ if isinstance(value, list):
+ pg_node = make_node_for_ast_list()
+ program_graph.add_node(pg_node)
+ program_graph.add_new_edge(
+ ast_node, pg_node, pb.EdgeType.FIELD, field_name)
+ for index, item in enumerate(value):
+ list_field_name = make_list_field_name(field_name, index)
+ if isinstance(item, ast.AST):
+ program_graph.add_new_edge(pg_node, item, pb.EdgeType.FIELD,
+ list_field_name)
+ else:
+ item_node = make_node_from_ast_value(item)
+ program_graph.add_node(item_node)
+ program_graph.add_new_edge(pg_node, item_node, pb.EdgeType.FIELD,
+ list_field_name)
+ elif isinstance(value, ast.AST):
+ program_graph.add_new_edge(
+ ast_node, value, pb.EdgeType.FIELD, field_name)
+ else:
+ pg_node = make_node_from_ast_value(value)
+ program_graph.add_node(pg_node)
+ program_graph.add_new_edge(
+ ast_node, pg_node, pb.EdgeType.FIELD, field_name)
+
+ # Add SYNTAX_NODE nodes. Also add NEXT_SYNTAX and LAST_LEXICAL_USE edges.
+ # Add these edges using a custom AST unparser to visit leaf nodes in preorder.
+ SyntaxNodeUnparser(program_node, program_graph)
+
+ # Perform data flow analysis.
+ analysis = data_flow.LastAccessAnalysis()
+ for node in control_flow_graph.get_enter_control_flow_nodes():
+ analysis.visit(node)
+
+ # Add control flow edges (CFG_NEXT).
+ for control_flow_node in control_flow_graph.get_control_flow_nodes():
+ instruction = control_flow_node.instruction
+ for next_control_flow_node in control_flow_node.next:
+ next_instruction = next_control_flow_node.instruction
+ program_graph.add_new_edge(
+ instruction.node, next_instruction.node,
+ edge_type=pb.EdgeType.CFG_NEXT)
+
+ # Add data flow edges (LAST_READ and LAST_WRITE).
+ for control_flow_node in control_flow_graph.get_control_flow_nodes():
+ # Start with the most recent accesses before this instruction.
+ last_accesses = control_flow_node.get_label('last_access_in').copy()
+ for access in control_flow_node.instruction.accesses:
+ # Extract the node and identifiers for the current access.
+ pg_node = program_graph.get_node_by_access(access)
+ access_name = instruction_module.access_name(access)
+ read_identifier = instruction_module.access_identifier(
+ access_name, 'read')
+ write_identifier = instruction_module.access_identifier(
+ access_name, 'write')
+ # Find previous reads.
+ for read in last_accesses.get(read_identifier, []):
+ read_pg_node = program_graph.get_node_by_access(read)
+ program_graph.add_new_edge(
+ pg_node, read_pg_node, edge_type=pb.EdgeType.LAST_READ)
+ # Find previous writes.
+ for write in last_accesses.get(write_identifier, []):
+ write_pg_node = program_graph.get_node_by_access(write)
+ program_graph.add_new_edge(
+ pg_node, write_pg_node, edge_type=pb.EdgeType.LAST_WRITE)
+ # Update the state to refer to this access as the most recent one.
+ if instruction_module.access_is_read(access):
+ last_accesses[read_identifier] = [access]
+ elif instruction_module.access_is_write(access):
+ last_accesses[write_identifier] = [access]
+
+ # Add COMPUTED_FROM edges.
+ for node in ast.walk(program_node):
+ if isinstance(node, ast.Assign):
+ for value_node in ast.walk(node.value):
+ if isinstance(value_node, ast.Name):
+ # TODO(dbieber): If possible, improve precision of these edges.
+ for target in node.targets:
+ program_graph.add_new_edge(
+ target, value_node, edge_type=pb.EdgeType.COMPUTED_FROM)
+
+ # Add CALLS, FORMAL_ARG_NAME and RETURNS_TO edges.
+ for node in ast.walk(program_node):
+ if isinstance(node, ast.Call):
+ if isinstance(node.func, ast.Name):
+ # TODO(dbieber): Use data flow analysis instead of all function defs.
+ func_defs = list(program_graph.get_nodes_by_function_name(node.func.id))
+ # For any possible last writes that are a function definition, add the
+ # formal_arg_name and returns_to edges.
+ if not func_defs:
+ # TODO(dbieber): Add support for additional classes of functions,
+ # such as attributes of known objects and builtins.
+ if node.func.id in dir(builtins):
+ message = 'Function is builtin.'
+ else:
+ message = 'Cannot statically determine the function being called.'
+ logging.debug('%s (%s)', message, node.func.id)
+ for func_def in func_defs:
+ fn_node = func_def.node
+ # Add calls edge from the call node to the function definition.
+ program_graph.add_new_edge(node, fn_node, edge_type=pb.EdgeType.CALLS)
+ # Add returns_to edges from the function's return statements to the
+ # call node.
+ for inner_node in ast.walk(func_def.node):
+ # TODO(dbieber): Determine if the returns_to should instead go to
+ # the next instruction after the Call node instead.
+ if isinstance(inner_node, ast.Return):
+ program_graph.add_new_edge(
+ inner_node, node, edge_type=pb.EdgeType.RETURNS_TO)
+
+ # Add formal_arg_name edges from the args of the Call node to the
+ # args in the FunctionDef.
+ for index, arg in enumerate(node.args):
+ formal_arg = None
+ if index < len(fn_node.args.args):
+ formal_arg = fn_node.args.args[index]
+ elif fn_node.args.vararg:
+ # Since args.vararg is a string, we use the arguments node.
+ # TODO(dbieber): Use a node specifically for the vararg.
+ formal_arg = fn_node.args
+ if formal_arg is not None:
+ # Note: formal_arg can be an AST node or a string.
+ program_graph.add_new_edge(
+ arg, formal_arg, edge_type=pb.EdgeType.FORMAL_ARG_NAME)
+ else:
+ # TODO(dbieber): If formal_arg is None, then remove all
+ # formal_arg_name edges for this FunctionDef.
+ logging.debug('formal_arg is None')
+ for keyword in node.keywords:
+ name = keyword.arg
+ formal_arg = None
+ for arg in fn_node.args.args:
+ if isinstance(arg, ast.Name) and arg.id == name:
+ formal_arg = arg
+ break
+ else:
+ if fn_node.args.kwarg:
+ # Since args.kwarg is a string, we use the arguments node.
+ # TODO(dbieber): Use a node specifically for the kwarg.
+ formal_arg = fn_node.args
+ if formal_arg is not None:
+ program_graph.add_new_edge(
+ keyword.value, formal_arg,
+ edge_type=pb.EdgeType.FORMAL_ARG_NAME)
+ else:
+ # TODO(dbieber): If formal_arg is None, then remove all
+ # formal_arg_name edges for this FunctionDef.
+ logging.debug('formal_arg is None')
+ else:
+ # TODO(dbieber): Add a special case for Attributes.
+ logging.debug(
+ 'Cannot statically determine the function being called. (%s)',
+ astunparse.unparse(node.func).strip())
+
+ return program_graph
+
+
+class SyntaxNodeUnparser(unparser.Unparser):
+ """An Unparser class helpful for creating Syntax Token nodes for fn graphs."""
+
+ def __init__(self, ast_node, graph):
+ self.graph = graph
+
+ self.current_ast_node = None # The AST node currently being unparsed.
+ self.last_syntax_node = None
+ self.last_lexical_uses = {}
+ self.last_indent = 0
+
+ with codecs.open(os.devnull, 'w', encoding='utf-8') as devnull:
+ super(SyntaxNodeUnparser, self).__init__(ast_node, file=devnull)
+
+ def dispatch(self, ast_node):
+ """Dispatcher function, dispatching tree type T to method _T."""
+ tmp_ast_node = self.current_ast_node
+ self.current_ast_node = ast_node
+ super(SyntaxNodeUnparser, self).dispatch(ast_node)
+ self.current_ast_node = tmp_ast_node
+
+ def fill(self, text=''):
+ """Indent a piece of text, according to the current indentation level."""
+ text_with_whitespace = NEWLINE_TOKEN
+ if self.last_indent > self._indent:
+ text_with_whitespace += UNINDENT_TOKEN * (self.last_indent - self._indent)
+ elif self.last_indent < self._indent:
+ text_with_whitespace += INDENT_TOKEN * (self._indent - self.last_indent)
+ self.last_indent = self._indent
+ text_with_whitespace += text
+ self._add_syntax_node(text_with_whitespace)
+ super(SyntaxNodeUnparser, self).fill(text)
+
+ def write(self, text):
+ """Append a piece of text to the current line."""
+ if isinstance(text, ast.AST): # text may be a Name, Tuple, or List node.
+ return self.dispatch(text)
+ self._add_syntax_node(text)
+ super(SyntaxNodeUnparser, self).write(text)
+
+ def _add_syntax_node(self, text):
+ text = text.strip()
+ if not text:
+ return
+ syntax_node = make_node_from_syntax(six.text_type(text))
+ self.graph.add_node(syntax_node)
+ self.graph.add_new_edge(
+ self.current_ast_node, syntax_node, edge_type=pb.EdgeType.SYNTAX)
+ if self.last_syntax_node:
+ self.graph.add_new_edge(
+ self.last_syntax_node, syntax_node, edge_type=pb.EdgeType.NEXT_SYNTAX)
+ self.last_syntax_node = syntax_node
+
+ def _Name(self, node):
+ if node.id in self.last_lexical_uses:
+ self.graph.add_new_edge(
+ node,
+ self.last_lexical_uses[node.id],
+ edge_type=pb.EdgeType.LAST_LEXICAL_USE)
+ self.last_lexical_uses[node.id] = node
+ super(SyntaxNodeUnparser, self)._Name(node)
+
+
+class ProgramGraphNode(object):
+ """A single node in a Program Graph.
+
+ Corresponds to either a SyntaxNode or an Instruction (as in a
+ ControlFlowGraph).
+
+ Attributes:
+ node_type: One of the node types from pb.NodeType.
+ id: A unique id for the node.
+ instruction: If applicable, the corresponding Instruction.
+ ast_node: If available, the AST node corresponding to the ProgramGraphNode.
+ ast_type: If available, the type of the AST node, as a string.
+ ast_value: If available, the primitive Python value corresponding to the
+ node.
+ syntax: For SYNTAX_NODEs, the syntax information stored in the node.
+ node: If available, the AST node for this program graph node or its
+ instruction.
+ """
+
+ def __init__(self):
+ self.node_type = None
+ self.id = None
+
+ self.instruction = None
+ self.ast_node = None
+ self.ast_type = ''
+ self.ast_value = ''
+ self.syntax = ''
+
+ def has_instruction(self):
+ return self.instruction is not None
+
+ def has_instance_of(self, t):
+ """Whether the node's instruction is an instance of type `t`."""
+ if self.instruction is None:
+ return False
+ return isinstance(self.instruction.node, t)
+
+ @property
+ def node(self):
+ if self.ast_node is not None:
+ return self.ast_node
+ if self.instruction is None:
+ return None
+ return self.instruction.node
+
+ def __repr__(self):
+ return str(self.id) + ' ' + str(self.ast_type)
+
+
+def make_node_from_syntax(text):
+ node = ProgramGraphNode()
+ node.node_type = pb.NodeType.SYNTAX_NODE
+ node.id = program_utils.unique_id()
+ node.syntax = text
+ return node
+
+
+def make_node_from_instruction(instruction):
+ """Creates a ProgramGraphNode corresponding to an existing Instruction.
+
+ Args:
+ instruction: An Instruction object.
+
+ Returns:
+ A ProgramGraphNode corresponding to that instruction.
+ """
+ ast_node = instruction.node
+ node = make_node_from_ast_node(ast_node)
+ node.instruction = instruction
+ return node
+
+
+def make_node_from_ast_node(ast_node):
+ """Creates a program graph node for the provided AST node.
+
+ This is only called when the AST node doesn't already correspond to an
+ Instruction in the program's control flow graph.
+
+ Args:
+ ast_node: An AST node from the program being analyzed.
+
+ Returns:
+ A node in the program graph corresponding to the AST node.
+ """
+ node = ProgramGraphNode()
+ node.node_type = pb.NodeType.AST_NODE
+ node.id = program_utils.unique_id()
+ node.ast_node = ast_node
+ node.ast_type = type(ast_node).__name__
+ return node
+
+
+def make_node_for_ast_list():
+ node = ProgramGraphNode()
+ node.node_type = pb.NodeType.AST_LIST
+ node.id = program_utils.unique_id()
+ return node
+
+
+def make_node_from_ast_value(value):
+ """Creates a ProgramGraphNode for the provided value.
+
+ `value` is a primitive value appearing in a Python AST.
+
+ For example, the number 1 in Python has AST Num(n=1). In this, the value '1'
+ is a primitive appearing in the AST. It gets its own ProgramGraphNode with
+ node_type AST_VALUE.
+
+ Args:
+ value: A primitive value appearing in an AST.
+
+ Returns:
+ A ProgramGraphNode corresponding to the provided value.
+ """
+ node = ProgramGraphNode()
+ node.node_type = pb.NodeType.AST_VALUE
+ node.id = program_utils.unique_id()
+ node.ast_value = value
+ return node
+
+
+def make_list_field_name(field_name, index):
+ return '{}:{}'.format(field_name, index)
+
+
+def parse_list_field_name(list_field_name):
+ field_name, index = list_field_name.split(':')
+ index = int(index)
+ return field_name, index
diff --git a/python_graphs/program_graph_dataclasses.py b/python_graphs/program_graph_dataclasses.py
new file mode 100644
index 0000000..750ed6a
--- /dev/null
+++ b/python_graphs/program_graph_dataclasses.py
@@ -0,0 +1,82 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""The dataclasses for representing a Program Graph."""
+
+import enum
+from typing import List, Optional, Text
+import dataclasses
+
+
+class NodeType(enum.Enum):
+ UNSPECIFIED = 0
+ AST_NODE = 1
+ AST_LIST = 2
+ AST_VALUE = 3
+ SYNTAX_NODE = 4
+ PLACEHOLDER = 5
+
+
+@dataclasses.dataclass
+class Node:
+ """Represents a node in a program graph."""
+ id: int
+ type: NodeType
+
+ # If an AST node, a string that identifies what type of AST node,
+ # e.g. "Num" or "Expr". These are defined by the underlying AST for the
+ # language.
+ ast_type: Optional[Text] = ""
+
+ # Primitive valued AST node, such as:
+ # - the name of an identifier for a Name node
+ # - the number attached to a Num node
+ # The corresponding ast_type value is the Python type of ast_value, not the
+ # type of the parent AST node.
+ ast_value_repr: Optional[Text] = ""
+
+ # For syntax nodes, the syntax attached to the node.
+ syntax: Optional[Text] = ""
+
+
+class EdgeType(enum.Enum):
+ """The different kinds of edges that can appear in a program graph."""
+ UNSPECIFIED = 0
+ CFG_NEXT = 1
+ LAST_READ = 2
+ LAST_WRITE = 3
+ COMPUTED_FROM = 4
+ RETURNS_TO = 5
+ FORMAL_ARG_NAME = 6
+ FIELD = 7
+ SYNTAX = 8
+ NEXT_SYNTAX = 9
+ LAST_LEXICAL_USE = 10
+ CALLS = 11
+
+
+@dataclasses.dataclass
+class Edge:
+ id1: int
+ id2: int
+ type: EdgeType
+ field_name: Optional[Text] = None # For FIELD edges, the field name.
+ has_back_edge: bool = False
+
+
+@dataclasses.dataclass
+class Graph:
+ nodes: List[Node]
+ edges: List[Edge]
+ root_id: int
diff --git a/python_graphs/program_graph_graphviz.py b/python_graphs/program_graph_graphviz.py
new file mode 100644
index 0000000..02e8e4a
--- /dev/null
+++ b/python_graphs/program_graph_graphviz.py
@@ -0,0 +1,61 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Graphviz visualizations of Program Graphs."""
+
+from absl import logging # pylint: disable=unused-import
+import pygraphviz
+from python_graphs import program_graph_dataclasses as pb
+import six
+
+
+def to_graphviz(graph):
+ """Creates a graphviz representation of a ProgramGraph.
+
+ Args:
+ graph: A ProgramGraph object to visualize.
+ Returns:
+ A pygraphviz object representing the ProgramGraph.
+ """
+ g = pygraphviz.AGraph(strict=False, directed=True)
+ for unused_key, node in graph.nodes.items():
+ node_attrs = {}
+ if node.ast_type:
+ node_attrs['label'] = six.ensure_binary(node.ast_type, 'utf-8')
+ else:
+ node_attrs['shape'] = 'point'
+ node_type_colors = {
+ }
+ if node.node_type in node_type_colors:
+ node_attrs['color'] = node_type_colors[node.node_type]
+ node_attrs['colorscheme'] = 'svg'
+
+ g.add_node(node.id, **node_attrs)
+ for edge in graph.edges:
+ edge_attrs = {}
+ edge_attrs['label'] = edge.type.name
+ edge_colors = {
+ pb.EdgeType.LAST_READ: 'red',
+ pb.EdgeType.LAST_WRITE: 'red',
+ }
+ if edge.type in edge_colors:
+ edge_attrs['color'] = edge_colors[edge.type]
+ edge_attrs['colorscheme'] = 'svg'
+ g.add_edge(edge.id1, edge.id2, **edge_attrs)
+ return g
+
+
+def render(graph, path='/tmp/graph.png'):
+ g = to_graphviz(graph)
+ g.draw(path, prog='dot')
diff --git a/python_graphs/program_graph_graphviz_test.py b/python_graphs/program_graph_graphviz_test.py
new file mode 100644
index 0000000..56ece83
--- /dev/null
+++ b/python_graphs/program_graph_graphviz_test.py
@@ -0,0 +1,34 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for program_graph_graphviz.py."""
+
+import inspect
+
+from absl.testing import absltest
+from python_graphs import control_flow_test_components as tc
+from python_graphs import program_graph
+from python_graphs import program_graph_graphviz
+
+
+class ControlFlowGraphvizTest(absltest.TestCase):
+
+ def test_to_graphviz_for_all_test_components(self):
+ for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction):
+ graph = program_graph.get_program_graph(fn)
+ program_graph_graphviz.to_graphviz(graph)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/python_graphs/program_graph_test.py b/python_graphs/program_graph_test.py
new file mode 100644
index 0000000..6fb4e61
--- /dev/null
+++ b/python_graphs/program_graph_test.py
@@ -0,0 +1,292 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for program_graph.py."""
+
+import collections
+import inspect
+import time
+
+from absl import logging
+from absl.testing import absltest
+import gast as ast
+
+from python_graphs import control_flow_test_components as cftc
+from python_graphs import program_graph
+from python_graphs import program_graph_dataclasses as pb
+from python_graphs import program_graph_test_components as pgtc
+from python_graphs import program_utils
+
+
+def test_components():
+ """Generates functions from two sets of test components.
+
+ Yields:
+ Functions from the program graph and control flow test components files.
+ """
+ for unused_name, fn in inspect.getmembers(pgtc, predicate=inspect.isfunction):
+ yield fn
+
+ for unused_name, fn in inspect.getmembers(cftc, predicate=inspect.isfunction):
+ yield fn
+
+
+class ProgramGraphTest(absltest.TestCase):
+
+ def assertEdge(self, graph, n1, n2, edge_type):
+ """Asserts that an edge of type edge_type exists from n1 to n2 in graph."""
+ edge = pb.Edge(id1=n1.id, id2=n2.id, type=edge_type)
+ self.assertIn(edge, graph.edges)
+
+ def assertNoEdge(self, graph, n1, n2, edge_type):
+ """Asserts that no edge of type edge_type exists from n1 to n2 in graph."""
+ edge = pb.Edge(id1=n1.id, id2=n2.id, type=edge_type)
+ self.assertNotIn(edge, graph.edges)
+
+ def test_get_program_graph_test_components(self):
+ self.analyze_get_program_graph(test_components(), start=0)
+
+ def analyze_get_program_graph(self, program_generator, start=0):
+ # TODO(dbieber): Remove the counting and logging logic from this method,
+ # and instead just get_program_graph for each program in the generator.
+ # The counting and logging logic is for development purposes only.
+ num_edges = 0
+ num_edges_by_type = collections.defaultdict(int)
+ num_nodes = 0
+ num_graphs = 1
+ times = {}
+ for index, program in enumerate(program_generator):
+ if index < start:
+ continue
+ start_time = time.time()
+ graph = program_graph.get_program_graph(program)
+ end_time = time.time()
+ times[index] = end_time - start_time
+ num_edges += len(graph.edges)
+ for edge in graph.edges:
+ num_edges_by_type[edge.type] += 1
+ num_nodes += len(graph.nodes)
+ num_graphs += 1
+ if index % 100 == 0:
+ logging.debug(sorted(times.items(), key=lambda kv: -kv[1])[:10])
+ logging.info('%d %d %d', num_edges, num_nodes, num_graphs)
+ logging.info('%f %f', num_edges / num_graphs, num_nodes / num_graphs)
+ for edge_type in num_edges_by_type:
+ logging.info('%s %f', edge_type,
+ num_edges_by_type[edge_type] / num_graphs)
+
+ logging.info(times)
+ logging.info(sorted(times.items(), key=lambda kv: -kv[1])[:10])
+
+ def test_last_lexical_use_edges_function_call(self):
+ graph = program_graph.get_program_graph(pgtc.function_call)
+ read = graph.get_node_by_source_and_identifier('return z', 'z')
+ write = graph.get_node_by_source_and_identifier(
+ 'z = function_call_helper(x, y)', 'z')
+ self.assertEdge(graph, read, write, pb.EdgeType.LAST_LEXICAL_USE)
+
+ def test_last_write_edges_function_call(self):
+ graph = program_graph.get_program_graph(pgtc.function_call)
+ write_z = graph.get_node_by_source_and_identifier(
+ 'z = function_call_helper(x, y)', 'z')
+ read_z = graph.get_node_by_source_and_identifier('return z', 'z')
+ self.assertEdge(graph, read_z, write_z, pb.EdgeType.LAST_WRITE)
+
+ write_y = graph.get_node_by_source_and_identifier('y = 2', 'y')
+ read_y = graph.get_node_by_source_and_identifier(
+ 'z = function_call_helper(x, y)', 'y')
+ self.assertEdge(graph, read_y, write_y, pb.EdgeType.LAST_WRITE)
+
+ def test_last_read_edges_assignments(self):
+ graph = program_graph.get_program_graph(pgtc.assignments)
+ write_a0 = graph.get_node_by_source_and_identifier('a, b = 0, 0', 'a')
+ read_a0 = graph.get_node_by_source_and_identifier('c = 2 * a + 1', 'a')
+ write_a1 = graph.get_node_by_source_and_identifier('a = c + 3', 'a')
+ self.assertEdge(graph, write_a1, read_a0, pb.EdgeType.LAST_READ)
+ self.assertNoEdge(graph, write_a0, read_a0, pb.EdgeType.LAST_READ)
+
+ read_a1 = graph.get_node_by_source_and_identifier('return a, b, c, d', 'a')
+ self.assertEdge(graph, read_a1, read_a0, pb.EdgeType.LAST_READ)
+
+ def test_last_read_last_write_edges_repeated_identifier(self):
+ graph = program_graph.get_program_graph(pgtc.repeated_identifier)
+ write_x0 = graph.get_node_by_source_and_identifier('x = 0', 'x')
+
+ stmt1 = graph.get_node_by_source('x = x + 1').ast_node
+ read_x0 = graph.get_node_by_ast_node(stmt1.value.left)
+ write_x1 = graph.get_node_by_ast_node(stmt1.targets[0])
+
+ stmt2 = graph.get_node_by_source('x = (x + (x + x)) + x').ast_node
+ read_x1 = graph.get_node_by_ast_node(stmt2.value.left.left)
+ read_x2 = graph.get_node_by_ast_node(stmt2.value.left.right.left)
+ read_x3 = graph.get_node_by_ast_node(stmt2.value.left.right.right)
+ read_x4 = graph.get_node_by_ast_node(stmt2.value.right)
+ write_x2 = graph.get_node_by_ast_node(stmt2.targets[0])
+
+ read_x5 = graph.get_node_by_source_and_identifier('return x', 'x')
+
+ self.assertEdge(graph, write_x1, read_x0, pb.EdgeType.LAST_READ)
+ self.assertEdge(graph, read_x1, read_x0, pb.EdgeType.LAST_READ)
+ self.assertEdge(graph, read_x2, read_x1, pb.EdgeType.LAST_READ)
+ self.assertEdge(graph, read_x3, read_x2, pb.EdgeType.LAST_READ)
+ self.assertEdge(graph, read_x4, read_x3, pb.EdgeType.LAST_READ)
+ self.assertEdge(graph, write_x2, read_x4, pb.EdgeType.LAST_READ)
+ self.assertEdge(graph, read_x5, read_x4, pb.EdgeType.LAST_READ)
+
+ self.assertEdge(graph, read_x0, write_x0, pb.EdgeType.LAST_WRITE)
+ self.assertEdge(graph, write_x1, write_x0, pb.EdgeType.LAST_WRITE)
+ self.assertEdge(graph, read_x2, write_x1, pb.EdgeType.LAST_WRITE)
+ self.assertEdge(graph, read_x3, write_x1, pb.EdgeType.LAST_WRITE)
+ self.assertEdge(graph, read_x4, write_x1, pb.EdgeType.LAST_WRITE)
+ self.assertEdge(graph, write_x2, write_x1, pb.EdgeType.LAST_WRITE)
+ self.assertEdge(graph, read_x5, write_x2, pb.EdgeType.LAST_WRITE)
+
+ def test_computed_from_edges(self):
+ graph = program_graph.get_program_graph(pgtc.assignments)
+ target_c = graph.get_node_by_source_and_identifier('c = 2 * a + 1', 'c')
+ from_a = graph.get_node_by_source_and_identifier('c = 2 * a + 1', 'a')
+ self.assertEdge(graph, target_c, from_a, pb.EdgeType.COMPUTED_FROM)
+
+ target_d = graph.get_node_by_source_and_identifier('d = b - c + 2', 'd')
+ from_b = graph.get_node_by_source_and_identifier('d = b - c + 2', 'b')
+ from_c = graph.get_node_by_source_and_identifier('d = b - c + 2', 'c')
+ self.assertEdge(graph, target_d, from_b, pb.EdgeType.COMPUTED_FROM)
+ self.assertEdge(graph, target_d, from_c, pb.EdgeType.COMPUTED_FROM)
+
+ def test_calls_edges(self):
+ graph = program_graph.get_program_graph(pgtc)
+ call = graph.get_node_by_source('function_call_helper(x, y)')
+ self.assertIsInstance(call.node, ast.Call)
+ function_call_helper_def = graph.get_node_by_function_name(
+ 'function_call_helper')
+ assignments_def = graph.get_node_by_function_name('assignments')
+ self.assertEdge(graph, call, function_call_helper_def, pb.EdgeType.CALLS)
+ self.assertNoEdge(graph, call, assignments_def, pb.EdgeType.CALLS)
+
+ def test_formal_arg_name_edges(self):
+ graph = program_graph.get_program_graph(pgtc)
+ x = graph.get_node_by_source_and_identifier('function_call_helper(x, y)',
+ 'x')
+ y = graph.get_node_by_source_and_identifier('function_call_helper(x, y)',
+ 'y')
+ function_call_helper_def = graph.get_node_by_function_name(
+ 'function_call_helper')
+ arg0_ast_node = function_call_helper_def.node.args.args[0]
+ arg0 = graph.get_node_by_ast_node(arg0_ast_node)
+ arg1_ast_node = function_call_helper_def.node.args.args[1]
+ arg1 = graph.get_node_by_ast_node(arg1_ast_node)
+ self.assertEdge(graph, x, arg0, pb.EdgeType.FORMAL_ARG_NAME)
+ self.assertEdge(graph, y, arg1, pb.EdgeType.FORMAL_ARG_NAME)
+ self.assertNoEdge(graph, x, arg1, pb.EdgeType.FORMAL_ARG_NAME)
+ self.assertNoEdge(graph, y, arg0, pb.EdgeType.FORMAL_ARG_NAME)
+
+ def test_returns_to_edges(self):
+ graph = program_graph.get_program_graph(pgtc)
+ call = graph.get_node_by_source('function_call_helper(x, y)')
+ return_stmt = graph.get_node_by_source('return arg0 + arg1')
+ self.assertEdge(graph, return_stmt, call, pb.EdgeType.RETURNS_TO)
+
+ def test_syntax_information(self):
+ # TODO(dbieber): Test that program graphs correctly capture syntax
+ # information. Do this once representation of syntax in program graphs
+ # stabilizes.
+ pass
+
+ def test_ast_acyclic(self):
+ for name, fn in inspect.getmembers(cftc, predicate=inspect.isfunction):
+ graph = program_graph.get_program_graph(fn)
+ ast_nodes = set()
+ worklist = [graph.root]
+ while worklist:
+ current = worklist.pop()
+ self.assertNotIn(
+ current, ast_nodes,
+ 'ProgramGraph AST cyclic. Function {}\nAST {}'.format(
+ name, graph.dump_tree()))
+ ast_nodes.add(current)
+ worklist.extend(graph.children(current))
+
+ def test_neighbors_children_consistent(self):
+ for unused_name, fn in inspect.getmembers(
+ cftc, predicate=inspect.isfunction):
+ graph = program_graph.get_program_graph(fn)
+ for node in graph.all_nodes():
+ if node.node_type == pb.NodeType.AST_NODE:
+ children0 = set(graph.outgoing_neighbors(node, pb.EdgeType.FIELD))
+ children1 = set(graph.children(node))
+ self.assertEqual(children0, children1)
+
+ def test_walk_ast_descendants(self):
+ for unused_name, fn in inspect.getmembers(
+ cftc, predicate=inspect.isfunction):
+ graph = program_graph.get_program_graph(fn)
+ for node in graph.walk_ast_descendants():
+ self.assertIn(node, graph.all_nodes())
+
+ def test_roundtrip_ast(self):
+ for unused_name, fn in inspect.getmembers(
+ cftc, predicate=inspect.isfunction):
+ ast_representation = program_utils.program_to_ast(fn)
+ graph = program_graph.get_program_graph(fn)
+ ast_reproduction = graph.to_ast()
+ self.assertEqual(ast.dump(ast_representation), ast.dump(ast_reproduction))
+
+ def test_reconstruct_missing_ast(self):
+ for unused_name, fn in inspect.getmembers(
+ cftc, predicate=inspect.isfunction):
+ graph = program_graph.get_program_graph(fn)
+ ast_original = graph.root.ast_node
+ # Remove the AST.
+ for node in graph.all_nodes():
+ node.ast_node = None
+ # Reconstruct it.
+ graph.reconstruct_ast()
+ ast_reproduction = graph.root.ast_node
+ # Check reconstruction.
+ self.assertEqual(ast.dump(ast_original), ast.dump(ast_reproduction))
+ # Check that all AST_NODE nodes are set.
+ for node in graph.all_nodes():
+ if node.node_type == pb.NodeType.AST_NODE:
+ self.assertIsInstance(node.ast_node, ast.AST)
+ self.assertIs(graph.get_node_by_ast_node(node.ast_node), node)
+ # Check that old AST nodes are no longer referenced.
+ self.assertFalse(graph.contains_ast_node(ast_original))
+
+ def test_remove(self):
+ graph = program_graph.get_program_graph(pgtc.assignments)
+
+ for edge in list(graph.edges)[:]:
+ # Remove the edge.
+ graph.remove_edge(edge)
+ self.assertNotIn(edge, graph.edges)
+ self.assertNotIn((edge, edge.id2), graph.neighbors_map[edge.id1])
+ self.assertNotIn((edge, edge.id1), graph.neighbors_map[edge.id2])
+
+ if edge.type == pb.EdgeType.FIELD:
+ self.assertNotIn(edge.id2, graph.child_map[edge.id1])
+ self.assertNotIn(edge.id2, graph.parent_map)
+
+ # Add the edge again.
+ graph.add_edge(edge)
+ self.assertIn(edge, graph.edges)
+ self.assertIn((edge, edge.id2), graph.neighbors_map[edge.id1])
+ self.assertIn((edge, edge.id1), graph.neighbors_map[edge.id2])
+
+ if edge.type == pb.EdgeType.FIELD:
+ self.assertIn(edge.id2, graph.child_map[edge.id1])
+ self.assertIn(edge.id2, graph.parent_map)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/python_graphs/program_graph_test_components.py b/python_graphs/program_graph_test_components.py
new file mode 100644
index 0000000..a396da4
--- /dev/null
+++ b/python_graphs/program_graph_test_components.py
@@ -0,0 +1,61 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test components for testing program graphs."""
+
+
+# pylint: disable=missing-docstring
+# pylint: disable=pointless-statement,undefined-variable
+# pylint: disable=unused-variable,unused-argument
+# pylint: disable=bare-except,lost-exception,unreachable
+# pylint: disable=global-variable-undefined
+def function_call():
+ x = 1
+ y = 2
+ z = function_call_helper(x, y)
+ return z
+
+
+def function_call_helper(arg0, arg1):
+ return arg0 + arg1
+
+
+def assignments():
+ a, b = 0, 0
+ c = 2 * a + 1
+ d = b - c + 2
+ a = c + 3
+ return a, b, c, d
+
+
+def fn_with_globals():
+ global global_a, global_b, global_c
+ global_a = 10
+ global_b = 20
+ global_c = 30
+ return global_a + global_b + global_c
+
+
+def fn_with_inner_fn():
+
+ def inner_fn():
+ while True:
+ pass
+
+
+def repeated_identifier():
+ x = 0
+ x = x + 1
+ x = (x + (x + x)) + x
+ return x
diff --git a/python_graphs/program_graph_visualizer.py b/python_graphs/program_graph_visualizer.py
new file mode 100644
index 0000000..7af130b
--- /dev/null
+++ b/python_graphs/program_graph_visualizer.py
@@ -0,0 +1,51 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""Create program graph visualizations for the test components.
+
+
+Usage:
+python -m python_graphs.program_graph_visualizer
+"""
+
+import inspect
+
+from absl import app
+from absl import logging # pylint: disable=unused-import
+
+from python_graphs import control_flow_test_components as tc
+from python_graphs import program_graph
+from python_graphs import program_graph_graphviz
+
+
+def render_functions(functions):
+ for name, function in functions:
+ logging.info(name)
+ graph = program_graph.get_program_graph(function)
+ path = '/tmp/program_graphs/{}.png'.format(name)
+ program_graph_graphviz.render(graph, path=path)
+
+
+def main(argv):
+ del argv # Unused.
+
+ functions = [
+ (name, fn)
+ for name, fn in inspect.getmembers(tc, predicate=inspect.isfunction)
+ ]
+ render_functions(functions)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/python_graphs/program_utils.py b/python_graphs/program_utils.py
new file mode 100644
index 0000000..74117e6
--- /dev/null
+++ b/python_graphs/program_utils.py
@@ -0,0 +1,62 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Program utility functions."""
+
+import inspect
+import textwrap
+import uuid
+
+import gast as ast
+import six
+
+
+def getsource(obj):
+ """Gets the source for the given object.
+
+ Args:
+ obj: A module, class, method, function, traceback, frame, or code object.
+ Returns:
+ The source of the object, if available.
+ """
+ if inspect.ismethod(obj):
+ func = obj.__func__
+ else:
+ func = obj
+ source = inspect.getsource(func)
+ return textwrap.dedent(source)
+
+
+def program_to_ast(program):
+ """Convert a program to its AST.
+
+ Args:
+ program: Either an AST node, source string, or a function.
+ Returns:
+ The root AST node of the AST representing the program.
+ """
+ if isinstance(program, ast.AST):
+ return program
+ if isinstance(program, six.string_types):
+ source = program
+ else:
+ source = getsource(program)
+ module_node = ast.parse(source, mode='exec')
+ return module_node
+
+
+def unique_id():
+ """Returns a unique id that is suitable for identifying graph nodes."""
+ return uuid.uuid4().int & ((1 << 64) - 1)
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..9c558e3
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1 @@
+.
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..a3abdfe
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,81 @@
+# Copyright (C) 2021 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""The setup.py file for python_graphs."""
+
+from setuptools import setup
+
+LONG_DESCRIPTION = """
+python_graphs is a static analysis tool for performing control flow and data
+flow analyses on Python programs, and for constructing Program Graphs.
+Python Program Graphs are graph representations of Python programs suitable
+for use with graph neural networks.
+""".strip()
+
+SHORT_DESCRIPTION = """
+A library for generating graph representations of Python programs.""".strip()
+
+DEPENDENCIES = [
+ 'absl-py',
+ 'astunparse',
+ 'gast',
+ 'six',
+ 'pygraphviz',
+]
+
+TEST_DEPENDENCIES = [
+]
+
+VERSION = '1.0.0'
+URL = 'https://github.com/google-research/python-graphs'
+
+setup(
+ name='python_graphs',
+ version=VERSION,
+ description=SHORT_DESCRIPTION,
+ long_description=LONG_DESCRIPTION,
+ url=URL,
+
+ author='David Bieber',
+ author_email='dbieber@google.com',
+ license='Apache Software License',
+
+ classifiers=[
+ 'Development Status :: 4 - Beta',
+
+ 'Intended Audience :: Developers',
+ 'Topic :: Software Development :: Libraries :: Python Modules',
+
+ 'License :: OSI Approved :: Apache Software License',
+
+ 'Programming Language :: Python',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ 'Programming Language :: Python :: 3.9',
+
+ 'Operating System :: OS Independent',
+ 'Operating System :: POSIX',
+ 'Operating System :: MacOS',
+ 'Operating System :: Unix',
+ ],
+
+ keywords='python program control flow data flow graph neural network',
+
+ packages=['python_graphs'],
+
+ install_requires=DEPENDENCIES,
+ tests_require=TEST_DEPENDENCIES,
+)