diff --git a/multistream.go b/multistream.go index 4ac48fe..23d2c9f 100644 --- a/multistream.go +++ b/multistream.go @@ -23,25 +23,16 @@ const ProtocolID = "/multistream/1.0.0" // handle a protocol/stream. type HandlerFunc func(protocol string, rwc io.ReadWriteCloser) error -// Handler is a wrapper to HandlerFunc which attaches a name (protocol) and a -// match function which can optionally be used to select a handler by other -// means than the name. -type Handler struct { - MatchFunc func(string) bool - Handle HandlerFunc - AddName string -} - // MultistreamMuxer is a muxer for multistream. Depending on the stream // protocol tag it will select the right handler and hand the stream off to it. type MultistreamMuxer struct { handlerlock sync.Mutex - handlers []Handler + handlers map[string]HandlerFunc } // NewMultistreamMuxer creates a muxer. func NewMultistreamMuxer() *MultistreamMuxer { - return new(MultistreamMuxer) + return &MultistreamMuxer{handlers: make(map[string]HandlerFunc)} } func writeUvarint(w io.Writer, i uint64) error { @@ -107,54 +98,26 @@ func Ls(rw io.ReadWriter) ([]string, error) { return out, nil } -func fulltextMatch(s string) func(string) bool { - return func(a string) bool { - return a == s - } -} - // AddHandler attaches a new protocol handler to the muxer. func (msm *MultistreamMuxer) AddHandler(protocol string, handler HandlerFunc) { - msm.AddHandlerWithFunc(protocol, fulltextMatch(protocol), handler) -} - -// AddHandlerWithFunc attaches a new protocol handler to the muxer with a match. -// If the match function returns true for a given protocol tag, the protocol -// will be selected even if the handler name and protocol tags are different. -func (msm *MultistreamMuxer) AddHandlerWithFunc(protocol string, match func(string) bool, handler HandlerFunc) { msm.handlerlock.Lock() - msm.removeHandler(protocol) - msm.handlers = append(msm.handlers, Handler{ - MatchFunc: fulltextMatch(protocol), - Handle: handler, - AddName: protocol, - }) + msm.handlers[protocol] = handler msm.handlerlock.Unlock() } -// RemoveHandler removes the handler with the given name from the muxer. +// RemoveHandler removes the protocol handler from the muxer. func (msm *MultistreamMuxer) RemoveHandler(protocol string) { msm.handlerlock.Lock() - defer msm.handlerlock.Unlock() - - msm.removeHandler(protocol) -} - -func (msm *MultistreamMuxer) removeHandler(protocol string) { - for i, h := range msm.handlers { - if h.AddName == protocol { - msm.handlers = append(msm.handlers[:i], msm.handlers[i+1:]...) - return - } - } + delete(msm.handlers, protocol) + msm.handlerlock.Unlock() } // Protocols returns the list of handler-names added to this this muxer. func (msm *MultistreamMuxer) Protocols() []string { var out []string msm.handlerlock.Lock() - for _, h := range msm.handlers { - out = append(out, h.AddName) + for k := range msm.handlers { + out = append(out, k) } msm.handlerlock.Unlock() return out @@ -164,17 +127,11 @@ func (msm *MultistreamMuxer) Protocols() []string { // fails because of a ProtocolID mismatch. var ErrIncorrectVersion = errors.New("client connected with incorrect version") -func (msm *MultistreamMuxer) findHandler(proto string) *Handler { +func (msm *MultistreamMuxer) findHandler(proto string) (HandlerFunc, bool) { msm.handlerlock.Lock() defer msm.handlerlock.Unlock() - - for _, h := range msm.handlers { - if h.MatchFunc(proto) { - return &h - } - } - - return nil + f, ok := msm.handlers[proto] + return f, ok } // NegotiateLazy performs protocol selection and returns @@ -243,8 +200,8 @@ loop: return nil, "", nil, err } default: - h := msm.findHandler(tok) - if h == nil { + h, ok := msm.findHandler(tok) + if !ok { select { case pval <- "na": case err := <-writeErr: @@ -262,7 +219,7 @@ loop: } // hand off processing to the sub-protocol handler - return lzc, tok, h.Handle, nil + return lzc, tok, h, nil } } } @@ -301,8 +258,8 @@ loop: return "", nil, err } default: - h := msm.findHandler(tok) - if h == nil { + h, ok := msm.findHandler(tok) + if !ok { err := delimWriteBuffered(rwc, []byte("na")) if err != nil { return "", nil, err @@ -316,7 +273,7 @@ loop: } // hand off processing to the sub-protocol handler - return tok, h.Handle, nil + return tok, h, nil } } @@ -332,8 +289,8 @@ func (msm *MultistreamMuxer) Ls(w io.Writer) error { return err } - for _, h := range msm.handlers { - err := delimWrite(buf, []byte(h.AddName)) + for k := range msm.handlers { + err := delimWrite(buf, []byte(k)) if err != nil { msm.handlerlock.Unlock() return err