Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
public class EntitlementInitialization {

private static final String POLICY_FILE_NAME = "entitlement-policy.yaml";
private static final Module ENTITLEMENTS_MODULE = PolicyManager.class.getModule();

private static ElasticsearchEntitlementChecker manager;

Expand Down Expand Up @@ -92,7 +93,7 @@ private static PolicyManager createPolicyManager() throws IOException {
"server",
List.of(new Scope("org.elasticsearch.server", List.of(new ExitVMEntitlement(), new CreateClassLoaderEntitlement())))
);
return new PolicyManager(serverPolicy, pluginPolicies, EntitlementBootstrap.bootstrapArgs().pluginResolver());
return new PolicyManager(serverPolicy, pluginPolicies, EntitlementBootstrap.bootstrapArgs().pluginResolver(), ENTITLEMENTS_MODULE);
}

private static Map<String, Policy> createPluginPolicies(Collection<EntitlementBootstrap.PluginData> pluginData) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;

import java.lang.StackWalker.StackFrame;
import java.lang.module.ModuleFinder;
import java.lang.module.ModuleReference;
import java.util.ArrayList;
Expand All @@ -29,6 +30,10 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.lang.StackWalker.Option.RETAIN_CLASS_REFERENCE;
import static java.util.Objects.requireNonNull;
import static java.util.function.Predicate.not;

public class PolicyManager {
private static final Logger logger = LogManager.getLogger(ElasticsearchEntitlementChecker.class);

Expand Down Expand Up @@ -63,6 +68,11 @@ public <E extends Entitlement> Stream<E> getEntitlements(Class<E> entitlementCla

private static final Set<Module> systemModules = findSystemModules();

/**
* Frames originating from this module are ignored in the permission logic.
*/
private final Module entitlementsModule;

private static Set<Module> findSystemModules() {
var systemModulesDescriptors = ModuleFinder.ofSystem()
.findAll()
Expand All @@ -77,13 +87,18 @@ private static Set<Module> findSystemModules() {
.collect(Collectors.toUnmodifiableSet());
}

public PolicyManager(Policy defaultPolicy, Map<String, Policy> pluginPolicies, Function<Class<?>, String> pluginResolver) {
this.serverEntitlements = buildScopeEntitlementsMap(Objects.requireNonNull(defaultPolicy));
this.pluginsEntitlements = Objects.requireNonNull(pluginPolicies)
.entrySet()
public PolicyManager(
Policy defaultPolicy,
Map<String, Policy> pluginPolicies,
Function<Class<?>, String> pluginResolver,
Module entitlementsModule
) {
this.serverEntitlements = buildScopeEntitlementsMap(requireNonNull(defaultPolicy));
this.pluginsEntitlements = requireNonNull(pluginPolicies).entrySet()
.stream()
.collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, e -> buildScopeEntitlementsMap(e.getValue())));
this.pluginResolver = pluginResolver;
this.entitlementsModule = entitlementsModule;
}

private static Map<String, List<Entitlement>> buildScopeEntitlementsMap(Policy policy) {
Expand Down Expand Up @@ -185,29 +200,51 @@ private static boolean isServerModule(Module requestingModule) {
return requestingModule.isNamed() && requestingModule.getLayer() == ModuleLayer.boot();
}

private static Module requestingModule(Class<?> callerClass) {
/**
* Walks the stack to determine which module's entitlements should be checked.
*
* @param callerClass when non-null will be used if its module is suitable;
* this is a fast-path check that can avoid the stack walk
* in cases where the caller class is available.
* @return the requesting module, or {@code null} if the entire call stack
* comes from modules that are trusted.
*/
Module requestingModule(Class<?> callerClass) {
if (callerClass != null) {
Module callerModule = callerClass.getModule();
if (systemModules.contains(callerModule) == false) {
// fast path
return callerModule;
}
}
int framesToSkip = 1 // getCallingClass (this method)
+ 1 // the checkXxx method
+ 1 // the runtime config method
+ 1 // the instrumented method
;
Optional<Module> module = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE)
.walk(
s -> s.skip(framesToSkip)
.map(f -> f.getDeclaringClass().getModule())
.filter(m -> systemModules.contains(m) == false)
.findFirst()
);
Optional<Module> module = StackWalker.getInstance(RETAIN_CLASS_REFERENCE)
.walk(frames -> findRequestingModule(frames.map(StackFrame::getDeclaringClass)));
return module.orElse(null);
}

/**
* Given a stream of classes corresponding to the frames from a {@link StackWalker},
* returns the module whose entitlements should be checked.
*
* @throws NullPointerException if the requesting module is {@code null}
*/
Optional<Module> findRequestingModule(Stream<Class<?>> classes) {
return classes.map(Objects::requireNonNull)
.map(PolicyManager::moduleOf)
.filter(m -> m != entitlementsModule) // Ignore the entitlements library itself
.filter(not(systemModules::contains)) // Skip trusted JDK modules
.findFirst();
}

private static Module moduleOf(Class<?> c) {
var result = c.getModule();
if (result == null) {
throw new NullPointerException("Entitlements system does not support non-modular class [" + c.getName() + "]");
} else {
return result;
}
}

private static boolean isTriviallyAllowed(Module requestingModule) {
if (requestingModule == null) {
logger.debug("Entitlement trivially allowed: entire call stack is in composed of classes in system modules");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;

import static java.util.Map.entry;
import static org.elasticsearch.entitlement.runtime.policy.PolicyManager.ALL_UNNAMED;
Expand All @@ -37,11 +38,14 @@
@ESTestCase.WithoutSecurityManager
public class PolicyManagerTests extends ESTestCase {

private static final Module NO_ENTITLEMENTS_MODULE = null;

public void testGetEntitlementsThrowsOnMissingPluginUnnamedModule() {
var policyManager = new PolicyManager(
createEmptyTestServerPolicy(),
Map.of("plugin1", createPluginPolicy("plugin.module")),
c -> "plugin1"
c -> "plugin1",
NO_ENTITLEMENTS_MODULE
);

// Any class from the current module (unnamed) will do
Expand All @@ -62,7 +66,7 @@ public void testGetEntitlementsThrowsOnMissingPluginUnnamedModule() {
}

public void testGetEntitlementsThrowsOnMissingPolicyForPlugin() {
var policyManager = new PolicyManager(createEmptyTestServerPolicy(), Map.of(), c -> "plugin1");
var policyManager = new PolicyManager(createEmptyTestServerPolicy(), Map.of(), c -> "plugin1", NO_ENTITLEMENTS_MODULE);

// Any class from the current module (unnamed) will do
var callerClass = this.getClass();
Expand All @@ -82,7 +86,7 @@ public void testGetEntitlementsThrowsOnMissingPolicyForPlugin() {
}

public void testGetEntitlementsFailureIsCached() {
var policyManager = new PolicyManager(createEmptyTestServerPolicy(), Map.of(), c -> "plugin1");
var policyManager = new PolicyManager(createEmptyTestServerPolicy(), Map.of(), c -> "plugin1", NO_ENTITLEMENTS_MODULE);

// Any class from the current module (unnamed) will do
var callerClass = this.getClass();
Expand All @@ -103,7 +107,8 @@ public void testGetEntitlementsReturnsEntitlementsForPluginUnnamedModule() {
var policyManager = new PolicyManager(
createEmptyTestServerPolicy(),
Map.ofEntries(entry("plugin2", createPluginPolicy(ALL_UNNAMED))),
c -> "plugin2"
c -> "plugin2",
NO_ENTITLEMENTS_MODULE
);

// Any class from the current module (unnamed) will do
Expand All @@ -115,7 +120,7 @@ public void testGetEntitlementsReturnsEntitlementsForPluginUnnamedModule() {
}

public void testGetEntitlementsThrowsOnMissingPolicyForServer() throws ClassNotFoundException {
var policyManager = new PolicyManager(createTestServerPolicy("example"), Map.of(), c -> null);
var policyManager = new PolicyManager(createTestServerPolicy("example"), Map.of(), c -> null, NO_ENTITLEMENTS_MODULE);

// Tests do not run modular, so we cannot use a server class.
// But we know that in production code the server module and its classes are in the boot layer.
Expand All @@ -138,7 +143,7 @@ public void testGetEntitlementsThrowsOnMissingPolicyForServer() throws ClassNotF
}

public void testGetEntitlementsReturnsEntitlementsForServerModule() throws ClassNotFoundException {
var policyManager = new PolicyManager(createTestServerPolicy("jdk.httpserver"), Map.of(), c -> null);
var policyManager = new PolicyManager(createTestServerPolicy("jdk.httpserver"), Map.of(), c -> null, NO_ENTITLEMENTS_MODULE);

// Tests do not run modular, so we cannot use a server class.
// But we know that in production code the server module and its classes are in the boot layer.
Expand All @@ -155,12 +160,13 @@ public void testGetEntitlementsReturnsEntitlementsForServerModule() throws Class
public void testGetEntitlementsReturnsEntitlementsForPluginModule() throws IOException, ClassNotFoundException {
final Path home = createTempDir();

Path jar = creteMockPluginJar(home);
Path jar = createMockPluginJar(home);

var policyManager = new PolicyManager(
createEmptyTestServerPolicy(),
Map.of("mock-plugin", createPluginPolicy("org.example.plugin")),
c -> "mock-plugin"
c -> "mock-plugin",
NO_ENTITLEMENTS_MODULE
);

var layer = createLayerForJar(jar, "org.example.plugin");
Expand All @@ -179,7 +185,8 @@ public void testGetEntitlementsResultIsCached() {
var policyManager = new PolicyManager(
createEmptyTestServerPolicy(),
Map.ofEntries(entry("plugin2", createPluginPolicy(ALL_UNNAMED))),
c -> "plugin2"
c -> "plugin2",
NO_ENTITLEMENTS_MODULE
);

// Any class from the current module (unnamed) will do
Expand All @@ -197,6 +204,73 @@ public void testGetEntitlementsResultIsCached() {
assertThat(entitlementsAgain, sameInstance(cachedResult));
}

public void testRequestingModuleFastPath() throws IOException, ClassNotFoundException {
var callerClass = makeClassInItsOwnModule();
assertEquals(callerClass.getModule(), policyManagerWithEntitlementsModule(NO_ENTITLEMENTS_MODULE).requestingModule(callerClass));
}

public void testRequestingModuleWithStackWalk() throws IOException, ClassNotFoundException {
var requestingClass = makeClassInItsOwnModule();
var runtimeClass = makeClassInItsOwnModule(); // A class in the entitlements library itself
var ignorableClass = makeClassInItsOwnModule();
var systemClass = Object.class;

var policyManager = policyManagerWithEntitlementsModule(runtimeClass.getModule());

var requestingModule = requestingClass.getModule();

assertEquals(
"Skip one system frame",
requestingModule,
policyManager.findRequestingModule(Stream.of(systemClass, requestingClass, ignorableClass)).orElse(null)
);
assertEquals(
"Skip multiple system frames",
requestingModule,
policyManager.findRequestingModule(Stream.of(systemClass, systemClass, systemClass, requestingClass, ignorableClass))
.orElse(null)
);
assertEquals(
"Skip system frame between runtime frames",
requestingModule,
policyManager.findRequestingModule(Stream.of(runtimeClass, systemClass, runtimeClass, requestingClass, ignorableClass))
.orElse(null)
);
assertEquals(
"Skip runtime frame between system frames",
requestingModule,
policyManager.findRequestingModule(Stream.of(systemClass, runtimeClass, systemClass, requestingClass, ignorableClass))
.orElse(null)
);
assertEquals(
"No system frames",
requestingModule,
policyManager.findRequestingModule(Stream.of(requestingClass, ignorableClass)).orElse(null)
);
assertEquals(
"Skip runtime frames up to the first system frame",
requestingModule,
policyManager.findRequestingModule(Stream.of(runtimeClass, runtimeClass, systemClass, requestingClass, ignorableClass))
.orElse(null)
);
assertThrows(
"Non-modular caller frames are not supported",
NullPointerException.class,
() -> policyManager.findRequestingModule(Stream.of(systemClass, null))
);
}

private static Class<?> makeClassInItsOwnModule() throws IOException, ClassNotFoundException {
final Path home = createTempDir();
Path jar = createMockPluginJar(home);
var layer = createLayerForJar(jar, "org.example.plugin");
return layer.findLoader("org.example.plugin").loadClass("q.B");
}

private static PolicyManager policyManagerWithEntitlementsModule(Module entitlementsModule) {
return new PolicyManager(createEmptyTestServerPolicy(), Map.of(), c -> "test", entitlementsModule);
}

private static Policy createEmptyTestServerPolicy() {
return new Policy("server", List.of());
}
Expand All @@ -219,7 +293,7 @@ private static Policy createPluginPolicy(String... pluginModules) {
);
}

private static Path creteMockPluginJar(Path home) throws IOException {
private static Path createMockPluginJar(Path home) throws IOException {
Path jar = home.resolve("mock-plugin.jar");

Map<String, CharSequence> sources = Map.ofEntries(
Expand Down