Skip to content

Commit

Permalink
Merge 433a844 into 9f57f78
Browse files Browse the repository at this point in the history
  • Loading branch information
levb committed Mar 27, 2023
2 parents 9f57f78 + 433a844 commit 7a9fc6c
Showing 1 changed file with 38 additions and 62 deletions.
100 changes: 38 additions & 62 deletions micro/service.go
Expand Up @@ -325,7 +325,7 @@ func (s Verb) String() string {
// A service name, version and Endpoint configuration are required to add a service.
// AddService returns a [Service] interface, allowing service management.
// Each service is assigned a unique ID.
func AddService(nc *nats.Conn, config Config) (Service, error) {
func AddService(nc *nats.Conn, config Config) (_ Service, err error) {
if err := config.valid(); err != nil {
return nil, err
}
Expand All @@ -342,9 +342,16 @@ func AddService(nc *nats.Conn, config Config) (Service, error) {
endpoints: make([]*Endpoint, 0),
}

svc.setupAsyncCallbacks()

go svc.asyncDispatcher.asyncCBDispatcher()
// Add connection event (closed, error) wrapper handlers. If the service has
// custom callbacks, the events are queued and invoked by the same
// goroutine, starting now.
go svc.asyncDispatcher.run()
defer func() {
if err != nil {
svc.asyncDispatcher.close()
}
}()
svc.wrapConnectionEventCallbacks()

if config.Endpoint != nil {
opts := []EndpointOpt{WithEndpointSubject(config.Endpoint.Subject)}
Expand All @@ -355,70 +362,39 @@ func AddService(nc *nats.Conn, config Config) (Service, error) {
opts = append(opts, WithEndpointMetadata(config.Endpoint.Metadata))
}
if err := svc.AddEndpoint("default", config.Endpoint.Handler, opts...); err != nil {
svc.asyncDispatcher.close()
return nil, err
}
}

// Setup internal subscriptions.
infoHandler := func(req Request) {
response, _ := json.Marshal(svc.Info())
if err := req.Respond(response); err != nil {
if err := req.Error("500", fmt.Sprintf("Error handling INFO request: %s", err), nil); err != nil && config.ErrorHandler != nil {
svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject(), err.Error()}) })
}
}
}

ping := Ping{
pingResponse := Ping{
ServiceIdentity: svc.serviceIdentity(),
Type: PingResponseType,
}
pingHandler := func(req Request) {
response, _ := json.Marshal(ping)
if err := req.Respond(response); err != nil {
if err := req.Error("500", fmt.Sprintf("Error handling PING request: %s", err), nil); err != nil && config.ErrorHandler != nil {
svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject(), err.Error()}) })
}
}
}

statsHandler := func(req Request) {
response, _ := json.Marshal(svc.Stats())
if err := req.Respond(response); err != nil {
if err := req.Error("500", fmt.Sprintf("Error handling STATS request: %s", err), nil); err != nil && config.ErrorHandler != nil {
svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject(), err.Error()}) })
handleVerb := func(verb Verb, valuef func() any) func(req Request) {
return func(req Request) {
response, _ := json.Marshal(valuef())
if err := req.Respond(response); err != nil {
if err := req.Error("500", fmt.Sprintf("Error handling %s request: %s", verb, err), nil); err != nil && config.ErrorHandler != nil {
svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject(), err.Error()}) })
}
}
}
}

schemaHandler := func(req Request) {
response, _ := json.Marshal(svc.schema())
if err := req.Respond(response); err != nil {
if err := req.Error("500", fmt.Sprintf("Error handling SCHEMA request: %s", err), nil); err != nil && config.ErrorHandler != nil {
svc.asyncDispatcher.push(func() { config.ErrorHandler(svc, &NATSError{req.Subject(), err.Error()}) })
}
for verb, source := range map[Verb]func() any{
InfoVerb: func() any { return svc.Info() },
PingVerb: func() any { return pingResponse },
StatsVerb: func() any { return svc.Stats() },
SchemaVerb: func() any { return svc.schema() },
} {
handler := handleVerb(verb, source)
if err := svc.addVerbHandlers(nc, verb, handler); err != nil {
return nil, err
}
}

if err := svc.verbHandlers(nc, InfoVerb, infoHandler); err != nil {
svc.asyncDispatcher.close()
return nil, err
}
if err := svc.verbHandlers(nc, PingVerb, pingHandler); err != nil {
svc.asyncDispatcher.close()
return nil, err
}
if err := svc.verbHandlers(nc, StatsVerb, statsHandler); err != nil {
svc.asyncDispatcher.close()
return nil, err
}

if err := svc.verbHandlers(nc, SchemaVerb, schemaHandler); err != nil {
svc.asyncDispatcher.close()
return nil, err
}

svc.started = time.Now().UTC()
return svc, nil
}
Expand Down Expand Up @@ -481,7 +457,7 @@ func (s *service) AddGroup(name string) Group {
}

// dispatch is responsible for calling any async callbacks
func (ac *asyncCallbacksHandler) asyncCBDispatcher() {
func (ac *asyncCallbacksHandler) run() {
for {
f := <-ac.cbQueue
if f == nil {
Expand Down Expand Up @@ -513,7 +489,7 @@ func (c *Config) valid() error {
return nil
}

func (s *service) setupAsyncCallbacks() {
func (s *service) wrapConnectionEventCallbacks() {
s.m.Lock()
defer s.m.Unlock()
s.natsHandlers.closed = s.nc.ClosedHandler()
Expand Down Expand Up @@ -574,6 +550,11 @@ func (s *service) setupAsyncCallbacks() {
}
}

func unwrapConnectionEventCallbacks(nc *nats.Conn, handlers handlers) {
nc.SetClosedHandler(handlers.closed)
nc.SetErrorHandler(handlers.asyncErr)
}

func (s *service) matchSubscriptionSubject(subj string) (*Endpoint, bool) {
s.m.Lock()
defer s.m.Unlock()
Expand Down Expand Up @@ -607,11 +588,11 @@ func matchEndpointSubject(endpointSubject, literalSubject string) bool {
return true
}

// verbHandlers generates control handlers for a specific verb.
// addVerbHandlers generates control handlers for a specific verb.
// Each request generates 3 subscriptions, one for the general verb
// affecting all services written with the framework, one that handles
// all services of a particular kind, and finally a specific service instance.
func (svc *service) verbHandlers(nc *nats.Conn, verb Verb, handler HandlerFunc) error {
func (svc *service) addVerbHandlers(nc *nats.Conn, verb Verb, handler HandlerFunc) error {
name := fmt.Sprintf("%s-all", verb.String())
if err := svc.addInternalHandler(nc, verb, "", "", name, handler); err != nil {
return err
Expand Down Expand Up @@ -680,7 +661,7 @@ func (s *service) Stop() error {
for _, key := range keys {
delete(s.verbSubs, key)
}
restoreAsyncHandlers(s.nc, s.natsHandlers)
unwrapConnectionEventCallbacks(s.nc, s.natsHandlers)
s.stopped = true
if s.DoneHandler != nil {
s.asyncDispatcher.push(func() { s.DoneHandler(s) })
Expand All @@ -689,11 +670,6 @@ func (s *service) Stop() error {
return nil
}

func restoreAsyncHandlers(nc *nats.Conn, handlers handlers) {
nc.SetClosedHandler(handlers.closed)
nc.SetErrorHandler(handlers.asyncErr)
}

func (s *service) serviceIdentity() ServiceIdentity {
return ServiceIdentity{
Name: s.Config.Name,
Expand Down

0 comments on commit 7a9fc6c

Please sign in to comment.