From a707cbc593daf5b46ec4a42800ad4f7dfcf4287c Mon Sep 17 00:00:00 2001 From: Chad Retz Date: Thu, 2 Dec 2021 13:57:46 -0600 Subject: [PATCH] Expose gRPC (#651) Fixes #641 --- client/client.go | 5 + internal/client.go | 15 +++ internal/grpc_dialer.go | 4 + internal/grpc_dialer_test.go | 53 +++++++++++ internal/internal_workflow_client.go | 5 + mocks/Client.go | 134 +++++++++++++++------------ 6 files changed, 155 insertions(+), 61 deletions(-) diff --git a/client/client.go b/client/client.go index a2a638cc8..fed84a69b 100644 --- a/client/client.go +++ b/client/client.go @@ -364,6 +364,11 @@ type ( // RequestId is used to deduplicate requests. It will be autogenerated if not set. ResetWorkflowExecution(ctx context.Context, request *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error) + // WorkflowService provides access to the underlying gRPC service. This should only be used for advanced use cases + // that cannot be accomplished via other Client methods. Unlike calls to other Client methods, calls directly to the + // service are not configured with internal semantics such as automatic retries. + WorkflowService() workflowservice.WorkflowServiceClient + // Close client and clean up underlying resources. Close() } diff --git a/internal/client.go b/internal/client.go index 245ba97be..feba7c550 100644 --- a/internal/client.go +++ b/internal/client.go @@ -341,6 +341,11 @@ type ( // RequestId is used to deduplicate requests. It will be autogenerated if not set. ResetWorkflowExecution(ctx context.Context, request *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error) + // WorkflowService provides access to the underlying gRPC service. This should only be used for advanced use cases + // that cannot be accomplished via other Client methods. Unlike calls to other Client methods, calls directly to the + // service are not configured with internal semantics such as automatic retries. + WorkflowService() workflowservice.WorkflowServiceClient + // Close client and clean up underlying resources. Close() } @@ -498,6 +503,16 @@ type ( // MaxPayloadSize is a number of bytes that gRPC would allow to travel to and from server. Defaults to 64 MB. MaxPayloadSize int + + // Advanced dial options for gRPC connections. These are applied after the internal default dial options are + // applied. Therefore any dial options here may override internal ones. + // + // For gRPC interceptors, internal interceptors such as error handling, metrics, and retrying are done via + // grpc.WithChainUnaryInterceptor. Therefore to add inner interceptors that are wrapped by those, a + // grpc.WithChainUnaryInterceptor can be added as an option here. To add a single outer interceptor, a + // grpc.WithUnaryInterceptor option can be added since grpc.WithUnaryInterceptor is prepended to chains set with + // grpc.WithChainUnaryInterceptor. + DialOptions []grpc.DialOption } // StartWorkflowOptions configuration parameters for starting a workflow execution. diff --git a/internal/grpc_dialer.go b/internal/grpc_dialer.go index ef3a5d6d8..e933e08f1 100644 --- a/internal/grpc_dialer.go +++ b/internal/grpc_dialer.go @@ -121,6 +121,10 @@ func dial(params dialParameters) (*grpc.ClientConn, error) { } opts = append(opts, grpc.WithKeepaliveParams(kap)) } + + // Append any user-supplied options + opts = append(opts, params.UserConnectionOptions.DialOptions...) + return grpc.Dial(params.HostPort, opts...) } diff --git a/internal/grpc_dialer_test.go b/internal/grpc_dialer_test.go index 8c2b606a6..5c176d7c5 100644 --- a/internal/grpc_dialer_test.go +++ b/internal/grpc_dialer_test.go @@ -29,6 +29,7 @@ import ( "errors" "log" "net" + "strings" "testing" "github.com/gogo/status" @@ -124,6 +125,58 @@ func TestHeadersProvider_IncludedWithHeadersProvider(t *testing.T) { require.Equal(t, 6, len(interceptors)) } +func TestDialOptions(t *testing.T) { + // Start an unimplemented gRPC server + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + srv := grpc.NewServer() + workflowservice.RegisterWorkflowServiceServer(srv, &workflowservice.UnimplementedWorkflowServiceServer{}) + healthServer := health.NewServer() + healthServer.SetServingStatus(healthCheckServiceName, grpc_health_v1.HealthCheckResponse_SERVING) + grpc_health_v1.RegisterHealthServer(srv, healthServer) + defer srv.Stop() + go func() { _ = srv.Serve(l) }() + + // Connect with unary outer and unary inner interceptors + var trace []string + tracer := func(name string) grpc.UnaryClientInterceptor { + return func( + ctx context.Context, + method string, + req interface{}, + reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + if strings.HasSuffix(method, "/SignalWorkflowExecution") { + trace = append(trace, "begin "+name) + defer func() { trace = append(trace, "end "+name) }() + } + return invoker(ctx, method, req, reply, cc, opts...) + } + } + client, err := NewClient(ClientOptions{ + HostPort: l.Addr().String(), + ConnectionOptions: ConnectionOptions{ + DialOptions: []grpc.DialOption{ + grpc.WithUnaryInterceptor(tracer("outer")), + grpc.WithChainUnaryInterceptor(tracer("inner1"), tracer("inner2")), + }, + }, + }) + require.NoError(t, err) + defer client.Close() + + // Make call we know will error (ignore error) + _, _ = client.WorkflowService().SignalWorkflowExecution(context.TODO(), + &workflowservice.SignalWorkflowExecutionRequest{}) + + // Confirm trace + expected := []string{"begin outer", "begin inner1", "begin inner2", "end inner2", "end inner1", "end outer"} + require.Equal(t, expected, trace) +} + func TestCustomResolver(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/internal/internal_workflow_client.go b/internal/internal_workflow_client.go index af231faa6..91d245374 100644 --- a/internal/internal_workflow_client.go +++ b/internal/internal_workflow_client.go @@ -730,6 +730,11 @@ func (wc *WorkflowClient) ResetWorkflowExecution(ctx context.Context, request *w return resp, nil } +// WorkflowService implements Client.WorkflowService. +func (wc *WorkflowClient) WorkflowService() workflowservice.WorkflowServiceClient { + return wc.workflowService +} + // Close client and clean up underlying resources. func (wc *WorkflowClient) Close() { if wc.connectionCloser == nil { diff --git a/mocks/Client.go b/mocks/Client.go index 20e97ff56..1dcdba46f 100644 --- a/mocks/Client.go +++ b/mocks/Client.go @@ -25,6 +25,7 @@ // Code generated by mockery v1.0.0. // Modified manually for type alias to work correctly. // https://github.com/vektra/mockery/issues/236 + package mocks import ( @@ -33,10 +34,8 @@ import ( "github.com/stretchr/testify/mock" enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/workflowservice/v1" - "go.temporal.io/sdk/client" "go.temporal.io/sdk/converter" - "go.temporal.io/sdk/internal" ) // Client is an autogenerated mock type for the Client type @@ -58,6 +57,11 @@ func (_m *Client) CancelWorkflow(ctx context.Context, workflowID string, runID s return r0 } +// Close provides a mock function with given fields: +func (_m *Client) Close() { + _m.Called() +} + // CompleteActivity provides a mock function with given fields: ctx, taskToken, result, err func (_m *Client) CompleteActivity(ctx context.Context, taskToken []byte, result interface{}, err error) error { ret := _m.Called(ctx, taskToken, result, err) @@ -162,17 +166,17 @@ func (_m *Client) ExecuteWorkflow(ctx context.Context, options client.StartWorkf _ca = append(_ca, args...) ret := _m.Called(_ca...) - var r0 internal.WorkflowRun - if rf, ok := ret.Get(0).(func(context.Context, internal.StartWorkflowOptions, interface{}, ...interface{}) internal.WorkflowRun); ok { + var r0 client.WorkflowRun + if rf, ok := ret.Get(0).(func(context.Context, client.StartWorkflowOptions, interface{}, ...interface{}) client.WorkflowRun); ok { r0 = rf(ctx, options, workflow, args...) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(internal.WorkflowRun) + r0 = ret.Get(0).(client.WorkflowRun) } } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, internal.StartWorkflowOptions, interface{}, ...interface{}) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, client.StartWorkflowOptions, interface{}, ...interface{}) error); ok { r1 = rf(ctx, options, workflow, args...) } else { r1 = ret.Error(1) @@ -224,18 +228,41 @@ func (_m *Client) GetWorkflow(ctx context.Context, workflowID string, runID stri func (_m *Client) GetWorkflowHistory(ctx context.Context, workflowID string, runID string, isLongPoll bool, filterType enumspb.HistoryEventFilterType) client.HistoryEventIterator { ret := _m.Called(ctx, workflowID, runID, isLongPoll, filterType) - var r0 internal.HistoryEventIterator - if rf, ok := ret.Get(0).(func(context.Context, string, string, bool, enumspb.HistoryEventFilterType) internal.HistoryEventIterator); ok { + var r0 client.HistoryEventIterator + if rf, ok := ret.Get(0).(func(context.Context, string, string, bool, enumspb.HistoryEventFilterType) client.HistoryEventIterator); ok { r0 = rf(ctx, workflowID, runID, isLongPoll, filterType) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(internal.HistoryEventIterator) + r0 = ret.Get(0).(client.HistoryEventIterator) } } return r0 } +// ListArchivedWorkflow provides a mock function with given fields: ctx, request +func (_m *Client) ListArchivedWorkflow(ctx context.Context, request *workflowservice.ListArchivedWorkflowExecutionsRequest) (*workflowservice.ListArchivedWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *workflowservice.ListArchivedWorkflowExecutionsResponse + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) *workflowservice.ListArchivedWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ListArchivedWorkflowExecutionsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // ListClosedWorkflow provides a mock function with given fields: ctx, request func (_m *Client) ListClosedWorkflow(ctx context.Context, request *workflowservice.ListClosedWorkflowExecutionsRequest) (*workflowservice.ListClosedWorkflowExecutionsResponse, error) { ret := _m.Called(ctx, request) @@ -305,29 +332,6 @@ func (_m *Client) ListWorkflow(ctx context.Context, request *workflowservice.Lis return r0, r1 } -// ListArchivedWorkflow provides a mock function with given fields: ctx, request -func (_m *Client) ListArchivedWorkflow(ctx context.Context, request *workflowservice.ListArchivedWorkflowExecutionsRequest) (*workflowservice.ListArchivedWorkflowExecutionsResponse, error) { - ret := _m.Called(ctx, request) - - var r0 *workflowservice.ListArchivedWorkflowExecutionsResponse - if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) *workflowservice.ListArchivedWorkflowExecutionsResponse); ok { - r0 = rf(ctx, request) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*workflowservice.ListArchivedWorkflowExecutionsResponse) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) error); ok { - r1 = rf(ctx, request) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // QueryWorkflow provides a mock function with given fields: ctx, workflowID, runID, queryType, args func (_m *Client) QueryWorkflow(ctx context.Context, workflowID string, runID string, queryType string, args ...interface{}) (converter.EncodedValue, error) { var _ca []interface{} @@ -356,9 +360,7 @@ func (_m *Client) QueryWorkflow(ctx context.Context, workflowID string, runID st // QueryWorkflowWithOptions provides a mock function with given fields: ctx, request func (_m *Client) QueryWorkflowWithOptions(ctx context.Context, request *client.QueryWorkflowWithOptionsRequest) (*client.QueryWorkflowWithOptionsResponse, error) { - var _ca []interface{} - _ca = append(_ca, ctx, request) - ret := _m.Called(_ca...) + ret := _m.Called(ctx, request) var r0 *client.QueryWorkflowWithOptionsResponse if rf, ok := ret.Get(0).(func(context.Context, *client.QueryWorkflowWithOptionsRequest) *client.QueryWorkflowWithOptionsResponse); ok { @@ -413,6 +415,29 @@ func (_m *Client) RecordActivityHeartbeatByID(ctx context.Context, namespace str return r0 } +// ResetWorkflowExecution provides a mock function with given fields: ctx, request +func (_m *Client) ResetWorkflowExecution(ctx context.Context, request *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *workflowservice.ResetWorkflowExecutionResponse + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) *workflowservice.ResetWorkflowExecutionResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ResetWorkflowExecutionResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // ScanWorkflow provides a mock function with given fields: ctx, request func (_m *Client) ScanWorkflow(ctx context.Context, request *workflowservice.ScanWorkflowExecutionsRequest) (*workflowservice.ScanWorkflowExecutionsResponse, error) { ret := _m.Called(ctx, request) @@ -478,11 +503,14 @@ func (_m *Client) SignalWorkflow(ctx context.Context, workflowID string, runID s // TerminateWorkflow provides a mock function with given fields: ctx, workflowID, runID, reason, details func (_m *Client) TerminateWorkflow(ctx context.Context, workflowID string, runID string, reason string, details ...interface{}) error { - ret := _m.Called(ctx, workflowID, runID, reason, details) + var _ca []interface{} + _ca = append(_ca, ctx, workflowID, runID, reason) + _ca = append(_ca, details...) + ret := _m.Called(_ca...) var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string, ...interface{}) error); ok { - r0 = rf(ctx, workflowID, runID, reason, details) + r0 = rf(ctx, workflowID, runID, reason, details...) } else { r0 = ret.Error(0) } @@ -490,34 +518,18 @@ func (_m *Client) TerminateWorkflow(ctx context.Context, workflowID string, runI return r0 } -// ResetWorkflowExecution provides a mock function with given fields: request -func (_m *Client) ResetWorkflowExecution(ctx context.Context, request *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error) { - ret := _m.Called(ctx, request) +// WorkflowService provides a mock function with given fields: +func (_m *Client) WorkflowService() workflowservice.WorkflowServiceClient { + ret := _m.Called() - var r0 *workflowservice.ResetWorkflowExecutionResponse - if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) *workflowservice.ResetWorkflowExecutionResponse); ok { - r0 = rf(ctx, request) + var r0 workflowservice.WorkflowServiceClient + if rf, ok := ret.Get(0).(func() workflowservice.WorkflowServiceClient); ok { + r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*workflowservice.ResetWorkflowExecutionResponse) + r0 = ret.Get(0).(workflowservice.WorkflowServiceClient) } } - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) error); ok { - r1 = rf(ctx, request) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Close provides a mock function without given fields -func (_m *Client) Close() { - ret := _m.Called() - - if rf, ok := ret.Get(0).(func()); ok { - rf() - } + return r0 }