diff --git a/src/main/java/graphql/GraphQLContext.java b/src/main/java/graphql/GraphQLContext.java index 081c17725f..8b913919d3 100644 --- a/src/main/java/graphql/GraphQLContext.java +++ b/src/main/java/graphql/GraphQLContext.java @@ -5,7 +5,9 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Function; import java.util.stream.Stream; import static graphql.Assert.assertNotNull; @@ -171,6 +173,52 @@ public GraphQLContext putAll(Consumer contextBuilderCons return putAll(builder); } + /** + * Attempts to compute a mapping for the specified key and its + * current mapped value (or null if there is no current mapping). + * + * @param key key with which the specified value is to be associated + * @param remappingFunction the function to compute a value + * + * @return the new value associated with the specified key, or null if none + * @param for two + */ + public T compute(Object key, BiFunction remappingFunction) { + assertNotNull(remappingFunction); + return (T) map.compute(assertNotNull(key), (k, v) -> remappingFunction.apply(k, (T) v)); + } + + /** + * If the specified key is not already associated with a value (or is mapped to null), + * attempts to compute its value using the given mapping function and enters it into this map unless null. + * + * @param key key with which the specified value is to be associated + * @param mappingFunction the function to compute a value + * + * @return the current (existing or computed) value associated with the specified key, or null if the computed value is null + * @param for two + */ + + public T computeIfAbsent(Object key, Function mappingFunction) { + return (T) map.computeIfAbsent(assertNotNull(key), assertNotNull(mappingFunction)); + } + + /** + * If the value for the specified key is present and non-null, + * attempts to compute a new mapping given the key and its current mapped value. + * + * @param key key with which the specified value is to be associated + * @param remappingFunction the function to compute a value + * + * @return the new value associated with the specified key, or null if none + * @param for two + */ + + public T computeIfPresent(Object key, BiFunction remappingFunction) { + assertNotNull(remappingFunction); + return (T) map.computeIfPresent(assertNotNull(key), (k, v) -> remappingFunction.apply(k, (T) v)); + } + /** * @return a stream of entries in the context */ diff --git a/src/test/groovy/graphql/GraphQLContextTest.groovy b/src/test/groovy/graphql/GraphQLContextTest.groovy index f409721363..8eebb17653 100644 --- a/src/test/groovy/graphql/GraphQLContextTest.groovy +++ b/src/test/groovy/graphql/GraphQLContextTest.groovy @@ -168,6 +168,52 @@ class GraphQLContextTest extends Specification { !context.hasKey("k3") } + def "compute works"() { + def context + when: + context = buildContext([k1: "foo"]) + then: + context.compute("k1", (k, v) -> v ? v + "bar" : "default") == "foobar" + context.get("k1") == "foobar" + context.compute("k2", (k, v) -> v ? "new" : "default") == "default" + context.get("k2") == "default" + !context.compute("k3", (k, v) -> null) + !context.hasKey("k3") + sizeOf(context) == 2 + } + + def "computeIfAbsent works"() { + def context + when: + context = buildContext([k1: "v1", k2: "v2"]) + then: + context.computeIfAbsent("k1", k -> "default") == "v1" + context.get("k1") == "v1" + context.computeIfAbsent("k2", k -> null) == "v2" + context.get("k2") == "v2" + context.computeIfAbsent("k3", k -> "default") == "default" + context.get("k3") == "default" + !context.computeIfAbsent("k4", k -> null) + !context.hasKey("k4") + sizeOf(context) == 3 + } + + def "computeIfPresent works"() { + def context + when: + context = buildContext([k1: "foo", k2: "v2"]) + then: + context.computeIfPresent("k1", (k, v) -> v + "bar") == "foobar" + context.get("k1") == "foobar" + !context.computeIfPresent("k2", (k, v) -> null) + !context.hasKey("k2") + !context.computeIfPresent("k3", (k, v) -> v + "bar") + !context.hasKey("k3") + !context.computeIfPresent("k4", (k, v) -> null) + !context.hasKey("k4") + sizeOf(context) == 1 + } + def "getOrDefault works"() { def context when: