diff --git a/internal/gengapic/custom_operation.go b/internal/gengapic/custom_operation.go index 51686a60..f4af5804 100644 --- a/internal/gengapic/custom_operation.go +++ b/internal/gengapic/custom_operation.go @@ -192,8 +192,10 @@ func (g *generator) customOperationType() error { p("}") p("") g.imports[pbinfo.ImportSpec{Path: "context"}] = true - g.imports[pbinfo.ImportSpec{Name: "gax", Path: "github.com/googleapis/gax-go/v2"}] = true g.imports[pbinfo.ImportSpec{Path: "time"}] = true + g.imports[pbinfo.ImportSpec{Name: "gax", Path: "github.com/googleapis/gax-go/v2"}] = true + g.imports[pbinfo.ImportSpec{Path: "github.com/googleapis/gax-go/v2/apierror"}] = true + g.imports[pbinfo.ImportSpec{Path: "google.golang.org/api/googleapi"}] = true for _, handle := range op.handles { pollingParams := op.pollingParams[handle] @@ -204,6 +206,17 @@ func (g *generator) customOperationType() error { poll := operationPollingMethod(handle) pollReq := g.descInfo.Type[poll.GetInputType()].(*descriptor.DescriptorProto) pollNameField := operationResponseField(pollReq, opNameField.GetName()) + // Look up the fields for error code and error message. + errorCodeField := operationField(op.message, extendedops.OperationResponseMapping_ERROR_CODE) + if errorCodeField == nil { + return fmt.Errorf("field %s not found in %T", extendedops.OperationResponseMapping_ERROR_CODE, op) + } + errorCode := snakeToCamel(errorCodeField.GetName()) + errorMessageField := operationField(op.message, extendedops.OperationResponseMapping_ERROR_MESSAGE) + if errorMessageField == nil { + return fmt.Errorf("field %s not found in %T", extendedops.OperationResponseMapping_ERROR_MESSAGE, op) + } + errorMessage := snakeToCamel(errorMessageField.GetName()) // type p("// Implements the %s interface for %s.", handleInt, handle.GetName()) @@ -229,6 +242,19 @@ func (g *generator) customOperationType() error { p(" return err") p(" }") p(" h.proto = resp") + p(" if resp.%[1]s != nil && (resp.Get%[1]s() < 200 || resp.Get%[1]s() > 299) {", errorCode) + p(" aErr := &googleapi.Error{") + p(" Code: int(resp.Get%s()),", errorCode) + if hasField(op.message, "error") { + g.imports[pbinfo.ImportSpec{Path: "fmt"}] = true + p(` Message: fmt.Sprintf("%%s: %%v", resp.Get%s(), resp.GetError()),`, errorMessage) + } else { + p(" Message: resp.Get%s(),", errorMessage) + } + p(" }") + p(" err, _ := apierror.FromError(aErr)") + p(" return err") + p(" }") p(" return nil") p("}") p("") diff --git a/internal/gengapic/custom_operation_test.go b/internal/gengapic/custom_operation_test.go index a030e5fe..c383cd49 100644 --- a/internal/gengapic/custom_operation_test.go +++ b/internal/gengapic/custom_operation_test.go @@ -121,6 +121,38 @@ func TestCustomOpInit(t *testing.T) { } func TestCustomOperationType(t *testing.T) { + errorType := &descriptor.DescriptorProto{ + Name: proto.String("Error"), + Field: []*descriptor.FieldDescriptorProto{ + { + Name: proto.String("nested"), + Type: typep(descriptor.FieldDescriptorProto_TYPE_STRING), + Label: labelp(descriptor.FieldDescriptorProto_LABEL_OPTIONAL), + }, + }, + } + errorField := &descriptor.FieldDescriptorProto{ + Name: proto.String("error"), + Type: typep(descriptor.FieldDescriptorProto_TYPE_MESSAGE), + TypeName: proto.String("Error"), + } + + errorCodeOpts := &descriptor.FieldOptions{} + proto.SetExtension(errorCodeOpts, extendedops.E_OperationField, extendedops.OperationResponseMapping_ERROR_CODE) + errorCodeField := &descriptor.FieldDescriptorProto{ + Name: proto.String("http_error_status_code"), + Type: descriptor.FieldDescriptorProto_TYPE_STRING.Enum(), + Options: errorCodeOpts, + } + + errorMessageOpts := &descriptor.FieldOptions{} + proto.SetExtension(errorMessageOpts, extendedops.E_OperationField, extendedops.OperationResponseMapping_ERROR_MESSAGE) + errorMessageField := &descriptor.FieldDescriptorProto{ + Name: proto.String("http_error_message"), + Type: descriptor.FieldDescriptorProto_TYPE_STRING.Enum(), + Options: errorMessageOpts, + } + nameOpts := &descriptor.FieldOptions{} proto.SetExtension(nameOpts, extendedops.E_OperationField, extendedops.OperationResponseMapping_NAME) nameField := &descriptor.FieldDescriptorProto{ @@ -213,25 +245,31 @@ func TestCustomOperationType(t *testing.T) { }, Type: map[string]pbinfo.ProtoType{ statusEnumField.GetTypeName(): statusEnum, + ".google.cloud.foo.v1.Error": errorType, ".google.cloud.foo.v1.GetFooOperationRequest": getInput, }, }, imports: map[pbinfo.ImportSpec]bool{}, } for _, tst := range []struct { - name string - st *descriptor.FieldDescriptorProto + name string + st *descriptor.FieldDescriptorProto + errorField bool }{ { name: "enum", st: statusEnumField, }, { - name: "bool", - st: statusBoolField, + name: "bool", + st: statusBoolField, + errorField: true, }, } { - op.Field = []*descriptor.FieldDescriptorProto{nameField, tst.st} + op.Field = []*descriptor.FieldDescriptorProto{errorCodeField, errorMessageField, nameField, tst.st} + if tst.errorField { + op.Field = append(op.Field, errorField) + } err := g.customOperationType() if err != nil { t.Fatal(err) diff --git a/internal/gengapic/helpers.go b/internal/gengapic/helpers.go index 67b1d84c..0d5424db 100644 --- a/internal/gengapic/helpers.go +++ b/internal/gengapic/helpers.go @@ -110,6 +110,17 @@ func grpcClientField(reducedServName string) string { return lowerFirst(reducedServName + "Client") } +// hasField returns true if the target DescriptorProto has the given field, +// otherwise it returns false. +func hasField(m *descriptor.DescriptorProto, field string) bool { + for _, f := range m.GetField() { + if f.GetName() == field { + return true + } + } + return false +} + // hasMethod reports if the given service defines an RPC with the same name as // the given simple method name. func hasMethod(service *descriptor.ServiceDescriptorProto, method string) bool { diff --git a/internal/gengapic/helpers_test.go b/internal/gengapic/helpers_test.go index ab0bbc35..332b6d7c 100644 --- a/internal/gengapic/helpers_test.go +++ b/internal/gengapic/helpers_test.go @@ -144,6 +144,26 @@ func TestStrContains(t *testing.T) { } } +func TestHasField(t *testing.T) { + msg := &descriptor.DescriptorProto{ + Field: []*descriptor.FieldDescriptorProto{ + {Name: proto.String("foo")}, + {Name: proto.String("bar")}, + }, + } + for _, tst := range []struct { + in string + want bool + }{ + {in: "foo", want: true}, + {in: "baz"}, + } { + if got := hasField(msg, tst.in); !cmp.Equal(got, tst.want) { + t.Errorf("TestHasField got %v want %v", got, tst.want) + } + } +} + func TestHasMethod(t *testing.T) { serv := &descriptor.ServiceDescriptorProto{ Method: []*descriptor.MethodDescriptorProto{ diff --git a/internal/gengapic/testdata/custom_op_type_bool.want b/internal/gengapic/testdata/custom_op_type_bool.want index 14c90252..18bc8ff8 100644 --- a/internal/gengapic/testdata/custom_op_type_bool.want +++ b/internal/gengapic/testdata/custom_op_type_bool.want @@ -61,6 +61,14 @@ func (h *fooOperationsHandle) Poll(ctx context.Context, opts ...gax.CallOption) return err } h.proto = resp + if resp.HttpErrorStatusCode != nil && (resp.GetHttpErrorStatusCode() < 200 || resp.GetHttpErrorStatusCode() > 299) { + aErr := &googleapi.Error{ + Code: int(resp.GetHttpErrorStatusCode()), + Message: fmt.Sprintf("%s: %v", resp.GetHttpErrorMessage(), resp.GetError()), + } + err, _ := apierror.FromError(aErr) + return err + } return nil } diff --git a/internal/gengapic/testdata/custom_op_type_enum.want b/internal/gengapic/testdata/custom_op_type_enum.want index b747113a..3bb1e3a5 100644 --- a/internal/gengapic/testdata/custom_op_type_enum.want +++ b/internal/gengapic/testdata/custom_op_type_enum.want @@ -61,6 +61,14 @@ func (h *fooOperationsHandle) Poll(ctx context.Context, opts ...gax.CallOption) return err } h.proto = resp + if resp.HttpErrorStatusCode != nil && (resp.GetHttpErrorStatusCode() < 200 || resp.GetHttpErrorStatusCode() > 299) { + aErr := &googleapi.Error{ + Code: int(resp.GetHttpErrorStatusCode()), + Message: resp.GetHttpErrorMessage(), + } + err, _ := apierror.FromError(aErr) + return err + } return nil } diff --git a/rules_go_gapic/go_gapic.bzl b/rules_go_gapic/go_gapic.bzl index 08166b99..f1a52d99 100644 --- a/rules_go_gapic/go_gapic.bzl +++ b/rules_go_gapic/go_gapic.bzl @@ -155,6 +155,7 @@ def go_gapic_library( actual_deps = deps + [ "@com_github_googleapis_gax_go_v2//:go_default_library", + "@com_github_googleapis_gax_go_v2//apierror:go_default_library", "@org_golang_google_api//googleapi:go_default_library", "@org_golang_google_api//option:go_default_library", "@org_golang_google_api//option/internaloption:go_default_library",