diff --git a/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/JUnit5Context.java b/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/JUnit5Context.java index 069d0dc780b..bd6d44942dc 100644 --- a/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/JUnit5Context.java +++ b/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/JUnit5Context.java @@ -10,56 +10,103 @@ package org.junit.gen5.engine.junit5; -import java.util.HashMap; -import java.util.Map; - import org.junit.gen5.engine.Context; import org.junit.gen5.engine.junit5.execution.TestExtensionRegistry; public class JUnit5Context implements Context { - private final Map map; + private final State state; public JUnit5Context() { - this(new HashMap<>()); + this(new State()); } - private JUnit5Context(Map map) { - this.map = map; + private JUnit5Context(State state) { + this.state = state; } - public JUnit5Context withTestInstanceProvider(TestInstanceProvider testInstanceProvider) { - return with(TestInstanceProvider.class.getName(), testInstanceProvider); + public TestInstanceProvider getTestInstanceProvider() { + return state.testInstanceProvider; } - public TestInstanceProvider getTestInstanceProvider() { - return get(TestInstanceProvider.class.getName(), TestInstanceProvider.class); + public BeforeEachCallback getBeforeEachCallback() { + return state.beforeEachCallback; } - public JUnit5Context withBeforeEachCallback(BeforeEachCallback beforeEachCallback) { - return with("beforeEachCallback", beforeEachCallback); + public TestExtensionRegistry getTestExtensionRegistry() { + return state.testExtensionRegistry; } - public BeforeEachCallback getBeforeEachCallback() { - return get("beforeEachCallback", BeforeEachCallback.class); + public Builder extend() { + return builder(this); } - public JUnit5Context withTestExtensionRegistry(TestExtensionRegistry testExtensionRegistry) { - return with("testExtensionRegistry", testExtensionRegistry); + public static Builder builder() { + return new Builder(null, new State()); } - public TestExtensionRegistry getTestExtensionRegistry() { - return get("testExtensionRegistry", TestExtensionRegistry.class); + public static Builder builder(JUnit5Context context) { + return new Builder(context.state, null); } - private JUnit5Context with(String key, Object value) { - Map newMap = new HashMap<>(map); - newMap.put(key, value); - return new JUnit5Context(newMap); + private static final class State implements Cloneable { + + TestInstanceProvider testInstanceProvider; + BeforeEachCallback beforeEachCallback; + TestExtensionRegistry testExtensionRegistry; + + @Override + public State clone() { + try { + return (State) super.clone(); + } + catch (CloneNotSupportedException e) { + throw new RuntimeException("State could not be cloned", e); + } + } + } - private T get(String key, Class clazz) { - return clazz.cast(map.get(key)); + public static class Builder { + + private State originalState; + private State newState; + + private Builder(State originalState, State state) { + this.originalState = originalState; + this.newState = state; + } + + public Builder withTestInstanceProvider(TestInstanceProvider testInstanceProvider) { + newState().testInstanceProvider = testInstanceProvider; + return this; + } + + public Builder withBeforeEachCallback(BeforeEachCallback beforeEachCallback) { + newState().beforeEachCallback = beforeEachCallback; + return this; + } + + public Builder withTestExtensionRegistry(TestExtensionRegistry testExtensionRegistry) { + newState().testExtensionRegistry = testExtensionRegistry; + return this; + } + + public JUnit5Context build() { + if (newState != null) { + originalState = newState; + newState = null; + } + return new JUnit5Context(originalState); + } + + private State newState() { + if (newState == null) { + this.newState = originalState.clone(); + } + return newState; + } + } } diff --git a/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/ClassTestDescriptor.java b/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/ClassTestDescriptor.java index 20875dfd6f7..5ffff1347e2 100644 --- a/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/ClassTestDescriptor.java +++ b/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/ClassTestDescriptor.java @@ -79,10 +79,11 @@ public boolean isContainer() { @Override public JUnit5Context beforeAll(JUnit5Context context) { // @formatter:off - return context + return context.extend() .withTestInstanceProvider(testInstanceProvider(context)) .withBeforeEachCallback(beforeEachCallback(context)) - .withTestExtensionRegistry(populateNewTestExtensionRegistryFromExtendWith(testClass, context.getTestExtensionRegistry())); + .withTestExtensionRegistry(populateNewTestExtensionRegistryFromExtendWith(testClass, context.getTestExtensionRegistry())) + .build(); // @formatter:on } diff --git a/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/JUnit5EngineDescriptor.java b/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/JUnit5EngineDescriptor.java index 0070ef9e565..b8b843ade5a 100644 --- a/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/JUnit5EngineDescriptor.java +++ b/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/JUnit5EngineDescriptor.java @@ -24,7 +24,7 @@ public JUnit5EngineDescriptor(TestEngine engine) { @Override public JUnit5Context beforeAll(JUnit5Context context) { - return context.withTestExtensionRegistry(new TestExtensionRegistry()); + return context.extend().withTestExtensionRegistry(new TestExtensionRegistry()).build(); } } diff --git a/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/MethodTestDescriptor.java b/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/MethodTestDescriptor.java index 840b87c0c23..add91f4c903 100644 --- a/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/MethodTestDescriptor.java +++ b/junit5-engine/src/main/java/org/junit/gen5/engine/junit5/descriptor/MethodTestDescriptor.java @@ -87,8 +87,8 @@ public boolean isContainer() { @Override public JUnit5Context execute(JUnit5Context context) throws Throwable { - JUnit5Context myContext = context.withTestExtensionRegistry( - populateNewTestExtensionRegistryFromExtendWith(testMethod, context.getTestExtensionRegistry())); + JUnit5Context myContext = context.extend().withTestExtensionRegistry( + populateNewTestExtensionRegistryFromExtendWith(testMethod, context.getTestExtensionRegistry())).build(); TestInstanceProvider provider = context.getTestInstanceProvider(); Object testInstance = provider.getTestInstance();