Skip to content

Commit

Permalink
Allow certain native functions in initializers
Browse files Browse the repository at this point in the history
RELNOTES: The following functions on `native` can now be used in rule
initializers: `repo_name`, `package_name`, `package_relative_label`,
`module_name`, `module_version`.
  • Loading branch information
fmeum committed Mar 20, 2024
1 parent 75a2b65 commit 0670b11
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/main/java/com/google/devtools/build/lib/analysis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/packages",
"//src/main/java/com/google/devtools/build/lib/packages:configured_attribute_mapper",
"//src/main/java/com/google/devtools/build/lib/packages:exec_group",
"//src/main/java/com/google/devtools/build/lib/packages:initializer_starlark_context",
"//src/main/java/com/google/devtools/build/lib/packages/semantics",
"//src/main/java/com/google/devtools/build/lib/profiler",
"//src/main/java/com/google/devtools/build/lib/profiler:google-auto-profiler-utils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import com.google.devtools.build.lib.packages.FunctionSplitTransitionAllowlist;
import com.google.devtools.build.lib.packages.ImplicitOutputsFunction.StarlarkImplicitOutputsFunctionWithCallback;
import com.google.devtools.build.lib.packages.ImplicitOutputsFunction.StarlarkImplicitOutputsFunctionWithMap;
import com.google.devtools.build.lib.packages.InitializerStarlarkContext;
import com.google.devtools.build.lib.packages.LabelConverter;
import com.google.devtools.build.lib.packages.MacroClass;
import com.google.devtools.build.lib.packages.MacroInstance;
Expand Down Expand Up @@ -1213,14 +1214,16 @@ public Object call(StarlarkThread thread, Tuple args, Dict<String, Object> kwarg

ImmutableSet<String> legacyAnyTypeAttrs = getLegacyAnyTypeAttrs(ruleClass);

// Remove {@link BazelStarlarkContext} to prevent calls to load and analysis time functions.
// Use a special BazelStarlarkContext to only allow certain functions on the native module.
// Mutating values in initializers is mostly not a problem, because the attribute values are
// copied before calling the initializers (<-TODO) and before they are set on the target.
// Exception is a legacy case allowing arbitrary type of parameter values. In that case the
// values may be mutated by the initializer, but they are still copied when set on the target.
BazelStarlarkContext bazelStarlarkContext = BazelStarlarkContext.fromOrFail(thread);
try {
thread.setThreadLocal(BazelStarlarkContext.class, null);
thread.setThreadLocal(
BazelStarlarkContext.class,
new InitializerStarlarkContext(pkgBuilder));
thread.setUncheckedExceptionContext(() -> "an initializer");

// We call all the initializers of the rule and its ancestor rules, proceeding from child to
Expand Down Expand Up @@ -1308,6 +1311,7 @@ public Object call(StarlarkThread thread, Tuple args, Dict<String, Object> kwarg
}
}
} finally {
thread.setThreadLocal(BazelStarlarkContext.class, null);
bazelStarlarkContext.storeInThread(thread);
}

Expand Down
10 changes: 10 additions & 0 deletions src/main/java/com/google/devtools/build/lib/packages/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ java_library(
"ExecGroup.java",
"ConfiguredAttributeMapper.java",
"LabelPrinter.java",
"InitializerStarlarkContext.java",
],
),
deps = [
Expand Down Expand Up @@ -146,3 +147,12 @@ java_library(
"//src/main/java/net/starlark/java/eval",
],
)

java_library(
name = "initializer_starlark_context",
srcs = ["InitializerStarlarkContext.java"],
deps = [
":packages",
"//src/main/java/com/google/devtools/build/lib/cmdline",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ public class BazelStarlarkContext implements StarlarkThread.UncheckedExceptionCo
public enum Phase {
WORKSPACE,
LOADING,
/** Evaluation for a rule initializer, which does not allow all loading phase operations. */
INITIALIZER,
ANALYSIS
}

Expand Down Expand Up @@ -167,6 +169,26 @@ public static void checkLoadingPhase(StarlarkThread thread, String function)
}
}

/**
* Checks that the current StarlarkThread is in the loading phase or in a rule initializer.
*
* @param function name of a function that requires this check
*/
public static void checkLoadingPhaseOrInitializer(StarlarkThread thread, String function)
throws EvalException {
BazelStarlarkContext ctx = thread.getThreadLocal(BazelStarlarkContext.class);
if (ctx == null) {
throw Starlark.errorf(
"'%s' cannot be called from %s", function, thread.getContextDescription());
}
if (ctx.phase != Phase.LOADING && ctx.phase != Phase.INITIALIZER) {
throw Starlark.errorf(
"'%s' can only be called from a BUILD file, a macro invoked from a BUILD file, or a"
+ " rule initializer",
function);
}
}

/**
* Checks that the current StarlarkThread is in the workspace phase.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.google.devtools.build.lib.packages;

import com.google.devtools.build.lib.cmdline.PackageIdentifier;
import com.google.devtools.build.lib.packages.LabelConverter;
import com.google.devtools.build.lib.packages.Package;
import com.google.devtools.build.lib.packages.SymbolGenerator;
import com.google.devtools.build.lib.packages.TargetDefinitionContext;

import java.util.Optional;

public final class InitializerStarlarkContext extends TargetDefinitionContext {
private final Package.Builder pkgBuilder;

public InitializerStarlarkContext(Package.Builder pkgBuilder) {
super(Phase.INITIALIZER, new SymbolGenerator<>(new Object()));
this.pkgBuilder = pkgBuilder;
}

@Override
PackageIdentifier getPackageIdentifier() {
return pkgBuilder.getPackageIdentifier();
}

@Override
Optional<String> getAssociatedModuleName() {
return pkgBuilder.getAssociatedModuleName();
}

@Override
Optional<String> getAssociatedModuleVersion() {
return pkgBuilder.getAssociatedModuleVersion();
}

@Override
public LabelConverter getLabelConverter() {
return pkgBuilder.getLabelConverter();
}
}
14 changes: 4 additions & 10 deletions src/main/java/com/google/devtools/build/lib/packages/Package.java
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ public static Builder fromOrFail(StarlarkThread thread, String what) throws Eval
return (Builder) ctx;
}

@Override
PackageIdentifier getPackageIdentifier() {
return pkg.getPackageIdentifier();
}
Expand All @@ -1146,20 +1147,12 @@ String getPackageWorkspaceName() {
return pkg.getWorkspaceName();
}

/**
* Returns the name of the Bzlmod module associated with the repo this package is in. If this
* package is not from a Bzlmod repo, this is empty. For repos generated by module extensions,
* this is the name of the module hosting the extension.
*/
@Override
Optional<String> getAssociatedModuleName() {
return pkg.metadata.associatedModuleName;
}

/**
* Returns the version of the Bzlmod module associated with the repo this package is in. If this
* package is not from a Bzlmod repo, this is empty. For repos generated by module extensions,
* this is the version of the module hosting the extension.
*/
@Override
Optional<String> getAssociatedModuleVersion() {
return pkg.metadata.associatedModuleVersion;
}
Expand Down Expand Up @@ -1201,6 +1194,7 @@ Builder addRepositoryMappings(Package aPackage) {
return this;
}

@Override
public LabelConverter getLabelConverter() {
return labelConverter;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,8 @@ public NoneType exportsFiles(

@Override
public String packageName(StarlarkThread thread) throws EvalException {
BazelStarlarkContext.checkLoadingPhase(thread, "native.package_name");
PackageIdentifier packageId = getContext(thread).getPackageIdentifier();
BazelStarlarkContext.checkLoadingPhaseOrInitializer(thread, "native.package_name");
PackageIdentifier packageId = getTargetDefinitionContext(thread).getPackageIdentifier();
return packageId.getPackageFragment().getPathString();
}

Expand All @@ -622,19 +622,19 @@ public String repositoryName(StarlarkThread thread) throws EvalException {

@Override
public String repoName(StarlarkThread thread) throws EvalException {
BazelStarlarkContext.checkLoadingPhase(thread, "native.repo_name");
return getContext(thread).getPackageIdentifier().getRepository().getName();
BazelStarlarkContext.checkLoadingPhaseOrInitializer(thread, "native.repo_name");
return getTargetDefinitionContext(thread).getPackageIdentifier().getRepository().getName();
}

@Override
public Label packageRelativeLabel(Object input, StarlarkThread thread) throws EvalException {
BazelStarlarkContext.checkLoadingPhase(thread, "native.package_relative_label");
BazelStarlarkContext.checkLoadingPhaseOrInitializer(thread, "native.package_relative_label");
if (input instanceof Label) {
return (Label) input;
}
try {
String s = (String) input;
return getContext(thread).getLabelConverter().convert(s);
return getTargetDefinitionContext(thread).getLabelConverter().convert(s);
} catch (LabelSyntaxException e) {
throw Starlark.errorf("invalid label in native.package_relative_label: %s", e.getMessage());
}
Expand All @@ -643,15 +643,15 @@ public Label packageRelativeLabel(Object input, StarlarkThread thread) throws Ev
@Override
@Nullable
public String moduleName(StarlarkThread thread) throws EvalException {
BazelStarlarkContext.checkLoadingPhase(thread, "native.module_name");
return getContext(thread).getAssociatedModuleName().orElse(null);
BazelStarlarkContext.checkLoadingPhaseOrInitializer(thread, "native.module_name");
return getTargetDefinitionContext(thread).getAssociatedModuleName().orElse(null);
}

@Override
@Nullable
public String moduleVersion(StarlarkThread thread) throws EvalException {
BazelStarlarkContext.checkLoadingPhase(thread, "native.module_version");
return getContext(thread).getAssociatedModuleVersion().orElse(null);
BazelStarlarkContext.checkLoadingPhaseOrInitializer(thread, "native.module_version");
return getTargetDefinitionContext(thread).getAssociatedModuleVersion().orElse(null);
}

private static Dict<String, Object> getRuleDict(Rule rule, Mutability mu) throws EvalException {
Expand Down Expand Up @@ -902,4 +902,17 @@ private List<String> runGlobOperation(
}
}
}

private static TargetDefinitionContext getTargetDefinitionContext(StarlarkThread thread)
throws EvalException {
var value = TargetDefinitionContext.fromOrNull(thread);
if (value == null) {
// if TargetDefinitionContext is missing, we're not called from a BUILD file. This happens if
// someone uses native.some_func() in the wrong place.
throw Starlark.errorf(
"The native module can be accessed only from a BUILD thread. "
+ "Wrap the function in a macro and call it from a BUILD file");
}
return value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@

package com.google.devtools.build.lib.packages;

import com.google.devtools.build.lib.cmdline.PackageIdentifier;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import javax.annotation.Nullable;
import net.starlark.java.eval.EvalException;
import net.starlark.java.eval.Starlark;
import net.starlark.java.eval.StarlarkThread;

import java.util.Optional;

/**
* A context object, usually stored in a {@link StarlarkThread}, upon which rules and symbolic
* macros can be instantiated.
Expand Down Expand Up @@ -65,4 +68,22 @@ public static TargetDefinitionContext fromOrFail(StarlarkThread thread, String w
}
return (TargetDefinitionContext) ctx;
}

public abstract LabelConverter getLabelConverter();

abstract PackageIdentifier getPackageIdentifier();

/**
* Returns the name of the Bzlmod module associated with the repo this target is in. If this
* package is not from a Bzlmod repo, this is empty. For repos generated by module extensions,
* this is the name of the module hosting the extension.
*/
abstract Optional<String> getAssociatedModuleName();

/**
* Returns the version of the Bzlmod module associated with the repo this target is in. If this
* package is not from a Bzlmod repo, this is empty. For repos generated by module extensions,
* this is the version of the module hosting the extension.
*/
abstract Optional<String> getAssociatedModuleVersion();
}
Original file line number Diff line number Diff line change
Expand Up @@ -3592,6 +3592,53 @@ public void initializer_withFails() throws Exception {
ev.assertContainsError("target 'my_target' not declared in package 'initializer_testing'");
}

@Test
@SuppressWarnings("unchecked")
public void initializer_nativeModule() throws Exception {
scratch.appendFile("MODULE.bazel", "module(name = 'my_mod', version = '1.2.3')");
scratch.file(
"initializer_testing/b.bzl",
"MyInfo = provider()",
"def initializer(name, **kwargs):",
" return {'props': {",
" 'module': native.module_name() + '@' + native.module_version(),",
" 'repo_name': native.repo_name(),",
" 'package_name': native.package_name(),",
" 'package_relative_label': str(native.package_relative_label(':target')),",
" }}",
"def impl(ctx): ",
" return [MyInfo(props = ctx.attr.props)]",
"my_rule = rule(impl,",
" initializer = initializer,",
" attrs = {",
" 'props': attr.string_dict(),",
" })");
scratch.file(
"initializer_testing/BUILD", //
"load(':b.bzl','my_rule')",
"my_rule(name = 'my_target')");

invalidatePackages();

ConfiguredTarget myTarget = getConfiguredTarget("//initializer_testing:my_target");
StructImpl info =
(StructImpl)
myTarget.get(
new StarlarkProvider.Key(
Label.parseCanonical("//initializer_testing:b.bzl"), "MyInfo"));

assertThat((Map<String, String>) info.getValue("props"))
.containsExactly(
"module",
"my_mod@1.2.3",
"repo_name",
"",
"package_name",
"initializer_testing",
"package_relative_label",
"@@//initializer_testing:target");
}

private void scratchParentRule(String rule, String... ruleArgs) throws IOException {
scratch.file("extend_rule_testing/parent/BUILD");
scratch.file(
Expand Down

0 comments on commit 0670b11

Please sign in to comment.