Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dropping some CompletableFuture allocations #3233

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 19 additions & 26 deletions src/main/java/graphql/execution/Async.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;

@Internal
Expand Down Expand Up @@ -56,7 +57,16 @@ public void add(CompletableFuture<T> completableFuture) {
@Override
public CompletableFuture<List<T>> await() {
Assert.assertTrue(ix == 0, () -> "expected size was " + 0 + " got " + ix);
return CompletableFuture.completedFuture(Collections.emptyList());
return typedEmpty();
}


// implementation details: infer the type of Completable<List<T>> from a singleton empty
private static final CompletableFuture<List<?>> EMPTY = CompletableFuture.completedFuture(Collections.emptyList());

@SuppressWarnings("unchecked")
private static <T> CompletableFuture<T> typedEmpty() {
return (CompletableFuture<T>) EMPTY;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty good pattern for CFs that are really a static value - eg one that is completed and the same value

}

Expand All @@ -75,18 +85,7 @@ public void add(CompletableFuture<T> completableFuture) {
@Override
public CompletableFuture<List<T>> await() {
Assert.assertTrue(ix == 1, () -> "expected size was " + 1 + " got " + ix);

CompletableFuture<List<T>> overallResult = new CompletableFuture<>();
completableFuture
.whenComplete((ignored, exception) -> {
if (exception != null) {
overallResult.completeExceptionally(exception);
return;
}
List<T> results = Collections.singletonList(completableFuture.join());
overallResult.complete(results);
});
return overallResult;
return completableFuture.thenApply(Collections::singletonList);
}
}

Expand Down Expand Up @@ -128,18 +127,12 @@ public CompletableFuture<List<T>> await() {

}

@FunctionalInterface
public interface CFFactory<T, U> {
CompletableFuture<U> apply(T input, int index, List<U> previousResults);
}

public static <T, U> CompletableFuture<List<U>> each(Collection<T> list, BiFunction<T, Integer, CompletableFuture<U>> cfFactory) {
public static <T, U> CompletableFuture<List<U>> each(Collection<T> list, Function<T, CompletableFuture<U>> cfFactory) {
CombinedBuilder<U> futures = ofExpectedSize(list.size());
int index = 0;
for (T t : list) {
CompletableFuture<U> cf;
try {
cf = cfFactory.apply(t, index++);
cf = cfFactory.apply(t);
Assert.assertNotNull(cf, () -> "cfFactory must return a non null value");
} catch (Exception e) {
cf = new CompletableFuture<>();
Expand All @@ -151,20 +144,20 @@ public static <T, U> CompletableFuture<List<U>> each(Collection<T> list, BiFunct
return futures.await();
}

public static <T, U> CompletableFuture<List<U>> eachSequentially(Iterable<T> list, CFFactory<T, U> cfFactory) {
public static <T, U> CompletableFuture<List<U>> eachSequentially(Iterable<T> list, BiFunction<T, List<U>, CompletableFuture<U>> cfFactory) {
CompletableFuture<List<U>> result = new CompletableFuture<>();
eachSequentiallyImpl(list.iterator(), cfFactory, 0, new ArrayList<>(), result);
eachSequentiallyImpl(list.iterator(), cfFactory, new ArrayList<>(), result);
return result;
}

private static <T, U> void eachSequentiallyImpl(Iterator<T> iterator, CFFactory<T, U> cfFactory, int index, List<U> tmpResult, CompletableFuture<List<U>> overallResult) {
private static <T, U> void eachSequentiallyImpl(Iterator<T> iterator, BiFunction<T, List<U>, CompletableFuture<U>> cfFactory, List<U> tmpResult, CompletableFuture<List<U>> overallResult) {
if (!iterator.hasNext()) {
overallResult.complete(tmpResult);
return;
}
CompletableFuture<U> cf;
try {
cf = cfFactory.apply(iterator.next(), index, tmpResult);
cf = cfFactory.apply(iterator.next(), tmpResult);
Assert.assertNotNull(cf, () -> "cfFactory must return a non null value");
} catch (Exception e) {
cf = new CompletableFuture<>();
Expand All @@ -176,7 +169,7 @@ private static <T, U> void eachSequentiallyImpl(Iterator<T> iterator, CFFactory<
return;
}
tmpResult.add(cfResult);
eachSequentiallyImpl(iterator, cfFactory, index + 1, tmpResult, overallResult);
eachSequentiallyImpl(iterator, cfFactory, tmpResult, overallResult);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public CompletableFuture<ExecutionResult> execute(ExecutionContext executionCont
MergedSelectionSet fields = parameters.getFields();
ImmutableList<String> fieldNames = ImmutableList.copyOf(fields.keySet());

CompletableFuture<List<ExecutionResult>> resultsFuture = Async.eachSequentially(fieldNames, (fieldName, index, prevResults) -> {
CompletableFuture<List<ExecutionResult>> resultsFuture = Async.eachSequentially(fieldNames, (fieldName, prevResults) -> {
MergedField currentField = fields.getSubField(fieldName);
ResultPath fieldPath = parameters.getPath().segment(mkNameForPath(currentField));
ExecutionStrategyParameters newParameters = parameters
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/graphql/execution/ExecutionStrategy.java
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ protected FieldValueInfo completeValueForList(ExecutionContext executionContext,
index++;
}

CompletableFuture<List<ExecutionResult>> resultsFuture = Async.each(fieldValueInfos, (item, i) -> item.getFieldValue());
CompletableFuture<List<ExecutionResult>> resultsFuture = Async.each(fieldValueInfos, FieldValueInfo::getFieldValue);

CompletableFuture<ExecutionResult> overallResult = new CompletableFuture<>();
completeListCtx.onDispatched(overallResult);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ public CompletableFuture<ExecutionResult> instrumentExecutionResult(ExecutionRes
@NotNull
@Override
public CompletableFuture<ExecutionResult> instrumentExecutionResult(ExecutionResult executionResult, InstrumentationExecutionParameters parameters, InstrumentationState state) {
CompletableFuture<List<ExecutionResult>> resultsFuture = Async.eachSequentially(instrumentations, (instrumentation, index, prevResults) -> {
CompletableFuture<List<ExecutionResult>> resultsFuture = Async.eachSequentially(instrumentations, (instrumentation, prevResults) -> {
InstrumentationState specificState = getSpecificState(instrumentation, state);
ExecutionResult lastResult = prevResults.size() > 0 ? prevResults.get(prevResults.size() - 1) : executionResult;
return instrumentation.instrumentExecutionResult(lastResult, parameters, specificState);
Expand Down
38 changes: 19 additions & 19 deletions src/test/groovy/graphql/execution/AsyncTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import spock.lang.Specification

import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionException
import java.util.function.Function
import java.util.function.BiFunction

import static java.util.concurrent.CompletableFuture.completedFuture
Expand All @@ -13,7 +14,7 @@ class AsyncTest extends Specification {
def "eachSequentially test"() {
given:
def input = ['a', 'b', 'c']
def cfFactory = Mock(Async.CFFactory)
def cfFactory = Mock(BiFunction)
def cf1 = new CompletableFuture()
def cf2 = new CompletableFuture()
def cf3 = new CompletableFuture()
Expand All @@ -23,21 +24,21 @@ class AsyncTest extends Specification {

then:
!result.isDone()
1 * cfFactory.apply('a', 0, []) >> cf1
1 * cfFactory.apply('a', []) >> cf1

when:
cf1.complete('x')

then:
!result.isDone()
1 * cfFactory.apply('b', 1, ['x']) >> cf2
1 * cfFactory.apply('b', ['x']) >> cf2

when:
cf2.complete('y')

then:
!result.isDone()
1 * cfFactory.apply('c', 2, ['x', 'y']) >> cf3
1 * cfFactory.apply('c', ['x', 'y']) >> cf3

when:
cf3.complete('z')
Expand All @@ -50,9 +51,9 @@ class AsyncTest extends Specification {
def "eachSequentially propagates exception"() {
given:
def input = ['a', 'b', 'c']
def cfFactory = Mock(Async.CFFactory)
cfFactory.apply('a', 0, _) >> completedFuture("x")
cfFactory.apply('b', 1, _) >> {
def cfFactory = Mock(BiFunction)
cfFactory.apply('a', _) >> completedFuture("x")
cfFactory.apply('b', _) >> {
def cf = new CompletableFuture<>()
cf.completeExceptionally(new RuntimeException("some error"))
cf
Expand All @@ -74,9 +75,9 @@ class AsyncTest extends Specification {
def "eachSequentially catches factory exception"() {
given:
def input = ['a', 'b', 'c']
def cfFactory = Mock(Async.CFFactory)
cfFactory.apply('a', 0, _) >> completedFuture("x")
cfFactory.apply('b', 1, _) >> { throw new RuntimeException("some error") }
def cfFactory = Mock(BiFunction)
cfFactory.apply('a', _) >> completedFuture("x")
cfFactory.apply('b', _) >> { throw new RuntimeException("some error") }

when:
def result = Async.eachSequentially(input, cfFactory)
Expand All @@ -94,10 +95,10 @@ class AsyncTest extends Specification {
def "each works for mapping function"() {
given:
def input = ['a', 'b', 'c']
def cfFactory = Mock(BiFunction)
cfFactory.apply('a', 0) >> completedFuture('x')
cfFactory.apply('b', 1) >> completedFuture('y')
cfFactory.apply('c', 2) >> completedFuture('z')
def cfFactory = Mock(Function)
cfFactory.apply('a') >> completedFuture('x')
cfFactory.apply('b') >> completedFuture('y')
cfFactory.apply('c') >> completedFuture('z')


when:
Expand All @@ -111,16 +112,15 @@ class AsyncTest extends Specification {
def "each with mapping function propagates factory exception"() {
given:
def input = ['a', 'b', 'c']
def cfFactory = Mock(BiFunction)

def cfFactory = Mock(Function)

when:
def result = Async.each(input, cfFactory)

then:
1 * cfFactory.apply('a', 0) >> completedFuture('x')
1 * cfFactory.apply('b', 1) >> { throw new RuntimeException('some error') }
1 * cfFactory.apply('c', 2) >> completedFuture('z')
1 * cfFactory.apply('a') >> completedFuture('x')
1 * cfFactory.apply('b') >> { throw new RuntimeException('some error') }
1 * cfFactory.apply('c') >> completedFuture('z')
result.isCompletedExceptionally()
Throwable exception
result.exceptionally({ e ->
Expand Down
39 changes: 20 additions & 19 deletions src/test/java/benchmark/BenchMark.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package benchmark;

import graphql.Assert;
import graphql.ExecutionResult;
import graphql.GraphQL;
import graphql.execution.ExecutionStepInfo;
import graphql.execution.instrumentation.tracing.TracingInstrumentation;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.GraphQLSchema;
import graphql.schema.TypeResolver;
import graphql.schema.idl.RuntimeWiring;
import graphql.schema.idl.SchemaGenerator;
import graphql.schema.idl.SchemaParser;
Expand All @@ -24,10 +26,10 @@
import static graphql.schema.idl.TypeRuntimeWiring.newTypeWiring;

/**
* See https://github.com/openjdk/jmh/tree/master/jmh-samples/src/main/java/org/openjdk/jmh/samples/ for more samples
* on what you can do with JMH
* See <a href="https://github.com/openjdk/jmh/tree/master/jmh-samples/src/main/java/org/openjdk/jmh/samples/">this link</a> for more samples
* on what you can do with JMH.
* <p>
* You MUST have the JMH plugin for IDEA in place for this to work : https://github.com/artyushov/idea-jmh-plugin
* You MUST have the JMH plugin for IDEA in place for this to work : <a href="https://github.com/artyushov/idea-jmh-plugin">idea-jmh-plugin</a>
* <p>
* Install it and then just hit "Run" on a certain benchmark method
*/
Expand All @@ -36,44 +38,41 @@
public class BenchMark {

private static final int NUMBER_OF_FRIENDS = 10 * 100;

static GraphQL graphQL = buildGraphQL();
private static final GraphQL GRAPHQL = buildGraphQL();

@Benchmark
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.SECONDS)
public void benchMarkSimpleQueriesThroughput() {
executeQuery();
public ExecutionResult benchMarkSimpleQueriesThroughput() {
return executeQuery();
}

@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void benchMarkSimpleQueriesAvgTime() {
executeQuery();
public ExecutionResult benchMarkSimpleQueriesAvgTime() {
return executeQuery();
}

public static void executeQuery() {
public static ExecutionResult executeQuery() {
String query = "{ hero { name friends { name friends { name } } } }";
graphQL.execute(query);
return GRAPHQL.execute(query);
}

private static GraphQL buildGraphQL() {
TypeDefinitionRegistry definitionRegistry = new SchemaParser().parse(BenchmarkUtils.loadResource("starWarsSchema.graphqls"));

DataFetcher heroDataFetcher = environment -> CharacterDTO.mkCharacter(environment, "r2d2", NUMBER_OF_FRIENDS);
DataFetcher<CharacterDTO> heroDataFetcher = environment -> CharacterDTO.mkCharacter(environment, "r2d2", NUMBER_OF_FRIENDS);
TypeResolver typeResolver = env -> env.getSchema().getObjectType("Human");

RuntimeWiring runtimeWiring = RuntimeWiring.newRuntimeWiring()
.type(
newTypeWiring("QueryType").dataFetcher("hero", heroDataFetcher))
.type(newTypeWiring("Character").typeResolver(
env -> env.getSchema().getObjectType("Human")
))
.type(newTypeWiring("QueryType").dataFetcher("hero", heroDataFetcher))
.type(newTypeWiring("Character").typeResolver(typeResolver))
.build();

GraphQLSchema graphQLSchema = new SchemaGenerator().makeExecutableSchema(definitionRegistry, runtimeWiring);

return GraphQL.newGraphQL(graphQLSchema)
.instrumentation(new TracingInstrumentation())
.build();
}

Expand All @@ -96,7 +95,9 @@ public List<CharacterDTO> getFriends() {

public static CharacterDTO mkCharacter(DataFetchingEnvironment environment, String name, int friendCount) {
Object sideEffect = environment.getArgument("episode");
Assert.assertNull(sideEffect);
ExecutionStepInfo anotherSideEffect = environment.getExecutionStepInfo();
Assert.assertNotNull(anotherSideEffect);
List<CharacterDTO> friends = new ArrayList<>(friendCount);
for (int i = 0; i < friendCount; i++) {
friends.add(mkCharacter(environment, "friend" + i, 0));
Expand Down