From b653ad9b0fae3466e7011e29fe3c986a58388bcd Mon Sep 17 00:00:00 2001 From: David Bieber Date: Tue, 6 Apr 2021 18:02:54 +0000 Subject: [PATCH] Initial commit of python_graphs. Project import generated by Copybara. PiperOrigin-RevId: 367412032 --- CONTRIBUTING.md | 29 + LICENSE | 202 +++ README.md | 47 + .../analysis/program_graph_analysis.py | 150 ++ .../analysis/program_graph_analysis_test.py | 366 +++++ .../analysis/run_program_graph_analysis.py | 277 ++++ python_graphs/control_flow.py | 1291 +++++++++++++++++ python_graphs/control_flow_graphviz.py | 113 ++ python_graphs/control_flow_graphviz_test.py | 41 + python_graphs/control_flow_test.py | 308 ++++ python_graphs/control_flow_test_components.py | 322 ++++ python_graphs/control_flow_visualizer.py | 74 + python_graphs/cyclomatic_complexity.py | 49 + python_graphs/cyclomatic_complexity_test.py | 38 + python_graphs/data_flow.py | 233 +++ python_graphs/data_flow_test.py | 138 ++ .../examples/control_flow_example.py | 57 + .../examples/cyclomatic_complexity_example.py | 46 + .../examples/program_graph_example.py | 50 + python_graphs/instruction.py | 400 +++++ python_graphs/instruction_test.py | 123 ++ python_graphs/program_graph.py | 963 ++++++++++++ python_graphs/program_graph_dataclasses.py | 82 ++ python_graphs/program_graph_graphviz.py | 61 + python_graphs/program_graph_graphviz_test.py | 34 + python_graphs/program_graph_test.py | 292 ++++ .../program_graph_test_components.py | 61 + python_graphs/program_graph_visualizer.py | 51 + python_graphs/program_utils.py | 62 + requirements.txt | 1 + setup.py | 81 ++ 31 files changed, 6042 insertions(+) create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 python_graphs/analysis/program_graph_analysis.py create mode 100644 python_graphs/analysis/program_graph_analysis_test.py create mode 100644 python_graphs/analysis/run_program_graph_analysis.py create mode 100644 python_graphs/control_flow.py create mode 100644 python_graphs/control_flow_graphviz.py create mode 100644 python_graphs/control_flow_graphviz_test.py create mode 100644 python_graphs/control_flow_test.py create mode 100644 python_graphs/control_flow_test_components.py create mode 100644 python_graphs/control_flow_visualizer.py create mode 100644 python_graphs/cyclomatic_complexity.py create mode 100644 python_graphs/cyclomatic_complexity_test.py create mode 100644 python_graphs/data_flow.py create mode 100644 python_graphs/data_flow_test.py create mode 100644 python_graphs/examples/control_flow_example.py create mode 100644 python_graphs/examples/cyclomatic_complexity_example.py create mode 100644 python_graphs/examples/program_graph_example.py create mode 100644 python_graphs/instruction.py create mode 100644 python_graphs/instruction_test.py create mode 100644 python_graphs/program_graph.py create mode 100644 python_graphs/program_graph_dataclasses.py create mode 100644 python_graphs/program_graph_graphviz.py create mode 100644 python_graphs/program_graph_graphviz_test.py create mode 100644 python_graphs/program_graph_test.py create mode 100644 python_graphs/program_graph_test_components.py create mode 100644 python_graphs/program_graph_visualizer.py create mode 100644 python_graphs/program_utils.py create mode 100644 requirements.txt create mode 100644 setup.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..cae7050 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,29 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement (CLA). You (or your employer) retain the copyright to your +contribution; this simply gives us permission to use and redistribute your +contributions as part of the project. Head over to + to see your current agreements on file or +to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code Reviews + +All submissions, including submissions by project members, require review. For +external contributions, we use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google/conduct/). diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index e69de29..d5ac5ea 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,47 @@ +# python_graphs + +This package is for computing graph representations of Python programs for +machine learning applications. It includes the following modules: + +* `control_flow` For computing control flow graphs statically from Python + programs. +* `data_flow` For computing data flow analyses of Python programs. +* `program_graph` For computing graphs statically to represent arbitrary + Python programs or functions. +* `cyclomatic_complexity` For computing the cyclomatic complexity of a Python function. + + +## Installation + +To install python_graphs with pip, run: `pip install python_graphs`. + +To install python_graphs from source, run: `python setup.py develop`. + +## Common Tasks + +**Generate a control flow graph from a function `fn`:** + +```python +from python_graphs import control_flow +graph = control_flow.get_control_flow_graph(fn) +``` + +**Generate a program graph from a function `fn`:** + +```python +from python_graphs import program_graph +graph = program_graph.get_program_graph(fn) +``` + +**Compute the cyclomatic complexity of a function `fn`:** + +```python +from python_graphs import control_flow +from python_graphs import cyclomatic_complexity +graph = control_flow.get_control_flow_graph(fn) +value = cyclomatic_complexity.cyclomatic_complexity(graph) +``` + +--- + +This is not an officially supported Google product. diff --git a/python_graphs/analysis/program_graph_analysis.py b/python_graphs/analysis/program_graph_analysis.py new file mode 100644 index 0000000..c21699a --- /dev/null +++ b/python_graphs/analysis/program_graph_analysis.py @@ -0,0 +1,150 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions to analyze program graphs. + +Computes properties such as the height of a program graph's AST. +""" + +import gast as ast +import networkx as nx + + +def num_nodes(graph): + """Returns the number of nodes in a ProgramGraph.""" + return len(graph.all_nodes()) + + +def num_edges(graph): + """Returns the number of edges in a ProgramGraph.""" + return len(graph.edges) + + +def ast_height(ast_node): + """Computes the height of an AST from the given node. + + Args: + ast_node: An AST node. + + Returns: + The height of the AST starting at ast_node. A leaf node or single-node AST + has a height of 1. + """ + max_child_height = 0 + for child_node in ast.iter_child_nodes(ast_node): + max_child_height = max(max_child_height, ast_height(child_node)) + return 1 + max_child_height + + +def graph_ast_height(graph): + """Computes the height of the AST of a ProgramGraph. + + Args: + graph: A ProgramGraph. + + Returns: + The height of the graph's AST. A single-node AST has a height of 1. + """ + return ast_height(graph.to_ast()) + + +def degrees(graph): + """Returns a list of node degrees in a ProgramGraph. + + Args: + graph: A ProgramGraph. + + Returns: + An (unsorted) list of node degrees (in-degree plus out-degree). + """ + return [len(graph.neighbors(node)) for node in graph.all_nodes()] + + +def in_degrees(graph): + """Returns a list of node in-degrees in a ProgramGraph. + + Args: + graph: A ProgramGraph. + + Returns: + An (unsorted) list of node in-degrees. + """ + return [len(graph.incoming_neighbors(node)) for node in graph.all_nodes()] + + +def out_degrees(graph): + """Returns a list of node out-degrees in a ProgramGraph. + + Args: + graph: A ProgramGraph. + + Returns: + An (unsorted) list of node out-degrees. + """ + return [len(graph.outgoing_neighbors(node)) for node in graph.all_nodes()] + + +def _program_graph_to_nx(program_graph, directed=False): + """Converts a ProgramGraph to a NetworkX graph. + + Args: + program_graph: A ProgramGraph. + directed: Whether the graph should be treated as a directed graph. + + Returns: + A NetworkX graph that can be analyzed by the networkx module. + """ + # Create a dict-of-lists representation, where {0: [1]} represents a directed + # edge from node 0 to node 1. + dict_of_lists = {} + for node in program_graph.all_nodes(): + neighbor_ids = [neighbor.id + for neighbor in program_graph.outgoing_neighbors(node)] + dict_of_lists[node.id] = neighbor_ids + return nx.DiGraph(dict_of_lists) if directed else nx.Graph(dict_of_lists) + + +def diameter(graph): + """Returns the diameter of a ProgramGraph. + + Note: this is very slow for large graphs. + + Args: + graph: A ProgramGraph. + + Returns: + The diameter of the graph. A single-node graph has diameter 0. The graph is + treated as an undirected graph. + + Raises: + networkx.exception.NetworkXError: Raised if the graph is not connected. + """ + nx_graph = _program_graph_to_nx(graph, directed=False) + return nx.algorithms.distance_measures.diameter(nx_graph) + + +def max_betweenness(graph): + """Returns the maximum node betweenness centrality in a ProgramGraph. + + Note: this is very slow for large graphs. + + Args: + graph: A ProgramGraph. + + Returns: + The maximum betweenness centrality value among all nodes in the graph. The + graph is treated as an undirected graph. + """ + nx_graph = _program_graph_to_nx(graph, directed=False) + return max(nx.algorithms.centrality.betweenness_centrality(nx_graph).values()) diff --git a/python_graphs/analysis/program_graph_analysis_test.py b/python_graphs/analysis/program_graph_analysis_test.py new file mode 100644 index 0000000..d8fb589 --- /dev/null +++ b/python_graphs/analysis/program_graph_analysis_test.py @@ -0,0 +1,366 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for program_graph_analysis.py.""" + +from absl.testing import absltest +import gast as ast +import networkx as nx + +from python_graphs import program_graph +from python_graphs.analysis import program_graph_analysis as pga + + +class ProgramGraphAnalysisTest(absltest.TestCase): + + def setUp(self): + super(ProgramGraphAnalysisTest, self).setUp() + self.singleton = self.create_singleton_graph() + self.disconnected = self.create_disconnected_graph() + self.cycle_3 = self.create_cycle_3() + self.chain_4 = self.create_chain_4() + self.wide_tree = self.create_wide_tree() + + def create_singleton_graph(self): + """Returns a graph with one node and zero edges.""" + graph = program_graph.ProgramGraph() + node = program_graph.make_node_from_syntax('singleton_node') + graph.add_node(node) + graph.root_id = node.id + return graph + + def create_disconnected_graph(self): + """Returns a disconnected graph with two nodes and zero edges.""" + graph = program_graph.ProgramGraph() + a = program_graph.make_node_from_syntax('a') + b = program_graph.make_node_from_syntax('b') + graph.add_node(a) + graph.add_node(b) + graph.root_id = a.id + return graph + + def create_cycle_3(self): + """Returns a 3-cycle graph, A -> B -> C -> A.""" + graph = program_graph.ProgramGraph() + a = program_graph.make_node_from_syntax('A') + b = program_graph.make_node_from_ast_value('B') + c = program_graph.make_node_from_syntax('C') + graph.add_node(a) + graph.add_node(b) + graph.add_node(c) + graph.add_new_edge(a, b) + graph.add_new_edge(b, c) + graph.add_new_edge(c, a) + graph.root_id = a.id + return graph + + def create_chain_4(self): + """Returns a chain of 4 nodes, A -> B -> C -> D.""" + graph = program_graph.ProgramGraph() + a = program_graph.make_node_from_syntax('A') + b = program_graph.make_node_from_ast_value('B') + c = program_graph.make_node_from_syntax('C') + d = program_graph.make_node_from_ast_value('D') + graph.add_node(a) + graph.add_node(b) + graph.add_node(c) + graph.add_node(d) + graph.add_new_edge(a, b) + graph.add_new_edge(b, c) + graph.add_new_edge(c, d) + graph.root_id = a.id + return graph + + def create_wide_tree(self): + """Returns a tree where the root has 4 children that are all leaves.""" + graph = program_graph.ProgramGraph() + root = program_graph.make_node_from_syntax('root') + graph.add_node(root) + graph.root_id = root.id + for i in range(4): + leaf = program_graph.make_node_from_ast_value(i) + graph.add_node(leaf) + graph.add_new_edge(root, leaf) + return graph + + def ids_from_cycle_3(self): + """Returns a triplet of IDs from the 3-cycle graph in cycle order.""" + root = self.cycle_3.root + id_a = root.id + id_b = self.cycle_3.outgoing_neighbors(root)[0].id + id_c = self.cycle_3.incoming_neighbors(root)[0].id + return id_a, id_b, id_c + + def test_num_nodes_returns_expected(self): + self.assertEqual(pga.num_nodes(self.singleton), 1) + self.assertEqual(pga.num_nodes(self.disconnected), 2) + self.assertEqual(pga.num_nodes(self.cycle_3), 3) + self.assertEqual(pga.num_nodes(self.chain_4), 4) + self.assertEqual(pga.num_nodes(self.wide_tree), 5) + + def test_num_edges_returns_expected(self): + self.assertEqual(pga.num_edges(self.singleton), 0) + self.assertEqual(pga.num_edges(self.disconnected), 0) + self.assertEqual(pga.num_edges(self.cycle_3), 3) + self.assertEqual(pga.num_edges(self.chain_4), 3) + self.assertEqual(pga.num_edges(self.wide_tree), 4) + + def test_ast_height_returns_expected_for_constructed_expression_ast(self): + # Testing the expression "1". + # Height 3: Module -> Expr -> Num. + ast_node = ast.Module( + body=[ast.Expr(value=ast.Constant(value=1, kind=None))], + type_ignores=[]) + self.assertEqual(pga.ast_height(ast_node), 3) + + # Testing the expression "1 + 1". + # Height 4: Module -> Expr -> BinOp -> Num. + ast_node = ast.Module( + body=[ + ast.Expr( + value=ast.BinOp( + left=ast.Constant(value=1, kind=None), + op=ast.Add(), + right=ast.Constant(value=1, kind=None))) + ], + type_ignores=[]) + self.assertEqual(pga.ast_height(ast_node), 4) + + # Testing the expression "a + 1". + # Height 5: Module -> Expr -> BinOp -> Name -> Load. + ast_node = ast.Module( + body=[ + ast.Expr( + value=ast.BinOp( + left=ast.Name( + id='a', + ctx=ast.Load(), + annotation=None, + type_comment=None), + op=ast.Add(), + right=ast.Constant(value=1, kind=None))) + ], + type_ignores=[]) + self.assertEqual(pga.ast_height(ast_node), 5) + + # Testing the expression "a.b + 1". + # Height 6: Module -> Expr -> BinOp -> Attribute -> Name -> Load. + ast_node = ast.Module( + body=[ + ast.Expr( + value=ast.BinOp( + left=ast.Attribute( + value=ast.Name( + id='a', + ctx=ast.Load(), + annotation=None, + type_comment=None), + attr='b', + ctx=ast.Load()), + op=ast.Add(), + right=ast.Constant(value=1, kind=None))) + ], + type_ignores=[]) + self.assertEqual(pga.ast_height(ast_node), 6) + + def test_ast_height_returns_expected_for_constructed_function_ast(self): + # Testing the function declaration "def foo(n): return". + # Height 5: Module -> FunctionDef -> arguments -> Name -> Param. + ast_node = ast.Module( + body=[ + ast.FunctionDef( + name='foo', + args=ast.arguments( + args=[ + ast.Name( + id='n', + ctx=ast.Param(), + annotation=None, + type_comment=None) + ], + posonlyargs=[], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[]), + body=[ast.Return(value=None)], + decorator_list=[], + returns=None, + type_comment=None) + ], + type_ignores=[]) + self.assertEqual(pga.ast_height(ast_node), 5) + + # Testing the function declaration "def foo(n): return n + 1". + # Height 6: Module -> FunctionDef -> Return -> BinOp -> Name -> Load. + ast_node = ast.Module( + body=[ + ast.FunctionDef( + name='foo', + args=ast.arguments( + args=[ + ast.Name( + id='n', + ctx=ast.Param(), + annotation=None, + type_comment=None) + ], + posonlyargs=[], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[]), + body=[ + ast.Return( + value=ast.BinOp( + left=ast.Name( + id='n', + ctx=ast.Load(), + annotation=None, + type_comment=None), + op=ast.Add(), + right=ast.Constant(value=1, kind=None))) + ], + decorator_list=[], + returns=None, + type_comment=None) + ], + type_ignores=[], + ) + self.assertEqual(pga.ast_height(ast_node), 6) + + def test_ast_height_returns_expected_for_parsed_ast(self): + # Height 3: Module -> Expr -> Num. + self.assertEqual(pga.ast_height(ast.parse('1')), 3) + + # Height 6: Module -> Expr -> BinOp -> Attribute -> Name -> Load. + self.assertEqual(pga.ast_height(ast.parse('a.b + 1')), 6) + + # Height 6: Module -> FunctionDef -> Return -> BinOp -> Name -> Load. + self.assertEqual(pga.ast_height(ast.parse('def foo(n): return n + 1')), 6) + + # Height 9: Module -> FunctionDef -> If -> Return -> BinOp -> Call + # -> BinOp -> Name -> Load. + # Adding whitespace before "def foo" causes an IndentationError in parse(). + ast_node = ast.parse("""def foo(n): + if n <= 0: + return 0 + else: + return 1 + foo(n - 1) + """) + self.assertEqual(pga.ast_height(ast_node), 9) + + def test_graph_ast_height_returns_expected(self): + # Height 6: Module -> FunctionDef -> Return -> BinOp -> Name -> Load. + def foo1(n): + return n + 1 + + graph = program_graph.get_program_graph(foo1) + self.assertEqual(pga.graph_ast_height(graph), 6) + + # Height 9: Module -> FunctionDef -> If -> Return -> BinOp -> Call + # -> BinOp -> Name -> Load. + def foo2(n): + if n <= 0: + return 0 + else: + return 1 + foo2(n - 1) + + graph = program_graph.get_program_graph(foo2) + self.assertEqual(pga.graph_ast_height(graph), 9) + + def test_degrees_returns_expected(self): + self.assertCountEqual(pga.degrees(self.singleton), [0]) + self.assertCountEqual(pga.degrees(self.disconnected), [0, 0]) + self.assertCountEqual(pga.degrees(self.cycle_3), [2, 2, 2]) + self.assertCountEqual(pga.degrees(self.chain_4), [1, 2, 2, 1]) + self.assertCountEqual(pga.degrees(self.wide_tree), [4, 1, 1, 1, 1]) + + def test_in_degrees_returns_expected(self): + self.assertCountEqual(pga.in_degrees(self.singleton), [0]) + self.assertCountEqual(pga.in_degrees(self.disconnected), [0, 0]) + self.assertCountEqual(pga.in_degrees(self.cycle_3), [1, 1, 1]) + self.assertCountEqual(pga.in_degrees(self.chain_4), [0, 1, 1, 1]) + self.assertCountEqual(pga.in_degrees(self.wide_tree), [0, 1, 1, 1, 1]) + + def test_out_degrees_returns_expected(self): + self.assertCountEqual(pga.out_degrees(self.singleton), [0]) + self.assertCountEqual(pga.out_degrees(self.disconnected), [0, 0]) + self.assertCountEqual(pga.out_degrees(self.cycle_3), [1, 1, 1]) + self.assertCountEqual(pga.out_degrees(self.chain_4), [1, 1, 1, 0]) + self.assertCountEqual(pga.out_degrees(self.wide_tree), [4, 0, 0, 0, 0]) + + def test_diameter_returns_expected_if_connected(self): + self.assertEqual(pga.diameter(self.singleton), 0) + self.assertEqual(pga.diameter(self.cycle_3), 1) + self.assertEqual(pga.diameter(self.chain_4), 3) + self.assertEqual(pga.diameter(self.wide_tree), 2) + + def test_diameter_throws_exception_if_disconnected(self): + with self.assertRaises(nx.exception.NetworkXError): + pga.diameter(self.disconnected) + + def test_program_graph_to_nx_undirected_has_correct_edges(self): + id_a, id_b, id_c = self.ids_from_cycle_3() + nx_graph = pga._program_graph_to_nx(self.cycle_3, directed=False) + self.assertCountEqual(nx_graph.nodes(), [id_a, id_b, id_c]) + expected_adj = { + id_a: { + id_b: {}, + id_c: {} + }, + id_b: { + id_a: {}, + id_c: {} + }, + id_c: { + id_a: {}, + id_b: {} + }, + } + self.assertEqual(nx_graph.adj, expected_adj) + + def test_program_graph_to_nx_directed_has_correct_edges(self): + id_a, id_b, id_c = self.ids_from_cycle_3() + nx_digraph = pga._program_graph_to_nx(self.cycle_3, directed=True) + self.assertCountEqual(nx_digraph.nodes(), [id_a, id_b, id_c]) + expected_adj = { + id_a: { + id_b: {} + }, + id_b: { + id_c: {} + }, + id_c: { + id_a: {} + }, + } + self.assertEqual(nx_digraph.adj, expected_adj) + + def test_max_betweenness_returns_expected(self): + self.assertAlmostEqual(pga.max_betweenness(self.singleton), 0) + self.assertAlmostEqual(pga.max_betweenness(self.disconnected), 0) + self.assertAlmostEqual(pga.max_betweenness(self.cycle_3), 0) + + # Middle nodes are in 2 shortest paths, normalizer = (4-1)*(4-2)/2 = 3 + self.assertAlmostEqual(pga.max_betweenness(self.chain_4), 2 / 3) + + # Root is in 6 shortest paths, normalizer = (5-1)*(5-2)/2 = 6 + self.assertAlmostEqual(pga.max_betweenness(self.wide_tree), 6 / 6) + + +if __name__ == '__main__': + absltest.main() diff --git a/python_graphs/analysis/run_program_graph_analysis.py b/python_graphs/analysis/run_program_graph_analysis.py new file mode 100644 index 0000000..6f214dc --- /dev/null +++ b/python_graphs/analysis/run_program_graph_analysis.py @@ -0,0 +1,277 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runs the program graph analysis for datasets of programs. + +Analyzes each dataset of programs, producing plots for properties such as the +AST height. +""" + +import inspect +import math + +from absl import app +from absl import logging +import matplotlib.pyplot as plt +import numpy as np +from python_graphs import control_flow_test_components as cftc +from python_graphs import program_graph +from python_graphs import program_graph_test_components as pgtc +from python_graphs.analysis import program_graph_analysis +import six +from six.moves import range + + + +TARGET_NUM_BINS = 15 # A reasonable constant number of histogram bins. +MAX_NUM_BINS = 20 # The maximum number of bins reasonable on a histogram. + + +def test_components(): + """Generates functions from two sets of test components. + + Yields: + All functions in the program graph and control flow test components files. + """ + for unused_name, fn in inspect.getmembers(pgtc, predicate=inspect.isfunction): + yield fn + + for unused_name, fn in inspect.getmembers(cftc, predicate=inspect.isfunction): + yield fn + + + + +def get_graph_generator(function_generator): + """Generates ProgramGraph objects from functions. + + Args: + function_generator: A function generator. + + Yields: + ProgramGraph objects for the functions. + """ + for index, function in enumerate(function_generator): + try: + graph = program_graph.get_program_graph(function) + yield graph + except SyntaxError: + # get_program_graph can fail for programs with different string encodings. + logging.info('SyntaxError in get_program_graph for function index %d. ' + 'First 100 chars of function source:\n%s', + index, function[:100]) + except RuntimeError: + # get_program_graph can fail for programs that are only return statements. + logging.info('RuntimeError in get_program_graph for function index %d. ' + 'First 100 chars of function source:\n%s', + index, function[:100]) + + +def get_percentiles(data, percentiles, integer_valued=True): + """Returns a dict of percentiles of the data. + + Args: + data: An unsorted list of datapoints. + percentiles: A list of ints or floats in the range [0, 100] representing the + percentiles to compute. + integer_valued: Whether or not the values are all integers. If so, + interpolate to the nearest datapoint (instead of computing a fractional + value between the two nearest datapoints). + + Returns: + A dict mapping each element of percentiles to the computed result. + """ + # Ensure integer datapoints for cleaner binning if necessary. + interpolation = 'nearest' if integer_valued else 'linear' + results = np.percentile(data, percentiles, interpolation=interpolation) + return {percentiles[i]: results[i] for i in range(len(percentiles))} + + +def analyze_graph(graph, identifier): + """Performs various analyses on a graph. + + Args: + graph: A ProgramGraph to analyze. + identifier: A unique identifier for this graph (for later aggregation). + + Returns: + A pair (identifier, result_dict), where result_dict contains the results of + analyses run on the graph. + """ + num_nodes = program_graph_analysis.num_nodes(graph) + num_edges = program_graph_analysis.num_edges(graph) + ast_height = program_graph_analysis.graph_ast_height(graph) + + degree_percentiles = [25, 50, 90] + degrees = get_percentiles(program_graph_analysis.degrees(graph), + degree_percentiles) + in_degrees = get_percentiles(program_graph_analysis.in_degrees(graph), + degree_percentiles) + out_degrees = get_percentiles(program_graph_analysis.out_degrees(graph), + degree_percentiles) + + diameter = program_graph_analysis.diameter(graph) + max_betweenness = program_graph_analysis.max_betweenness(graph) + + # TODO(kshi): Turn this into a protobuf and fix everywhere else in this file. + # Eventually this should be parallelized (currently takes ~6 hours to run). + result_dict = { + 'num_nodes': num_nodes, + 'num_edges': num_edges, + 'ast_height': ast_height, + 'degrees': degrees, + 'in_degrees': in_degrees, + 'out_degrees': out_degrees, + 'diameter': diameter, + 'max_betweenness': max_betweenness, + } + + return (identifier, result_dict) + + +def create_bins(values, integer_valued=True, log_x=False): + """Creates appropriate histogram bins. + + Args: + values: The values to be plotted in a histogram. + integer_valued: Whether the values are all integers. + log_x: Whether to plot the x-axis using a log scale. + + Returns: + An object (sequence, integer, or 'auto') that can be used as the 'bins' + keyword argument to plt.hist(). If there are no values to plot, or all of + the values are identical, then 'auto' is returned. + """ + if not values: + return 'auto' # No data to plot; let pyplot handle this case. + min_value = min(values) + max_value = max(values) + if min_value == max_value: + return 'auto' # All values are identical; let pyplot handle this case. + + if log_x: + return np.logspace(np.log10(min_value), np.log10(max_value + 1), + num=(TARGET_NUM_BINS + 1)) + elif integer_valued: + # The minimum integer width resulting in at most MAX_NUM_BINS bins. + bin_width = math.ceil((max_value - min_value + 1) / MAX_NUM_BINS) + # Place bin boundaries between integers. + return np.arange(min_value - 0.5, max_value + bin_width + 0.5, bin_width) + else: + return TARGET_NUM_BINS + + +def create_histogram(values, title, percentiles=False, integer_valued=True, + log_x=False, log_y=False): + """Returns a histogram of integer values computed from a dataset. + + Args: + values: A list of integer values to plot, or if percentiles is True, then + each value is a dict mapping some chosen percentiles in [0, 100] to the + corresponding data value. + title: The figure title. + percentiles: Whether to plot multiple histograms for percentiles. + integer_valued: Whether the values are all integers, which affects how the + data is partitioned into bins. + log_x: Whether to plot the x-axis using a log scale. + log_y: Whether to plot the y-axis using a log scale. + + Returns: + A histogram figure. + """ + figure = plt.figure() + + if percentiles: + for percentile in sorted(values[0].keys()): + new_values = [percentile_dict[percentile] + for percentile_dict in values] + bins = create_bins(new_values, integer_valued=integer_valued, log_x=log_x) + plt.hist(new_values, bins=bins, alpha=0.5, label='{}%'.format(percentile)) + plt.legend(loc='upper right') + else: + bins = create_bins(values, integer_valued=integer_valued, log_x=log_x) + plt.hist(values, bins=bins) + + if log_x: + plt.xscale('log', nonposx='clip') + if log_y: + plt.yscale('log', nonposy='clip') + plt.title(title) + return figure + + +def save_histogram(all_results, result_key, dataset_name, path_root, + percentiles=False, integer_valued=True, + log_x=False, log_y=False): + """Saves a histogram image to disk. + + Args: + all_results: A list of dicts containing all analysis results for each graph. + result_key: The key in the result dicts specifying what data to plot. + dataset_name: The name of the dataset, which appears in the figure title and + the image filename. + path_root: The directory to save the histogram image in. + percentiles: Whether the data has multiple percentiles to plot. + integer_valued: Whether the values are all integers, which affects how the + data is partitioned into bins. + log_x: Whether to plot the x-axis using a log scale. + log_y: Whether to plot the y-axis using a log scale. + """ + values = [result[result_key] for result in all_results] + title = '{} distribution for {}'.format(result_key, dataset_name) + figure = create_histogram(values, title, percentiles=percentiles, + integer_valued=integer_valued, + log_x=log_x, log_y=log_y) + path = '{}/{}-{}.png'.format(path_root, result_key, dataset_name) + figure.savefig(path) + logging.info('Saved image %s', path) + + +def main(argv): + del argv # Unused. + + dataset_pairs = [ + (test_components(), 'test_components'), + ] + path_root = '/tmp/program_graph_analysis' + + for function_generator, dataset_name in dataset_pairs: + logging.info('Analyzing graphs in dataset %s...', dataset_name) + graph_generator = get_graph_generator(function_generator) + all_results = [] + for index, graph in enumerate(graph_generator): + identifier = '{}-{}'.format(dataset_name, index) + # Discard the identifiers (not needed until this is parallelized). + all_results.append(analyze_graph(graph, identifier)[1]) + + if all_results: + logging.info('Creating plots for dataset %s...', dataset_name) + for result_key in ['num_nodes', 'num_edges']: + save_histogram(all_results, result_key, dataset_name, path_root, + percentiles=False, integer_valued=True, log_x=True) + for result_key in ['ast_height', 'diameter']: + save_histogram(all_results, result_key, dataset_name, path_root, + percentiles=False, integer_valued=True) + for result_key in ['max_betweenness']: + save_histogram(all_results, result_key, dataset_name, path_root, + percentiles=False, integer_valued=False) + for result_key in ['degrees', 'in_degrees', 'out_degrees']: + save_histogram(all_results, result_key, dataset_name, path_root, + percentiles=True, integer_valued=True) + else: + logging.warn('Dataset %s is empty.', dataset_name) + + +if __name__ == '__main__': + app.run(main) diff --git a/python_graphs/control_flow.py b/python_graphs/control_flow.py new file mode 100644 index 0000000..a5d8f8c --- /dev/null +++ b/python_graphs/control_flow.py @@ -0,0 +1,1291 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Computes the control flow graph for a Python program from its AST. +""" + +import uuid + +from absl import logging # pylint: disable=unused-import +import gast as ast +from python_graphs import instruction as instruction_module +from python_graphs import program_utils +import six + + +def get_control_flow_graph(program): + """Get a ControlFlowGraph for the provided AST node. + + Args: + program: Either an AST node, source string, or a function. + Returns: + A ControlFlowGraph. + """ + control_flow_visitor = ControlFlowVisitor() + node = program_utils.program_to_ast(program) + control_flow_visitor.run(node) + return control_flow_visitor.graph + + +class ControlFlowGraph(object): + """A control flow graph for a Python program. + + Attributes: + blocks: All blocks contained in the control flow graph. + nodes: All control flow nodes in the control flow graph. + start_block: The entry point to the program. + """ + + def __init__(self): + self.blocks = [] + self.nodes = [] + + self.start_block = self.new_block(prunable=False) + self.start_block.label = '' + + def add_node(self, control_flow_node): + self.nodes.append(control_flow_node) + + def new_block(self, node=None, label=None, prunable=True): + block = BasicBlock(node=node, label=label, prunable=prunable) + block.graph = self + self.blocks.append(block) + return block + + def get_control_flow_nodes(self): + return self.nodes + + def get_enter_blocks(self): + """Returns entry blocks for all functions.""" + return six.moves.filter( + lambda block: block.label.startswith(''.format(name=name), + self.blocks) + + def get_block_by_function_name(self, name): + return next(self.get_blocks_by_function_name(name)) + + def get_control_flow_nodes_by_source(self, source): + module = ast.parse(source, mode='exec') # TODO(dbieber): Factor out 4 lines + node = module.body[0] + if isinstance(node, ast.Expr): + node = node.value + + return six.moves.filter( + lambda cfn: cfn.instruction.contains_subprogram(node), + self.get_control_flow_nodes()) + + def get_control_flow_node_by_source(self, source): + return next(self.get_control_flow_nodes_by_source(source)) + + def get_control_flow_nodes_by_source_and_identifier(self, source, name): + for control_flow_node in self.get_control_flow_nodes_by_source(source): + for node in ast.walk(control_flow_node.instruction.node): + if isinstance(node, ast.Name) and node.id == name: + for i2 in self.get_control_flow_nodes_by_ast_node(node): + yield i2 + + def get_control_flow_node_by_source_and_identifier(self, source, name): + return next( + self.get_control_flow_nodes_by_source_and_identifier(source, name)) + + def get_blocks_by_source(self, source): + """Yields blocks that contain instructions matching the query source.""" + module = ast.parse(source, mode='exec') + node = module.body[0] + if isinstance(node, ast.Expr): + node = node.value + + for block in self.blocks: + for control_flow_node in block.control_flow_nodes: + if control_flow_node.instruction.contains_subprogram(node): + yield block + break + + def get_block_by_source(self, source): + return next(self.get_blocks_by_source(source)) + + def get_blocks_by_source_and_ast_node_type(self, source, node_type): + """Blocks with an Instruction matching node_type and containing source.""" + module = ast.parse(source, mode='exec') + node = module.body[0] + if isinstance(node, ast.Expr): + node = node.value + + for block in self.blocks: + for instruction in block.instructions: + if (isinstance(instruction.node, node_type) + and instruction.contains_subprogram(node)): + yield block + break + + def get_block_by_source_and_ast_node_type(self, source, node_type): + """A block with an Instruction matching node_type and containing source.""" + return next(self.get_blocks_by_source_and_ast_node_type(source, node_type)) + + def get_block_by_ast_node_and_label(self, node, label): + """Gets the block corresponding to `node` having label `label`.""" + for block in self.blocks: + if block.node is node and block.label == label: + return block + + def get_blocks_by_ast_node_type_and_label(self, node_type, label): + """Gets the blocks with node type `node_type` having label `label`.""" + for block in self.blocks: + if isinstance(block.node, node_type) and block.label == label: + yield block + + def get_block_by_ast_node_type_and_label(self, node_type, label): + """Gets a block with node type `node_type` having label `label`.""" + return next(self.get_blocks_by_ast_node_type_and_label(node_type, label)) + + def prune(self): + """Prunes all prunable blocks from the graph.""" + progress = True + while progress: + progress = False + for block in iter(self.blocks): + if block.can_prune(): + to_remove = block.prune() + self.blocks.remove(to_remove) + progress = True + + def compact(self): + """Prunes unused blocks and merges blocks when possible.""" + self.prune() + for block in iter(self.blocks): + while block.can_merge(): + to_remove = block.merge() + self.blocks.remove(to_remove) + for block in self.blocks: + block.compact() + + +class Frame(object): + """A Frame indicates how statements affect control flow in parts of a program. + + Frames are introduced when the program enters a new loop, function definition, + or try/except/finally block. + + A Frame indicates how an exit such as a continue, break, exception, or return + affects control flow. For example, a continue statement inside of a loop sends + control back to the loop's condition. In nested loops, a continue statement + sends control back to the condition of the innermost loop containing the + continue statement. + + Attributes: + kind: One of LOOP, FUNCTION, TRY_EXCEPT, or TRY_FINALLY. + blocks: A dictionary with the blocks relevant to the frame. + """ + + # Kinds: + LOOP = 'loop' + FUNCTION = 'function' + TRY_EXCEPT = 'try-except' + TRY_FINALLY = 'try-finally' + + def __init__(self, kind, **blocks): + self.kind = kind + self.blocks = blocks + + +class BasicBlock(object): + """A basic block in a control flow graph. + + All instructions (generally, AST nodes) in a basic block are either executed + or none are (with the exception of blocks interrupted by exceptions). These + instructions are executed in a straight-line manner. + + Attributes: + graph: The control flow graph which this basic block is a part of. + + next: Indicates which basic blocks may be executed after this basic block. + prev: Indicates which basic blocks may lead to the execution of this basic + block in a Python program. + control_flow_nodes: A list of the ControlFlowNodes contained in this basic + block. Each ControlFlowNode corresponds to a single Instruction. + control_flow_node_indexes: Maps from id(control_flow_node) to the + ControlFlowNode's index in self.control_flow_nodes. Only available once + the block is compacted. + + branches: A map from booleans to the basic block reachable by making the + branch decision indicated by that boolean. + exits_from_middle: These basic blocks may be exited to at any point during + the execution of this basic block. + exits_from_end: These basic blocks may only be exited to at the end of + the execution of this basic block. + node: The AST node this basic block is associated with. + prunable: Whether this basic block may be pruned from the control flow graph + if empty. Set to False for special blocks, such as enter and exit blocks. + label: A label for the basic block. + identities: A list of (node, label) pairs that refer to this basic block. + This starts as (self.node, self.label), but old identities are preserved + during merging and pruning. Allows lookup of blocks by node and label, + e.g. for finding the after block of a particular if statement. + labels: Labels, used for example by data flow analyses. Maps from label name + to value. + """ + + def __init__(self, node=None, label=None, prunable=True): + self.graph = None + self.next = set() + self.prev = set() + self.control_flow_nodes = [] + self.control_flow_node_indexes = None + + self.branches = {} + self.exits_from_middle = set() + self.exits_from_end = set() + self.node = node + self.prunable = prunable + self.label = label + self.identities = [(node, label)] + self.labels = {} + + def has_label(self, label): + """Returns whether this BasicBlock has the specified label.""" + return label in self.labels + + def set_label(self, label, value): + """Sets the value of a label on the BasicBlock.""" + self.labels[label] = value + + def get_label(self, label): + """Gets the value of a label on the BasicBlock.""" + return self.labels[label] + + def is_empty(self): + """Whether this block is empty.""" + return not self.control_flow_nodes + + def exits_to(self, block): + """Whether this block exits to `block`.""" + return block in self.next + + def raises_to(self, block): + """Whether this block exits to `block` in the case of an exception.""" + return block in self.next and block in self.exits_from_middle + + def add_exit(self, block, interrupting=False, branch=None): + """Adds an exit from this block to `block`.""" + self.next.add(block) + block.prev.add(self) + + if branch is not None: + self.branches[branch] = block + + if interrupting: + self.exits_from_middle.add(block) + else: + self.exits_from_end.add(block) + + def remove_exit(self, block): + """Removes the exit from this block to `block`.""" + self.next.remove(block) + block.prev.remove(self) + + if block in self.exits_from_middle: + self.exits_from_middle.remove(block) + if block in self.exits_from_end: + self.exits_from_end.remove(block) + for branch_decision, branch_exit in self.branches.copy().items(): + if branch_exit is block: + del self.branches[branch_decision] + + def can_prune(self): + return self.is_empty() and self.prunable + + def prune(self): + """Prunes the empty block from its control flow graph. + + A block is prunable if it has no control flow nodes and has not been marked + as unprunable (e.g. because it's the exit block, or a return block, etc). + + Returns: + The block removed by the prune operation. That is, self. + """ + assert self.can_prune() + prevs = self.prev.copy() + nexts = self.next.copy() + for prev_block in prevs: + exits_from_middle = prev_block.exits_from_middle.copy() + exits_from_end = prev_block.exits_from_end.copy() + branches = prev_block.branches.copy() + for next_block in nexts: + if self in exits_from_middle: + prev_block.add_exit(next_block, interrupting=True) + if self in exits_from_end: + prev_block.add_exit(next_block, interrupting=False) + + for branch_decision, branch_exit in branches.items(): + if branch_exit is self: + prev_block.branches[branch_decision] = next_block + + for prev_block in prevs: + prev_block.remove_exit(self) + for next_block in nexts: + self.remove_exit(next_block) + next_block.identities = next_block.identities + self.identities + return self + + def can_merge(self): + if len(self.exits_from_end) != 1: + return False + next_block = next(iter(self.exits_from_end)) + if not next_block.prunable: + return False + if self.exits_from_middle != next_block.exits_from_middle: + return False + if len(next_block.prev) == 1: + return True + + def merge(self): + """Merge this block with its one successor. + + Returns: + The successor block removed by the merge operation. + """ + assert self.can_merge() + next_block = next(iter(self.exits_from_end)) + + exits_from_middle = next_block.exits_from_middle.copy() + exits_from_end = next_block.exits_from_end.copy() + + for branch_decision, branch_exit in next_block.branches.items(): + self.branches[branch_decision] = branch_exit + + self.remove_exit(next_block) + for block in next_block.next.copy(): + next_block.remove_exit(block) + if block in exits_from_middle: + self.add_exit(block, interrupting=True) + if block in exits_from_end: + self.add_exit(block, interrupting=False) + for control_flow_node in next_block.control_flow_nodes: + control_flow_node.block = self + self.control_flow_nodes.append(control_flow_node) + self.prunable = self.prunable and next_block.prunable + self.label = self.label or next_block.label + self.identities = self.identities + next_block.identities + # Note: self.exits_from_middle is unchanged. + return next_block + + def add_instruction(self, instruction): + assert isinstance(instruction, instruction_module.Instruction) + control_flow_node = ControlFlowNode(graph=self.graph, + block=self, + instruction=instruction) + self.graph.add_node(control_flow_node) + self.control_flow_nodes.append(control_flow_node) + + def compact(self): + self.control_flow_node_indexes = {} + for index, control_flow_node in enumerate(self.control_flow_nodes): + self.control_flow_node_indexes[control_flow_node.uuid] = index + + def index_of(self, control_flow_node): + """Returns the index of the Instruction in this BasicBlock.""" + return self.control_flow_node_indexes[control_flow_node.uuid] + + +class ControlFlowNode(object): + """A node in a control flow graph. + + Corresponds to a single Instruction contained in a single BasicBlock. + + Attributes: + graph: The ControlFlowGraph which this node is a part of. + block: The BasicBlock in which this node's instruction resides. + instruction: The Instruction corresponding to this node. + labels: Metadata attached to this node, for example for use by data flow + analyses. + uuid: A unique identifier for the ControlFlowNode. + """ + + def __init__(self, graph, block, instruction): + self.graph = graph + self.block = block + self.instruction = instruction + self.labels = {} + self.uuid = uuid.uuid4() + + @property + def next(self): + """Returns the set of possible next instructions.""" + if self.block is None: + return None + index_in_block = self.block.index_of(self) + if len(self.block.control_flow_nodes) > index_in_block + 1: + return {self.block.control_flow_nodes[index_in_block + 1]} + control_flow_nodes = set() + for next_block in self.block.next: + if next_block.control_flow_nodes: + control_flow_nodes.add(next_block.control_flow_nodes[0]) + else: + # If next_block is empty, it isn't the case that some downstream block + # is nonempty. This is guaranteed by the pruning phase of control flow + # graph construction. + assert not next_block.next + return control_flow_nodes + + @property + def prev(self): + """Returns the set of possible previous instructions.""" + if self.block is None: + return None + index_in_block = self.block.index_of(self) + if index_in_block - 1 >= 0: + return {self.block.control_flow_nodes[index_in_block - 1]} + control_flow_nodes = set() + for prev_block in self.block.prev: + if prev_block.control_flow_nodes: + control_flow_nodes.add(prev_block.control_flow_nodes[-1]) + else: + # If prev_block is empty, it isn't the case that some upstream block + # is nonempty. This is guaranteed by the pruning phase of control flow + # graph construction. + assert not prev_block.prev + return control_flow_nodes + + @property + def branches(self): + """Returns the branch options available at the end of this node. + + Returns: + A dictionary with possible keys True and False, and values given by the + node that is reached by taking the True/False branch. An empty dictionary + indicates that there are no branches to take, and so self.next gives the + next node (in a set of size 1). A value of None indicates that taking that + branch leads to the exit, since there are no exit ControlFlowNodes in a + ControlFlowGraph. + """ + if self.block is None: + return {} # We're not in a block. No branch decision. + index_in_block = self.block.index_of(self) + if len(self.block.control_flow_nodes) > index_in_block + 1: + return {} # We're not yet at the end of the block. No branch decision. + + branches = {} # We're at the end of the block. + for key, next_block in self.block.branches.items(): + if next_block.control_flow_nodes: + branches[key] = next_block.control_flow_nodes[0] + else: + # If next_block is empty, it isn't the case that some downstream block + # is nonempty. This is guaranteed by the pruning phase of control flow + # graph construction. + assert not next_block.next + branches[key] = None # Indicates exit; there is no node to return. + return branches + + def has_label(self, label): + """Returns whether this Instruction has the specified label.""" + return label in self.labels + + def set_label(self, label, value): + """Sets the value of a label on the Instruction.""" + self.labels[label] = value + + def get_label(self, label): + """Gets the value of a label on the Instruction.""" + return self.labels[label] + + +# pylint: disable=invalid-name,g-doc-return-or-yield,g-doc-args +class ControlFlowVisitor(object): + """A visitor for determining the control flow of a Python program from an AST. + + The main function of interest here is `visit`, which causes the visitor to + construct the control flow graph for the node passed to visit. + + Basic control flow: + The state of the Visitor consists of a sequence of frames, and a current + basic block. When an AST node is visited by `visit`, it is added to the + current basic block. When a node can indicate a possible change in control, + new basic blocks are created and exits between the basic blocks are added + as appropriate. + + For example, an If statement introduces two possibilities for control flow. + Consider the program: + + if a > b: + c = 1 + else: + c = 2 + return c + + There are four basic blocks in this program: let's call them `compare`, + `c = 1`, `c = 2`, and `return`. The exits between the blocks are: + `compare` -> `c = 1`, `compare` -> `c = 2`, `c = 1` -> `return`, and + `c = 2` -> `return`. + + Frames: + There are four kinds of frames: function frames, loop frames, try-except, and + try-finally frames. All AST nodes in a function definition are in that + function's function frame. All AST nodes in the body of a loop are in that + loop's loop frame. And all AST nodes in the try and except blocks of a + try/except/finally are in that try's try-finally frame. + + A function frame contains information about where control should flow to in + the case of a return statement or an uncaught exception. + + A loop frame contains information about where control should pass to in the + case of a continue or break statement. + + A try-except frame contains information about where control should flow to in + the case of an exception. + + A try-finally frame contains information about where control should flow to + in the case of an exit (such as a finally block that must run before a return, + continue, or break statement can be executed). + + Attributes: + graph: The control flow graph being generated by the visitor. + frames: The current frames. Each frame in this list contains all frames that + come after it in the list. + """ + + def __init__(self): + self.graph = ControlFlowGraph() + self.frames = [] + + def run(self, node): + start_block = self.graph.start_block + end_block = self.visit(node, start_block) + exit_block = self.new_block(node=node, label='', prunable=False) + end_block.add_exit(exit_block) + self.graph.compact() + + def visit(self, node, current_block): + """Visit node, either an AST node or a list. + + Args: + node: The AST node being visited. Not necessarily an instance of ast.AST; + node may also be a list, primitive, or Instruction. + current_block: The basic block whose execution necessarily precedes the + execution of `node`. + Returns: + The final basic block for the node. + """ + assert isinstance(node, ast.AST) + + if isinstance(node, instruction_module.INSTRUCTION_AST_NODES): + self.add_new_instruction(current_block, node) + + method_name = 'visit_' + node.__class__.__name__ + method = getattr(self, method_name, None) + if method is not None: + current_block = method(node, current_block) + return current_block + + def visit_list(self, items, current_block): + """Visit each of the items in a list from the AST.""" + for item in items: + current_block = self.visit(item, current_block) + return current_block + + def add_new_instruction(self, block, node, accesses=None, source=None): + assert isinstance(node, ast.AST) + instruction = instruction_module.Instruction( + node, accesses=accesses, source=source) + self.add_instruction(block, instruction) + + def add_instruction(self, block, instruction): + assert isinstance(instruction, instruction_module.Instruction) + block.add_instruction(instruction) + + # Any instruction may raise an exception. + if not block.exits_from_middle: + self.raise_through_frames(block, interrupting=True) + + def raise_through_frames(self, block, interrupting=True): + """Adds exits for the control flow of a raised exception. + + `interrupting` means the exit can occur at any point (exit_from_middle). + `not interrupting` means the exit can only occur at the end of the block. + + The reason to raise_through_frames with interrupting=False is for an + exception that already has been partially raised, but has passed control to + a finally block, and is now being raised at the end of that finally block. + + Args: + block: The block where the exception's control flow begins. + interrupting: Whether the exception can be raised from any point in block. + If False, the exception is only raised from the end of block. + """ + frames = self.get_current_exception_handling_frames() + + if frames is None: + return + + for frame in frames: + if frame.kind == Frame.TRY_FINALLY: + # Exit to finally and have finally exit to whatever's next... + final_block = frame.blocks['final_block'] + block.add_exit(final_block, interrupting=interrupting) + block = frame.blocks['final_block_end'] + interrupting = False + elif frame.kind == Frame.TRY_EXCEPT: + handler_block = frame.blocks['handler_block'] + block.add_exit(handler_block, interrupting=interrupting) + interrupting = False # return... + elif frame.kind == Frame.FUNCTION: + raise_block = frame.blocks['raise_block'] + block.add_exit(raise_block, interrupting=interrupting) + + def new_block(self, node=None, label=None, prunable=True): + """Create a new block.""" + return self.graph.new_block(node=node, label=label, prunable=prunable) + + def enter_loop_frame(self, continue_block, break_block): + # The loop body is the interior of the frame. + # The continue block (loop condition) and break block (loop's after block) + # are the exits from the frame. + self.frames.append(Frame(Frame.LOOP, + continue_block=continue_block, + break_block=break_block)) + + def enter_function_frame(self, return_block, raise_block): + # The function body is the interior of the frame. + # The return block and raise block are the exits from the frame. + self.frames.append(Frame(Frame.FUNCTION, + return_block=return_block, + raise_block=raise_block)) + + def enter_try_except_frame(self, handler_block): + # The try block is the interior of the frame. + # handler_block is where the frame exits to on an exception. + self.frames.append(Frame(Frame.TRY_EXCEPT, + handler_block=handler_block)) + + def enter_try_finally_frame(self, final_block, final_block_end): + # The try block and handler blocks are the interior of the frame. + # The finally block is the exit from the frame. + self.frames.append(Frame(Frame.TRY_FINALLY, + final_block=final_block, + final_block_end=final_block_end)) + + def exit_frame(self): + """Exits the innermost current frame. + + Note: Each enter_* function must be matched to exactly one exit_frame call + in reverse order. + + Returns: + The frame being exited. + """ + return self.frames.pop() + + def get_current_loop_frame(self): + """Gets the current loop frame and contained current try-finally frames. + + In order to exit the current loop frame, we must first enter the finally + blocks of all current contained try-finally frames. + + Returns: + A list of frames, all of which are try-finally frames except for the last, + which is the current loop frame. Each of the returned try-finally + frames is contained within the current loop frame. + """ + frames = [] + for frame in reversed(self.frames): + if frame.kind == Frame.TRY_FINALLY: + frames.append(frame) + if frame.kind == Frame.LOOP: + frames.append(frame) + return frames + # There are no loop frames. + return None + + def get_current_function_frame(self): + """Gets the current function frame and contained current try-finally frames. + + In order to exit the current function frame, we must first enter the finally + blocks of all current contained try-finally frames. + + Returns: + A list of frames, all of which are try-finally frames except for the last, + which is the current function frame. Each of the returned try-finally + frames is contained within the current function frame. + """ + frames = [] + for frame in reversed(self.frames): + if frame.kind == Frame.TRY_FINALLY: + frames.append(frame) + if frame.kind == Frame.FUNCTION: + frames.append(frame) + return frames + # There are no function frames. + return None + + def get_current_exception_handling_frames(self): + """Get all exception handling frames containing the current block. + + Returns: + A list of frames, all of which are exception handling frames containing + the current block. Any instruction contained in a try-except frame may + exit to the frame's exception handling block, with the caveat that an + instruction cannot exit through a TRY_FINALLY frame without passing first + through the frame's finally block. (The instruction will exit to the + finally block, and the finally block in turn will exit to the exception + handler.) A function frame's raise block serves to catch exceptions as + well. + """ + frames = [] + # Traverse frames from innermost to outermost until a frame that fully + # catches the exception is found. + for frame in reversed(self.frames): + if frame.kind == Frame.TRY_FINALLY: + frames.append(frame) + if frame.kind == Frame.TRY_EXCEPT: + # A try-except frame catches any exception, even if the frame's except + # statements do not match the exception. In this case, the final except + # will reraise the exception to higher frames. + frames.append(frame) + return frames + if frame.kind == Frame.FUNCTION: + # A function frame's raise_block catches any exception that reaches it. + frames.append(frame) + return frames + # There is no frame to fully catch the exception. + return None + + def visit_Module(self, node, current_block): + return self.visit_list(node.body, current_block) + + def visit_ClassDef(self, node, current_block): + """Visit a ClassDef node of the AST. + + Blocks: + current_block: The block in which the class is defined. + """ + # TODO(dbieber): Make sure all statements are handled, such as base classes. + # http://greentreesnakes.readthedocs.io/en/latest/nodes.html#ClassDef + # The body is exceuted before the decorators. + current_block = self.visit_list(node.body, current_block) + for decorator in node.decorator_list: + self.add_new_instruction(current_block, decorator) + assert isinstance(node.name, six.string_types) + self.add_new_instruction( + current_block, + node, + accesses=instruction_module.create_writes(node.name, node), + source=instruction_module.CLASS) + return current_block + + def visit_FunctionDef(self, node, current_block): + """Visit a FunctionDef node of the AST. + + Blocks: + current_block: The block in which the function is defined. + """ + # First defaults are computed, then decorators are run, then the functiondef + # is assigned to the function name. + current_block = self.handle_argument_defaults(node.args, current_block) + for decorator in node.decorator_list: + self.add_new_instruction(current_block, decorator) + assert isinstance(node.name, six.string_types) + self.add_new_instruction( + current_block, + node, + accesses=instruction_module.create_writes(node.name, node), + source=instruction_module.FUNCTION) + self.handle_function_definition(node, node.name, node.args, node.body) + return current_block + + def visit_Lambda(self, node, current_block): + """Visit a Lambda node of the AST. + + Blocks: + current_block: The block in which the lambda is defined. + """ + current_block = self.handle_argument_defaults(node.args, current_block) + self.handle_function_definition(node, 'lambda', node.args, node.body) + return current_block + + def handle_function_definition(self, node, name, args, body): + """A helper fn for Lambda and FunctionDef. + + Note that this function doesn't require a block as input, since it doesn't + modify the blocks where the function definition resides. + + Blocks: + entry_block: The block where control flow starts when the function is + called. + return_block: The block the function returns to. + raise_block: The block the function raises uncaught exceptions to. + fn_block: The first used block of the FunctionDef. + + Args: + node: The AST node of the function definition, either a FunctionDef or + Lambda node. + name: The function's name, a string. + args: The function's args, an ast.arguments node. + body: The function's body, a list of AST nodes. + """ + return_block = self.new_block(node=node, label='', prunable=False) + raise_block = self.new_block(node=node, label='', prunable=False) + self.enter_function_frame(return_block, raise_block) + + entry_block = self.new_block(node=node, label=''.format(name), + prunable=False) + fn_block = self.new_block(node=node, label='fn_block') + entry_block.add_exit(fn_block) + fn_block = self.handle_argument_writes(args, fn_block) + fn_block = self.visit_list(body, fn_block) + fn_block.add_exit(return_block) + self.exit_frame() + + def handle_argument_defaults(self, node, current_block): + """Add Instructions for all of a FunctionDef's default values. + + Note that these instructions are in the block containing the function, not + in the function definition itself. + """ + for default in node.defaults: + self.add_new_instruction(current_block, default) + for default in node.kw_defaults: + self.add_new_instruction(current_block, default) + return current_block + + def handle_argument_writes(self, node, current_block): + """Add Instructions for all of a FunctionDef's arguments. + + These instructions are part of a function's body. + """ + accesses = [] + if node.args: + for arg in node.args: + accesses.extend(instruction_module.create_writes(arg, node)) + if node.vararg: + accesses.extend(instruction_module.create_writes(node.vararg, node)) + if node.kwonlyargs: + for arg in node.kwonlyargs: + accesses.extend(instruction_module.create_writes(arg, node)) + if node.kwarg: + accesses.extend(instruction_module.create_writes(node.kwarg, node)) + + if accesses: + self.add_new_instruction( + current_block, + node, + accesses=accesses, + source=instruction_module.ARGS) + return current_block + + def visit_If(self, node, current_block): + """Visit an If node of the AST. + + Blocks: + current_block: This is where the if statement resides. The if statement's + test is added here. + after_block: The block to which control is passed after the if statement + is completed. + true_block: The true branch of the if statements. + false_block: The false branch of the if statements. + """ + self.add_new_instruction(current_block, node.test) + after_block = self.new_block(node=node, label='after_block') + true_block = self.new_block(node=node, label='true_block') + current_block.add_exit(true_block, branch=True) + true_block = self.visit_list(node.body, true_block) + true_block.add_exit(after_block) + if node.orelse: + false_block = self.new_block(node=node, label='false_block') + current_block.add_exit(false_block, branch=False) + false_block = self.visit_list(node.orelse, false_block) + false_block.add_exit(after_block) + else: + current_block.add_exit(after_block, branch=False) + return after_block + + def visit_While(self, node, current_block): + """Visit a While node of the AST. + + Blocks: + current_block: This is where the while statement resides. + """ + test_instruction = instruction_module.Instruction(node.test) + return self.handle_Loop(node, test_instruction, current_block) + + def visit_For(self, node, current_block): + """Visit a For node of the AST. + + Blocks: + current_block: This is where the for statement resides. + """ + self.add_new_instruction(current_block, node.iter) + # node.target is a Name, Tuple, or List node. + # We wrap it in an Instruction so it knows where its write is coming from. + target = instruction_module.Instruction( + node.target, + accesses=instruction_module.create_writes(node.target, node), + source=instruction_module.ITERATOR) + return self.handle_Loop(node, target, current_block) + + def handle_Loop(self, node, loop_instruction, current_block): + """A helper fn for For and While. + + Args: + node: The AST node representing the loop. + loop_instruction: The Instruction in the loop header, such as a test or an + assignment from an iterator. + current_block: The BasicBlock containing the loop. + + Blocks: + current_block: This is where the loop resides. + test_block: Contains the part of the loop header that is repeated. For a + While, this is the loop condition. For a For, this is assignment to the + target variable. + test_block_end: The last block in the test (often the same as test_block.) + body_block: The body of the loop. + else_block: Executed if the loop terminates naturally. + after_block: Follows the completion of the loop. + """ + # We do not add an instruction for the ast.For or ast.While node. + test_block = self.new_block(node=node, label='test_block') + current_block.add_exit(test_block) + self.add_instruction(test_block, loop_instruction) + body_block = self.new_block(node=node, label='body_block') + after_block = self.new_block(node=node, label='after_block') + + test_block.add_exit(body_block, branch=True) + # In the loop, continue goes to test_block and break goes to after_block. + self.enter_loop_frame(test_block, after_block) + body_block = self.visit_list(node.body, body_block) + body_block.add_exit(test_block) + self.exit_frame() + + # If a loop exits via its test (rather than via a break) and it has + # an orelse, then it enters the orelse. + if node.orelse: + else_block = self.new_block(node=node, label='else_block') + test_block.add_exit(else_block, branch=False) + else_block = self.visit_list(node.orelse, else_block) + else_block.add_exit(after_block) + else: + test_block.add_exit(after_block, branch=False) + + return after_block + + def visit_Try(self, node, current_block): + """Visit a Try node of the AST. + + Blocks: + current_block: This is where the try statement resides. + after_block: The block to which control flows after the conclusion of the + full try statement (including e.g. the else and finally sections, if + present). + handler_blocks: A list of blocks corresponding to the except statements. + bare_handler_block: The handler block corresponding to a bare except + statement. One of handler_blocks or None. + handler_body_blocks: A list of blocks corresponding to the bodies of the + except sections. + final_block: The block corresponding to the finally section. + final_block_end: The last block corresponding to the finally section. + try_block: The block corresponding to the try section. + try_block_end: The last block corresponding to the try section. + else_block: The block corresponding to the else section. + """ + # We do not add an instruction for the ast.Try node. + after_block = self.new_block() + handler_blocks = [self.new_block() for _ in node.handlers] + handler_body_blocks = [self.new_block() for _ in node.handlers] + + # If there is a bare except clause, determine its handler block. + # Only the last except is permitted to be a bare except. + if node.handlers and node.handlers[-1].type is None: + bare_handler_block = handler_blocks[-1] + else: + bare_handler_block = None + + if node.finalbody: + final_block = self.new_block(node=node, label='final_block') + final_block_end = self.visit_list(node.finalbody, final_block) + final_block_end.add_exit(after_block) + self.enter_try_finally_frame(final_block, final_block_end) + else: + final_block = after_block + + if node.handlers: + self.enter_try_except_frame(handler_blocks[0]) + + # The exits from the try_block may happen at any point since any instruction + # can throw an exception. + try_block = self.new_block(node=node, label='try_block') + current_block.add_exit(try_block) + try_block_end = self.visit_list(node.body, try_block) + if node.orelse: # The try body can exit to the else block. + else_block = self.new_block(node=node, label='else_block') + try_block_end.add_exit(else_block) + else: # If there is no else block, the try body can exit to final/after. + try_block_end.add_exit(final_block) + + if node.handlers: + self.exit_frame() # Exit the try-except frame. + + previous_handler_block_end = None + for handler, handler_block, handler_body_block in zip(node.handlers, + handler_blocks, + handler_body_blocks): + previous_handler_block_end = self.handle_ExceptHandler( + handler, handler_block, handler_body_block, final_block, + previous_handler_block_end=previous_handler_block_end) + + if bare_handler_block is None and previous_handler_block_end is not None: + # If no exceptions match, then raise up through the frames. + # (A bare-except will always match.) + self.raise_through_frames(previous_handler_block_end, interrupting=False) + + if node.orelse: + else_block = self.visit_list(node.orelse, else_block) + else_block.add_exit(final_block) # orelse exits to final/after + + if node.finalbody: + self.exit_frame() # Exit the try-finally frame. + + return after_block + + def handle_ExceptHandler(self, handler, handler_block, handler_body_block, + final_block, previous_handler_block_end=None): + """Create the blocks appropriate for an exception handler. + + Args: + handler: The AST ExceptHandler node. + handler_block: The block corresponding the ExceptHandler header. + handler_body_block: The block corresponding to the ExceptHandler body. + final_block: Where the handler body should exit to when it executes + successfully. + previous_handler_block_end: The last block corresponding to the previous + ExceptHandler header, if there is one, or None otherwise. The previous + handler's header should exit to this handler's header if the exception + doesn't match the previous handler's header. + + Returns: + The last (usually the only) block in the handler's header. + + Note that rather than having a visit_ExceptHandler function, we instead + use the following logic. This is because except statements don't follow + the visitor pattern exactly. Specifically, a handler may exit to either + its body or to the next handler, but under the visitor pattern the + handler would not know the block belonging to the next handler. + """ + if handler.type is not None: + self.add_new_instruction(handler_block, handler.type) + # An ExceptHandler header can only have a single Instruction, so there is + # only one handler_block BasicBlock. + handler_block.add_exit(handler_body_block) + + if previous_handler_block_end is not None: + previous_handler_block_end.add_exit(handler_block) + previous_handler_block_end = handler_block + + if handler.name is not None: + # handler.name is a Name, Tuple, or List AST node. + self.add_new_instruction( + handler_body_block, + handler.name, + accesses=instruction_module.create_writes(handler.name, handler), + source=instruction_module.EXCEPTION) + handler_body_block = self.visit_list(handler.body, handler_body_block) + handler_body_block.add_exit(final_block) # handler exits to final/after + return previous_handler_block_end + + def visit_Return(self, node, current_block): + """Visit a Return node of the AST. + + Blocks: + current_block: This is where the return statement resides. + return_block: The containing function's return block. All successful exits + from the function lead here. + + Raises: + RuntimeError: If a return AST node is visited while not in a function + frame. + """ + # The Return statement is an Instruction. Don't visit the node's children. + frames = self.get_current_function_frame() + if frames is None: + raise RuntimeError('return occurs outside of a function frame.') + try_finally_frames = frames[:-1] + function_frame = frames[-1] + + return_block = function_frame.blocks['return_block'] + return self.handle_ExitStatement(node, + return_block, + try_finally_frames, + current_block) + + def visit_Yield(self, node, current_block): + """Visit a Yield node of the AST. + + The current implementation of yields allows control to flow directly through + a yield statement. TODO(dbieber): Introduce a node in between + yielding and resuming execution. + TODO(dbieber): Yield nodes aren't even visited since they are contained in + Expr nodes. Determine if Yield can occur outside of an Expr. Check for + Yield when visiting Expr. + """ + logging.warn('yield visited: %s', ast.dump(node)) + # The Yield statement is an Instruction. Don't visit children. + return current_block + + def visit_Continue(self, node, current_block): + """Visit a Continue node of the AST. + + Blocks: + current_block: This is where the continue statement resides. + continue_block: The block of the containing loop's header. For a For, + this is the target variable assignment. For a While, this is the loop + condition. + + Raises: + RuntimeError: If a continue AST node is visited while not in a loop frame. + """ + frames = self.get_current_loop_frame() + if frames is None: + raise RuntimeError('continue occurs outside of a loop frame.') + + try_finally_frames = frames[:-1] + loop_frame = frames[-1] + + continue_block = loop_frame.blocks['continue_block'] + return self.handle_ExitStatement(node, + continue_block, + try_finally_frames, + current_block) + + def visit_Break(self, node, current_block): + """Visit a Break node of the AST. + + Blocks: + current_block: This is where the break statement resides. + break_block: The block that the containing loop exits to. + + Raises: + RuntimeError: If a break AST node is visited while not in a loop frame. + """ + frames = self.get_current_loop_frame() + if frames is None: + raise RuntimeError('break occurs outside of a loop frame.') + + try_finally_frames = frames[:-1] + loop_frame = frames[-1] + + break_block = loop_frame.blocks['break_block'] + return self.handle_ExitStatement(node, + break_block, + try_finally_frames, + current_block) + + def visit_Raise(self, node, current_block): + """Visit a Raise node of the AST. + + Blocks: + current_block: This is where the raise statement resides. + after_block: An unreachable block for code that follows the raise + statement. + """ + del current_block + # The Raise statement is an Instruction. Don't visit children. + + # Note there is no exit to the after_block. It is unreachable. + after_block = self.new_block(node=node, label='after_block') + return after_block + + def handle_ExitStatement(self, node, next_block, try_finally_frames, + current_block): + """A helper fn for Return, Continue, and Break. + + An exit statement is a statement such as return, continue, break, or raise. + Such a statement causes control to leave through a frame's exit. Any + instructions immediately following an exit statement will be unreachable. + + Args: + node: The AST node of the exit statement. + next_block: The block the exit statement exits to. + try_finally_frames: A possibly empty list of try-finally frames whose + finally blocks must be executed before control can pass to next_block. + current_block: The block the exit statement resides in. + + Blocks: + current_block: This is where the exit statement resides. + next_block: The block the exit statement exits to (after first passing + through all the finally blocks.) + final_block: The start of a finally section that control must pass + through on the way to next_block. + final_block_end: The end of a finally section that control must pass + through on the way to next_block. + after_block: An unreachable block for code that follows the raise + statement. + """ + for try_finally_frame in try_finally_frames: + final_block = try_finally_frame.blocks['final_block'] + current_block.add_exit(final_block) + current_block = try_finally_frame.blocks['final_block_end'] + + current_block.add_exit(next_block) + + # Note there is no exit to the after_block. It is unreachable. + after_block = self.new_block(node=node, label='after_block') + return after_block diff --git a/python_graphs/control_flow_graphviz.py b/python_graphs/control_flow_graphviz.py new file mode 100644 index 0000000..fff84bb --- /dev/null +++ b/python_graphs/control_flow_graphviz.py @@ -0,0 +1,113 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Graphviz render for control flow graphs.""" + +from absl import logging # pylint: disable=unused-import +import astunparse +import gast as ast +import pygraphviz + +LEFT_ALIGN = '\l' # pylint: disable=anomalous-backslash-in-string + + +def render(graph, include_src=None, path='/tmp/graph.png'): + g = to_graphviz(graph, include_src=include_src) + g.draw(path, prog='dot') + + +def trim(line, max_length=30): + if len(line) <= max_length: + return line + return line[:max_length - 3] + '...' + + +def unparse(node): + source = astunparse.unparse(node) + trimmed_source = '\n'.join(trim(line) for line in source.split('\n')) + return ( + trimmed_source.strip() + .rstrip(' \n') + .lstrip(' \n') + .replace('\n', LEFT_ALIGN) + ) + + +def write_as_str(write): + if isinstance(write, ast.AST): + return unparse(write) + else: + return write + + +def get_label_for_instruction(instruction): + if instruction.source is not None: + line = ', '.join(write for write in instruction.get_write_names()) + line += ' <- ' + instruction.source + return line + else: + return unparse(instruction.node) + + +def get_label(block): + """Gets the source code for a control flow basic block.""" + lines = [] + for control_flow_node in block.control_flow_nodes: + instruction = control_flow_node.instruction + line = get_label_for_instruction(instruction) + if line.strip(): + lines.append(line) + + return LEFT_ALIGN.join(lines) + LEFT_ALIGN + + +def to_graphviz(graph, include_src=None): + """To graphviz.""" + g = pygraphviz.AGraph(strict=False, directed=True) + for block in graph.blocks: + node_attrs = {} + label = get_label(block) + # We only show the , , , , block labels. + if block.label is not None and block.label.startswith('<'): + node_attrs['style'] = 'bold' + if not label.rstrip(LEFT_ALIGN): + label = block.label + LEFT_ALIGN + else: + label = block.label + LEFT_ALIGN + label + node_attrs['label'] = label + node_attrs['fontname'] = 'Courier New' + node_attrs['fontsize'] = 10.0 + + node_id = id(block) + g.add_node(node_id, **node_attrs) + for next_node in block.next: + next_node_id = id(next_node) + if next_node in block.exits_from_middle: + edge_attrs = {} + edge_attrs['style'] = 'dashed' + g.add_edge(node_id, next_node_id, **edge_attrs) + if next_node in block.exits_from_end: + edge_attrs = {} + edge_attrs['style'] = 'solid' + g.add_edge(node_id, next_node_id, **edge_attrs) + + if include_src is not None: + node_id = id(include_src) + node_attrs['label'] = include_src.replace('\n', LEFT_ALIGN) + node_attrs['fontname'] = 'Courier New' + node_attrs['fontsize'] = 10.0 + node_attrs['shape'] = 'box' + g.add_node(node_id, **node_attrs) + + return g diff --git a/python_graphs/control_flow_graphviz_test.py b/python_graphs/control_flow_graphviz_test.py new file mode 100644 index 0000000..6741c19 --- /dev/null +++ b/python_graphs/control_flow_graphviz_test.py @@ -0,0 +1,41 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for control_flow_graphviz.py.""" + +import inspect + +from absl.testing import absltest +from python_graphs import control_flow +from python_graphs import control_flow_graphviz +from python_graphs import control_flow_test_components as tc + + +class ControlFlowGraphvizTest(absltest.TestCase): + + def test_to_graphviz_for_all_test_components(self): + for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction): + graph = control_flow.get_control_flow_graph(fn) + control_flow_graphviz.to_graphviz(graph) + + def test_get_label_multi_op_expression(self): + graph = control_flow.get_control_flow_graph(tc.multi_op_expression) + block = graph.get_block_by_source('1 + 2 * 3') + self.assertEqual( + control_flow_graphviz.get_label(block).strip(), + 'return (1 + (2 * 3))\\l') + + +if __name__ == '__main__': + absltest.main() diff --git a/python_graphs/control_flow_test.py b/python_graphs/control_flow_test.py new file mode 100644 index 0000000..5156d01 --- /dev/null +++ b/python_graphs/control_flow_test.py @@ -0,0 +1,308 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for control_flow.py.""" + +import inspect + +from absl import logging # pylint: disable=unused-import +from absl.testing import absltest +import gast as ast +from python_graphs import control_flow +from python_graphs import control_flow_test_components as tc +from python_graphs import instruction as instruction_module +from python_graphs import program_utils +import six + + +class ControlFlowTest(absltest.TestCase): + + def get_block(self, graph, selector): + if isinstance(selector, control_flow.BasicBlock): + return selector + elif isinstance(selector, six.string_types): + return graph.get_block_by_source(selector) + + def assertSameBlock(self, graph, selector1, selector2): + block1 = self.get_block(graph, selector1) + block2 = self.get_block(graph, selector2) + self.assertEqual(block1, block2) + + def assertExitsTo(self, graph, selector1, selector2): + block1 = self.get_block(graph, selector1) + block2 = self.get_block(graph, selector2) + self.assertTrue(block1.exits_to(block2)) + + def assertNotExitsTo(self, graph, selector1, selector2): + block1 = self.get_block(graph, selector1) + block2 = self.get_block(graph, selector2) + self.assertFalse(block1.exits_to(block2)) + + def assertRaisesTo(self, graph, selector1, selector2): + block1 = self.get_block(graph, selector1) + block2 = self.get_block(graph, selector2) + self.assertTrue(block1.raises_to(block2)) + + def assertNotRaisesTo(self, graph, selector1, selector2): + block1 = self.get_block(graph, selector1) + block2 = self.get_block(graph, selector2) + self.assertFalse(block1.raises_to(block2)) + + def test_control_flow_straight_line_code(self): + graph = control_flow.get_control_flow_graph(tc.straight_line_code) + self.assertSameBlock(graph, 'x = 1', 'y = x + 2') + self.assertSameBlock(graph, 'x = 1', 'z = y * 3') + self.assertSameBlock(graph, 'x = 1', 'return z') + + def test_control_flow_simple_if_statement(self): + graph = control_flow.get_control_flow_graph(tc.simple_if_statement) + x1_block = 'x = 1' + y2_block = 'y = 2' + xy_block = 'x > y' + y3_block = 'y = 3' + return_block = 'return y' + self.assertSameBlock(graph, x1_block, y2_block) + self.assertSameBlock(graph, x1_block, xy_block) + self.assertExitsTo(graph, xy_block, y3_block) + self.assertExitsTo(graph, xy_block, return_block) + self.assertExitsTo(graph, y3_block, return_block) + self.assertNotExitsTo(graph, y3_block, x1_block) + self.assertNotExitsTo(graph, return_block, x1_block) + self.assertNotExitsTo(graph, return_block, y3_block) + + def test_control_flow_simple_for_loop(self): + graph = control_flow.get_control_flow_graph(tc.simple_for_loop) + x1_block = 'x = 1' + iter_block = 'range' + target_block = 'y' + body_block = 'y + 3' + return_block = 'return z' + self.assertSameBlock(graph, x1_block, iter_block) + self.assertExitsTo(graph, iter_block, target_block) + self.assertExitsTo(graph, target_block, body_block) + self.assertNotExitsTo(graph, body_block, return_block) + self.assertExitsTo(graph, target_block, return_block) + + def test_control_flow_simple_while_loop(self): + graph = control_flow.get_control_flow_graph(tc.simple_while_loop) + x1_block = 'x = 1' + test_block = 'x < 2' + body_block = 'x += 3' + return_block = 'return x' + + self.assertExitsTo(graph, x1_block, test_block) + self.assertExitsTo(graph, test_block, body_block) + self.assertExitsTo(graph, body_block, test_block) + self.assertNotExitsTo(graph, body_block, return_block) + self.assertExitsTo(graph, test_block, return_block) + + def test_control_flow_break_in_while_loop(self): + graph = control_flow.get_control_flow_graph(tc.break_in_while_loop) + # This is just one block since there's no edge from the while loop end + # back to the while loop test, and so the 'x = 1' line can be merged with + # the test. + x1_and_test_block = 'x < 2' + body_block = 'x += 3' + return_block = 'return x' + + self.assertExitsTo(graph, x1_and_test_block, body_block) + self.assertExitsTo(graph, body_block, return_block) + self.assertNotExitsTo(graph, body_block, x1_and_test_block) + self.assertExitsTo(graph, x1_and_test_block, return_block) + + def test_control_flow_nested_while_loops(self): + graph = control_flow.get_control_flow_graph(tc.nested_while_loops) + x1_block = 'x = 1' + outer_test_block = 'x < 2' + y3_block = 'y = 3' + inner_test_block = 'y < 4' + y5_block = 'y += 5' + x6_block = 'x += 6' + return_block = 'return x' + + self.assertExitsTo(graph, x1_block, outer_test_block) + self.assertExitsTo(graph, outer_test_block, y3_block) + self.assertExitsTo(graph, outer_test_block, return_block) + self.assertExitsTo(graph, y3_block, inner_test_block) + self.assertExitsTo(graph, inner_test_block, y5_block) + self.assertExitsTo(graph, inner_test_block, x6_block) + self.assertExitsTo(graph, y5_block, inner_test_block) + self.assertExitsTo(graph, x6_block, outer_test_block) + + def test_control_flow_exception_handling(self): + graph = control_flow.get_control_flow_graph(tc.exception_handling) + self.assertSameBlock(graph, 'before_stmt0', 'before_stmt1') + self.assertExitsTo(graph, 'before_stmt1', 'try_block') + self.assertNotExitsTo(graph, 'before_stmt0', 'except_block1') + self.assertNotExitsTo(graph, 'before_stmt1', 'final_block_stmt0') + self.assertRaisesTo(graph, 'try_block', 'error_type') + self.assertRaisesTo(graph, 'error_type', 'except_block2_stmt0') + self.assertExitsTo(graph, 'except_block1', 'after_stmt0') + + self.assertRaisesTo(graph, 'after_stmt0', 'except_block2_stmt0') + self.assertNotRaisesTo(graph, 'try_block', 'except_block2_stmt0') + + def test_control_flow_try_with_loop(self): + graph = control_flow.get_control_flow_graph(tc.try_with_loop) + self.assertSameBlock(graph, 'for_body0', 'for_body1') + self.assertSameBlock(graph, 'except_body0', 'except_body1') + + self.assertExitsTo(graph, 'before_stmt0', 'iterator') + self.assertExitsTo(graph, 'iterator', 'target') + self.assertExitsTo(graph, 'target', 'for_body0') + self.assertExitsTo(graph, 'for_body1', 'target') + self.assertExitsTo(graph, 'target', 'after_stmt0') + + self.assertRaisesTo(graph, 'iterator', 'except_body0') + self.assertRaisesTo(graph, 'target', 'except_body0') + self.assertRaisesTo(graph, 'for_body1', 'except_body0') + + def test_control_flow_break_in_finally(self): + graph = control_flow.get_control_flow_graph(tc.break_in_finally) + + # The exception handlers are tried sequentially until one matches. + self.assertRaisesTo(graph, 'try0', 'Exception0') + self.assertExitsTo(graph, 'Exception0', 'Exception1') + self.assertExitsTo(graph, 'Exception1', 'finally_stmt0') + # If the finally block were to finish and the exception hadn't matched, then + # the exception would exit to the FunctionDef's raise_block. However, the + # break statement prevents the finally from finishing and so the exception + # is lost when the break statement is reached. + # TODO(dbieber): Add the following assert. + # raise_block = graph.get_raise_block('break_in_finally') + # self.assertNotExitsFromEndTo(graph, 'finally_stmt1', raise_block) + # The finally block can of course still raise an exception of its own, so + # the following is still true: + # TODO(dbieber): Add the following assert. + # self.assertRaisesTo(graph, 'finally_stmt1', raise_block) + + # An exception in the except handlers could flow to the finally block. + self.assertRaisesTo(graph, 'Exception0', 'finally_stmt0') + self.assertRaisesTo(graph, 'exception0_stmt0', 'finally_stmt0') + self.assertRaisesTo(graph, 'Exception1', 'finally_stmt0') + + # The break statement flows to after0, rather than to the loop header. + self.assertNotExitsTo(graph, 'finally_stmt1', 'target0') + self.assertExitsTo(graph, 'finally_stmt1', 'after0') + + def test_control_flow_for_loop_with_else(self): + graph = control_flow.get_control_flow_graph(tc.for_with_else) + self.assertExitsTo(graph, 'target', 'for_stmt0') + self.assertSameBlock(graph, 'for_stmt0', 'condition') + + # If break is encountered, then the else clause is skipped. + self.assertExitsTo(graph, 'condition', 'after_stmt0') + + # The else clause executes if the loop completes without reaching the break. + self.assertExitsTo(graph, 'target', 'else_stmt0') + self.assertNotExitsTo(graph, 'target', 'after_stmt0') + + def test_control_flow_lambda(self): + graph = control_flow.get_control_flow_graph(tc.create_lambda) + self.assertNotExitsTo(graph, 'before_stmt0', 'args') + self.assertNotExitsTo(graph, 'before_stmt0', 'output') + + def test_control_flow_generator(self): + graph = control_flow.get_control_flow_graph(tc.generator) + self.assertExitsTo(graph, 'target', 'yield_statement') + self.assertSameBlock(graph, 'yield_statement', 'after_stmt0') + + def test_control_flow_inner_fn_while_loop(self): + graph = control_flow.get_control_flow_graph(tc.fn_with_inner_fn) + self.assertExitsTo(graph, 'x = 10', 'True') + self.assertExitsTo(graph, 'True', 'True') + self.assertSameBlock(graph, 'True', 'True') + + def test_control_flow_example_class(self): + graph = control_flow.get_control_flow_graph(tc.ExampleClass) + self.assertSameBlock(graph, 'method_stmt0', 'method_stmt1') + + def test_control_flow_return_outside_function(self): + with self.assertRaises(RuntimeError) as error: + control_flow.get_control_flow_graph('return x') + self.assertContainsSubsequence(str(error.exception), + 'outside of a function frame') + + def test_control_flow_continue_outside_loop(self): + control_flow.get_control_flow_graph('for i in j: continue') + with self.assertRaises(RuntimeError) as error: + control_flow.get_control_flow_graph('if x: continue') + self.assertContainsSubsequence(str(error.exception), + 'outside of a loop frame') + + def test_control_flow_break_outside_loop(self): + control_flow.get_control_flow_graph('for i in j: break') + with self.assertRaises(RuntimeError) as error: + control_flow.get_control_flow_graph('if x: break') + self.assertContainsSubsequence(str(error.exception), + 'outside of a loop frame') + + def test_control_flow_for_all_test_components(self): + for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction): + control_flow.get_control_flow_graph(fn) + + def test_control_flow_for_all_test_components_ast_to_instruction(self): + """All INSTRUCTION_AST_NODES in an AST correspond to one Instruction. + + This assumes that a simple statement can't contain another simple statement. + However, Yield nodes are the exception to this as they are contained within + Expr nodes. + + We omit Yield nodes from INSTRUCTION_AST_NODES despite them being listed + as simple statements in the Python docs. + """ + for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction): + node = program_utils.program_to_ast(fn) + graph = control_flow.get_control_flow_graph(node) + for n in ast.walk(node): + if not isinstance(n, instruction_module.INSTRUCTION_AST_NODES): + continue + control_flow_nodes = list(graph.get_control_flow_nodes_by_ast_node(n)) + self.assertLen(control_flow_nodes, 1, ast.dump(n)) + + def test_control_flow_reads_and_writes_appear_once(self): + """Asserts each read and write in an Instruction is unique in the graph. + + Note that in the case of AugAssign, the same Name AST node is used once as + a read and once as a write. + """ + for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction): + reads = set() + writes = set() + node = program_utils.program_to_ast(fn) + graph = control_flow.get_control_flow_graph(node) + for instruction in graph.get_instructions(): + # Check that all reads are unique. + for read in instruction.get_reads(): + if isinstance(read, tuple): + read = read[1] + self.assertIsInstance(read, ast.Name, 'Unexpected read type.') + self.assertNotIn(read, reads, + instruction_module.access_name(read)) + reads.add(read) + + # Check that all writes are unique. + for write in instruction.get_writes(): + if isinstance(write, tuple): + write = write[1] + if isinstance(write, six.string_types): + continue + self.assertIsInstance(write, ast.Name) + self.assertNotIn(write, writes, + instruction_module.access_name(write)) + writes.add(write) + + +if __name__ == '__main__': + absltest.main() diff --git a/python_graphs/control_flow_test_components.py b/python_graphs/control_flow_test_components.py new file mode 100644 index 0000000..ba55b54 --- /dev/null +++ b/python_graphs/control_flow_test_components.py @@ -0,0 +1,322 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test components for testing control flow. + +Many of these components would produce RuntimeErrors if run. Their purpose is +for the testing of the control_flow module. +""" + + +# pylint: disable=missing-docstring +# pylint: disable=pointless-statement,undefined-variable +# pylint: disable=unused-variable,unused-argument +# pylint: disable=bare-except,lost-exception,unreachable +# pylint: disable=keyword-arg-before-vararg +def straight_line_code(): + x = 1 + y = x + 2 + z = y * 3 + return z + + +def simple_if_statement(): + x = 1 + y = 2 + if x > y: + y = 3 + return y + + +def simple_for_loop(): + x = 1 + for y in range(x + 2): + z = y + 3 + return z + + +def tuple_in_for_loop(): + a, b = 0, 1 + for a, b in [(1, 2), (2, 3)]: + if a > b: + break + return b - a + + +def simple_while_loop(): + x = 1 + while x < 2: + x += 3 + return x + + +def break_in_while_loop(): + x = 1 + while x < 2: + x += 3 + break + return x + + +def nested_while_loops(): + x = 1 + while x < 2: + y = 3 + while y < 4: + y += 5 + x += 6 + return x + + +def multiple_excepts(): + try: + x = 1 + except ValueError: + x = 2 + x = 3 + except RuntimeError: + x = 4 + except: + x = 5 + return x + + +def try_finally(): + header0 + try: + try0 + try1 + except Exception0 as value0: + exception0_stmt0 + finally: + finally_stmt0 + finally_stmt1 + after0 + + +def exception_handling(): + try: + before_stmt0 + before_stmt1 + try: + try_block + except error_type as value: + except_block1 + after_stmt0 + after_stmt1 + except: + except_block2_stmt0 + except_block2_stmt1 + finally: + final_block_stmt0 + final_block_stmt1 + end_block_stmt0 + end_block_stmt1 + + +def fn_with_args(a, b=10, *varargs, **kwargs): + body_stmt0 + body_stmt1 + return + + +def fn1(a, b): + return a + b + + +def fn2(a, b): + c = a + if a > b: + c -= b + return c + + +def fn3(a, b): + c = a + if a > b: + c -= b + c += 1 + c += 2 + c += 3 + else: + c += b + return c + + +def fn4(i): + count = 0 + for i in range(i): + count += 1 + return count + + +def fn5(i): + count = 0 + for _ in range(i): + if count > 5: + break + count += 1 + return count + + +def fn6(): + count = 0 + while count < 10: + count += 1 + return count + + +def fn7(): + try: + raise ValueError('This will be caught.') + except ValueError as e: + del e + return + + +def try_with_else(): + try: + raise ValueError('This will be caught.') + except ValueError as e: + del e + else: + return 1 + return 2 + + +def for_with_else(): + for target in iterator: + for_stmt0 + if condition: + break + for_stmt1 + else: + else_stmt0 + else_stmt1 + after_stmt0 + + +def fn8(a): + a += 1 + + +def nested_loops(a): + """A test function illustrating nested loops.""" + for i in range(a): + while True: + break + unreachable = 10 + for j in range(i): + for k in range(j): + if j * k > 10: + continue + unreachable = 5 + if i + j == 10: + return True + return False + + +def try_with_loop(): + before_stmt0 + try: + for target in iterator: + for_body0 + for_body1 + except: + except_body0 + except_body1 + after_stmt0 + + +def break_in_finally(): + header0 + for target0 in iter0: + try: + try0 + try1 + except Exception0 as value0: + exception0_stmt0 + except Exception1 as value1: + exception1_stmt0 + exception1_stmt1 + finally: + finally_stmt0 + finally_stmt1 + # This breaks out of the for-loop. + break + after0 + + +def break_in_try(): + count = 0 + for _ in range(10): + try: + count += 1 + # This breaks out of the for-loop through the finally block. + break + except ValueError: + pass + finally: + count += 2 + return count + + +def nested_try_excepts(): + try: + try: + x = 0 + x += 1 + try: + x = 2 + 2 + except ValueError(1+1) as e: + x = 3 - 3 + finally: + x = 4 + except RuntimeError: + x = 5 * 5 + finally: + x = 6 ** 6 + except: + x = 7 / 7 + return x + + +def multi_op_expression(): + return 1 + 2 * 3 + + +def create_lambda(): + before_stmt0 + fn = lambda args: output + after_stmt0 + + +def generator(): + for target in iterator: + yield yield_statement + after_stmt0 + + +def fn_with_inner_fn(): + def inner_fn(): + x = 10 + while True: + pass + + +class ExampleClass(object): + + def method0(self, arg): + method_stmt0 + method_stmt1 diff --git a/python_graphs/control_flow_visualizer.py b/python_graphs/control_flow_visualizer.py new file mode 100644 index 0000000..185c23b --- /dev/null +++ b/python_graphs/control_flow_visualizer.py @@ -0,0 +1,74 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Create control flow graph visualizations for the test components. + + +Usage: +python -m python_graphs.control_flow_visualizer +""" + +import inspect +import os + +from absl import app +from absl import flags +from absl import logging # pylint: disable=unused-import + +from python_graphs import control_flow +from python_graphs import control_flow_graphviz +from python_graphs import control_flow_test_components as tc +from python_graphs import program_utils + +FLAGS = flags.FLAGS + + +def render_functions(functions): + for name, function in functions: + logging.info(name) + graph = control_flow.get_control_flow_graph(function) + path = '/tmp/control_flow_graphs/{}.png'.format(name) + source = program_utils.getsource(function) # pylint: disable=protected-access + control_flow_graphviz.render(graph, include_src=source, path=path) + + +def render_filepaths(filepaths): + for filepath in filepaths: + filename = os.path.basename(filepath).split('.')[0] + logging.info(filename) + with open(filepath, 'r') as f: + source = f.read() + graph = control_flow.get_control_flow_graph(source) + path = '/tmp/control_flow_graphs/{}.png'.format(filename) + control_flow_graphviz.render(graph, include_src=source, path=path) + + +def main(argv): + del argv # Unused. + + functions = [ + (name, fn) + for name, fn in inspect.getmembers(tc, predicate=inspect.isfunction) + ] + render_functions(functions) + + # Add filepaths here to visualize their functions. + filepaths = [ + __file__, + ] + render_filepaths(filepaths) + + +if __name__ == '__main__': + app.run(main) diff --git a/python_graphs/cyclomatic_complexity.py b/python_graphs/cyclomatic_complexity.py new file mode 100644 index 0000000..4ac9812 --- /dev/null +++ b/python_graphs/cyclomatic_complexity.py @@ -0,0 +1,49 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Computes the cyclomatic complexity of a program or control flow graph.""" + + +def cyclomatic_complexity(control_flow_graph): + """Computes the cyclomatic complexity of a function from its cfg.""" + enter_block = next(control_flow_graph.get_enter_blocks()) + + new_blocks = [] + seen_block_ids = set() + new_blocks.append(enter_block) + seen_block_ids.add(id(enter_block)) + num_edges = 0 + + while new_blocks: + block = new_blocks.pop() + for next_block in block.exits_from_end: + num_edges += 1 + if id(next_block) not in seen_block_ids: + new_blocks.append(next_block) + seen_block_ids.add(id(next_block)) + num_nodes = len(seen_block_ids) + + p = 1 # num_connected_components + e = num_edges + n = num_nodes + return e - n + 2 * p + + +def cyclomatic_complexity2(control_flow_graph): + """Computes the cyclomatic complexity of a program from its cfg.""" + # Assumes a single connected component. + p = 1 # num_connected_components + e = sum(len(block.exits_from_end) for block in control_flow_graph.blocks) + n = len(control_flow_graph.blocks) + return e - n + 2 * p diff --git a/python_graphs/cyclomatic_complexity_test.py b/python_graphs/cyclomatic_complexity_test.py new file mode 100644 index 0000000..9aee1bd --- /dev/null +++ b/python_graphs/cyclomatic_complexity_test.py @@ -0,0 +1,38 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for cyclomatic_complexity.py.""" + +from absl.testing import absltest +from absl.testing import parameterized + +from python_graphs import control_flow +from python_graphs import control_flow_test_components as tc +from python_graphs import cyclomatic_complexity + + +class CyclomaticComplexityTest(parameterized.TestCase): + + @parameterized.parameters( + (tc.straight_line_code, 1), + (tc.simple_if_statement, 2), + (tc.simple_for_loop, 2), + ) + def test_cyclomatic_complexity(self, component, target_value): + graph = control_flow.get_control_flow_graph(component) + value = cyclomatic_complexity.cyclomatic_complexity(graph) + self.assertEqual(value, target_value) + +if __name__ == '__main__': + absltest.main() diff --git a/python_graphs/data_flow.py b/python_graphs/data_flow.py new file mode 100644 index 0000000..a1afa4e --- /dev/null +++ b/python_graphs/data_flow.py @@ -0,0 +1,233 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data flow analysis of Python programs.""" + +import collections + +from absl import logging # pylint: disable=unused-import +import gast as ast + +from python_graphs import control_flow +from python_graphs import instruction as instruction_module + + +READ = instruction_module.READ +WRITE = instruction_module.WRITE + + +class Analysis(object): + """Base class for a data flow analysis. + + Attributes: + label: The name of the analysis. + forward: (bool) True for forward analyses, False for backward analyses. + in_label: The name of the analysis, suffixed with _in. + out_label: The name of the analysis, suffixed with _out. + before_label: Either the in_label or out_label depending on the direction of + the analysis. Marks the before_value on a node during an analysis. + after_label: Either the in_label or out_label depending on the direction of + the analysis. Marks the after_value on a node during an analysis. + """ + + def __init__(self, label, forward): + self.label = label + self.forward = forward + + self.in_label = label + '_in' + self.out_label = label + '_out' + + self.before_label = self.in_label if forward else self.out_label + self.after_label = self.out_label if forward else self.in_label + + def aggregate_previous_after_values(self, previous_after_values): + """Computes the before value for a node from the previous after values. + + This is the 'meet' or 'join' function of the analysis. + TODO(dbieber): Update terminology to match standard textbook notation. + + Args: + previous_after_values: The after values of all before nodes. + Returns: + The before value for the current node. + """ + raise NotImplementedError + + def compute_after_value(self, node, before_value): + """Computes the after value for a node from the node and the before value. + + This is the 'transfer' function of the analysis. + TODO(dbieber): Update terminology to match standard textbook notation. + + Args: + node: The node or block for which to compute the after value. + before_value: The before value of the node. + Returns: + The computed after value for the node. + """ + raise NotImplementedError + + def visit(self, node): + """Visit the nodes of the control flow graph, performing the analysis. + + Terminology: + in_value: The value of the analysis at the start of a node. + out_value: The value of the analysis at the end of a node. + before_value: in_value in a forward analysis; out_value in a backward + analysis. + after_value: out_value in a forward analysis; in_value in a backward + analysis. + + Args: + node: A graph element that supports the .next / .prev API, such as a + ControlFlowNode from a ControlFlowGraph or a BasicBlock from a + ControlFlowGraph. + """ + to_visit = collections.deque([node]) + while to_visit: + node = to_visit.popleft() + + before_nodes = node.prev if self.forward else node.next + after_nodes = node.next if self.forward else node.prev + previous_after_values = [ + before_node.get_label(self.after_label) + for before_node in before_nodes + if before_node.has_label(self.after_label)] + + if node.has_label(self.after_label): + initial_after_value_hash = hash(node.get_label(self.after_label)) + else: + initial_after_value_hash = None + before_value = self.aggregate_previous_after_values(previous_after_values) + node.set_label(self.before_label, before_value) + after_value = self.compute_after_value(node, before_value) + node.set_label(self.after_label, after_value) + if hash(after_value) != initial_after_value_hash: + for after_node in after_nodes: + to_visit.append(after_node) + + +def get_while_loop_variables(node, graph=None): + """Gets the set of loop variables used for while loop rewriting. + + This is the set of variables used for rewriting a while loop into its + functional form. + + Args: + node: An ast.While AST node. + graph: (Optional) The ControlFlowGraph of the function or program containing + the while loop. If not present, the control flow graph for the while loop + will be computed. + Returns: + The set of variable identifiers that are live at the start of the loop's + test and at the start of the loop's body. + """ + graph = graph or control_flow.get_control_flow_graph(node) + test_block = graph.get_block_by_ast_node(node.test) + + for block in graph.get_exit_blocks(): + analysis = LivenessAnalysis() + analysis.visit(block) + # TODO(dbieber): Move this logic into the Analysis class to avoid the use of + # magic strings. + live_variables = test_block.get_label('liveness_in') + written_variables = { + write.id + for write in instruction_module.get_writes_from_ast_node(node) + if isinstance(write, ast.Name) + } + return live_variables & written_variables + + +class LivenessAnalysis(Analysis): + """Liveness analysis by basic block. + + In the liveness analysis, the in_value of a block is the set of variables + that are live at the start of a block. "Live" means that the current value of + the variable may be used later in the execution. The out_value of a block is + the set of variable identifiers that are live at the end of the block. + + Since this is a backward analysis, the "before_value" is the out_value and the + "after_value" is the in_value. + """ + + def __init__(self): + super(LivenessAnalysis, self).__init__(label='liveness', forward=False) + + def aggregate_previous_after_values(self, previous_after_values): + """Computes the out_value (before_value) of a block. + + Args: + previous_after_values: A list of the sets of live variables at the start + of each of the blocks following the current block. + Returns: + The set of live variables at the end of the current block. This is the + union of live variable sets at the start of each subsequent block. + """ + result = set() + for before_value in previous_after_values: + result |= before_value + return frozenset(result) + + def compute_after_value(self, block, before_value): + """Computes the liveness analysis gen and kill sets for a basic block. + + The gen set is the set of variables read by the block before they are + written to. + The kill set is the set of variables written to by the basic block. + + Args: + block: The BasicBlock to analyze. + before_value: The out_value for block (the set of variables live at the + end of the block.) + Returns: + The in_value for block (the set of variables live at the start of the + block). + """ + gen = set() + kill = set() + for control_flow_node in block.control_flow_nodes: + instruction = control_flow_node.instruction + for read in instruction.get_read_names(): + if read not in kill: + gen.add(read) + kill.update(instruction.get_write_names()) + return frozenset((before_value - kill) | gen) + + +class FrozenDict(dict): + + def __hash__(self): + return hash(tuple(sorted(self.items()))) + + +class LastAccessAnalysis(Analysis): + """Computes for each variable its possible last reads and last writes.""" + + def __init__(self): + super(LastAccessAnalysis, self).__init__(label='last_access', forward=True) + + def aggregate_previous_after_values(self, previous_after_values): + result = collections.defaultdict(frozenset) + for previous_after_value in previous_after_values: + for key, value in previous_after_value.items(): + result[key] |= value + return FrozenDict(result) + + def compute_after_value(self, node, before_value): + result = before_value.copy() + for access in node.instruction.accesses: + kind_and_name = instruction_module.access_kind_and_name(access) + result[kind_and_name] = frozenset([access]) + return FrozenDict(result) diff --git a/python_graphs/data_flow_test.py b/python_graphs/data_flow_test.py new file mode 100644 index 0000000..a3849ee --- /dev/null +++ b/python_graphs/data_flow_test.py @@ -0,0 +1,138 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for data_flow.py.""" + +import inspect + +from absl import logging # pylint: disable=unused-import +from absl.testing import absltest +import gast as ast + +from python_graphs import control_flow +from python_graphs import control_flow_test_components as tc +from python_graphs import data_flow +from python_graphs import program_utils + + +class DataFlowTest(absltest.TestCase): + + def test_get_while_loop_variables(self): + root = program_utils.program_to_ast(tc.nested_while_loops) + graph = control_flow.get_control_flow_graph(root) + + # node = graph.get_ast_node_by_type(ast.While) + # TODO(dbieber): data_flow.get_while_loop_variables(node, graph) + + analysis = data_flow.LivenessAnalysis() + for block in graph.get_exit_blocks(): + analysis.visit(block) + + for block in graph.get_blocks_by_ast_node_type_and_label( + ast.While, 'test_block'): + logging.info(block.get_label('liveness_out')) + + def test_liveness_simple_while_loop(self): + def simple_while_loop(): + a = 2 + b = 10 + x = 1 + while x < b: + tmp = x + a + x = tmp + 1 + + program_node = program_utils.program_to_ast(simple_while_loop) + graph = control_flow.get_control_flow_graph(program_node) + + # TODO(dbieber): Use unified query system. + while_node = [ + node for node in ast.walk(program_node) + if isinstance(node, ast.While)][0] + loop_variables = data_flow.get_while_loop_variables(while_node, graph) + self.assertEqual(loop_variables, {'x'}) + + def test_data_flow_nested_loops(self): + def fn(): + count = 0 + for x in range(10): + for y in range(10): + if x == y: + count += 1 + return count + + program_node = program_utils.program_to_ast(fn) + graph = control_flow.get_control_flow_graph(program_node) + + # Perform the analysis. + analysis = data_flow.LastAccessAnalysis() + analysis.visit(graph.start_block.control_flow_nodes[0]) + for node in graph.get_enter_control_flow_nodes(): + analysis.visit(node) + + # Verify correctness. + node = graph.get_control_flow_node_by_source('count += 1') + last_accesses_in = node.get_label('last_access_in') + last_accesses_out = node.get_label('last_access_out') + self.assertLen(last_accesses_in['write-count'], 2) # += 1, = 0 + self.assertLen(last_accesses_in['read-count'], 1) # += 1 + self.assertLen(last_accesses_out['write-count'], 1) # += 1 + self.assertLen(last_accesses_out['read-count'], 1) # += 1 + + def test_last_accesses_analysis(self): + root = program_utils.program_to_ast(tc.nested_while_loops) + graph = control_flow.get_control_flow_graph(root) + + analysis = data_flow.LastAccessAnalysis() + analysis.visit(graph.start_block.control_flow_nodes[0]) + + for node in graph.get_enter_control_flow_nodes(): + analysis.visit(node) + + for block in graph.blocks: + for cfn in block.control_flow_nodes: + self.assertTrue(cfn.has_label('last_access_in')) + self.assertTrue(cfn.has_label('last_access_out')) + + node = graph.get_control_flow_node_by_source('y += 5') + last_accesses = node.get_label('last_access_out') + # TODO(dbieber): Add asserts that these are the correct accesses. + self.assertLen(last_accesses['write-x'], 2) # x = 1, x += 6 + self.assertLen(last_accesses['read-x'], 1) # x < 2 + + node = graph.get_control_flow_node_by_source('return x') + last_accesses = node.get_label('last_access_out') + self.assertLen(last_accesses['write-x'], 2) # x = 1, x += 6 + self.assertLen(last_accesses['read-x'], 1) # x < 2 + + def test_liveness_analysis_all_test_components(self): + for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction): + root = program_utils.program_to_ast(fn) + graph = control_flow.get_control_flow_graph(root) + + analysis = data_flow.LivenessAnalysis() + for block in graph.get_exit_blocks(): + analysis.visit(block) + + def test_last_access_analysis_all_test_components(self): + for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction): + root = program_utils.program_to_ast(fn) + graph = control_flow.get_control_flow_graph(root) + + analysis = data_flow.LastAccessAnalysis() + for node in graph.get_enter_control_flow_nodes(): + analysis.visit(node) + + +if __name__ == '__main__': + absltest.main() diff --git a/python_graphs/examples/control_flow_example.py b/python_graphs/examples/control_flow_example.py new file mode 100644 index 0000000..01dfee1 --- /dev/null +++ b/python_graphs/examples/control_flow_example.py @@ -0,0 +1,57 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example generating a control flow graph from a Python function. + +Generates an image visualizing the control flow graph for each of the functions +in control_flow_test_components.py. Saves the resulting images to the directory +`out`. + +Usage: +python -m python_graphs.examples.control_flow_example +""" + +import inspect +import os + +from absl import app + +from python_graphs import control_flow +from python_graphs import control_flow_graphviz +from python_graphs import control_flow_test_components as tc +from python_graphs import program_utils + + +def plot_control_flow_graph(fn, path): + graph = control_flow.get_control_flow_graph(fn) + source = program_utils.getsource(fn) + control_flow_graphviz.render(graph, include_src=source, path=path) + + +def main(argv) -> None: + del argv # Unused + + # Create the output directory. + os.makedirs('out', exist_ok=True) + + # For each function in control_flow_test_components.py, visualize its + # control flow graph. Save the results in the output directory. + for name, fn in inspect.getmembers(tc, predicate=inspect.isfunction): + path = f'out/{name}_cfg.png' + plot_control_flow_graph(fn, path) + print('Done. See the `out` directory for the results.') + + +if __name__ == '__main__': + app.run(main) diff --git a/python_graphs/examples/cyclomatic_complexity_example.py b/python_graphs/examples/cyclomatic_complexity_example.py new file mode 100644 index 0000000..3077a73 --- /dev/null +++ b/python_graphs/examples/cyclomatic_complexity_example.py @@ -0,0 +1,46 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example computing the cyclomatic complexity of various Python functions. + +For each of the functions in control_flow_test_components.py, this computes and +prints the function's cyclomatic complexity. + +Usage: +python -m python_graphs.examples.cyclomatic_complexity_example +""" + +import inspect + +from absl import app + +from python_graphs import control_flow +from python_graphs import control_flow_test_components as tc +from python_graphs import cyclomatic_complexity + + +def main(argv) -> None: + del argv # Unused + + # For each function in control_flow_test_components.py, compute its cyclomatic + # complexity and print the result. + for name, fn in inspect.getmembers(tc, predicate=inspect.isfunction): + print(f'{name}: ', end='') + graph = control_flow.get_control_flow_graph(fn) + value = cyclomatic_complexity.cyclomatic_complexity(graph) + print(value) + + +if __name__ == '__main__': + app.run(main) diff --git a/python_graphs/examples/program_graph_example.py b/python_graphs/examples/program_graph_example.py new file mode 100644 index 0000000..b54a720 --- /dev/null +++ b/python_graphs/examples/program_graph_example.py @@ -0,0 +1,50 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example generating a complete program graph from a Python function. + +Generates an image visualizing the complete program graph for each function +in program_graph_test_components.py. Saves the resulting images to the directory +`out`. + +Usage: +python -m python_graphs.examples.program_graph_example +""" + +import inspect +import os + +from absl import app +from python_graphs import program_graph +from python_graphs import program_graph_graphviz +from python_graphs import program_graph_test_components as tc + + +def main(argv) -> None: + del argv # Unused + + # Create the output directory. + os.makedirs('out', exist_ok=True) + + # For each function in program_graph_test_components.py, visualize its + # program graph. Save the results in the output directory. + for name, fn in inspect.getmembers(tc, predicate=inspect.isfunction): + path = f'out/{name}-program-graph.png' + graph = program_graph.get_program_graph(fn) + program_graph_graphviz.render(graph, path=path) + print('Done. See the `out` directory for the results.') + + +if __name__ == '__main__': + app.run(main) diff --git a/python_graphs/instruction.py b/python_graphs/instruction.py new file mode 100644 index 0000000..18d5d73 --- /dev/null +++ b/python_graphs/instruction.py @@ -0,0 +1,400 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An Instruction represents an executable unit of a Python program. + +Almost all simple statements correspond to Instructions, except for statements +likes pass, continue, and break, whose effects are already represented in the +structure of the control-flow graph. + +In addition to simple statements, assignments that take place outside of simple +statements such as implicitly in a function or class definition also correspond +to Instructions. + +The complete set of places where Instructions occur in source are listed here: + +1. (Any node in INSTRUCTION_AST_NODES used as a statement.) +2. if : ... (elif is the same.) +3+4. for in : ... +5. while : ... +6. try: ... except : ... +7. TODO(dbieber): Test for "with :"... + +In the code: + +@decorator +def fn(args=defaults): + body + +Outside of the function definition, we get the following instructions: +8. Each decorator is an Instruction. +9. Each default is an Instruction. +10. The assignment of the function def to the function name is an Instruction. +Inside the function definition, we get the following instructions: +11. An Instruction for the assignment of values to the arguments. +(1, again) And then the body can consist of multiple Instructions too. + +Likewise in the code: + +@decorator +class C(object): + body + +The following are Instructions: +(8, again) Each decorator is an Instruction +12. The assignment of the class to the variable C is an Instruction. +(1, again) And then the body can consist of multiple Instructions too. +13. TODO(dbieber): The base class (object) is an Instruction too. +""" + +import gast as ast +import six + +# Types of accesses: +READ = 'read' +WRITE = 'write' + +# Context lists +WRITE_CONTEXTS = (ast.Store, ast.Del, ast.Param, ast.AugStore) +READ_CONTEXTS = (ast.Load, ast.AugLoad) + +# Sources of implicit writes: +CLASS = 'class' +FUNCTION = 'function' +ARGS = 'args' +KWARG = 'kwarg' +KWONLYARGS = 'kwonlyargs' +VARARG = 'vararg' +ITERATOR = 'iter' +EXCEPTION = 'exception' + +INSTRUCTION_AST_NODES = ( + ast.Expr, # expression_stmt + ast.Assert, # assert_stmt + ast.Assign, # assignment_stmt + ast.AugAssign, # augmented_assignment_stmt + ast.Delete, # del_stmt + ast.Print, # print_stmt + ast.Return, # return_stmt + # ast.Yield, # yield_stmt. ast.Yield nodes are contained in ast.Expr nodes. + ast.Raise, # raise_stmt + ast.Import, # import_stmt + ast.ImportFrom, + ast.Global, # global_stmt + ast.Exec, # exec_stmt +) + +# https://docs.python.org/2/reference/simple_stmts.html +SIMPLE_STATEMENT_AST_NODES = INSTRUCTION_AST_NODES + ( + ast.Pass, # pass_stmt + ast.Break, # break_stmt + ast.Continue, # continue_stmt +) + + +def _canonicalize(node): + if isinstance(node, list) and len(node) == 1: + return _canonicalize(node[0]) + if isinstance(node, ast.Module): + return _canonicalize(node.body) + if isinstance(node, ast.Expr): + return _canonicalize(node.value) + return node + + +def represent_same_program(node1, node2): + """Whether AST nodes node1 and node2 represent the same program syntactically. + + Two programs are the same syntactically is they have equivalent ASTs, up to + some small changes. The context field of Name nodes can change without the + syntax represented by the AST changing. This allows for example for the short + program 'x' (a read) to match with a subprogram 'x' of 'x = 3' (in which x is + a write), since these two programs are the same syntactically ('x' and 'x'). + + Except for the context field of Name nodes, the two nodes are recursively + checked for exact equality. + + Args: + node1: An AST node. This can be an ast.AST object, a primitive, or a list of + AST nodes (primitives or ast.AST objects). + node2: An AST node. This can be an ast.AST object, a primitive, or a list of + AST nodes (primitives or ast.AST objects). + + Returns: + Whether the two nodes represent equivalent programs. + """ + node1 = _canonicalize(node1) + node2 = _canonicalize(node2) + + if type(node1) != type(node2): # pylint: disable=unidiomatic-typecheck + return False + if not isinstance(node1, ast.AST): + return node1 == node2 + + fields1 = list(ast.iter_fields(node1)) + fields2 = list(ast.iter_fields(node2)) + if len(fields1) != len(fields2): + return False + + for (field1, value1), (field2, value2) in zip(fields1, fields2): + if field1 == 'ctx': + continue + if field1 != field2 or type(value1) is not type(value2): + return False + if isinstance(value1, list): + for item1, item2 in zip(value1, value2): + if not represent_same_program(item1, item2): + return False + elif not represent_same_program(value1, value2): + return False + + return True + + +class AccessVisitor(ast.NodeVisitor): + """Visitor that computes an ordered list of accesses. + + Accesses are ordered based on a depth-first traversal of the AST, using the + order of fields defined in `gast`, except for Assign nodes, for which the RHS + is ordered before the LHS. + + This may differ from Python execution semantics in two ways: + + - Both branches sides of short-circuit `and`/`or` expressions or conditional + `X if Y else Z` expressions are considered to be evaluated, even if one of + them is actually skipped at runtime. + - For AST nodes whose field order doesn't match the Python interpreter's + evaluation order, the field order is used instead. Most AST nodes match + execution order, but some differ (e.g. for dictionary literals, the + interpreter alternates evaluating keys and values, but the field order has + all keys and then all values). Assignments are a special case; the + AccessVisitor evaluates the RHS first even though the LHS occurs first in + the expression. + + Attributes: + accesses: List of accesses encountered by the visitor. + """ + + # TODO(dbieber): Include accesses of ast.Subscript and ast.Attribute targets. + + def __init__(self): + self.accesses = [] + + def visit_Name(self, node): + """Visit a Name, adding it to the list of accesses.""" + self.accesses.append(node) + + def visit_Assign(self, node): + """Visit an Assign, ordering RHS accesses before LHS accesses.""" + self.visit(node.value) + for target in node.targets: + self.visit(target) + + def visit_AugAssign(self, node): + """Visit an AugAssign, which contains both a read and a write.""" + # An AugAssign is a read as well as a write, even with the ctx of a write. + self.visit(node.value) + # Add a read access if we are assigning to a name. + if isinstance(node.target, ast.Name): + # TODO(dbieber): Use a proper type instead of a tuple for accesses. + self.accesses.append(('read', node.target, node)) + # Add the write access as normal. + self.visit(node.target) + + +def get_accesses_from_ast_node(node): + """Get all accesses for an AST node, in depth-first AST field order.""" + visitor = AccessVisitor() + visitor.visit(node) + return visitor.accesses + + +def get_reads_from_ast_node(ast_node): + """Get all reads for an AST node, in depth-first AST field order. + + Args: + ast_node: The AST node of interest. + + Returns: + A list of writes performed by that AST node. + """ + return [ + access for access in get_accesses_from_ast_node(ast_node) + if access_is_read(access) + ] + + +def get_writes_from_ast_node(ast_node): + """Get all writes for an AST node, in depth-first AST field order. + + Args: + ast_node: The AST node of interest. + + Returns: + A list of writes performed by that AST node. + """ + return [ + access for access in get_accesses_from_ast_node(ast_node) + if access_is_write(access) + ] + + +def create_writes(node, parent=None): + # TODO(dbieber): Use a proper type instead of a tuple for accesses. + if isinstance(node, ast.AST): + return [ + ('write', n, parent) for n in ast.walk(node) if isinstance(n, ast.Name) + ] + else: + return [('write', node, parent)] + + +def access_is_read(access): + if isinstance(access, ast.AST): + assert isinstance(access, ast.Name), access + return isinstance(access.ctx, READ_CONTEXTS) + else: + return access[0] == 'read' + + +def access_is_write(access): + if isinstance(access, ast.AST): + assert isinstance(access, ast.Name), access + return isinstance(access.ctx, WRITE_CONTEXTS) + else: + return access[0] == 'write' + + +def access_name(access): + if isinstance(access, ast.AST): + return access.id + elif isinstance(access, tuple): + if isinstance(access[1], six.string_types): + return access[1] + elif isinstance(access[1], ast.Name): + return access[1].id + raise ValueError('Unexpected access type.', access) + + +def access_kind(access): + if access_is_read(access): + return 'read' + elif access_is_write(access): + return 'write' + + +def access_kind_and_name(access): + return '{}-{}'.format(access_kind(access), access_name(access)) + + +def access_identifier(name, kind): + return '{}-{}'.format(kind, name) + + +class Instruction(object): + # pyformat:disable + """Represents an executable unit of a Python program. + + An Instruction is a part of an AST corresponding to a simple statement or + assignment, not corresponding to control flow. The part of the AST is not + necessarily an AST node. It may be an AST node, or it may instead be a string + (such as a variable name). + + Instructions play an important part in control flow graphs. An Instruction + is the smallest unit of a control flow graph (wrapped in a ControlFlowNode). + A control flow graph consists of basic blocks which represent a sequence of + Instructions that are executed in a straight-line manner, or not at all. + + Conceptually an Instruction is immutable. This means that while Python does + permit the mutation of an Instruction, in practice an Instruction object + should not be modified once it is created. + + Note that an Instruction may be interrupted by an exception mid-execution. + This is captured in control flow graphs via interrupting exits from basic + blocks to either exception handlers or special 'raises' blocks. + + In addition to pure simple statements, an Instruction can represent a number + of different parts of code. These are all listed explicitly in the module + docstring. + + In the common case, the accesses made by an Instruction are given by the Name + AST nodes contained in the Instruction's AST node. In some cases, when the + instruction.source field is not None, the accesses made by an Instruction are + not simply the Name AST nodes of the Instruction's node. For example, in a + function definition, the only access is the assignment of the function def to + the variable with the function's name; the Name nodes contained in the + function definition are not part of the function definition Instruction, and + instead are part of other Instructions that make up the function. The set of + accesses made by an Instruction is computed when the Instruction is created + and available via the accesses attribute of the Instruction. + + Attributes: + node: The AST node corresponding to the instruction. + accesses: (optional) An ordered list of all reads and writes made by this + instruction. Each item in `accesses` is one of either: + - A 3-tuple with fields (kind, node, parent). kind is either 'read' or + 'write'. node is either a string or Name AST node. parent is an AST + node where node occurs. + - A Name AST node + # TODO(dbieber): Use a single type for all accesses. + source: (optional) The source of the writes. For example in the for loop + `for x in items: pass` there is a instruction for the Name node "x". Its + source is ITERATOR, indicating that this instruction corresponds to x + being assigned a value from an iterator. When source is not None, the + Python code corresponding to the instruction does not coincide with the + Python code corresponding to the instruction's node. + """ + # pyformat:enable + + def __init__(self, node, accesses=None, source=None): + if not isinstance(node, ast.AST): + raise TypeError('node must be an instance of ast.AST.', node) + self.node = node + if accesses is None: + accesses = get_accesses_from_ast_node(node) + self.accesses = accesses + self.source = source + + def contains_subprogram(self, node): + """Whether this Instruction contains the given AST as a subprogram. + + Computes whether `node` is a subtree of this Instruction's AST. + If the Instruction represents an implied write, then the node must match + against the Instruction's writes. + + Args: + node: The node to check the instruction against for a match. + + Returns: + (bool) Whether or not this Instruction contains the node, syntactically. + """ + if self.source is not None: + # Only exact matches are permissible if source is not None. + return represent_same_program(node, self.node) + for subtree in ast.walk(self.node): + if represent_same_program(node, subtree): + return True + return False + + def get_reads(self): + return {access for access in self.accesses if access_is_read(access)} + + def get_read_names(self): + return {access_name(access) for access in self.get_reads()} + + def get_writes(self): + return {access for access in self.accesses if access_is_write(access)} + + def get_write_names(self): + return {access_name(access) for access in self.get_writes()} diff --git a/python_graphs/instruction_test.py b/python_graphs/instruction_test.py new file mode 100644 index 0000000..1f26b6e --- /dev/null +++ b/python_graphs/instruction_test.py @@ -0,0 +1,123 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for instruction module.""" + +from absl.testing import absltest +import gast as ast +from python_graphs import instruction as instruction_module + + +def create_instruction(source): + node = ast.parse(source) + node = instruction_module._canonicalize(node) + return instruction_module.Instruction(node) + + +class InstructionTest(absltest.TestCase): + + def test_instruction(self): + self.assertIsNotNone(instruction_module.Instruction) + + def test_represent_same_program_basic_positive_case(self): + program1 = ast.parse('x + 1') + program2 = ast.parse('x + 1') + self.assertTrue( + instruction_module.represent_same_program(program1, program2)) + + def test_represent_same_program_basic_negative_case(self): + program1 = ast.parse('x + 1') + program2 = ast.parse('x + 2') + self.assertFalse( + instruction_module.represent_same_program(program1, program2)) + + def test_represent_same_program_different_contexts(self): + full_program1 = ast.parse('y = x + 1') # y is a write + program1 = full_program1.body[0].targets[0] # 'y' + program2 = ast.parse('y') # y is a read + self.assertTrue( + instruction_module.represent_same_program(program1, program2)) + + def test_get_accesses(self): + instruction = create_instruction('x + 1') + self.assertEqual(instruction.get_read_names(), {'x'}) + self.assertEqual(instruction.get_write_names(), set()) + + instruction = create_instruction('return x + y + z') + self.assertEqual(instruction.get_read_names(), {'x', 'y', 'z'}) + self.assertEqual(instruction.get_write_names(), set()) + + instruction = create_instruction('fn(a, b, c)') + self.assertEqual(instruction.get_read_names(), {'a', 'b', 'c', 'fn'}) + self.assertEqual(instruction.get_write_names(), set()) + + instruction = create_instruction('c = fn(a, b, c)') + self.assertEqual(instruction.get_read_names(), {'a', 'b', 'c', 'fn'}) + self.assertEqual(instruction.get_write_names(), {'c'}) + + def test_get_accesses_augassign(self): + instruction = create_instruction('x += 1') + self.assertEqual(instruction.get_read_names(), {'x'}) + self.assertEqual(instruction.get_write_names(), {'x'}) + + instruction = create_instruction('x *= y') + self.assertEqual(instruction.get_read_names(), {'x', 'y'}) + self.assertEqual(instruction.get_write_names(), {'x'}) + + def test_get_accesses_augassign_subscript(self): + instruction = create_instruction('x[0] *= y') + # This is not currently considered a write of x. It is a read of x. + self.assertEqual(instruction.get_read_names(), {'x', 'y'}) + self.assertEqual(instruction.get_write_names(), set()) + + def test_get_accesses_augassign_attribute(self): + instruction = create_instruction('x.attribute *= y') + # This is not currently considered a write of x. It is a read of x. + self.assertEqual(instruction.get_read_names(), {'x', 'y'}) + self.assertEqual(instruction.get_write_names(), set()) + + def test_get_accesses_subscript(self): + instruction = create_instruction('x[0] = y') + # This is not currently considered a write of x. It is a read of x. + self.assertEqual(instruction.get_read_names(), {'x', 'y'}) + self.assertEqual(instruction.get_write_names(), set()) + + def test_get_accesses_attribute(self): + instruction = create_instruction('x.attribute = y') + # This is not currently considered a write of x. It is a read of x. + self.assertEqual(instruction.get_read_names(), {'x', 'y'}) + self.assertEqual(instruction.get_write_names(), set()) + + def test_access_ordering(self): + instruction = create_instruction('c = fn(a, b + c, d / a)') + access_names_and_kinds = [(instruction_module.access_name(access), + instruction_module.access_kind(access)) + for access in instruction.accesses] + self.assertEqual(access_names_and_kinds, [('fn', 'read'), ('a', 'read'), + ('b', 'read'), ('c', 'read'), + ('d', 'read'), ('a', 'read'), + ('c', 'write')]) + + instruction = create_instruction('c += fn(a, b + c, d / a)') + access_names_and_kinds = [(instruction_module.access_name(access), + instruction_module.access_kind(access)) + for access in instruction.accesses] + self.assertEqual(access_names_and_kinds, [('fn', 'read'), ('a', 'read'), + ('b', 'read'), ('c', 'read'), + ('d', 'read'), ('a', 'read'), + ('c', 'read'), ('c', 'write')]) + + +if __name__ == '__main__': + absltest.main() diff --git a/python_graphs/program_graph.py b/python_graphs/program_graph.py new file mode 100644 index 0000000..4da2abd --- /dev/null +++ b/python_graphs/program_graph.py @@ -0,0 +1,963 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Creates ProgramGraphs from a program or function's AST. + +A ProgramGraph represents a Python program or function. The nodes in a +ProgramGraph represent an Instruction (see instruction.py), an AST node, or a +piece of syntax from the program. The edges in a ProgramGraph represent the +relationships between these nodes. +""" + +import codecs +import collections +import os + +from absl import logging +import astunparse +from astunparse import unparser +import gast as ast +from python_graphs import control_flow +from python_graphs import data_flow +from python_graphs import instruction as instruction_module +from python_graphs import program_graph_dataclasses as pb +from python_graphs import program_utils +import six +from six.moves import builtins +from six.moves import filter + +NEWLINE_TOKEN = '#NEWLINE#' +UNINDENT_TOKEN = '#UNINDENT#' +INDENT_TOKEN = '#INDENT#' + + +class ProgramGraph(object): + """A ProgramGraph represents a Python program or function. + + Attributes: + root_id: The id of the root ProgramGraphNode. + nodes: Maps from node id to the ProgramGraphNode with that id. + edges: A list of the edges (from_node.id, to_node.id, edge type) in the + graph. + child_map: Maps from node id to a list of that node's AST children node ids. + parent_map: Maps from node id to that node's AST parent node id. + neighbors_map: Maps from node id to a list of that node's neighboring edges. + ast_id_to_program_graph_node: Maps from an AST node's object id to the + corresponding AST program graph node, if it exists. + root: The root ProgramGraphNode. + """ + + def __init__(self): + """Constructs an empty ProgramGraph with no root.""" + self.root_id = None + + self.nodes = {} + # TODO(charlessutton): Seems odd to have Edge proto objects as part of the + # program graph object if node protos aren't. Consider a more consistent + # treatment. + self.edges = [] + + self.ast_id_to_program_graph_node = {} + self.child_map = collections.defaultdict(list) + self.parent_map = collections.defaultdict(lambda: None) + self.neighbors_map = collections.defaultdict(list) + + # Accessors + @property + def root(self): + if self.root_id not in self.nodes: + raise ValueError('Graph has no root node.') + return self.nodes[self.root_id] + + def all_nodes(self): + return self.nodes.values() + + def get_node(self, obj): + """Returns the node in the program graph corresponding to an object. + + Arguments: + obj: Can be an integer, AST node, ProgramGraphNode, or program graph node + protobuf. + + Raises: + ValueError: no node exists in the program graph matching obj. + """ + if isinstance(obj, six.integer_types) and obj in self.nodes: + return self.get_node_by_id(obj) + elif isinstance(obj, ProgramGraphNode): + # assert obj in self.nodes.values() + return obj + elif isinstance(obj, pb.Node): + return self.get_node_by_id(obj.id) + elif isinstance(obj, (ast.AST, list)): + return self.get_node_by_ast_node(obj) + else: + raise ValueError('Unexpected value for obj.', obj) + + def get_node_by_id(self, obj): + """Gets a ProgramGraph node for the given integer id.""" + return self.nodes[obj] + + def get_node_by_access(self, access): + """Gets a ProgramGraph node for the given read or write.""" + if isinstance(access, ast.Name): + return self.get_node(access) + else: + assert isinstance(access, tuple) + if isinstance(access[1], ast.Name): + return self.get_node(access[1]) + else: + return self.get_node(access[2]) + raise ValueError('Could not find node for access.', access) + + def get_nodes_by_source(self, source): + """Generates the nodes in the program graph containing the query source. + + Args: + source: The query source. + + Returns: + A generator of all nodes in the program graph with an Instruction with + source that includes the query source. + """ + module = ast.parse(source, mode='exec') # TODO(dbieber): Factor out 4 lines + # TODO(dbieber): Use statements beyond the first statement from source. + node = module.body[0] + # If the query source is an Expression, and the matching instruction matches + # the value field of that Expression, then the matching instruction is + # considered a match. This allows us to match subexpressions which appear in + # ast.Expr nodes in the query but not in the parent. + if isinstance(node, ast.Expr): + node = node.value + + def matches_source(pg_node): + if pg_node.has_instruction(): + return pg_node.instruction.contains_subprogram(node) + else: + return instruction_module.represent_same_program(pg_node.ast_node, node) + + return filter(matches_source, self.nodes.values()) + + def get_node_by_source(self, node): + # We use min since nodes can contain each other and we want the most + # specific one. + return min( + self.get_nodes_by_source(node), key=lambda x: len(ast.dump(x.node))) + + def get_nodes_by_function_name(self, name): + return filter( + lambda n: n.has_instance_of(ast.FunctionDef) and n.node.name == name, + self.nodes.values()) + + def get_node_by_function_name(self, name): + return next(self.get_nodes_by_function_name(name)) + + def get_node_by_ast_node(self, ast_node): + return self.ast_id_to_program_graph_node[id(ast_node)] + + def contains_ast_node(self, ast_node): + return id(ast_node) in self.ast_id_to_program_graph_node + + def get_ast_nodes_of_type(self, ast_type): + for node in six.itervalues(self.nodes): + if node.node_type == pb.NodeType.AST_NODE and node.ast_type == ast_type: + yield node + + # TODO(dbieber): Unify selectors across program_graph and control_flow. + def get_nodes_by_source_and_identifier(self, source, name): + for pg_node in self.get_nodes_by_source(source): + for node in ast.walk(pg_node.node): + if isinstance(node, ast.Name) and node.id == name: + if self.contains_ast_node(node): + yield self.get_node_by_ast_node(node) + + def get_node_by_source_and_identifier(self, source, name): + return next(self.get_nodes_by_source_and_identifier(source, name)) + + # Graph Construction Methods + def add_node(self, node): + """Adds a ProgramGraphNode to this graph. + + Args: + node: The ProgramGraphNode that should be added. + + Returns: + The node that was added. + + Raises: + ValueError: the node has already been added to this graph. + """ + assert isinstance(node, ProgramGraphNode), 'Not a ProgramGraphNode' + if node.id in self.nodes: + raise ValueError('Already contains node', self.nodes[node.id], node.id) + if node.ast_node is not None: + if self.contains_ast_node(node.ast_node): + raise ValueError('Already contains ast node', node.ast_node) + self.ast_id_to_program_graph_node[id(node.ast_node)] = node + self.nodes[node.id] = node + return node + + def add_node_from_instruction(self, instruction): + """Adds a node to the program graph.""" + node = make_node_from_instruction(instruction) + return self.add_node(node) + + def add_edge(self, edge): + """Adds an edge between two nodes in the graph. + + Args: + edge: The edge, a pb.Edge proto. + """ + assert isinstance(edge, pb.Edge), 'Not a pb.Edge' + self.edges.append(edge) + + n1 = self.get_node_by_id(edge.id1) + n2 = self.get_node_by_id(edge.id2) + if edge.type == pb.EdgeType.FIELD: # An AST node. + self.child_map[edge.id1].append(edge.id2) + # TODO(charlessutton): Add the below sanity check back once Instruction + # updates are complete. + # pylint: disable=line-too-long + # other_parent_id = self.parent_map[edge.id2] + # if other_parent_id and other_parent_id != edge.id1: + # raise Exception('Node {} {} with two parents\n {} {}\n {} {}' + # .format(edge.id2, dump_node(self.get_node(edge.id2)), + # edge.id1, dump_node(self.get_node(edge.id1)), + # other_parent_id, dump_node(self.get_node(other_parent_id)))) + # pylint: enable=line-too-long + self.parent_map[n2.id] = edge.id1 + self.neighbors_map[n1.id].append((edge, edge.id2)) + self.neighbors_map[n2.id].append((edge, edge.id1)) + + def remove_edge(self, edge): + """Removes an edge from the graph. + + If there are multiple copies of the same edge, only one copy is removed. + + Args: + edge: The edge, a pb.Edge proto. + """ + self.edges.remove(edge) + + n1 = self.get_node_by_id(edge.id1) + n2 = self.get_node_by_id(edge.id2) + + if edge.type == pb.EdgeType.FIELD: # An AST node. + self.child_map[edge.id1].remove(edge.id2) + del self.parent_map[n2.id] + + self.neighbors_map[n1.id].remove((edge, edge.id2)) + self.neighbors_map[n2.id].remove((edge, edge.id1)) + + def add_new_edge(self, n1, n2, edge_type=None, field_name=None): + """Adds a new edge between two nodes in the graph. + + Both nodes must already be part of the graph. + + Args: + n1: Specifies the from node of the edge. Can be any object type accepted + by get_node. + n2: Specifies the to node of the edge. Can be any object type accepted by + get_node. + edge_type: The type of edge. Can be any integer in the pb.Edge enum. + field_name: For AST edges, a string describing the Python AST field + + Returns: + The new edge. + """ + n1 = self.get_node(n1) + n2 = self.get_node(n2) + new_edge = pb.Edge( + id1=n1.id, id2=n2.id, type=edge_type, field_name=field_name) + self.add_edge(new_edge) + return new_edge + + # AST Methods + # TODO(charlessutton): Consider whether AST manipulation should be moved + # e.g., to a more general graph object. + def to_ast(self, node=None): + """Convert the program graph to a Python AST.""" + if node is None: + node = self.root + return self._build_ast(node=node, update_references=False) + + def reconstruct_ast(self): + """Reconstruct all internal ProgramGraphNode.ast_node references. + + After calling this method, all nodes of type AST_NODE will have their + `ast_node` property refer to subtrees of a reconstructed AST object, and + self.ast_id_to_program_graph_node will contain only entries from this new + AST. + + Note that only AST nodes reachable by fields from the root node will be + converted; this should be all of them but this is not checked. + """ + self.ast_id_to_program_graph_node.clear() + self._build_ast(node=self.root, update_references=True) + + def _build_ast(self, node, update_references): + """Helper method: builds an AST and optionally sets ast_node references. + + Args: + node: Program graph node to build an AST for. + update_references: Whether to modify this node and all of its children so + that they point to the reconstructed AST node. + + Returns: + AST node corresponding to the program graph node. + """ + if node.node_type == pb.NodeType.AST_NODE: + ast_node = getattr(ast, node.ast_type)() + adjacent_edges = self.neighbors_map[node.id] + for edge, other_node_id in adjacent_edges: + if other_node_id == edge.id1: # it's an incoming edge + continue + if edge.type == pb.EdgeType.FIELD: + child_id = other_node_id + child = self.get_node_by_id(child_id) + setattr( + ast_node, edge.field_name, + self._build_ast(node=child, update_references=update_references)) + if update_references: + node.ast_node = ast_node + self.ast_id_to_program_graph_node[id(ast_node)] = node + return ast_node + elif node.node_type == pb.NodeType.AST_LIST: + list_items = {} + adjacent_edges = self.neighbors_map[node.id] + for edge, other_node_id in adjacent_edges: + if other_node_id == edge.id1: # it's an incoming edge + continue + if edge.type == pb.EdgeType.FIELD: + child_id = other_node_id + child = self.get_node_by_id(child_id) + unused_field_name, index = parse_list_field_name(edge.field_name) + list_items[index] = self._build_ast( + node=child, update_references=update_references) + + ast_list = [] + for index in six.moves.range(len(list_items)): + ast_list.append(list_items[index]) + return ast_list + elif node.node_type == pb.NodeType.AST_VALUE: + return node.ast_value + else: + raise ValueError('This ProgramGraphNode does not correspond to a node in' + ' an AST.') + + def walk_ast_descendants(self, node=None): + """Yields the nodes that correspond to the descendants of node in the AST. + + Args: + node: the node in the program graph corresponding to the root of the AST + subtree that should be walked. If None, defaults to the root of the + program graph. + + Yields: + All nodes corresponding to descendants of node in the AST. + """ + if node is None: + node = self.root + frontier = [node] + while frontier: + current = frontier.pop() + for child_id in reversed(self.child_map[current.id]): + frontier.append(self.get_node_by_id(child_id)) + yield current + + def parent(self, node): + """Returns the AST parent of an AST program graph node. + + Args: + node: A ProgramGraphNode. + + Returns: + The node's AST parent, which is also a ProgramGraphNode. + """ + parent_id = self.parent_map[node.id] + if parent_id is None: + return None + else: + return self.get_node_by_id(parent_id) + + def children(self, node): + """Yields the (direct) AST children of an AST program graph node. + + Args: + node: A ProgramGraphNode. + + Yields: + The AST children of node, which are ProgramGraphNode objects. + """ + for child_id in self.child_map[node.id]: + yield self.get_node_by_id(child_id) + + def neighbors(self, node, edge_type=None): + """Returns the incoming and outgoing neighbors of a program graph node. + + Args: + node: A ProgramGraphNode. + edge_type: If provided, only edges of this type are considered. + + Returns: + The incoming and outgoing neighbors of node, which are ProgramGraphNode + objects but not necessarily AST nodes. + """ + adj_edges = self.neighbors_map[node.id] + if edge_type is None: + ids = list(tup[1] for tup in adj_edges) + else: + ids = list(tup[1] for tup in adj_edges if tup[0].type == edge_type) + return [self.get_node_by_id(id0) for id0 in ids] + + def incoming_neighbors(self, node, edge_type=None): + """Returns the incoming neighbors of a program graph node. + + Args: + node: A ProgramGraphNode. + edge_type: If provided, only edges of this type are considered. + + Returns: + The incoming neighbors of node, which are ProgramGraphNode objects but not + necessarily AST nodes. + """ + adj_edges = self.neighbors_map[node.id] + result = [] + for edge, neighbor_id in adj_edges: + if edge.id2 == node.id: + if (edge_type is None) or (edge.type == edge_type): + result.append(self.get_node_by_id(neighbor_id)) + return result + + def outgoing_neighbors(self, node, edge_type=None): + """Returns the outgoing neighbors of a program graph node. + + Args: + node: A ProgramGraphNode. + edge_type: If provided, only edges of this type are considered. + + Returns: + The outgoing neighbors of node, which are ProgramGraphNode objects but not + necessarily AST nodes. + """ + adj_edges = self.neighbors_map[node.id] + result = [] + for edge, neighbor_id in adj_edges: + if edge.id1 == node.id: + if (edge_type is None) or (edge.type == edge_type): + result.append(self.get_node_by_id(neighbor_id)) + return result + + def dump_tree(self, start_node=None): + """Returns a string representation for debugging.""" + + def dump_tree_recurse(node, indent, all_lines): + """Create a string representation for a subtree.""" + indent_str = ' ' + ('--' * indent) + node_str = dump_node(node) + line = ' '.join([indent_str, node_str, '\n']) + all_lines.append(line) + # output long distance edges + for edge, neighbor_id in self.neighbors_map[node.id]: + if (not is_ast_edge(edge) and not is_syntax_edge(edge) and + node.id == edge.id1): + type_str = edge.type.name + line = [indent_str, '--((', type_str, '))-->', str(neighbor_id), '\n'] + all_lines.append(' '.join(line)) + for child in self.children(node): + dump_tree_recurse(child, indent + 1, all_lines) + return all_lines + + if start_node is None: + start_node = self.root + return ''.join(dump_tree_recurse(start_node, 0, [])) + + # TODO(charlessutton): Consider whether this belongs in ProgramGraph + # or in make_synthesis_problems. + def copy_with_placeholder(self, node): + """Returns a new program graph in which the subtree of NODE is removed. + + In the new graph, the subtree headed by NODE is replaced by a single + node of type PLACEHOLDER, which is connected to the AST parent of NODE + by the same edge type as in the original graph. + + The new program graph will share structure (i.e. the ProgramGraphNode + objects) with the original graph. + + Args: + node: A node in this program graph + + Returns: + A new ProgramGraph object with NODE replaced + """ + descendant_ids = {n.id for n in self.walk_ast_descendants(node)} + new_graph = ProgramGraph() + new_graph.add_node(self.root) + new_graph.root_id = self.root_id + for edge in self.edges: + v1 = self.nodes[edge.id1] + v2 = self.nodes[edge.id2] + # Omit edges that are adjacent to the subtree rooted at `node` UNLESS this + # is the AST edge to the root of the subtree. + # In that case, create an edge to a new placeholder node + adj_bad_subtree = ((edge.id1 in descendant_ids) or + (edge.id2 in descendant_ids)) + if adj_bad_subtree: + if edge.id2 == node.id and is_ast_edge(edge): + placeholder = ProgramGraphNode() + placeholder.node_type = pb.NodeType.PLACEHOLDER + placeholder.id = node.id + new_graph.add_node(placeholder) + new_graph.add_new_edge(v1, placeholder, edge_type=edge.type) + else: + # nodes on the edge have not been added yet + if edge.id1 not in new_graph.nodes: + new_graph.add_node(v1) + if edge.id2 not in new_graph.nodes: + new_graph.add_node(v2) + new_graph.add_new_edge(v1, v2, edge_type=edge.type) + return new_graph + + def copy_subgraph(self, node): + """Returns a new program graph containing only the subtree rooted at NODE. + + All edges that connect nodes in the subtree are included, both AST edges + and other types of edges. + + Args: + node: A node in this program graph + + Returns: + A new ProgramGraph object whose root is NODE + """ + descendant_ids = {n.id for n in self.walk_ast_descendants(node)} + new_graph = ProgramGraph() + new_graph.add_node(node) + new_graph.root_id = node.id + for edge in self.edges: + v1 = self.nodes[edge.id1] + v2 = self.nodes[edge.id2] + # Omit edges that are adjacent to the subtree rooted at NODE + # UNLESS this is the AST edge to the root of the subtree. + # In that case, create an edge to a new placeholder node + good_edge = ((edge.id1 in descendant_ids) and + (edge.id2 in descendant_ids)) + if good_edge: + if edge.id1 not in new_graph.nodes: + new_graph.add_node(v1) + if edge.id2 not in new_graph.nodes: + new_graph.add_node(v2) + new_graph.add_new_edge(v1, v2, edge_type=edge.type) + return new_graph + + +def is_ast_node(node): + return node.node_type == pb.NodeType.AST_NODE + + +def is_ast_edge(edge): + # TODO(charlessutton): Expand to enumerate edge types in gast. + return edge.type == pb.EdgeType.FIELD + + +def is_syntax_edge(edge): + return edge.type == pb.EdgeType.SYNTAX + + +def dump_node(node): + type_str = '[' + node.node_type.name + ']' + elements = [type_str, str(node.id), node.ast_type] + if node.ast_value: + elements.append(str(node.ast_value)) + if node.syntax: + elements.append(str(node.syntax)) + return ' '.join(elements) + + +def get_program_graph(program): + """Constructs a program graph to represent the given program.""" + program_node = program_utils.program_to_ast(program) # An AST node. + + # TODO(dbieber): Refactor sections of graph building into separate functions. + program_graph = ProgramGraph() + + # Perform control flow analysis. + control_flow_graph = control_flow.get_control_flow_graph(program_node) + + # Add AST_NODE program graph nodes corresponding to Instructions in the + # control flow graph. + for control_flow_node in control_flow_graph.get_control_flow_nodes(): + program_graph.add_node_from_instruction(control_flow_node.instruction) + + # Add AST_NODE program graph nodes corresponding to AST nodes. + for ast_node in ast.walk(program_node): + if not program_graph.contains_ast_node(ast_node): + pg_node = make_node_from_ast_node(ast_node) + program_graph.add_node(pg_node) + + root = program_graph.get_node_by_ast_node(program_node) + program_graph.root_id = root.id + + # Add AST edges (FIELD). Also add AST_LIST and AST_VALUE program graph nodes. + for ast_node in ast.walk(program_node): + for field_name, value in ast.iter_fields(ast_node): + if isinstance(value, list): + pg_node = make_node_for_ast_list() + program_graph.add_node(pg_node) + program_graph.add_new_edge( + ast_node, pg_node, pb.EdgeType.FIELD, field_name) + for index, item in enumerate(value): + list_field_name = make_list_field_name(field_name, index) + if isinstance(item, ast.AST): + program_graph.add_new_edge(pg_node, item, pb.EdgeType.FIELD, + list_field_name) + else: + item_node = make_node_from_ast_value(item) + program_graph.add_node(item_node) + program_graph.add_new_edge(pg_node, item_node, pb.EdgeType.FIELD, + list_field_name) + elif isinstance(value, ast.AST): + program_graph.add_new_edge( + ast_node, value, pb.EdgeType.FIELD, field_name) + else: + pg_node = make_node_from_ast_value(value) + program_graph.add_node(pg_node) + program_graph.add_new_edge( + ast_node, pg_node, pb.EdgeType.FIELD, field_name) + + # Add SYNTAX_NODE nodes. Also add NEXT_SYNTAX and LAST_LEXICAL_USE edges. + # Add these edges using a custom AST unparser to visit leaf nodes in preorder. + SyntaxNodeUnparser(program_node, program_graph) + + # Perform data flow analysis. + analysis = data_flow.LastAccessAnalysis() + for node in control_flow_graph.get_enter_control_flow_nodes(): + analysis.visit(node) + + # Add control flow edges (CFG_NEXT). + for control_flow_node in control_flow_graph.get_control_flow_nodes(): + instruction = control_flow_node.instruction + for next_control_flow_node in control_flow_node.next: + next_instruction = next_control_flow_node.instruction + program_graph.add_new_edge( + instruction.node, next_instruction.node, + edge_type=pb.EdgeType.CFG_NEXT) + + # Add data flow edges (LAST_READ and LAST_WRITE). + for control_flow_node in control_flow_graph.get_control_flow_nodes(): + # Start with the most recent accesses before this instruction. + last_accesses = control_flow_node.get_label('last_access_in').copy() + for access in control_flow_node.instruction.accesses: + # Extract the node and identifiers for the current access. + pg_node = program_graph.get_node_by_access(access) + access_name = instruction_module.access_name(access) + read_identifier = instruction_module.access_identifier( + access_name, 'read') + write_identifier = instruction_module.access_identifier( + access_name, 'write') + # Find previous reads. + for read in last_accesses.get(read_identifier, []): + read_pg_node = program_graph.get_node_by_access(read) + program_graph.add_new_edge( + pg_node, read_pg_node, edge_type=pb.EdgeType.LAST_READ) + # Find previous writes. + for write in last_accesses.get(write_identifier, []): + write_pg_node = program_graph.get_node_by_access(write) + program_graph.add_new_edge( + pg_node, write_pg_node, edge_type=pb.EdgeType.LAST_WRITE) + # Update the state to refer to this access as the most recent one. + if instruction_module.access_is_read(access): + last_accesses[read_identifier] = [access] + elif instruction_module.access_is_write(access): + last_accesses[write_identifier] = [access] + + # Add COMPUTED_FROM edges. + for node in ast.walk(program_node): + if isinstance(node, ast.Assign): + for value_node in ast.walk(node.value): + if isinstance(value_node, ast.Name): + # TODO(dbieber): If possible, improve precision of these edges. + for target in node.targets: + program_graph.add_new_edge( + target, value_node, edge_type=pb.EdgeType.COMPUTED_FROM) + + # Add CALLS, FORMAL_ARG_NAME and RETURNS_TO edges. + for node in ast.walk(program_node): + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + # TODO(dbieber): Use data flow analysis instead of all function defs. + func_defs = list(program_graph.get_nodes_by_function_name(node.func.id)) + # For any possible last writes that are a function definition, add the + # formal_arg_name and returns_to edges. + if not func_defs: + # TODO(dbieber): Add support for additional classes of functions, + # such as attributes of known objects and builtins. + if node.func.id in dir(builtins): + message = 'Function is builtin.' + else: + message = 'Cannot statically determine the function being called.' + logging.debug('%s (%s)', message, node.func.id) + for func_def in func_defs: + fn_node = func_def.node + # Add calls edge from the call node to the function definition. + program_graph.add_new_edge(node, fn_node, edge_type=pb.EdgeType.CALLS) + # Add returns_to edges from the function's return statements to the + # call node. + for inner_node in ast.walk(func_def.node): + # TODO(dbieber): Determine if the returns_to should instead go to + # the next instruction after the Call node instead. + if isinstance(inner_node, ast.Return): + program_graph.add_new_edge( + inner_node, node, edge_type=pb.EdgeType.RETURNS_TO) + + # Add formal_arg_name edges from the args of the Call node to the + # args in the FunctionDef. + for index, arg in enumerate(node.args): + formal_arg = None + if index < len(fn_node.args.args): + formal_arg = fn_node.args.args[index] + elif fn_node.args.vararg: + # Since args.vararg is a string, we use the arguments node. + # TODO(dbieber): Use a node specifically for the vararg. + formal_arg = fn_node.args + if formal_arg is not None: + # Note: formal_arg can be an AST node or a string. + program_graph.add_new_edge( + arg, formal_arg, edge_type=pb.EdgeType.FORMAL_ARG_NAME) + else: + # TODO(dbieber): If formal_arg is None, then remove all + # formal_arg_name edges for this FunctionDef. + logging.debug('formal_arg is None') + for keyword in node.keywords: + name = keyword.arg + formal_arg = None + for arg in fn_node.args.args: + if isinstance(arg, ast.Name) and arg.id == name: + formal_arg = arg + break + else: + if fn_node.args.kwarg: + # Since args.kwarg is a string, we use the arguments node. + # TODO(dbieber): Use a node specifically for the kwarg. + formal_arg = fn_node.args + if formal_arg is not None: + program_graph.add_new_edge( + keyword.value, formal_arg, + edge_type=pb.EdgeType.FORMAL_ARG_NAME) + else: + # TODO(dbieber): If formal_arg is None, then remove all + # formal_arg_name edges for this FunctionDef. + logging.debug('formal_arg is None') + else: + # TODO(dbieber): Add a special case for Attributes. + logging.debug( + 'Cannot statically determine the function being called. (%s)', + astunparse.unparse(node.func).strip()) + + return program_graph + + +class SyntaxNodeUnparser(unparser.Unparser): + """An Unparser class helpful for creating Syntax Token nodes for fn graphs.""" + + def __init__(self, ast_node, graph): + self.graph = graph + + self.current_ast_node = None # The AST node currently being unparsed. + self.last_syntax_node = None + self.last_lexical_uses = {} + self.last_indent = 0 + + with codecs.open(os.devnull, 'w', encoding='utf-8') as devnull: + super(SyntaxNodeUnparser, self).__init__(ast_node, file=devnull) + + def dispatch(self, ast_node): + """Dispatcher function, dispatching tree type T to method _T.""" + tmp_ast_node = self.current_ast_node + self.current_ast_node = ast_node + super(SyntaxNodeUnparser, self).dispatch(ast_node) + self.current_ast_node = tmp_ast_node + + def fill(self, text=''): + """Indent a piece of text, according to the current indentation level.""" + text_with_whitespace = NEWLINE_TOKEN + if self.last_indent > self._indent: + text_with_whitespace += UNINDENT_TOKEN * (self.last_indent - self._indent) + elif self.last_indent < self._indent: + text_with_whitespace += INDENT_TOKEN * (self._indent - self.last_indent) + self.last_indent = self._indent + text_with_whitespace += text + self._add_syntax_node(text_with_whitespace) + super(SyntaxNodeUnparser, self).fill(text) + + def write(self, text): + """Append a piece of text to the current line.""" + if isinstance(text, ast.AST): # text may be a Name, Tuple, or List node. + return self.dispatch(text) + self._add_syntax_node(text) + super(SyntaxNodeUnparser, self).write(text) + + def _add_syntax_node(self, text): + text = text.strip() + if not text: + return + syntax_node = make_node_from_syntax(six.text_type(text)) + self.graph.add_node(syntax_node) + self.graph.add_new_edge( + self.current_ast_node, syntax_node, edge_type=pb.EdgeType.SYNTAX) + if self.last_syntax_node: + self.graph.add_new_edge( + self.last_syntax_node, syntax_node, edge_type=pb.EdgeType.NEXT_SYNTAX) + self.last_syntax_node = syntax_node + + def _Name(self, node): + if node.id in self.last_lexical_uses: + self.graph.add_new_edge( + node, + self.last_lexical_uses[node.id], + edge_type=pb.EdgeType.LAST_LEXICAL_USE) + self.last_lexical_uses[node.id] = node + super(SyntaxNodeUnparser, self)._Name(node) + + +class ProgramGraphNode(object): + """A single node in a Program Graph. + + Corresponds to either a SyntaxNode or an Instruction (as in a + ControlFlowGraph). + + Attributes: + node_type: One of the node types from pb.NodeType. + id: A unique id for the node. + instruction: If applicable, the corresponding Instruction. + ast_node: If available, the AST node corresponding to the ProgramGraphNode. + ast_type: If available, the type of the AST node, as a string. + ast_value: If available, the primitive Python value corresponding to the + node. + syntax: For SYNTAX_NODEs, the syntax information stored in the node. + node: If available, the AST node for this program graph node or its + instruction. + """ + + def __init__(self): + self.node_type = None + self.id = None + + self.instruction = None + self.ast_node = None + self.ast_type = '' + self.ast_value = '' + self.syntax = '' + + def has_instruction(self): + return self.instruction is not None + + def has_instance_of(self, t): + """Whether the node's instruction is an instance of type `t`.""" + if self.instruction is None: + return False + return isinstance(self.instruction.node, t) + + @property + def node(self): + if self.ast_node is not None: + return self.ast_node + if self.instruction is None: + return None + return self.instruction.node + + def __repr__(self): + return str(self.id) + ' ' + str(self.ast_type) + + +def make_node_from_syntax(text): + node = ProgramGraphNode() + node.node_type = pb.NodeType.SYNTAX_NODE + node.id = program_utils.unique_id() + node.syntax = text + return node + + +def make_node_from_instruction(instruction): + """Creates a ProgramGraphNode corresponding to an existing Instruction. + + Args: + instruction: An Instruction object. + + Returns: + A ProgramGraphNode corresponding to that instruction. + """ + ast_node = instruction.node + node = make_node_from_ast_node(ast_node) + node.instruction = instruction + return node + + +def make_node_from_ast_node(ast_node): + """Creates a program graph node for the provided AST node. + + This is only called when the AST node doesn't already correspond to an + Instruction in the program's control flow graph. + + Args: + ast_node: An AST node from the program being analyzed. + + Returns: + A node in the program graph corresponding to the AST node. + """ + node = ProgramGraphNode() + node.node_type = pb.NodeType.AST_NODE + node.id = program_utils.unique_id() + node.ast_node = ast_node + node.ast_type = type(ast_node).__name__ + return node + + +def make_node_for_ast_list(): + node = ProgramGraphNode() + node.node_type = pb.NodeType.AST_LIST + node.id = program_utils.unique_id() + return node + + +def make_node_from_ast_value(value): + """Creates a ProgramGraphNode for the provided value. + + `value` is a primitive value appearing in a Python AST. + + For example, the number 1 in Python has AST Num(n=1). In this, the value '1' + is a primitive appearing in the AST. It gets its own ProgramGraphNode with + node_type AST_VALUE. + + Args: + value: A primitive value appearing in an AST. + + Returns: + A ProgramGraphNode corresponding to the provided value. + """ + node = ProgramGraphNode() + node.node_type = pb.NodeType.AST_VALUE + node.id = program_utils.unique_id() + node.ast_value = value + return node + + +def make_list_field_name(field_name, index): + return '{}:{}'.format(field_name, index) + + +def parse_list_field_name(list_field_name): + field_name, index = list_field_name.split(':') + index = int(index) + return field_name, index diff --git a/python_graphs/program_graph_dataclasses.py b/python_graphs/program_graph_dataclasses.py new file mode 100644 index 0000000..750ed6a --- /dev/null +++ b/python_graphs/program_graph_dataclasses.py @@ -0,0 +1,82 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The dataclasses for representing a Program Graph.""" + +import enum +from typing import List, Optional, Text +import dataclasses + + +class NodeType(enum.Enum): + UNSPECIFIED = 0 + AST_NODE = 1 + AST_LIST = 2 + AST_VALUE = 3 + SYNTAX_NODE = 4 + PLACEHOLDER = 5 + + +@dataclasses.dataclass +class Node: + """Represents a node in a program graph.""" + id: int + type: NodeType + + # If an AST node, a string that identifies what type of AST node, + # e.g. "Num" or "Expr". These are defined by the underlying AST for the + # language. + ast_type: Optional[Text] = "" + + # Primitive valued AST node, such as: + # - the name of an identifier for a Name node + # - the number attached to a Num node + # The corresponding ast_type value is the Python type of ast_value, not the + # type of the parent AST node. + ast_value_repr: Optional[Text] = "" + + # For syntax nodes, the syntax attached to the node. + syntax: Optional[Text] = "" + + +class EdgeType(enum.Enum): + """The different kinds of edges that can appear in a program graph.""" + UNSPECIFIED = 0 + CFG_NEXT = 1 + LAST_READ = 2 + LAST_WRITE = 3 + COMPUTED_FROM = 4 + RETURNS_TO = 5 + FORMAL_ARG_NAME = 6 + FIELD = 7 + SYNTAX = 8 + NEXT_SYNTAX = 9 + LAST_LEXICAL_USE = 10 + CALLS = 11 + + +@dataclasses.dataclass +class Edge: + id1: int + id2: int + type: EdgeType + field_name: Optional[Text] = None # For FIELD edges, the field name. + has_back_edge: bool = False + + +@dataclasses.dataclass +class Graph: + nodes: List[Node] + edges: List[Edge] + root_id: int diff --git a/python_graphs/program_graph_graphviz.py b/python_graphs/program_graph_graphviz.py new file mode 100644 index 0000000..02e8e4a --- /dev/null +++ b/python_graphs/program_graph_graphviz.py @@ -0,0 +1,61 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Graphviz visualizations of Program Graphs.""" + +from absl import logging # pylint: disable=unused-import +import pygraphviz +from python_graphs import program_graph_dataclasses as pb +import six + + +def to_graphviz(graph): + """Creates a graphviz representation of a ProgramGraph. + + Args: + graph: A ProgramGraph object to visualize. + Returns: + A pygraphviz object representing the ProgramGraph. + """ + g = pygraphviz.AGraph(strict=False, directed=True) + for unused_key, node in graph.nodes.items(): + node_attrs = {} + if node.ast_type: + node_attrs['label'] = six.ensure_binary(node.ast_type, 'utf-8') + else: + node_attrs['shape'] = 'point' + node_type_colors = { + } + if node.node_type in node_type_colors: + node_attrs['color'] = node_type_colors[node.node_type] + node_attrs['colorscheme'] = 'svg' + + g.add_node(node.id, **node_attrs) + for edge in graph.edges: + edge_attrs = {} + edge_attrs['label'] = edge.type.name + edge_colors = { + pb.EdgeType.LAST_READ: 'red', + pb.EdgeType.LAST_WRITE: 'red', + } + if edge.type in edge_colors: + edge_attrs['color'] = edge_colors[edge.type] + edge_attrs['colorscheme'] = 'svg' + g.add_edge(edge.id1, edge.id2, **edge_attrs) + return g + + +def render(graph, path='/tmp/graph.png'): + g = to_graphviz(graph) + g.draw(path, prog='dot') diff --git a/python_graphs/program_graph_graphviz_test.py b/python_graphs/program_graph_graphviz_test.py new file mode 100644 index 0000000..56ece83 --- /dev/null +++ b/python_graphs/program_graph_graphviz_test.py @@ -0,0 +1,34 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for program_graph_graphviz.py.""" + +import inspect + +from absl.testing import absltest +from python_graphs import control_flow_test_components as tc +from python_graphs import program_graph +from python_graphs import program_graph_graphviz + + +class ControlFlowGraphvizTest(absltest.TestCase): + + def test_to_graphviz_for_all_test_components(self): + for unused_name, fn in inspect.getmembers(tc, predicate=inspect.isfunction): + graph = program_graph.get_program_graph(fn) + program_graph_graphviz.to_graphviz(graph) + + +if __name__ == '__main__': + absltest.main() diff --git a/python_graphs/program_graph_test.py b/python_graphs/program_graph_test.py new file mode 100644 index 0000000..6fb4e61 --- /dev/null +++ b/python_graphs/program_graph_test.py @@ -0,0 +1,292 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for program_graph.py.""" + +import collections +import inspect +import time + +from absl import logging +from absl.testing import absltest +import gast as ast + +from python_graphs import control_flow_test_components as cftc +from python_graphs import program_graph +from python_graphs import program_graph_dataclasses as pb +from python_graphs import program_graph_test_components as pgtc +from python_graphs import program_utils + + +def test_components(): + """Generates functions from two sets of test components. + + Yields: + Functions from the program graph and control flow test components files. + """ + for unused_name, fn in inspect.getmembers(pgtc, predicate=inspect.isfunction): + yield fn + + for unused_name, fn in inspect.getmembers(cftc, predicate=inspect.isfunction): + yield fn + + +class ProgramGraphTest(absltest.TestCase): + + def assertEdge(self, graph, n1, n2, edge_type): + """Asserts that an edge of type edge_type exists from n1 to n2 in graph.""" + edge = pb.Edge(id1=n1.id, id2=n2.id, type=edge_type) + self.assertIn(edge, graph.edges) + + def assertNoEdge(self, graph, n1, n2, edge_type): + """Asserts that no edge of type edge_type exists from n1 to n2 in graph.""" + edge = pb.Edge(id1=n1.id, id2=n2.id, type=edge_type) + self.assertNotIn(edge, graph.edges) + + def test_get_program_graph_test_components(self): + self.analyze_get_program_graph(test_components(), start=0) + + def analyze_get_program_graph(self, program_generator, start=0): + # TODO(dbieber): Remove the counting and logging logic from this method, + # and instead just get_program_graph for each program in the generator. + # The counting and logging logic is for development purposes only. + num_edges = 0 + num_edges_by_type = collections.defaultdict(int) + num_nodes = 0 + num_graphs = 1 + times = {} + for index, program in enumerate(program_generator): + if index < start: + continue + start_time = time.time() + graph = program_graph.get_program_graph(program) + end_time = time.time() + times[index] = end_time - start_time + num_edges += len(graph.edges) + for edge in graph.edges: + num_edges_by_type[edge.type] += 1 + num_nodes += len(graph.nodes) + num_graphs += 1 + if index % 100 == 0: + logging.debug(sorted(times.items(), key=lambda kv: -kv[1])[:10]) + logging.info('%d %d %d', num_edges, num_nodes, num_graphs) + logging.info('%f %f', num_edges / num_graphs, num_nodes / num_graphs) + for edge_type in num_edges_by_type: + logging.info('%s %f', edge_type, + num_edges_by_type[edge_type] / num_graphs) + + logging.info(times) + logging.info(sorted(times.items(), key=lambda kv: -kv[1])[:10]) + + def test_last_lexical_use_edges_function_call(self): + graph = program_graph.get_program_graph(pgtc.function_call) + read = graph.get_node_by_source_and_identifier('return z', 'z') + write = graph.get_node_by_source_and_identifier( + 'z = function_call_helper(x, y)', 'z') + self.assertEdge(graph, read, write, pb.EdgeType.LAST_LEXICAL_USE) + + def test_last_write_edges_function_call(self): + graph = program_graph.get_program_graph(pgtc.function_call) + write_z = graph.get_node_by_source_and_identifier( + 'z = function_call_helper(x, y)', 'z') + read_z = graph.get_node_by_source_and_identifier('return z', 'z') + self.assertEdge(graph, read_z, write_z, pb.EdgeType.LAST_WRITE) + + write_y = graph.get_node_by_source_and_identifier('y = 2', 'y') + read_y = graph.get_node_by_source_and_identifier( + 'z = function_call_helper(x, y)', 'y') + self.assertEdge(graph, read_y, write_y, pb.EdgeType.LAST_WRITE) + + def test_last_read_edges_assignments(self): + graph = program_graph.get_program_graph(pgtc.assignments) + write_a0 = graph.get_node_by_source_and_identifier('a, b = 0, 0', 'a') + read_a0 = graph.get_node_by_source_and_identifier('c = 2 * a + 1', 'a') + write_a1 = graph.get_node_by_source_and_identifier('a = c + 3', 'a') + self.assertEdge(graph, write_a1, read_a0, pb.EdgeType.LAST_READ) + self.assertNoEdge(graph, write_a0, read_a0, pb.EdgeType.LAST_READ) + + read_a1 = graph.get_node_by_source_and_identifier('return a, b, c, d', 'a') + self.assertEdge(graph, read_a1, read_a0, pb.EdgeType.LAST_READ) + + def test_last_read_last_write_edges_repeated_identifier(self): + graph = program_graph.get_program_graph(pgtc.repeated_identifier) + write_x0 = graph.get_node_by_source_and_identifier('x = 0', 'x') + + stmt1 = graph.get_node_by_source('x = x + 1').ast_node + read_x0 = graph.get_node_by_ast_node(stmt1.value.left) + write_x1 = graph.get_node_by_ast_node(stmt1.targets[0]) + + stmt2 = graph.get_node_by_source('x = (x + (x + x)) + x').ast_node + read_x1 = graph.get_node_by_ast_node(stmt2.value.left.left) + read_x2 = graph.get_node_by_ast_node(stmt2.value.left.right.left) + read_x3 = graph.get_node_by_ast_node(stmt2.value.left.right.right) + read_x4 = graph.get_node_by_ast_node(stmt2.value.right) + write_x2 = graph.get_node_by_ast_node(stmt2.targets[0]) + + read_x5 = graph.get_node_by_source_and_identifier('return x', 'x') + + self.assertEdge(graph, write_x1, read_x0, pb.EdgeType.LAST_READ) + self.assertEdge(graph, read_x1, read_x0, pb.EdgeType.LAST_READ) + self.assertEdge(graph, read_x2, read_x1, pb.EdgeType.LAST_READ) + self.assertEdge(graph, read_x3, read_x2, pb.EdgeType.LAST_READ) + self.assertEdge(graph, read_x4, read_x3, pb.EdgeType.LAST_READ) + self.assertEdge(graph, write_x2, read_x4, pb.EdgeType.LAST_READ) + self.assertEdge(graph, read_x5, read_x4, pb.EdgeType.LAST_READ) + + self.assertEdge(graph, read_x0, write_x0, pb.EdgeType.LAST_WRITE) + self.assertEdge(graph, write_x1, write_x0, pb.EdgeType.LAST_WRITE) + self.assertEdge(graph, read_x2, write_x1, pb.EdgeType.LAST_WRITE) + self.assertEdge(graph, read_x3, write_x1, pb.EdgeType.LAST_WRITE) + self.assertEdge(graph, read_x4, write_x1, pb.EdgeType.LAST_WRITE) + self.assertEdge(graph, write_x2, write_x1, pb.EdgeType.LAST_WRITE) + self.assertEdge(graph, read_x5, write_x2, pb.EdgeType.LAST_WRITE) + + def test_computed_from_edges(self): + graph = program_graph.get_program_graph(pgtc.assignments) + target_c = graph.get_node_by_source_and_identifier('c = 2 * a + 1', 'c') + from_a = graph.get_node_by_source_and_identifier('c = 2 * a + 1', 'a') + self.assertEdge(graph, target_c, from_a, pb.EdgeType.COMPUTED_FROM) + + target_d = graph.get_node_by_source_and_identifier('d = b - c + 2', 'd') + from_b = graph.get_node_by_source_and_identifier('d = b - c + 2', 'b') + from_c = graph.get_node_by_source_and_identifier('d = b - c + 2', 'c') + self.assertEdge(graph, target_d, from_b, pb.EdgeType.COMPUTED_FROM) + self.assertEdge(graph, target_d, from_c, pb.EdgeType.COMPUTED_FROM) + + def test_calls_edges(self): + graph = program_graph.get_program_graph(pgtc) + call = graph.get_node_by_source('function_call_helper(x, y)') + self.assertIsInstance(call.node, ast.Call) + function_call_helper_def = graph.get_node_by_function_name( + 'function_call_helper') + assignments_def = graph.get_node_by_function_name('assignments') + self.assertEdge(graph, call, function_call_helper_def, pb.EdgeType.CALLS) + self.assertNoEdge(graph, call, assignments_def, pb.EdgeType.CALLS) + + def test_formal_arg_name_edges(self): + graph = program_graph.get_program_graph(pgtc) + x = graph.get_node_by_source_and_identifier('function_call_helper(x, y)', + 'x') + y = graph.get_node_by_source_and_identifier('function_call_helper(x, y)', + 'y') + function_call_helper_def = graph.get_node_by_function_name( + 'function_call_helper') + arg0_ast_node = function_call_helper_def.node.args.args[0] + arg0 = graph.get_node_by_ast_node(arg0_ast_node) + arg1_ast_node = function_call_helper_def.node.args.args[1] + arg1 = graph.get_node_by_ast_node(arg1_ast_node) + self.assertEdge(graph, x, arg0, pb.EdgeType.FORMAL_ARG_NAME) + self.assertEdge(graph, y, arg1, pb.EdgeType.FORMAL_ARG_NAME) + self.assertNoEdge(graph, x, arg1, pb.EdgeType.FORMAL_ARG_NAME) + self.assertNoEdge(graph, y, arg0, pb.EdgeType.FORMAL_ARG_NAME) + + def test_returns_to_edges(self): + graph = program_graph.get_program_graph(pgtc) + call = graph.get_node_by_source('function_call_helper(x, y)') + return_stmt = graph.get_node_by_source('return arg0 + arg1') + self.assertEdge(graph, return_stmt, call, pb.EdgeType.RETURNS_TO) + + def test_syntax_information(self): + # TODO(dbieber): Test that program graphs correctly capture syntax + # information. Do this once representation of syntax in program graphs + # stabilizes. + pass + + def test_ast_acyclic(self): + for name, fn in inspect.getmembers(cftc, predicate=inspect.isfunction): + graph = program_graph.get_program_graph(fn) + ast_nodes = set() + worklist = [graph.root] + while worklist: + current = worklist.pop() + self.assertNotIn( + current, ast_nodes, + 'ProgramGraph AST cyclic. Function {}\nAST {}'.format( + name, graph.dump_tree())) + ast_nodes.add(current) + worklist.extend(graph.children(current)) + + def test_neighbors_children_consistent(self): + for unused_name, fn in inspect.getmembers( + cftc, predicate=inspect.isfunction): + graph = program_graph.get_program_graph(fn) + for node in graph.all_nodes(): + if node.node_type == pb.NodeType.AST_NODE: + children0 = set(graph.outgoing_neighbors(node, pb.EdgeType.FIELD)) + children1 = set(graph.children(node)) + self.assertEqual(children0, children1) + + def test_walk_ast_descendants(self): + for unused_name, fn in inspect.getmembers( + cftc, predicate=inspect.isfunction): + graph = program_graph.get_program_graph(fn) + for node in graph.walk_ast_descendants(): + self.assertIn(node, graph.all_nodes()) + + def test_roundtrip_ast(self): + for unused_name, fn in inspect.getmembers( + cftc, predicate=inspect.isfunction): + ast_representation = program_utils.program_to_ast(fn) + graph = program_graph.get_program_graph(fn) + ast_reproduction = graph.to_ast() + self.assertEqual(ast.dump(ast_representation), ast.dump(ast_reproduction)) + + def test_reconstruct_missing_ast(self): + for unused_name, fn in inspect.getmembers( + cftc, predicate=inspect.isfunction): + graph = program_graph.get_program_graph(fn) + ast_original = graph.root.ast_node + # Remove the AST. + for node in graph.all_nodes(): + node.ast_node = None + # Reconstruct it. + graph.reconstruct_ast() + ast_reproduction = graph.root.ast_node + # Check reconstruction. + self.assertEqual(ast.dump(ast_original), ast.dump(ast_reproduction)) + # Check that all AST_NODE nodes are set. + for node in graph.all_nodes(): + if node.node_type == pb.NodeType.AST_NODE: + self.assertIsInstance(node.ast_node, ast.AST) + self.assertIs(graph.get_node_by_ast_node(node.ast_node), node) + # Check that old AST nodes are no longer referenced. + self.assertFalse(graph.contains_ast_node(ast_original)) + + def test_remove(self): + graph = program_graph.get_program_graph(pgtc.assignments) + + for edge in list(graph.edges)[:]: + # Remove the edge. + graph.remove_edge(edge) + self.assertNotIn(edge, graph.edges) + self.assertNotIn((edge, edge.id2), graph.neighbors_map[edge.id1]) + self.assertNotIn((edge, edge.id1), graph.neighbors_map[edge.id2]) + + if edge.type == pb.EdgeType.FIELD: + self.assertNotIn(edge.id2, graph.child_map[edge.id1]) + self.assertNotIn(edge.id2, graph.parent_map) + + # Add the edge again. + graph.add_edge(edge) + self.assertIn(edge, graph.edges) + self.assertIn((edge, edge.id2), graph.neighbors_map[edge.id1]) + self.assertIn((edge, edge.id1), graph.neighbors_map[edge.id2]) + + if edge.type == pb.EdgeType.FIELD: + self.assertIn(edge.id2, graph.child_map[edge.id1]) + self.assertIn(edge.id2, graph.parent_map) + + +if __name__ == '__main__': + absltest.main() diff --git a/python_graphs/program_graph_test_components.py b/python_graphs/program_graph_test_components.py new file mode 100644 index 0000000..a396da4 --- /dev/null +++ b/python_graphs/program_graph_test_components.py @@ -0,0 +1,61 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test components for testing program graphs.""" + + +# pylint: disable=missing-docstring +# pylint: disable=pointless-statement,undefined-variable +# pylint: disable=unused-variable,unused-argument +# pylint: disable=bare-except,lost-exception,unreachable +# pylint: disable=global-variable-undefined +def function_call(): + x = 1 + y = 2 + z = function_call_helper(x, y) + return z + + +def function_call_helper(arg0, arg1): + return arg0 + arg1 + + +def assignments(): + a, b = 0, 0 + c = 2 * a + 1 + d = b - c + 2 + a = c + 3 + return a, b, c, d + + +def fn_with_globals(): + global global_a, global_b, global_c + global_a = 10 + global_b = 20 + global_c = 30 + return global_a + global_b + global_c + + +def fn_with_inner_fn(): + + def inner_fn(): + while True: + pass + + +def repeated_identifier(): + x = 0 + x = x + 1 + x = (x + (x + x)) + x + return x diff --git a/python_graphs/program_graph_visualizer.py b/python_graphs/program_graph_visualizer.py new file mode 100644 index 0000000..7af130b --- /dev/null +++ b/python_graphs/program_graph_visualizer.py @@ -0,0 +1,51 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Create program graph visualizations for the test components. + + +Usage: +python -m python_graphs.program_graph_visualizer +""" + +import inspect + +from absl import app +from absl import logging # pylint: disable=unused-import + +from python_graphs import control_flow_test_components as tc +from python_graphs import program_graph +from python_graphs import program_graph_graphviz + + +def render_functions(functions): + for name, function in functions: + logging.info(name) + graph = program_graph.get_program_graph(function) + path = '/tmp/program_graphs/{}.png'.format(name) + program_graph_graphviz.render(graph, path=path) + + +def main(argv): + del argv # Unused. + + functions = [ + (name, fn) + for name, fn in inspect.getmembers(tc, predicate=inspect.isfunction) + ] + render_functions(functions) + + +if __name__ == '__main__': + app.run(main) diff --git a/python_graphs/program_utils.py b/python_graphs/program_utils.py new file mode 100644 index 0000000..74117e6 --- /dev/null +++ b/python_graphs/program_utils.py @@ -0,0 +1,62 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Program utility functions.""" + +import inspect +import textwrap +import uuid + +import gast as ast +import six + + +def getsource(obj): + """Gets the source for the given object. + + Args: + obj: A module, class, method, function, traceback, frame, or code object. + Returns: + The source of the object, if available. + """ + if inspect.ismethod(obj): + func = obj.__func__ + else: + func = obj + source = inspect.getsource(func) + return textwrap.dedent(source) + + +def program_to_ast(program): + """Convert a program to its AST. + + Args: + program: Either an AST node, source string, or a function. + Returns: + The root AST node of the AST representing the program. + """ + if isinstance(program, ast.AST): + return program + if isinstance(program, six.string_types): + source = program + else: + source = getsource(program) + module_node = ast.parse(source, mode='exec') + return module_node + + +def unique_id(): + """Returns a unique id that is suitable for identifying graph nodes.""" + return uuid.uuid4().int & ((1 << 64) - 1) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9c558e3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +. diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..a3abdfe --- /dev/null +++ b/setup.py @@ -0,0 +1,81 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The setup.py file for python_graphs.""" + +from setuptools import setup + +LONG_DESCRIPTION = """ +python_graphs is a static analysis tool for performing control flow and data +flow analyses on Python programs, and for constructing Program Graphs. +Python Program Graphs are graph representations of Python programs suitable +for use with graph neural networks. +""".strip() + +SHORT_DESCRIPTION = """ +A library for generating graph representations of Python programs.""".strip() + +DEPENDENCIES = [ + 'absl-py', + 'astunparse', + 'gast', + 'six', + 'pygraphviz', +] + +TEST_DEPENDENCIES = [ +] + +VERSION = '1.0.0' +URL = 'https://github.com/google-research/python-graphs' + +setup( + name='python_graphs', + version=VERSION, + description=SHORT_DESCRIPTION, + long_description=LONG_DESCRIPTION, + url=URL, + + author='David Bieber', + author_email='dbieber@google.com', + license='Apache Software License', + + classifiers=[ + 'Development Status :: 4 - Beta', + + 'Intended Audience :: Developers', + 'Topic :: Software Development :: Libraries :: Python Modules', + + 'License :: OSI Approved :: Apache Software License', + + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + + 'Operating System :: OS Independent', + 'Operating System :: POSIX', + 'Operating System :: MacOS', + 'Operating System :: Unix', + ], + + keywords='python program control flow data flow graph neural network', + + packages=['python_graphs'], + + install_requires=DEPENDENCIES, + tests_require=TEST_DEPENDENCIES, +)