diff --git a/micro/service.go b/micro/service.go index ed151b893..6fca3862a 100644 --- a/micro/service.go +++ b/micro/service.go @@ -325,7 +325,7 @@ func (s Verb) String() string { // A service name, version and Endpoint configuration are required to add a service. // AddService returns a [Service] interface, allowing service management. // Each service is assigned a unique ID. -func AddService(nc *nats.Conn, config Config) (Service, error) { +func AddService(nc *nats.Conn, config Config) (_ Service, err error) { if err := config.valid(); err != nil { return nil, err } @@ -342,9 +342,16 @@ func AddService(nc *nats.Conn, config Config) (Service, error) { endpoints: make([]*Endpoint, 0), } - svc.setupAsyncCallbacks() - - go svc.asyncDispatcher.asyncCBDispatcher() + // Add connection event (closed, error) wrapper handlers. If the service has + // custom callbacks, the events are queued and invoked by the same + // goroutine, starting now. + go svc.asyncDispatcher.run() + defer func() { + if err != nil { + svc.asyncDispatcher.close() + } + }() + svc.wrapConnectionEventCallbacks() if config.Endpoint != nil { opts := []EndpointOpt{WithEndpointSubject(config.Endpoint.Subject)} @@ -355,70 +362,39 @@ func AddService(nc *nats.Conn, config Config) (Service, error) { opts = append(opts, WithEndpointMetadata(config.Endpoint.Metadata)) } if err := svc.AddEndpoint("default", config.Endpoint.Handler, opts...); err != nil { - svc.asyncDispatcher.close() return nil, err } } // Setup internal subscriptions. - infoHandler := func(req Request) { - response, _ := json.Marshal(svc.Info()) - if err := req.Respond(response); err != nil { - if err := req.Error("500", fmt.Sprintf("Error handling INFO request: %s", err), nil); err != nil && config.ErrorHandler != nil { - svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject(), err.Error()}) }) - } - } - } - - ping := Ping{ + pingResponse := Ping{ ServiceIdentity: svc.serviceIdentity(), Type: PingResponseType, } - pingHandler := func(req Request) { - response, _ := json.Marshal(ping) - if err := req.Respond(response); err != nil { - if err := req.Error("500", fmt.Sprintf("Error handling PING request: %s", err), nil); err != nil && config.ErrorHandler != nil { - svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject(), err.Error()}) }) - } - } - } - statsHandler := func(req Request) { - response, _ := json.Marshal(svc.Stats()) - if err := req.Respond(response); err != nil { - if err := req.Error("500", fmt.Sprintf("Error handling STATS request: %s", err), nil); err != nil && config.ErrorHandler != nil { - svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject(), err.Error()}) }) + handleVerb := func(verb Verb, valuef func() any) func(req Request) { + return func(req Request) { + response, _ := json.Marshal(valuef()) + if err := req.Respond(response); err != nil { + if err := req.Error("500", fmt.Sprintf("Error handling %s request: %s", verb, err), nil); err != nil && config.ErrorHandler != nil { + svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject(), err.Error()}) }) + } } } } - schemaHandler := func(req Request) { - response, _ := json.Marshal(svc.schema()) - if err := req.Respond(response); err != nil { - if err := req.Error("500", fmt.Sprintf("Error handling SCHEMA request: %s", err), nil); err != nil && config.ErrorHandler != nil { - svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject(), err.Error()}) }) - } + for verb, source := range map[Verb]func() any{ + InfoVerb: func() any { return svc.Info() }, + PingVerb: func() any { return pingResponse }, + StatsVerb: func() any { return svc.Stats() }, + SchemaVerb: func() any { return svc.schema() }, + } { + handler := handleVerb(verb, source) + if err := svc.addVerbHandlers(nc, verb, handler); err != nil { + return nil, err } } - if err := svc.verbHandlers(nc, InfoVerb, infoHandler); err != nil { - svc.asyncDispatcher.close() - return nil, err - } - if err := svc.verbHandlers(nc, PingVerb, pingHandler); err != nil { - svc.asyncDispatcher.close() - return nil, err - } - if err := svc.verbHandlers(nc, StatsVerb, statsHandler); err != nil { - svc.asyncDispatcher.close() - return nil, err - } - - if err := svc.verbHandlers(nc, SchemaVerb, schemaHandler); err != nil { - svc.asyncDispatcher.close() - return nil, err - } - svc.started = time.Now().UTC() return svc, nil } @@ -481,7 +457,7 @@ func (s *service) AddGroup(name string) Group { } // dispatch is responsible for calling any async callbacks -func (ac *asyncCallbacksHandler) asyncCBDispatcher() { +func (ac *asyncCallbacksHandler) run() { for { f := <-ac.cbQueue if f == nil { @@ -513,7 +489,7 @@ func (c *Config) valid() error { return nil } -func (s *service) setupAsyncCallbacks() { +func (s *service) wrapConnectionEventCallbacks() { s.m.Lock() defer s.m.Unlock() s.natsHandlers.closed = s.nc.ClosedHandler() @@ -574,6 +550,11 @@ func (s *service) setupAsyncCallbacks() { } } +func unwrapConnectionEventCallbacks(nc *nats.Conn, handlers handlers) { + nc.SetClosedHandler(handlers.closed) + nc.SetErrorHandler(handlers.asyncErr) +} + func (s *service) matchSubscriptionSubject(subj string) (*Endpoint, bool) { s.m.Lock() defer s.m.Unlock() @@ -607,11 +588,11 @@ func matchEndpointSubject(endpointSubject, literalSubject string) bool { return true } -// verbHandlers generates control handlers for a specific verb. +// addVerbHandlers generates control handlers for a specific verb. // Each request generates 3 subscriptions, one for the general verb // affecting all services written with the framework, one that handles // all services of a particular kind, and finally a specific service instance. -func (svc *service) verbHandlers(nc *nats.Conn, verb Verb, handler HandlerFunc) error { +func (svc *service) addVerbHandlers(nc *nats.Conn, verb Verb, handler HandlerFunc) error { name := fmt.Sprintf("%s-all", verb.String()) if err := svc.addInternalHandler(nc, verb, "", "", name, handler); err != nil { return err @@ -680,7 +661,7 @@ func (s *service) Stop() error { for _, key := range keys { delete(s.verbSubs, key) } - restoreAsyncHandlers(s.nc, s.natsHandlers) + unwrapConnectionEventCallbacks(s.nc, s.natsHandlers) s.stopped = true if s.DoneHandler != nil { s.asyncDispatcher.push(func() { s.DoneHandler(s) }) @@ -689,11 +670,6 @@ func (s *service) Stop() error { return nil } -func restoreAsyncHandlers(nc *nats.Conn, handlers handlers) { - nc.SetClosedHandler(handlers.closed) - nc.SetErrorHandler(handlers.asyncErr) -} - func (s *service) serviceIdentity() ServiceIdentity { return ServiceIdentity{ Name: s.Config.Name,