Skip to content

Commit

Permalink
[isort]: support submodules in known_(first|third)_party config opt…
Browse files Browse the repository at this point in the history
…ions (#3768)
  • Loading branch information
astaric committed Mar 29, 2023
1 parent 5501fc9 commit b6f1fed
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 90 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import sys
import baz
from foo import bar, baz
from foo.bar import blah, blub
from foo.bar.baz import something
import foo
import foo.bar
import foo.bar.baz
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,15 @@ pub fn typing_only_runtime_import(
// Extract the module base and level from the full name.
// Ex) `foo.bar.baz` -> `foo`, `0`
// Ex) `.foo.bar.baz` -> `foo`, `1`
let module_base = full_name.split('.').next().unwrap();
let level = full_name.chars().take_while(|c| *c == '.').count();

// Categorize the import.
match categorize(
module_base,
full_name,
Some(&level),
&settings.src,
package,
&settings.isort.known_first_party,
&settings.isort.known_third_party,
&settings.isort.known_local_folder,
&settings.isort.extra_standard_library,
&settings.isort.known_modules,
settings.target_version,
) {
ImportType::LocalFolder | ImportType::FirstParty => Some(Diagnostic::new(
Expand Down
141 changes: 102 additions & 39 deletions crates/ruff/src/rules/isort/categorize.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::{BTreeMap, BTreeSet};
use std::fs;
use std::path::{Path, PathBuf};
use std::{fs, iter};

use log::debug;
use schemars::JsonSchema;
Expand Down Expand Up @@ -52,29 +52,21 @@ enum Reason<'a> {

#[allow(clippy::too_many_arguments)]
pub fn categorize(
module_base: &str,
module_name: &str,
level: Option<&usize>,
src: &[PathBuf],
package: Option<&Path>,
known_first_party: &BTreeSet<String>,
known_third_party: &BTreeSet<String>,
known_local_folder: &BTreeSet<String>,
extra_standard_library: &BTreeSet<String>,
known_modules: &KnownModules,
target_version: PythonVersion,
) -> ImportType {
let module_base = module_name.split('.').next().unwrap();
let (import_type, reason) = {
if level.map_or(false, |level| *level > 0) {
(ImportType::LocalFolder, Reason::NonZeroLevel)
} else if known_first_party.contains(module_base) {
(ImportType::FirstParty, Reason::KnownFirstParty)
} else if known_third_party.contains(module_base) {
(ImportType::ThirdParty, Reason::KnownThirdParty)
} else if known_local_folder.contains(module_base) {
(ImportType::LocalFolder, Reason::KnownLocalFolder)
} else if extra_standard_library.contains(module_base) {
(ImportType::StandardLibrary, Reason::ExtraStandardLibrary)
} else if module_base == "__future__" {
(ImportType::Future, Reason::Future)
} else if let Some((import_type, reason)) = known_modules.categorize(module_name) {
(import_type, reason)
} else if KNOWN_STANDARD_LIBRARY
.get(&target_version.as_tuple())
.unwrap()
Expand All @@ -91,7 +83,7 @@ pub fn categorize(
};
debug!(
"Categorized '{}' as {:?} ({:?})",
module_base, import_type, reason
module_name, import_type, reason
);
import_type
}
Expand Down Expand Up @@ -121,24 +113,18 @@ pub fn categorize_imports<'a>(
block: ImportBlock<'a>,
src: &[PathBuf],
package: Option<&Path>,
known_first_party: &BTreeSet<String>,
known_third_party: &BTreeSet<String>,
known_local_folder: &BTreeSet<String>,
extra_standard_library: &BTreeSet<String>,
known_modules: &KnownModules,
target_version: PythonVersion,
) -> BTreeMap<ImportType, ImportBlock<'a>> {
let mut block_by_type: BTreeMap<ImportType, ImportBlock> = BTreeMap::default();
// Categorize `StmtKind::Import`.
for (alias, comments) in block.import {
let import_type = categorize(
&alias.module_base(),
&alias.module_name(),
None,
src,
package,
known_first_party,
known_third_party,
known_local_folder,
extra_standard_library,
known_modules,
target_version,
);
block_by_type
Expand All @@ -150,14 +136,11 @@ pub fn categorize_imports<'a>(
// Categorize `StmtKind::ImportFrom` (without re-export).
for (import_from, aliases) in block.import_from {
let classification = categorize(
&import_from.module_base(),
&import_from.module_name(),
import_from.level,
src,
package,
known_first_party,
known_third_party,
known_local_folder,
extra_standard_library,
known_modules,
target_version,
);
block_by_type
Expand All @@ -169,14 +152,11 @@ pub fn categorize_imports<'a>(
// Categorize `StmtKind::ImportFrom` (with re-export).
for ((import_from, alias), aliases) in block.import_from_as {
let classification = categorize(
&import_from.module_base(),
&import_from.module_name(),
import_from.level,
src,
package,
known_first_party,
known_third_party,
known_local_folder,
extra_standard_library,
known_modules,
target_version,
);
block_by_type
Expand All @@ -188,14 +168,11 @@ pub fn categorize_imports<'a>(
// Categorize `StmtKind::ImportFrom` (with star).
for (import_from, comments) in block.import_from_star {
let classification = categorize(
&import_from.module_base(),
&import_from.module_name(),
import_from.level,
src,
package,
known_first_party,
known_third_party,
known_local_folder,
extra_standard_library,
known_modules,
target_version,
);
block_by_type
Expand All @@ -206,3 +183,89 @@ pub fn categorize_imports<'a>(
}
block_by_type
}

#[derive(Debug, Default, CacheKey)]
pub struct KnownModules {
/// A set of user-provided first-party modules.
pub first_party: BTreeSet<String>,
/// A set of user-provided third-party modules.
pub third_party: BTreeSet<String>,
/// A set of user-provided local folder modules.
pub local_folder: BTreeSet<String>,
/// A set of user-provided standard library modules.
pub standard_library: BTreeSet<String>,
/// Whether any of the known modules are submodules (e.g., `foo.bar`, as opposed to `foo`).
has_submodules: bool,
}

impl KnownModules {
pub fn new(
first_party: Vec<String>,
third_party: Vec<String>,
local_folder: Vec<String>,
standard_library: Vec<String>,
) -> Self {
let first_party = BTreeSet::from_iter(first_party);
let third_party = BTreeSet::from_iter(third_party);
let local_folder = BTreeSet::from_iter(local_folder);
let standard_library = BTreeSet::from_iter(standard_library);
let has_submodules = first_party
.iter()
.chain(third_party.iter())
.chain(local_folder.iter())
.chain(standard_library.iter())
.any(|m| m.contains('.'));
Self {
first_party,
third_party,
local_folder,
standard_library,
has_submodules,
}
}

/// Return the [`ImportType`] for a given module, if it's been categorized as a known module
/// by the user.
fn categorize(&self, module_name: &str) -> Option<(ImportType, Reason)> {
if self.has_submodules {
// Check all module prefixes from the longest to the shortest (e.g., given
// `foo.bar.baz`, check `foo.bar.baz`, then `foo.bar`, then `foo`, taking the first,
// most precise match).
for i in module_name
.match_indices('.')
.map(|(i, _)| i)
.chain(iter::once(module_name.len()))
.rev()
{
let submodule = &module_name[0..i];
if self.first_party.contains(submodule) {
return Some((ImportType::FirstParty, Reason::KnownFirstParty));
}
if self.third_party.contains(submodule) {
return Some((ImportType::ThirdParty, Reason::KnownThirdParty));
}
if self.local_folder.contains(submodule) {
return Some((ImportType::LocalFolder, Reason::KnownLocalFolder));
}
if self.standard_library.contains(submodule) {
return Some((ImportType::StandardLibrary, Reason::ExtraStandardLibrary));
}
}
None
} else {
// Happy path: no submodules, so we can check the module base and be done.
let module_base = module_name.split('.').next().unwrap();
if self.first_party.contains(module_base) {
Some((ImportType::FirstParty, Reason::KnownFirstParty))
} else if self.third_party.contains(module_base) {
Some((ImportType::ThirdParty, Reason::KnownThirdParty))
} else if self.local_folder.contains(module_base) {
Some((ImportType::LocalFolder, Reason::KnownLocalFolder))
} else if self.standard_library.contains(module_base) {
Some((ImportType::StandardLibrary, Reason::ExtraStandardLibrary))
} else {
None
}
}
}
}

0 comments on commit b6f1fed

Please sign in to comment.