diff --git a/micro/example_package_test.go b/micro/example_package_test.go index 113e8825d..7a93e99ec 100644 --- a/micro/example_package_test.go +++ b/micro/example_package_test.go @@ -34,16 +34,15 @@ func Example() { // Service handler is a function which takes Service.Request as argument. // req.Respond or req.Error should be used to respond to the request. - incrementHandler := func(req *Request) error { - val, err := strconv.Atoi(string(req.Data)) + incrementHandler := func(req *Request) { + val, err := strconv.Atoi(string(req.Data())) if err != nil { req.Error("400", "request data should be a number", nil) - return nil + return } responseData := val + 1 req.Respond([]byte(strconv.Itoa(responseData))) - return nil } config := Config{ diff --git a/micro/example_test.go b/micro/example_test.go index 3e50497ac..be34a9670 100644 --- a/micro/example_test.go +++ b/micro/example_test.go @@ -28,9 +28,8 @@ func ExampleAddService() { } defer nc.Close() - echoHandler := func(req *Request) error { - req.Respond(req.Data) - return nil + echoHandler := func(req *Request) { + req.Respond(req.Data()) } config := Config{ @@ -73,7 +72,7 @@ func ExampleService_Info() { Name: "EchoService", Endpoint: Endpoint{ Subject: "echo", - Handler: func(*Request) error { return nil }, + Handler: func(*Request) {}, }, } @@ -101,7 +100,7 @@ func ExampleService_Stats() { Version: "0.1.0", Endpoint: Endpoint{ Subject: "echo", - Handler: func(*Request) error { return nil }, + Handler: func(*Request) {}, }, } @@ -127,7 +126,7 @@ func ExampleService_Stop() { Version: "0.1.0", Endpoint: Endpoint{ Subject: "echo", - Handler: func(*Request) error { return nil }, + Handler: func(*Request) {}, }, } @@ -158,7 +157,7 @@ func ExampleService_Stopped() { Version: "0.1.0", Endpoint: Endpoint{ Subject: "echo", - Handler: func(*Request) error { return nil }, + Handler: func(*Request) {}, }, } @@ -187,7 +186,7 @@ func ExampleService_Reset() { Version: "0.1.0", Endpoint: Endpoint{ Subject: "echo", - Handler: func(*Request) error { return nil }, + Handler: func(*Request) {}, }, } @@ -220,14 +219,14 @@ func ExampleControlSubject() { // Output: // $SRV.PING - // $SRV.PING.COOLSERVICE - // $SRV.PING.COOLSERVICE.123 + // $SRV.PING.CoolService + // $SRV.PING.CoolService.123 } func ExampleRequest_Respond() { handler := func(req *Request) { // respond to the request - if err := req.Respond(req.Data); err != nil { + if err := req.Respond(req.Data()); err != nil { log.Fatal(err) } } @@ -254,13 +253,12 @@ func ExampleRequest_RespondJSON() { } func ExampleRequest_Error() { - handler := func(req *Request) error { + handler := func(req *Request) { // respond with an error // Error sets Nats-Service-Error and Nats-Service-Error-Code headers in the response if err := req.Error("400", "bad request", []byte(`{"error": "value should be a number"}`)); err != nil { - return err + log.Fatal(err) } - return nil } fmt.Printf("%T", handler) diff --git a/micro/request.go b/micro/request.go index f7cc4a580..321f0d88e 100644 --- a/micro/request.go +++ b/micro/request.go @@ -22,17 +22,19 @@ import ( ) type ( + // Request represents service request available in the service handler. + // It exposes methods to respond to the request, as well as + // getting the request data and headers. Request struct { - *nats.Msg + msg *nats.Msg + respondError error } // RequestHandler is a function used as a Handler for a service. - // It takes a request, which contains the data (payload and headers) of the request, - // as well as exposes methods to respond to the request. - // - // RequestHandler returns an error - if returned, the request will be accounted form in stats (in num_requests), - // and last_error will be set with the value. - RequestHandler func(*Request) error + RequestHandler func(*Request) + + // Headers is a wrapper around [*nats.Header] + Headers nats.Header ) var ( @@ -41,27 +43,37 @@ var ( ErrArgRequired = errors.New("argument required") ) -func (r *Request) Respond(response []byte) error { - if err := r.Msg.Respond(response); err != nil { - return fmt.Errorf("%w: %s", ErrRespond, err) +// RespondOpt is a +type RespondOpt func(*nats.Msg) + +func (r *Request) Respond(response []byte, opts ...RespondOpt) error { + respMsg := &nats.Msg{ + Data: response, + } + for _, opt := range opts { + opt(respMsg) + } + + if err := r.msg.RespondMsg(respMsg); err != nil { + r.respondError = fmt.Errorf("%w: %s", ErrRespond, err) + return r.respondError } return nil } -func (r *Request) RespondJSON(response interface{}) error { +func (r *Request) RespondJSON(response interface{}, opts ...RespondOpt) error { resp, err := json.Marshal(response) if err != nil { return ErrMarshalResponse } - - return r.Respond(resp) + return r.Respond(resp, opts...) } // Error prepares and publishes error response from a handler. // A response error should be set containing an error code and description. // Optionally, data can be set as response payload. -func (r *Request) Error(code, description string, data []byte) error { +func (r *Request) Error(code, description string, data []byte, opts ...RespondOpt) error { if code == "" { return fmt.Errorf("%w: error code", ErrArgRequired) } @@ -74,6 +86,47 @@ func (r *Request) Error(code, description string, data []byte) error { ErrorCodeHeader: []string{code}, }, } + for _, opt := range opts { + opt(response) + } + response.Data = data - return r.RespondMsg(response) + if err := r.msg.RespondMsg(response); err != nil { + r.respondError = err + return err + } + return nil +} + +func WithHeaders(headers Headers) RespondOpt { + return func(m *nats.Msg) { + if m.Header == nil { + m.Header = nats.Header(headers) + return + } + + for k, v := range headers { + m.Header[k] = v + } + } +} + +func (r *Request) Data() []byte { + return r.msg.Data +} + +func (r *Request) Headers() Headers { + return Headers(r.msg.Header) +} + +// Get gets the first value associated with the given key. +// It is case-sensitive. +func (h Headers) Get(key string) string { + return nats.Header(h).Get(key) +} + +// Values returns all values associated with the given key. +// It is case-sensitive. +func (h Headers) Values(key string) []string { + return nats.Header(h).Values(key) } diff --git a/micro/service.go b/micro/service.go index 6943f170b..5e0c2ad71 100644 --- a/micro/service.go +++ b/micro/service.go @@ -18,7 +18,6 @@ import ( "errors" "fmt" "regexp" - "strings" "sync" "time" @@ -252,7 +251,7 @@ func AddService(nc *nats.Conn, config Config) (Service, error) { var err error svc.reqSub, err = nc.QueueSubscribe(config.Endpoint.Subject, QG, func(m *nats.Msg) { - svc.reqHandler(&Request{Msg: m}) + svc.reqHandler(&Request{msg: m}) }) if err != nil { svc.asyncDispatcher.close() @@ -261,48 +260,44 @@ func AddService(nc *nats.Conn, config Config) (Service, error) { ping := Ping(svcIdentity) - infoHandler := func(req *Request) error { + infoHandler := func(req *Request) { response, _ := json.Marshal(svc.Info()) if err := req.Respond(response); err != nil { if err := req.Error("500", fmt.Sprintf("Error handling INFO request: %s", err), nil); err != nil && config.ErrorHandler != nil { - svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject, err.Error()}) }) + svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.msg.Subject, err.Error()}) }) } } - return nil } - pingHandler := func(req *Request) error { + pingHandler := func(req *Request) { response, _ := json.Marshal(ping) if err := req.Respond(response); err != nil { if err := req.Error("500", fmt.Sprintf("Error handling PING request: %s", err), nil); err != nil && config.ErrorHandler != nil { - svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject, err.Error()}) }) + svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.msg.Subject, err.Error()}) }) } } - return nil } - statsHandler := func(req *Request) error { + statsHandler := func(req *Request) { response, _ := json.Marshal(svc.Stats()) if err := req.Respond(response); err != nil { if err := req.Error("500", fmt.Sprintf("Error handling STATS request: %s", err), nil); err != nil && config.ErrorHandler != nil { - svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject, err.Error()}) }) + svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.msg.Subject, err.Error()}) }) } } - return nil } schema := SchemaResp{ ServiceIdentity: svcIdentity, Schema: config.Schema, } - schemaHandler := func(req *Request) error { + schemaHandler := func(req *Request) { response, _ := json.Marshal(schema) if err := req.Respond(response); err != nil { if err := req.Error("500", fmt.Sprintf("Error handling SCHEMA request: %s", err), nil); err != nil && config.ErrorHandler != nil { - svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject, err.Error()}) }) + svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.msg.Subject, err.Error()}) }) } } - return nil } if err := svc.verbHandlers(nc, InfoVerb, infoHandler); err != nil { @@ -368,6 +363,8 @@ func (e *Endpoint) valid() error { } func (svc *service) setupAsyncCallbacks() { + svc.m.Lock() + defer svc.m.Unlock() svc.natsHandlers.closed = svc.conn.ClosedHandler() if svc.natsHandlers.closed != nil { svc.conn.SetClosedHandler(func(c *nats.Conn) { @@ -392,6 +389,10 @@ func (svc *service) setupAsyncCallbacks() { Description: err.Error(), }) } + svc.m.Lock() + svc.stats.NumErrors++ + svc.stats.LastError = err.Error() + svc.m.Unlock() svc.Stop() svc.natsHandlers.asyncErr(c, s, err) }) @@ -406,6 +407,10 @@ func (svc *service) setupAsyncCallbacks() { Description: err.Error(), }) } + svc.m.Lock() + svc.stats.NumErrors++ + svc.stats.LastError = err.Error() + svc.m.Unlock() svc.Stop() }) } @@ -448,7 +453,7 @@ func (s *service) addInternalHandler(nc *nats.Conn, verb Verb, kind, id, name st } s.verbSubs[name], err = nc.Subscribe(subj, func(msg *nats.Msg) { - handler(&Request{Msg: msg}) + handler(&Request{msg: msg}) }) if err != nil { s.Stop() @@ -457,19 +462,19 @@ func (s *service) addInternalHandler(nc *nats.Conn, verb Verb, kind, id, name st return nil } -// reqHandler itself +// reqHandller invokes the service request handler and modifies service stats func (s *service) reqHandler(req *Request) { start := time.Now() - err := s.Endpoint.Handler(req) + s.Endpoint.Handler(req) s.m.Lock() s.stats.NumRequests++ s.stats.ProcessingTime += time.Since(start) avgProcessingTime := s.stats.ProcessingTime.Nanoseconds() / int64(s.stats.NumRequests) s.stats.AverageProcessingTime = time.Duration(avgProcessingTime) - if err != nil { + if req.respondError != nil { s.stats.NumErrors++ - s.stats.LastError = err.Error() + s.stats.LastError = req.respondError.Error() } s.m.Unlock() } @@ -577,7 +582,6 @@ func ControlSubject(verb Verb, name, id string) (string, error) { if name == "" && id != "" { return "", ErrServiceNameRequired } - name = strings.ToUpper(name) if name == "" && id == "" { return fmt.Sprintf("%s.%s", APIPrefix, verbStr), nil } diff --git a/micro/service_test.go b/micro/service_test.go index 4cf4ccd0c..89e1d1882 100644 --- a/micro/service_test.go +++ b/micro/service_test.go @@ -39,14 +39,12 @@ func TestServiceBasics(t *testing.T) { defer nc.Close() // Stub service. - doAdd := func(req *Request) error { + doAdd := func(req *Request) { if rand.Intn(10) == 0 { - if err := req.Error("400", "client error!", nil); err != nil { + if err := req.Error("500", "Unexpected error!", nil); err != nil { t.Fatalf("Unexpected error when sending error response: %v", err) } - - // for client-side errors, return nil to avoid tracking the errors in stats - return nil + return } // Happy Path. // Random delay between 5-10ms @@ -55,9 +53,8 @@ func TestServiceBasics(t *testing.T) { if err := req.Error("500", "Unexpected error!", nil); err != nil { t.Fatalf("Unexpected error when sending error response: %v", err) } - return err + return } - return nil } var svcs []Service @@ -194,7 +191,7 @@ func TestServiceBasics(t *testing.T) { } func TestAddService(t *testing.T) { - testHandler := func(*Request) error { return nil } + testHandler := func(*Request) {} errNats := make(chan struct{}) errService := make(chan struct{}) closedNats := make(chan struct{}) @@ -533,7 +530,7 @@ func TestMonitoringHandlers(t *testing.T) { Version: "0.1.0", Endpoint: Endpoint{ Subject: "test.sub", - Handler: func(*Request) error { return nil }, + Handler: func(*Request) {}, }, Schema: Schema{ Request: "some_schema", @@ -570,7 +567,7 @@ func TestMonitoringHandlers(t *testing.T) { }, { name: "PING name", - subject: "$SRV.PING.TEST_SERVICE", + subject: "$SRV.PING.test_service", expectedResponse: Ping{ Name: "test_service", Version: "0.1.0", @@ -579,7 +576,7 @@ func TestMonitoringHandlers(t *testing.T) { }, { name: "PING ID", - subject: fmt.Sprintf("$SRV.PING.TEST_SERVICE.%s", info.ID), + subject: fmt.Sprintf("$SRV.PING.test_service.%s", info.ID), expectedResponse: Ping{ Name: "test_service", Version: "0.1.0", @@ -600,7 +597,7 @@ func TestMonitoringHandlers(t *testing.T) { }, { name: "INFO name", - subject: "$SRV.INFO.TEST_SERVICE", + subject: "$SRV.INFO.test_service", expectedResponse: Info{ ServiceIdentity: ServiceIdentity{ Name: "test_service", @@ -612,7 +609,7 @@ func TestMonitoringHandlers(t *testing.T) { }, { name: "INFO ID", - subject: fmt.Sprintf("$SRV.INFO.TEST_SERVICE.%s", info.ID), + subject: fmt.Sprintf("$SRV.INFO.test_service.%s", info.ID), expectedResponse: Info{ ServiceIdentity: ServiceIdentity{ Name: "test_service", @@ -638,7 +635,7 @@ func TestMonitoringHandlers(t *testing.T) { }, { name: "SCHEMA name", - subject: "$SRV.SCHEMA.TEST_SERVICE", + subject: "$SRV.SCHEMA.test_service", expectedResponse: SchemaResp{ ServiceIdentity: ServiceIdentity{ Name: "test_service", @@ -652,7 +649,7 @@ func TestMonitoringHandlers(t *testing.T) { }, { name: "SCHEMA ID", - subject: fmt.Sprintf("$SRV.SCHEMA.TEST_SERVICE.%s", info.ID), + subject: fmt.Sprintf("$SRV.SCHEMA.test_service.%s", info.ID), expectedResponse: SchemaResp{ ServiceIdentity: ServiceIdentity{ Name: "test_service", @@ -727,19 +724,8 @@ func TestMonitoringHandlers(t *testing.T) { } func TestServiceStats(t *testing.T) { - handler := func(r *Request) error { - if bytes.Equal(r.Data, []byte("err")) { - r.Error("500", "oops", nil) - return fmt.Errorf("oops") - } - - // client errors (validation etc.) should not be accounted for in stats - if bytes.Equal(r.Data, []byte("client_err")) { - r.Error("400", "bad request", nil) - return nil - } + handler := func(r *Request) { r.Respond([]byte("ok")) - return nil } tests := []struct { name string @@ -835,15 +821,16 @@ func TestServiceStats(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } } - if _, err := nc.Request(srv.Info().Subject, []byte("client_err"), time.Second); err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if _, err := nc.Request(srv.Info().Subject, []byte("err"), time.Second); err != nil { + + // Malformed request, missing reply subjtct + // This should be reflected in errors + if err := nc.Publish(srv.Info().Subject, []byte("err")); err != nil { t.Fatalf("Unexpected error: %v", err) } + time.Sleep(10 * time.Millisecond) info := srv.Info() - resp, err := nc.Request(fmt.Sprintf("$SRV.STATS.TEST_SERVICE.%s", info.ID), nil, 1*time.Second) + resp, err := nc.Request(fmt.Sprintf("$SRV.STATS.test_service.%s", info.ID), nil, 1*time.Second) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -859,11 +846,11 @@ func TestServiceStats(t *testing.T) { if stats.ID != info.ID { t.Errorf("Unexpected service name; want: %s; got: %s", info.ID, stats.ID) } - if stats.NumRequests != 12 { - t.Errorf("Unexpected num_requests; want: 12; got: %d", stats.NumRequests) + if stats.NumRequests != 11 { + t.Errorf("Unexpected num_requests; want: 11; got: %d", stats.NumRequests) } if stats.NumErrors != 1 { - t.Errorf("Unexpected num_requests; want: 1; got: %d", stats.NumErrors) + t.Errorf("Unexpected num_errors; want: 1; got: %d", stats.NumErrors) } if test.expectedStats != nil { var data map[string]interface{} @@ -887,6 +874,7 @@ func TestRequestRespond(t *testing.T) { tests := []struct { name string respondData interface{} + respondHeaders Headers errDescription string errCode string errData []byte @@ -900,6 +888,12 @@ func TestRequestRespond(t *testing.T) { respondData: []byte("OK"), expectedResponse: []byte("OK"), }, + { + name: "byte response, with headers", + respondHeaders: Headers{"key": []string{"value"}}, + respondData: []byte("OK"), + expectedResponse: []byte("OK"), + }, { name: "byte response, connection closed", respondData: []byte("OK"), @@ -923,6 +917,15 @@ func TestRequestRespond(t *testing.T) { expectedMessage: "oops", expectedCode: "500", }, + { + name: "generic error, with headers", + respondHeaders: Headers{"key": []string{"value"}}, + errDescription: "oops", + errCode: "500", + errData: []byte("error!"), + expectedMessage: "oops", + expectedCode: "500", + }, { name: "error without response payload", errDescription: "oops", @@ -958,57 +961,58 @@ func TestRequestRespond(t *testing.T) { errCode := test.errCode errDesc := test.errDescription errData := test.errData - // Stub service. - handler := func(req *Request) error { + handler := func(req *Request) { if errors.Is(test.withRespondError, ErrRespond) { nc.Close() } + if val := req.Headers().Get("key"); val != "value" { + t.Fatalf("Expected headers in the request") + } if errCode == "" && errDesc == "" { if resp, ok := respData.([]byte); ok { - err := req.Respond(resp) + err := req.Respond(resp, WithHeaders(test.respondHeaders)) if respError != nil { if !errors.Is(err, respError) { t.Fatalf("Expected error: %v; got: %v", respError, err) } - return nil + return } if err != nil { t.Fatalf("Unexpected error when sending response: %v", err) } } else { - err := req.RespondJSON(respData) + err := req.RespondJSON(respData, WithHeaders(test.respondHeaders)) if respError != nil { if !errors.Is(err, respError) { t.Fatalf("Expected error: %v; got: %v", respError, err) } - return nil + return } if err != nil { t.Fatalf("Unexpected error when sending response: %v", err) } } - return nil + return } - err := req.Error(errCode, errDesc, errData) + err := req.Error(errCode, errDesc, errData, WithHeaders(test.respondHeaders)) if respError != nil { if !errors.Is(err, respError) { t.Fatalf("Expected error: %v; got: %v", respError, err) } - return nil + return } if err != nil { t.Fatalf("Unexpected error when sending response: %v", err) } - return nil } svc, err := AddService(nc, Config{ Name: "CoolService", Version: "0.1.0", - Description: "Erroring service", + Description: "test service", Endpoint: Endpoint{ - Subject: "svc.fail", + Subject: "svc.test", Handler: handler, }, }) @@ -1017,7 +1021,11 @@ func TestRequestRespond(t *testing.T) { } defer svc.Stop() - resp, err := nc.Request("svc.fail", nil, 50*time.Millisecond) + resp, err := nc.RequestMsg(&nats.Msg{ + Subject: svc.Info().Subject, + Data: nil, + Header: nats.Header{"key": []string{"value"}}, + }, 50*time.Millisecond) if test.withRespondError != nil { return } @@ -1030,12 +1038,15 @@ func TestRequestRespond(t *testing.T) { if description != test.expectedMessage { t.Fatalf("Invalid response message; want: %q; got: %q", test.expectedMessage, description) } - code := resp.Header.Get("Nats-Service-Error-Code") - if code != test.expectedCode { - t.Fatalf("Invalid response code; want: %q; got: %q", test.expectedCode, code) + expectedHeaders := Headers{ + "Nats-Service-Error-Code": []string{resp.Header.Get("Nats-Service-Error-Code")}, + "Nats-Service-Error": []string{resp.Header.Get("Nats-Service-Error")}, } - if !bytes.Equal(resp.Data, test.errData) { - t.Fatalf("Invalid response payload; want: %q; got: %q", string(test.errData), resp.Data) + for k, v := range test.respondHeaders { + expectedHeaders[k] = v + } + if !reflect.DeepEqual(expectedHeaders, Headers(resp.Header)) { + t.Fatalf("Invalid response headers; want: %v; got: %v", test.respondHeaders, resp.Header) } return } @@ -1047,6 +1058,10 @@ func TestRequestRespond(t *testing.T) { if !bytes.Equal(bytes.TrimSpace(resp.Data), bytes.TrimSpace(test.expectedResponse)) { t.Fatalf("Invalid response; want: %s; got: %s", string(test.expectedResponse), string(resp.Data)) } + + if !reflect.DeepEqual(test.respondHeaders, Headers(resp.Header)) { + t.Fatalf("Invalid response headers; want: %v; got: %v", test.respondHeaders, resp.Header) + } }) } } @@ -1079,14 +1094,14 @@ func TestControlSubject(t *testing.T) { name: "PING name", verb: PingVerb, srvName: "test", - expectedSubject: "$SRV.PING.TEST", + expectedSubject: "$SRV.PING.test", }, { name: "PING id", verb: PingVerb, srvName: "test", id: "123", - expectedSubject: "$SRV.PING.TEST.123", + expectedSubject: "$SRV.PING.test.123", }, { name: "invalid verb",