diff --git a/rewrite-python/src/main/java/org/openrewrite/python/RequirementsTxtParser.java b/rewrite-python/src/main/java/org/openrewrite/python/RequirementsTxtParser.java index 91485fa7c02..876d65f91b5 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/RequirementsTxtParser.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/RequirementsTxtParser.java @@ -19,6 +19,7 @@ import org.openrewrite.ExecutionContext; import org.openrewrite.Parser; import org.openrewrite.SourceFile; +import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.marker.PythonResolutionResult; import org.openrewrite.python.marker.PythonResolutionResult.Dependency; import org.openrewrite.python.marker.PythonResolutionResult.PackageManager; @@ -85,7 +86,8 @@ public Stream parseInputs(Iterable sources, @Nullable Path re return sf; } - List deps = dependenciesFromResolved(resolvedDeps); + List deps = dependenciesFromResolved(resolvedDeps, + parseDeclaredPackageNames(text.getText())); PythonResolutionResult marker = new PythonResolutionResult( randomId(), @@ -142,7 +144,11 @@ static List parseFreezeLines(String freezeContent) { * are treated as direct so that client code traversing {@code getDependencies()} finds every package. */ public static List dependenciesFromResolved(List resolved) { - // Collect all packages that appear as a transitive dependency of another package + return dependenciesFromResolved(resolved, Collections.emptySet()); + } + + public static List dependenciesFromResolved(List resolved, + Set declaredPackageNames) { Set transitive = new HashSet<>(); for (ResolvedDependency r : resolved) { if (r.getDependencies() != null) { @@ -154,13 +160,31 @@ public static List dependenciesFromResolved(List List deps = new ArrayList<>(); for (ResolvedDependency r : resolved) { - if (transitive.isEmpty() || !transitive.contains(PythonResolutionResult.normalizeName(r.getName()))) { + String normalizedName = PythonResolutionResult.normalizeName(r.getName()); + if (transitive.isEmpty() || + !transitive.contains(normalizedName) || + declaredPackageNames.contains(normalizedName)) { deps.add(new Dependency(r.getName(), "==" + r.getVersion(), null, null, r)); } } return deps; } + static Set parseDeclaredPackageNames(String requirementsTxtContent) { + Set names = new HashSet<>(); + for (String line : requirementsTxtContent.split("\n")) { + String trimmed = line.trim(); + if (trimmed.isEmpty() || trimmed.startsWith("#") || trimmed.startsWith("-")) { + continue; + } + String name = PyProjectHelper.extractPackageName(trimmed); + if (name != null) { + names.add(PythonResolutionResult.normalizeName(name)); + } + } + return names; + } + /** * Link transitive dependencies by reading installed package METADATA files from site-packages. * Uses a two-pass approach: first builds a name→entry map, then reads each package's diff --git a/rewrite-python/src/main/java/org/openrewrite/python/internal/PyProjectHelper.java b/rewrite-python/src/main/java/org/openrewrite/python/internal/PyProjectHelper.java index 733b502636d..5d5268f0724 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/internal/PyProjectHelper.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/internal/PyProjectHelper.java @@ -65,7 +65,7 @@ public static String normalizeVersionConstraint(String version) { int end = 0; while (end < trimmed.length()) { char c = trimmed.charAt(end); - if (c == '[' || c == '>' || c == '<' || c == '=' || c == '!' || c == '~' || c == ';' || c == ' ') { + if (c == '[' || c == '>' || c == '<' || c == '=' || c == '!' || c == '~' || c == ';' || c == ' ' || c == '@') { break; } end++; diff --git a/rewrite-python/src/test/java/org/openrewrite/python/RequirementsTxtParserTest.java b/rewrite-python/src/test/java/org/openrewrite/python/RequirementsTxtParserTest.java index f10249e8bd9..76b5fac4d60 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/RequirementsTxtParserTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/RequirementsTxtParserTest.java @@ -33,6 +33,7 @@ import java.nio.file.Paths; import java.util.Collections; import java.util.List; +import java.util.Set; import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -107,6 +108,57 @@ void dependenciesFromResolvedExcludesTransitives() { assertThat(deps.get(0).getResolved()).isSameAs(requests); } + @Test + void dependenciesFromResolvedTreatsDeclaredPackagesAsDirect() { + ResolvedDependency certifi = new ResolvedDependency("certifi", "2024.2.2", null, null); + ResolvedDependency requests = new ResolvedDependency("requests", "2.31.0", null, List.of(certifi)); + + List resolved = List.of(certifi, requests); + Set declared = RequirementsTxtParser.parseDeclaredPackageNames( + "certifi==2024.2.2\nrequests==2.31.0\n"); + List deps = RequirementsTxtParser.dependenciesFromResolved(resolved, declared); + + assertThat(deps).hasSize(2); + assertThat(deps.get(0).getName()).isEqualTo("certifi"); + assertThat(deps.get(1).getName()).isEqualTo("requests"); + } + + @Test + void declaredPackagesAreDirectAndUndeclaredTransitivesAreExcluded() { + ResolvedDependency urllib3 = new ResolvedDependency("urllib3", "2.2.1", null, null); + ResolvedDependency charsetNormalizer = new ResolvedDependency("charset-normalizer", "3.3.2", null, null); + ResolvedDependency certifi = new ResolvedDependency("certifi", "2024.2.2", null, null); + ResolvedDependency requests = new ResolvedDependency("requests", "2.31.0", null, + List.of(certifi, urllib3, charsetNormalizer)); + + List resolved = List.of(urllib3, charsetNormalizer, certifi, requests); + Set declared = RequirementsTxtParser.parseDeclaredPackageNames( + "requests==2.31.0\ncertifi==2024.2.2\n"); + List deps = RequirementsTxtParser.dependenciesFromResolved(resolved, declared); + + assertThat(deps).hasSize(2); + assertThat(deps.stream().map(Dependency::getName)) + .containsExactly("certifi", "requests"); + } + + @Test + void parseDeclaredPackageNamesExtractsNames() { + Set names = RequirementsTxtParser.parseDeclaredPackageNames(""" + # This is a comment + requests>=2.28.0 + certifi==2024.2.2 + charset-normalizer<4,>=2 + Jinja2~=3.1.5 + -r other-requirements.txt + aiohttp==3.13.3 + + langchain-core==1.2.12 + """); + assertThat(names).containsExactlyInAnyOrder( + "requests", "certifi", "charset_normalizer", "jinja2", + "aiohttp", "langchain_core"); + } + @Test void linkDependenciesFromMetadataBuildsGraph(@TempDir Path tempDir) throws IOException { // Create a fake site-packages with METADATA files