diff --git a/Cargo.lock b/Cargo.lock index bb408ae3..7d06db8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,12 +17,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "anyhow" -version = "1.0.86" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" - [[package]] name = "arc-swap" version = "1.7.1" @@ -204,9 +198,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.12.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] @@ -522,7 +516,7 @@ checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" [[package]] name = "ruff_python_ast" version = "0.0.0" -source = "git+https://github.com/astral-sh/ruff?tag=v0.4.2#77c93fd63c1c072501d297082aa59c741b2d5466" +source = "git+https://github.com/astral-sh/ruff?tag=v0.4.9#4f49e918a9154de16145d77217a4af2b8ce38297" dependencies = [ "aho-corasick", "bitflags", @@ -538,15 +532,13 @@ dependencies = [ [[package]] name = "ruff_python_parser" version = "0.0.0" -source = "git+https://github.com/astral-sh/ruff?tag=v0.4.2#77c93fd63c1c072501d297082aa59c741b2d5466" +source = "git+https://github.com/astral-sh/ruff?tag=v0.4.9#4f49e918a9154de16145d77217a4af2b8ce38297" dependencies = [ - "anyhow", "bitflags", "bstr", - "is-macro", - "itertools", "memchr", "ruff_python_ast", + "ruff_python_trivia", "ruff_text_size", "rustc-hash", "static_assertions", @@ -558,7 +550,7 @@ dependencies = [ [[package]] name = "ruff_python_trivia" version = "0.0.0" -source = "git+https://github.com/astral-sh/ruff?tag=v0.4.2#77c93fd63c1c072501d297082aa59c741b2d5466" +source = "git+https://github.com/astral-sh/ruff?tag=v0.4.9#4f49e918a9154de16145d77217a4af2b8ce38297" dependencies = [ "itertools", "ruff_source_file", @@ -569,7 +561,7 @@ dependencies = [ [[package]] name = "ruff_source_file" version = "0.0.0" -source = "git+https://github.com/astral-sh/ruff?tag=v0.4.2#77c93fd63c1c072501d297082aa59c741b2d5466" +source = "git+https://github.com/astral-sh/ruff?tag=v0.4.9#4f49e918a9154de16145d77217a4af2b8ce38297" dependencies = [ "memchr", "once_cell", @@ -579,7 +571,7 @@ dependencies = [ [[package]] name = "ruff_text_size" version = "0.0.0" -source = "git+https://github.com/astral-sh/ruff?tag=v0.4.2#77c93fd63c1c072501d297082aa59c741b2d5466" +source = "git+https://github.com/astral-sh/ruff?tag=v0.4.9#4f49e918a9154de16145d77217a4af2b8ce38297" [[package]] name = "rustc-hash" diff --git a/Cargo.toml b/Cargo.toml index 94400ed4..3f32db63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,10 +18,10 @@ pyo3 = { version = "0.21.2", features = ["abi3-py38"] } pyo3-log = "0.10.0" rayon = "1.10.0" regex = "1.10.5" -ruff_python_ast = { git = "https://github.com/astral-sh/ruff", tag = "v0.4.2" } -ruff_python_parser = { git = "https://github.com/astral-sh/ruff", tag = "v0.4.2" } -ruff_source_file = { git = "https://github.com/astral-sh/ruff", tag = "v0.4.2" } -ruff_text_size = { git = "https://github.com/astral-sh/ruff", tag = "v0.4.2" } +ruff_python_ast = { git = "https://github.com/astral-sh/ruff", tag = "v0.4.9" } +ruff_python_parser = { git = "https://github.com/astral-sh/ruff", tag = "v0.4.9" } +ruff_source_file = { git = "https://github.com/astral-sh/ruff", tag = "v0.4.9" } +ruff_text_size = { git = "https://github.com/astral-sh/ruff", tag = "v0.4.9" } serde_json = "1.0.117" [profile.release] diff --git a/src/imports/ipynb.rs b/src/imports/ipynb.rs index 5bed38af..cec16400 100644 --- a/src/imports/ipynb.rs +++ b/src/imports/ipynb.rs @@ -48,8 +48,8 @@ fn _get_imports_from_ipynb_file(path_str: &str) -> PyResult) -> PyRe /// Used internally by both parallel and single file processing functions. fn _get_imports_from_py_file(path_str: &str) -> PyResult>> { let file_content = read_file(path_str)?; - let ast = shared::get_ast_from_file_content(&file_content)?; - let imported_modules = shared::extract_imports_from_ast(ast); + let ast = shared::parse_file_content(&file_content)?; + let imported_modules = shared::extract_imports_from_parsed_file_content(ast); Ok(shared::convert_imports_with_textranges_to_location_objects( imported_modules, path_str, diff --git a/src/imports/shared.rs b/src/imports/shared.rs index cb6c6433..c0a01bb4 100644 --- a/src/imports/shared.rs +++ b/src/imports/shared.rs @@ -6,8 +6,8 @@ use pyo3::exceptions::PySyntaxError; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; use ruff_python_ast::visitor::Visitor; -use ruff_python_ast::Mod; -use ruff_python_parser::{parse, Mode}; +use ruff_python_ast::{Mod, ModModule}; +use ruff_python_parser::{parse, Mode, Parsed}; use ruff_source_file::LineIndex; use ruff_text_size::TextRange; use std::collections::HashMap; @@ -21,20 +21,22 @@ pub struct ThreadResult { pub result: PyResult, } -/// Parses the content of a Python file into an abstract syntax tree (AST). -pub fn get_ast_from_file_content(file_content: &str) -> PyResult { - let ast = +/// Parses the content of a Python file into a parsed source code. +pub fn parse_file_content(file_content: &str) -> PyResult> { + let parsed = parse(file_content, Mode::Module).map_err(|e| PySyntaxError::new_err(e.to_string()))?; - Ok(ast) + Ok(parsed) } -/// Iterates through an AST to identify and collect import statements, and returns them together with their -/// respective `TextRange` for each occurrence. -pub fn extract_imports_from_ast(ast: Mod) -> HashMap> { +/// Iterates through a parsed source code to identify and collect import statements, and returns them +/// together with their respective `TextRange` for each occurrence. +pub fn extract_imports_from_parsed_file_content( + parsed: Parsed, +) -> HashMap> { let mut visitor = ImportVisitor::new(); - if let Mod::Module(module) = ast { - for stmt in module.body { + if let Mod::Module(ModModule { body, .. }) = parsed.into_syntax() { + for stmt in body { visitor.visit_stmt(&stmt); } } diff --git a/src/visitor.rs b/src/visitor.rs index ac9a4cdb..e6ed0e97 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -34,7 +34,7 @@ impl<'a> Visitor<'a> for ImportVisitor { } Stmt::ImportFrom(import_from_stmt) => { if let Some(module) = &import_from_stmt.module { - if import_from_stmt.level == Some(0) { + if import_from_stmt.level == 0 { self.imports .entry(get_top_level_module_name(module.as_str())) .or_default()