diff --git a/docs/images/import-loops.png b/docs/images/import-loops.png new file mode 100644 index 000000000..e8dea1ef0 Binary files /dev/null and b/docs/images/import-loops.png differ diff --git a/docs/images/large-import-loop.png b/docs/images/large-import-loop.png new file mode 100644 index 000000000..6cd3ce710 Binary files /dev/null and b/docs/images/large-import-loop.png differ diff --git a/docs/images/problematic-import-loop.png b/docs/images/problematic-import-loop.png new file mode 100644 index 000000000..d920c3d9d Binary files /dev/null and b/docs/images/problematic-import-loop.png differ diff --git a/docs/images/valid-import-loop.png b/docs/images/valid-import-loop.png new file mode 100644 index 000000000..384961507 Binary files /dev/null and b/docs/images/valid-import-loop.png differ diff --git a/docs/introduction/getting-started.mdx b/docs/introduction/getting-started.mdx index cf7b8a1c6..bf2f0e309 100644 --- a/docs/introduction/getting-started.mdx +++ b/docs/introduction/getting-started.mdx @@ -17,7 +17,7 @@ uv tool install codegen ## Quick Start with Jupyter -The [codgen notebook](/cli/notebook) command creates a virtual environment and opens a Jupyter notebook for quick prototyping. This is often the fastest way to get up and running. +The [codegen notebook](/cli/notebook) command creates a virtual environment and opens a Jupyter notebook for quick prototyping. This is often the fastest way to get up and running. ```bash # Launch Jupyter with a demo notebook diff --git a/docs/mint.json b/docs/mint.json index 7bf73f9bf..bf30e7086 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -103,6 +103,7 @@ "tutorials/react-modernization", "tutorials/unittest-to-pytest", "tutorials/sqlalchemy-1.6-to-2.0", + "tutorials/fixing-import-loops-in-pytorch", "tutorials/python2-to-python3", "tutorials/flask-to-fastapi" ] diff --git a/docs/tutorials/fixing-import-loops-in-pytorch.mdx b/docs/tutorials/fixing-import-loops-in-pytorch.mdx new file mode 100644 index 000000000..e196f72f1 --- /dev/null +++ b/docs/tutorials/fixing-import-loops-in-pytorch.mdx @@ -0,0 +1,260 @@ +--- +title: "Fixing Import Loops" +description: "Learn how to identify and fix problematic import loops using Codegen." +icon: "arrows-rotate" +iconType: "solid" +--- + + + + + +Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain. + +In this tutorial, we'll explore how to identify and fix problematic import cycles using Codegen. + + +You can find the complete example code in our [examples repository](https://github.com/codegen-sh/codegen-examples/tree/main/examples/removing_import_loops_in_pytorch). + + +## Overview + +The steps to identify and fix import loops are as follows: +1. Detect import loops +2. Visualize them +3. Identify problematic cycles with mixed static/dynamic imports +4. Fix these cycles using Codegen + +# Step 1: Detect Import Loops +- Create a graph +- Loop through imports in the codebase and add edges between the import files +- Find strongly connected components using Networkx (the import loops) +```python +G = nx.MultiDiGraph() + +# Add all edges to the graph +for imp in codebase.imports: + if imp.from_file and imp.to_file: + edge_color = "red" if imp.is_dynamic else "black" + edge_label = "dynamic" if imp.is_dynamic else "static" + + # Store the import statement and its metadata + G.add_edge( + imp.to_file.filepath, + imp.from_file.filepath, + color=edge_color, + label=edge_label, + is_dynamic=imp.is_dynamic, + import_statement=imp, # Store the whole import object + key=id(imp.import_statement), + ) +# Find strongly connected components +cycles = [scc for scc in nx.strongly_connected_components(G) if len(scc) > 1] + +print(f"šŸ”„ Found {len(cycles)} import cycles:") +for i, cycle in enumerate(cycles, 1): + print(f"\nCycle #{i}:") + print(f"Size: {len(cycle)} files") + + # Create subgraph for this cycle to count edges + cycle_subgraph = G.subgraph(cycle) + + # Count total edges + total_edges = cycle_subgraph.number_of_edges() + print(f"Total number of imports in cycle: {total_edges}") + + # Count dynamic and static imports separately + dynamic_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "red") + static_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "black") + + print(f"Number of dynamic imports: {dynamic_imports}") + print(f"Number of static imports: {static_imports}") +``` + + +## Understanding Import Cycles + +Not all import cycles are problematic! Here's an example of a cycle that one may think would cause an error but it does not because due to using dynamic imports. + +```python +# top level import in in APoT_tensor.py +from quantizer.py import objectA +``` + +```python +# dynamic import in quantizer.py +def some_func(): + # dynamic import (evaluated when some_func() is called) + from APoT_tensor.py import objectB +``` + + + +A dynamic import is an import defined inside of a function, method or any executable body of code which delays the import execution until that function, method or body of code is called. + +You can use `imp.is_dynamic` to check if the import is dynamic allowing you to investigate imports that are handled more intentionally. + +# Step 2: Visualize Import Loops +- Create a new subgraph to visualize one cycle +- color and label the edges based on their type (dynamic/static) +- visualize the cycle graph using `codebase.visualize(graph)` + +```python +cycle = cycles[0] + +def create_single_loop_graph(cycle): + cycle_graph = nx.MultiDiGraph() # Changed to MultiDiGraph to support multiple edges + cycle = list(cycle) + for i in range(len(cycle)): + for j in range(len(cycle)): + # Get all edges between these nodes from original graph + edge_data_dict = G.get_edge_data(cycle[i], cycle[j]) + if edge_data_dict: + # For each edge between these nodes + for edge_key, edge_data in edge_data_dict.items(): + # Add edge with all its attributes to cycle graph + cycle_graph.add_edge(cycle[i], cycle[j], **edge_data) + return cycle_graph + + +cycle_graph = create_single_loop_graph(cycle) +codebase.visualize(cycle_graph) +``` + + + + + + +# Step 3: Identify problematic cycles with mixed static & dynamic imports + +The import loops that we are really concerned about are those that have mixed static/dynamic imports. + +Here's an example of a problematic cycle that we want to fix: + +```python +# In flex_decoding.py +from .flex_attention import ( + compute_forward_block_mn, + compute_forward_inner, + # ... more static imports +) + +# Also in flex_decoding.py +def create_flex_decoding_kernel(*args, **kwargs): + from .flex_attention import set_head_dim_values # dynamic import +``` + +It's clear that there is both a top level and a dynamic import that imports from the *same* module. Thus, this can cause issues if not handled carefully. + + + +Let's find these problematic cycles: + +```python +def find_problematic_import_loops(G, sccs): + """Find cycles where files have both static and dynamic imports between them.""" + problematic_cycles = [] + + for i, scc in enumerate(sccs): + if i == 2: # skipping the second import loop as it's incredibly long (it's also invalid) + continue + mixed_import_files = {} # (from_file, to_file) -> {dynamic: count, static: count} + + # Check all file pairs in the cycle + for from_file in scc: + for to_file in scc: + if G.has_edge(from_file, to_file): + # Get all edges between these files + edges = G.get_edge_data(from_file, to_file) + + # Count imports by type + dynamic_count = sum(1 for e in edges.values() if e["color"] == "red") + static_count = sum(1 for e in edges.values() if e["color"] == "black") + + # If we have both types between same files, this is problematic + if dynamic_count > 0 and static_count > 0: + mixed_import_files[(from_file, to_file)] = {"dynamic": dynamic_count, "static": static_count, "edges": edges} + + if mixed_import_files: + problematic_cycles.append({"files": scc, "mixed_imports": mixed_import_files, "index": i}) + + # Print findings + print(f"Found {len(problematic_cycles)} cycles with mixed imports:") + for i, cycle in enumerate(problematic_cycles): + print(f"\nāš ļø Problematic Cycle #{i + 1}:") + print(f"\nāš ļø Index #{cycle['index']}:") + print(f"Size: {len(cycle['files'])} files") + + for (from_file, to_file), data in cycle["mixed_imports"].items(): + print("\nšŸ“ Mixed imports detected:") + print(f" From: {from_file}") + print(f" To: {to_file}") + print(f" Dynamic imports: {data['dynamic']}") + print(f" Static imports: {data['static']}") + + return problematic_cycles + +problematic_cycles = find_problematic_import_loops(G, cycles) +``` + +# Step 4: Fix the loop by moving the shared symbols to a separate `utils.py` file +One common fix to this problem to break this cycle is to move all the shared symbols to a separate `utils.py` file. We can do this using the method `symbol.move_to_file`: + +```python +# Create new utils file +utils_file = codebase.create_file("torch/_inductor/kernel/flex_utils.py") + +# Get the two files involved in the import cycle +decoding_file = codebase.get_file("torch/_inductor/kernel/flex_decoding.py") +attention_file = codebase.get_file("torch/_inductor/kernel/flex_attention.py") +attention_file_path = "torch/_inductor/kernel/flex_attention.py" +decoding_file_path = "torch/_inductor/kernel/flex_decoding.py" + +# Track symbols to move +symbols_to_move = set() + +# Find imports from flex_attention in flex_decoding +for imp in decoding_file.imports: + if imp.from_file and imp.from_file.filepath == attention_file_path: + # Get the actual symbol from flex_attention + if imp.imported_symbol: + symbols_to_move.add(imp.imported_symbol) + +# Move identified symbols to utils file +for symbol in symbols_to_move: + symbol.move_to_file(utils_file) + +print(f"šŸ”„ Moved {len(symbols_to_move)} symbols to flex_utils.py") +for symbol in symbols_to_move: + print(symbol.name) +``` + +```python +# run this command to have the changes take effect in the codebase +codebase.commit() +``` + +Next Steps +Verify all tests pass after the migration and fix other problematic import loops using the suggested strategies: + 1. Move the shared symbols to a separate file + 2. If a module needs imports only for type hints, consider using `if TYPE_CHECKING` from the `typing` module + 3. Use lazy imports using `importlib` to load imports dynamically \ No newline at end of file diff --git a/docs/tutorials/migrating-apis.mdx b/docs/tutorials/migrating-apis.mdx index 57dea5fd4..a0e68da95 100644 --- a/docs/tutorials/migrating-apis.mdx +++ b/docs/tutorials/migrating-apis.mdx @@ -1,7 +1,7 @@ --- title: "Migrating APIs" sidebarTitle: "API Migrations" -icon: "arrows-rotate" +icon: "webhook" iconType: "solid" --- diff --git a/src/codegen/cli/utils/notebooks.py b/src/codegen/cli/utils/notebooks.py index 2414eeff7..2311cabdf 100644 --- a/src/codegen/cli/utils/notebooks.py +++ b/src/codegen/cli/utils/notebooks.py @@ -20,7 +20,7 @@ ] DEMO_CELLS = [ - ##### [ CODGEN DEMO ] ##### + ##### [ CODEGEN DEMO ] ##### { "cell_type": "markdown", "source": """# Codegen Demo: FastAPI diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/.gitignore b/src/codegen/sdk/typescript/external/typescript_analyzer/.gitignore index 09f6195d9..930dd1b95 100644 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/.gitignore +++ b/src/codegen/sdk/typescript/external/typescript_analyzer/.gitignore @@ -1,3 +1,4 @@ # Typescript Analyzer Specific GitIgnores node_modules dist +package-lock.json diff --git a/tests/unit/codegen/sdk/benchmark/codebase/test_codebase_reset.py b/tests/unit/codegen/sdk/benchmark/codebase/test_codebase_reset.py index 9ac955db6..317b3b2c8 100644 --- a/tests/unit/codegen/sdk/benchmark/codebase/test_codebase_reset.py +++ b/tests/unit/codegen/sdk/benchmark/codebase/test_codebase_reset.py @@ -38,6 +38,7 @@ def setup(): @pytest.mark.skip("Skipping this test for now") @pytest.mark.timeout(5, func_only=True) +@pytest.mark.skip(reason="Test is timing out and needs investigation") # Skip this test for now @pytest.mark.parametrize("extension", ["txt", "py"]) def test_codebase_reset_correctness(extension: str, tmp_path): codebase, files = setup_codebase(NUM_FILES, extension, tmp_path)