diff --git a/src/main/java/io/iworkflow/core/Client.java b/src/main/java/io/iworkflow/core/Client.java index 316c00c3..b13b71d5 100644 --- a/src/main/java/io/iworkflow/core/Client.java +++ b/src/main/java/io/iworkflow/core/Client.java @@ -37,7 +37,7 @@ public class Client { final ClientOptions clientOptions; /** - * return a full featured client. If you don't have the workflow Registry, you should use {@link UnregisteredClient} instead + * return a full-featured client. If you don't have the workflow Registry, you should use {@link UnregisteredClient} instead * * @param registry registry is required so that this client can perform some validation checks (workflow types, channel names) * @param clientOptions is for configuring the client @@ -686,6 +686,19 @@ public O invokeRPC(RpcDefinitions.RpcFunc1 rpcStubMethod, I input) return rpcStubMethod.execute(null, input, null, null); } + /** + * invoking the RPC through RPC stub + * + * @param rpcStubMethod the RPC method from stub created by {@link #newRpcStub(Class, String, String)} + * @param input the input of the RPC method + * @param the input type + * @param the output type + * @return output + */ + public O invokeRPC(RpcDefinitions.RpcFunc1NoPersistence rpcStubMethod, I input) { + return rpcStubMethod.execute(null, input, null); + } + /** * invoking the RPC through RPC stub * @@ -697,6 +710,17 @@ public O invokeRPC(RpcDefinitions.RpcFunc0 rpcStubMethod) { return rpcStubMethod.execute(null, null, null); } + /** + * invoking the RPC through RPC stub + * + * @param rpcStubMethod the RPC method from stub created by {@link #newRpcStub(Class, String, String)} + * @param the output type + * @return output + */ + public O invokeRPC(RpcDefinitions.RpcFunc0NoPersistence rpcStubMethod) { + return rpcStubMethod.execute(null, null); + } + /** * invoking the RPC through RPC stub * @@ -708,6 +732,17 @@ public void invokeRPC(RpcDefinitions.RpcProc1 rpcStubMethod, I input) { rpcStubMethod.execute(null, input, null, null); } + /** + * invoking the RPC through RPC stub + * + * @param rpcStubMethod the RPC method from stub created by {@link #newRpcStub(Class, String, String)} + * @param input the input of the RPC method + * @param the input type + */ + public void invokeRPC(RpcDefinitions.RpcProc1NoPersistence rpcStubMethod, I input) { + rpcStubMethod.execute(null, input, null); + } + /** * invoking the RPC through RPC stub * @@ -717,6 +752,15 @@ public void invokeRPC(RpcDefinitions.RpcProc0 rpcStubMethod) { rpcStubMethod.execute(null, null, null); } + /** + * invoking the RPC through RPC stub + * + * @param rpcStubMethod the RPC method from stub created by {@link #newRpcStub(Class, String, String)} + */ + public void invokeRPC(RpcDefinitions.RpcProc0NoPersistence rpcStubMethod) { + rpcStubMethod.execute(null, null); + } + /** * Get specified search attributes (by attributeKeys) of a workflow * diff --git a/src/main/java/io/iworkflow/core/RpcDefinitions.java b/src/main/java/io/iworkflow/core/RpcDefinitions.java index b2b6b2da..43ff9870 100644 --- a/src/main/java/io/iworkflow/core/RpcDefinitions.java +++ b/src/main/java/io/iworkflow/core/RpcDefinitions.java @@ -1,18 +1,21 @@ package io.iworkflow.core; +import com.google.common.collect.ImmutableMap; import io.iworkflow.core.communication.Communication; import io.iworkflow.core.persistence.Persistence; import java.io.Serializable; import java.lang.reflect.Method; +import java.util.Map; public final class RpcDefinitions { private RpcDefinitions() { } /** - * RPC with input and output - * + * RPC definition + * with: input, output, persistence, communication + * without: NA * @param input type * @param output type */ @@ -22,8 +25,21 @@ public interface RpcFunc1 extends Serializable { } /** - * RPC with output only - * + * RPC definition + * with: input, output, communication + * without: persistence + * @param input type + * @param output type + */ + @FunctionalInterface + public interface RpcFunc1NoPersistence extends Serializable { + O execute(Context context, I input, Communication communication); + } + + /** + * RPC definition + * with: output, persistence, communication + * without: input * @param output type */ @FunctionalInterface @@ -32,8 +48,20 @@ public interface RpcFunc0 extends Serializable { } /** - * RPC with input only - * + * RPC definition + * with: output, communication + * without: input, persistence + * @param output type + */ + @FunctionalInterface + public interface RpcFunc0NoPersistence extends Serializable { + O execute(Context context, Communication communication); + } + + /** + * RPC definition + * with: input, persistence, communication + * without: output * @param input type */ @FunctionalInterface @@ -42,34 +70,42 @@ public interface RpcProc1 extends Serializable { } /** - * RPC without input or output + * RPC definition + * with: input, communication + * without: output, persistence + * @param input type + */ + @FunctionalInterface + public interface RpcProc1NoPersistence extends Serializable { + void execute(Context context, I input, Communication communication); + } + + /** + * RPC definition + * with: persistence, communication + * without: input, output */ @FunctionalInterface public interface RpcProc0 extends Serializable { void execute(Context context, Persistence persistence, Communication communication); } - public static final int PARAMETERS_WITH_INPUT = 4; - public static final int PARAMETERS_NO_INPUT = 3; + /** + * RPC definition + * with: communication + * without: input, output, persistence + */ + @FunctionalInterface + public interface RpcProc0NoPersistence extends Serializable { + void execute(Context context, Communication communication); + } - public static final int INDEX_OF_INPUT_PARAMETER = 1; + public static final String ERROR_MESSAGE = "An RPC method must be in the form of one of {@link RpcDefinitions}"; public static void validateRpcMethod(final Method method) { - final Class[] paramTypes = method.getParameterTypes(); - final Class persistenceType, communicationType, contextType; - if (paramTypes.length == PARAMETERS_NO_INPUT) { - contextType = paramTypes[0]; - persistenceType = paramTypes[1]; - communicationType = paramTypes[2]; - } else if (paramTypes.length == PARAMETERS_WITH_INPUT) { - contextType = paramTypes[0]; - persistenceType = paramTypes[2]; - communicationType = paramTypes[3]; - } else { - throw new WorkflowDefinitionException("An RPC method must be in the form of one of {@link RpcDefinitions}"); - } - if (!persistenceType.equals(Persistence.class) || !communicationType.equals(Communication.class)) { - throw new WorkflowDefinitionException("An RPC method must be in the form of one of {@link RpcDefinitions}"); + RpcMethodMetadata methodMetadata = RpcMethodMatcher.match(method); + if (methodMetadata == null) { + throw new WorkflowDefinitionException(ERROR_MESSAGE); } } -} \ No newline at end of file +} diff --git a/src/main/java/io/iworkflow/core/RpcInvocationHandler.java b/src/main/java/io/iworkflow/core/RpcInvocationHandler.java index 50c8e5ae..41c9e5a2 100644 --- a/src/main/java/io/iworkflow/core/RpcInvocationHandler.java +++ b/src/main/java/io/iworkflow/core/RpcInvocationHandler.java @@ -2,6 +2,7 @@ import io.iworkflow.core.persistence.PersistenceOptions; import io.iworkflow.gen.models.PersistenceLoadingPolicy; +import io.iworkflow.gen.models.PersistenceLoadingType; import io.iworkflow.gen.models.SearchAttributeKeyAndType; import net.bytebuddy.implementation.bind.annotation.AllArguments; import net.bytebuddy.implementation.bind.annotation.Origin; @@ -11,9 +12,7 @@ import java.util.Arrays; import java.util.List; -import static io.iworkflow.core.RpcDefinitions.INDEX_OF_INPUT_PARAMETER; -import static io.iworkflow.core.RpcDefinitions.PARAMETERS_WITH_INPUT; -import static io.iworkflow.core.RpcDefinitions.validateRpcMethod; +import static io.iworkflow.core.RpcDefinitions.*; public class RpcInvocationHandler { @@ -41,11 +40,12 @@ public Object intercept(@AllArguments Object[] allArguments, if (rpcAnno == null) { throw new WorkflowDefinitionException("An RPC method must be annotated by RPC annotation"); } - validateRpcMethod(method); - Object input = null; - if (method.getParameterTypes().length == PARAMETERS_WITH_INPUT) { - input = allArguments[INDEX_OF_INPUT_PARAMETER]; + + RpcMethodMetadata metadata = RpcMethodMatcher.match(method); + if (metadata == null) { + throw new WorkflowDefinitionException("An RPC method must be annotated by RPC annotation"); } + Object input = metadata.hasInput() ? allArguments[metadata.getInputIndex()] : null; final Class outputType = method.getReturnType(); @@ -53,18 +53,41 @@ public Object intercept(@AllArguments Object[] allArguments, if (rpcAnno.bypassCachingForStrongConsistency()) { useMemo = false; } - final Object output = unregisteredClient.invokeRpc(outputType, input, workflowId, workflowRunId, method.getName(), rpcAnno.timeoutSeconds(), - new PersistenceLoadingPolicy() - .persistenceLoadingType(rpcAnno.dataAttributesLoadingType()) - .partialLoadingKeys(Arrays.asList(rpcAnno.dataAttributesPartialLoadingKeys())) - .lockingKeys(Arrays.asList(rpcAnno.dataAttributesLockingKeys())), - new PersistenceLoadingPolicy() - .persistenceLoadingType(rpcAnno.searchAttributesLoadingType()) - .lockingKeys(Arrays.asList(rpcAnno.searchAttributesLockingKeys())) - .partialLoadingKeys(Arrays.asList(rpcAnno.searchAttributesPartialLoadingKeys())), - useMemo, - searchAttributeKeyAndTypes - ); - return output; + + if (metadata.usesPersistence()) { + return unregisteredClient.invokeRpc( + outputType, + input, + workflowId, + workflowRunId, + method.getName(), + rpcAnno.timeoutSeconds(), + new PersistenceLoadingPolicy() + .persistenceLoadingType(rpcAnno.dataAttributesLoadingType()) + .partialLoadingKeys(Arrays.asList(rpcAnno.dataAttributesPartialLoadingKeys())) + .lockingKeys(Arrays.asList(rpcAnno.dataAttributesLockingKeys())), + new PersistenceLoadingPolicy() + .persistenceLoadingType(rpcAnno.searchAttributesLoadingType()) + .lockingKeys(Arrays.asList(rpcAnno.searchAttributesLockingKeys())) + .partialLoadingKeys(Arrays.asList(rpcAnno.searchAttributesPartialLoadingKeys())), + useMemo, + searchAttributeKeyAndTypes + ); + } else { + return unregisteredClient.invokeRpc( + outputType, + input, + workflowId, + workflowRunId, + method.getName(), + rpcAnno.timeoutSeconds(), + new PersistenceLoadingPolicy() + .persistenceLoadingType(PersistenceLoadingType.NONE), + new PersistenceLoadingPolicy() + .persistenceLoadingType(PersistenceLoadingType.NONE), + useMemo, + null); + } + } } \ No newline at end of file diff --git a/src/main/java/io/iworkflow/core/RpcMethodMatcher.java b/src/main/java/io/iworkflow/core/RpcMethodMatcher.java new file mode 100644 index 00000000..367edbf2 --- /dev/null +++ b/src/main/java/io/iworkflow/core/RpcMethodMatcher.java @@ -0,0 +1,109 @@ +package io.iworkflow.core; + +import com.google.common.collect.ImmutableMap; +import io.iworkflow.core.communication.Communication; +import io.iworkflow.core.persistence.Persistence; + +import java.lang.reflect.Method; +import java.util.Map; + +class RpcMethodMatcher { + public static final Map> RPC_WITH_INPUT_PERSISTENCE_PARAM_TYPES = + new ImmutableMap.Builder>() + .put(0, Context.class) + .put(2, Persistence.class) + .put(3, Communication.class) + .build(); + public static final RpcMethodMetadata METADATA_RPC_WITH_INPUT_PERSISTENCE = + ImmutableRpcMethodMetadata.builder() + .hasInput(true) + .inputIndex(1) + .usesPersistence(true) + .build(); + + public static final Map> RPC_WITH_INPUT_PARAM_TYPES = + new ImmutableMap.Builder>() + .put(0, Context.class) + .put(2, Communication.class) + .build(); + public static final RpcMethodMetadata METADATA_RPC_WITH_INPUT = + ImmutableRpcMethodMetadata.builder() + .hasInput(true) + .inputIndex(1) + .usesPersistence(false) + .build(); + + public static final Map> RPC_WITH_PERSISTENCE_PARAM_TYPES = + new ImmutableMap.Builder>() + .put(0, Context.class) + .put(1, Persistence.class) + .put(2, Communication.class) + .build(); + public static final RpcMethodMetadata METADATA_RPC_WITH_PERSISTENCE = + ImmutableRpcMethodMetadata.builder() + .hasInput(false) + .inputIndex(-1) + .usesPersistence(true) + .build(); + + public static final Map> RPC_WITHOUT_INPUT_PERSISTENCE_PARAM_TYPES = + new ImmutableMap.Builder>() + .put(0, Context.class) + .put(1, Communication.class) + .build(); + public static final RpcMethodMetadata METADATA_RPC_WITHOUT_INPUT_PERSISTENCE = + ImmutableRpcMethodMetadata.builder() + .hasInput(false) + .inputIndex(-1) + .usesPersistence(false) + .build(); + + private static final int RPC_PARAM_COUNT_MAX = 4; + private static final int RPC_PARAM_COUNT_MIN = 2; + + public static RpcMethodMetadata match(Method method) { + final Class[] paramTypes = method.getParameterTypes(); + if (paramTypes.length < RPC_PARAM_COUNT_MIN || paramTypes.length > RPC_PARAM_COUNT_MAX) { + return null; + } + + switch (paramTypes.length) { + case 2: + if (validateInputParameters(paramTypes, RPC_WITHOUT_INPUT_PERSISTENCE_PARAM_TYPES)) { + return METADATA_RPC_WITHOUT_INPUT_PERSISTENCE; + } else { + return null; + } + case 3: + if (validateInputParameters(paramTypes, RPC_WITH_PERSISTENCE_PARAM_TYPES)) { + return METADATA_RPC_WITH_PERSISTENCE; + } else if (validateInputParameters(paramTypes, RPC_WITH_INPUT_PARAM_TYPES)) { + return METADATA_RPC_WITH_INPUT; + } else { + return null; + } + case 4: + if (validateInputParameters(paramTypes, RPC_WITH_INPUT_PERSISTENCE_PARAM_TYPES)) { + return METADATA_RPC_WITH_INPUT_PERSISTENCE; + } else { + return null; + } + } + + return null; + } + + private static boolean validateInputParameters(Class[] paramTypes, Map> expectedInputParamTypes) { + for (Map.Entry> entry: expectedInputParamTypes.entrySet()) { + if (entry.getKey() >= paramTypes.length) { + return false; + } + + if (!paramTypes[entry.getKey()].equals(entry.getValue())) { + return false; + } + } + + return true; + } +} diff --git a/src/main/java/io/iworkflow/core/RpcMethodMetadata.java b/src/main/java/io/iworkflow/core/RpcMethodMetadata.java new file mode 100644 index 00000000..e70adffc --- /dev/null +++ b/src/main/java/io/iworkflow/core/RpcMethodMetadata.java @@ -0,0 +1,10 @@ +package io.iworkflow.core; + +import org.immutables.value.Value; + +@Value.Immutable +abstract class RpcMethodMetadata { + public abstract boolean hasInput(); + public abstract int getInputIndex(); + public abstract boolean usesPersistence(); +} diff --git a/src/main/java/io/iworkflow/core/UnregisteredClient.java b/src/main/java/io/iworkflow/core/UnregisteredClient.java index 3be53338..02ac531a 100644 --- a/src/main/java/io/iworkflow/core/UnregisteredClient.java +++ b/src/main/java/io/iworkflow/core/UnregisteredClient.java @@ -46,6 +46,13 @@ public class UnregisteredClient { private final ClientOptions clientOptions; + private WorkflowRpcRequest outgoingWorkflowRpcRequest; + + // for testing purpose + public WorkflowRpcRequest getLastOutgoingWorkflowRpcRequest() { + return this.outgoingWorkflowRpcRequest; + } + public UnregisteredClient(final ClientOptions clientOptions) { this.clientOptions = clientOptions; @@ -647,18 +654,18 @@ public T invokeRpc( final List allSearchAttributes) { try { final EncodedObject encodedInput = this.clientOptions.getObjectEncoder().encode(input); - final WorkflowRpcResponse response = defaultApi.apiV1WorkflowRpcPost( - new WorkflowRpcRequest() - .input(encodedInput) - .workflowId(workflowId) - .workflowRunId(workflowRunId) - .rpcName(rpcName) - .timeoutSeconds(timeoutSeconds) - .dataAttributesLoadingPolicy(dataAttributesLoadingPolicy) - .searchAttributesLoadingPolicy(searchAttributesLoadingPolicy) - .useMemoForDataAttributes(usingMemoForDataAttributes) - .searchAttributes(allSearchAttributes) - ); + WorkflowRpcRequest request = new WorkflowRpcRequest() + .input(encodedInput) + .workflowId(workflowId) + .workflowRunId(workflowRunId) + .rpcName(rpcName) + .timeoutSeconds(timeoutSeconds) + .dataAttributesLoadingPolicy(dataAttributesLoadingPolicy) + .searchAttributesLoadingPolicy(searchAttributesLoadingPolicy) + .useMemoForDataAttributes(usingMemoForDataAttributes) + .searchAttributes(allSearchAttributes); + this.outgoingWorkflowRpcRequest = request; + final WorkflowRpcResponse response = defaultApi.apiV1WorkflowRpcPost(request); return this.clientOptions.getObjectEncoder().decode(response.getOutput(), valueClass); } catch (FeignException.FeignClientException exp) { throw IwfHttpException.fromFeignException(clientOptions.getObjectEncoder(), exp); diff --git a/src/main/java/io/iworkflow/core/WorkerService.java b/src/main/java/io/iworkflow/core/WorkerService.java index 19d6aaad..120f6ee2 100644 --- a/src/main/java/io/iworkflow/core/WorkerService.java +++ b/src/main/java/io/iworkflow/core/WorkerService.java @@ -32,9 +32,6 @@ import java.util.Optional; import java.util.stream.Collectors; -import static io.iworkflow.core.RpcDefinitions.INDEX_OF_INPUT_PARAMETER; -import static io.iworkflow.core.RpcDefinitions.PARAMETERS_WITH_INPUT; - public class WorkerService { public static final String WORKFLOW_STATE_WAIT_UNTIL_API_PATH = "/api/v1/workflowState/start"; @@ -54,10 +51,15 @@ public WorkerService(Registry registry, WorkerOptions workerOptions) { public WorkflowWorkerRpcResponse handleWorkflowWorkerRpc(final WorkflowWorkerRpcRequest req) { final ObjectWorkflow workflow = registry.getWorkflow(req.getWorkflowType()); final Method method = registry.getWorkflowRpcMethod(req.getWorkflowType(), req.getRpcName()); + + RpcMethodMetadata methodMetadata = RpcMethodMatcher.match(method); + if (methodMetadata == null) { + throw new WorkflowDefinitionException("An RPC method must be annotated by RPC annotation and matches one of the RPC definitions"); + } Object input = null; - if (method.getParameterTypes().length == PARAMETERS_WITH_INPUT) { + if (methodMetadata.hasInput()) { // the second one will be input - Class inputType = method.getParameterTypes()[INDEX_OF_INPUT_PARAMETER]; + Class inputType = method.getParameterTypes()[methodMetadata.getInputIndex()]; input = workerOptions.getObjectEncoder().decode(req.getInput(), inputType); } @@ -78,20 +80,38 @@ public WorkflowWorkerRpcResponse handleWorkflowWorkerRpc(final WorkflowWorkerRpc Object output = null; try { - if (method.getParameterTypes().length == PARAMETERS_WITH_INPUT) { - output = method.invoke( - workflow, - context, - input, - persistence, - communication); + if (methodMetadata.usesPersistence()) { + if (methodMetadata.hasInput()) { + output = method.invoke( + workflow, + context, + input, + persistence, + communication + ); + } else { + output = method.invoke( + workflow, + context, + persistence, + communication + ); + } } else { - // without input - output = method.invoke( - workflow, - context, - persistence, - communication); + if (methodMetadata.hasInput()) { + output = method.invoke( + workflow, + context, + input, + communication + ); + } else { + output = method.invoke( + workflow, + context, + communication + ); + } } } catch (IllegalAccessException e) { throw new RuntimeException(e); diff --git a/src/test/java/io/iworkflow/integ/RpcTest.java b/src/test/java/io/iworkflow/integ/RpcTest.java index 466ee140..2b36bbed 100644 --- a/src/test/java/io/iworkflow/integ/RpcTest.java +++ b/src/test/java/io/iworkflow/integ/RpcTest.java @@ -6,9 +6,7 @@ import io.iworkflow.core.ClientSideException; import io.iworkflow.core.ImmutableStopWorkflowOptions; import io.iworkflow.core.ImmutableWorkflowOptions; -import io.iworkflow.gen.models.ErrorResponse; -import io.iworkflow.gen.models.WorkflowConfig; -import io.iworkflow.gen.models.WorkflowStopType; +import io.iworkflow.gen.models.*; import io.iworkflow.integ.persistence.BasicPersistenceWorkflow; import io.iworkflow.integ.rpc.NoStateWorkflow; import io.iworkflow.integ.rpc.RpcWorkflow; @@ -100,6 +98,28 @@ public void testRPCLocking() throws InterruptedException, ExecutionException { client.stopWorkflow(wfId, null); } + @Test + public void testRpcNoPersistence() { + final Client client = new Client(WorkflowRegistry.registry, ClientOptions.localDefault); + final String wfId = "testRpcWithNoPersistence" + System.currentTimeMillis() / 1000; + final String runId = client.startWorkflow( + RpcWorkflow.class, wfId, 10, 999); + + final RpcWorkflow rpcStub = client.newRpcStub(RpcWorkflow.class, wfId, "" ); + client.invokeRPC(rpcStub::testRpcNoPersistence); + WorkflowRpcRequest request = client.getUnregisteredClient().getLastOutgoingWorkflowRpcRequest(); + Assertions.assertNotNull(request.getDataAttributesLoadingPolicy()); + Assertions.assertEquals(PersistenceLoadingType.NONE, + request.getDataAttributesLoadingPolicy().getPersistenceLoadingType()); + Assertions.assertNotNull(request.getSearchAttributesLoadingPolicy()); + Assertions.assertEquals(PersistenceLoadingType.NONE, + request.getSearchAttributesLoadingPolicy().getPersistenceLoadingType()); + + final Integer output = client.getSimpleWorkflowResultWithWait(Integer.class, wfId); + RpcWorkflowState2.resetCounter(); + Assertions.assertEquals(2, output); + } + @Test public void testRPCWorkflowFunc1() throws InterruptedException { final Client client = new Client(WorkflowRegistry.registry, ClientOptions.localDefault); diff --git a/src/test/java/io/iworkflow/integ/rpc/RpcWorkflow.java b/src/test/java/io/iworkflow/integ/rpc/RpcWorkflow.java index 2af96767..e7abf152 100644 --- a/src/test/java/io/iworkflow/integ/rpc/RpcWorkflow.java +++ b/src/test/java/io/iworkflow/integ/rpc/RpcWorkflow.java @@ -54,6 +54,15 @@ public List getPersistenceSchema() { ); } + @RPC + public void testRpcNoPersistence(Context context, Communication communication) { + if (context.getWorkflowId().isEmpty() || context.getWorkflowRunId().isEmpty()) { + throw new RuntimeException("invalid context"); + } + communication.publishInternalChannel(INTERNAL_CHANNEL_NAME, null); + communication.triggerStateMovements(StateMovement.create(RpcWorkflowState2.class)); + } + @RPC public Long testRpcFunc1(Context context, String input, Persistence persistence, Communication communication) { if (context.getWorkflowId().isEmpty() || context.getWorkflowRunId().isEmpty()) {