diff --git a/context.go b/context.go index c921d6be7..769f88a01 100644 --- a/context.go +++ b/context.go @@ -11,8 +11,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build go1.7 - // A Go client for the NATS messaging system (https://nats.io). package nats @@ -21,9 +19,33 @@ import ( "reflect" ) +// RequestMsgWithContext takes a context, a subject and payload +// in bytes and request expecting a single response. +func (nc *Conn) RequestMsgWithContext(ctx context.Context, msg *Msg) (*Msg, error) { + var hdr []byte + var err error + + if len(msg.Header) > 0 { + if !nc.info.Headers { + return nil, ErrHeadersNotSupported + } + + hdr, err = msg.headerBytes() + if err != nil { + return nil, err + } + } + + return nc.requestWithContext(ctx, msg.Subject, hdr, msg.Data) +} + // RequestWithContext takes a context, a subject and payload // in bytes and request expecting a single response. func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte) (*Msg, error) { + return nc.requestWithContext(ctx, subj, nil, data) +} + +func (nc *Conn) requestWithContext(ctx context.Context, subj string, hdr, data []byte) (*Msg, error) { if ctx == nil { return nil, ErrInvalidContext } @@ -40,10 +62,10 @@ func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte // If user wants the old style. if nc.Opts.UseOldRequestStyle { nc.mu.Unlock() - return nc.oldRequestWithContext(ctx, subj, data) + return nc.oldRequestWithContext(ctx, subj, hdr, data) } - mch, token, err := nc.createNewRequestAndSend(subj, data) + mch, token, err := nc.createNewRequestAndSend(subj, hdr, data) if err != nil { return nil, err } @@ -67,7 +89,7 @@ func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte } // oldRequestWithContext utilizes inbox and subscription per request. -func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, data []byte) (*Msg, error) { +func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, hdr, data []byte) (*Msg, error) { inbox := NewInbox() ch := make(chan *Msg, RequestChanLen) @@ -78,7 +100,7 @@ func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, data []b s.AutoUnsubscribe(1) defer s.Unsubscribe() - err = nc.PublishRequest(subj, inbox, data) + err = nc.publish(subj, inbox, hdr, data) if err != nil { return nil, err } diff --git a/nats.go b/nats.go index d3fa2b991..4184ada94 100644 --- a/nats.go +++ b/nats.go @@ -510,6 +510,31 @@ type Msg struct { barrier *barrierInfo } +func (m *Msg) headerBytes() ([]byte, error) { + var hdr []byte + if len(m.Header) == 0 { + return hdr, nil + } + + var b bytes.Buffer + _, err := b.WriteString(hdrLine) + if err != nil { + return nil, ErrBadHeaderMsg + } + + err = m.Header.Write(&b) + if err != nil { + return nil, ErrBadHeaderMsg + } + + _, err = b.WriteString(crlf) + if err != nil { + return nil, ErrBadHeaderMsg + } + + return b.Bytes(), nil +} + type barrierInfo struct { refs int64 f func() @@ -2687,18 +2712,21 @@ func (nc *Conn) PublishMsg(m *Msg) error { if m == nil { return ErrInvalidMsg } + var hdr []byte + var err error + if len(m.Header) > 0 { if !nc.info.Headers { return ErrHeadersNotSupported } - // FIXME(dlc) - Optimize - var b bytes.Buffer - b.WriteString(hdrLine) - m.Header.Write(&b) - b.WriteString(crlf) - hdr = b.Bytes() + + hdr, err = m.headerBytes() + if err != nil { + return err + } } + return nc.publish(m.Subject, m.Reply, hdr, m.Data) } @@ -2874,7 +2902,7 @@ func (nc *Conn) respHandler(m *Msg) { } // Helper to setup and send new request style requests. Return the chan to receive the response. -func (nc *Conn) createNewRequestAndSend(subj string, data []byte) (chan *Msg, string, error) { +func (nc *Conn) createNewRequestAndSend(subj string, hdr, data []byte) (chan *Msg, string, error) { // Do setup for the new style if needed. if nc.respMap == nil { nc.initNewResp() @@ -2898,28 +2926,55 @@ func (nc *Conn) createNewRequestAndSend(subj string, data []byte) (chan *Msg, st } nc.mu.Unlock() - if err := nc.PublishRequest(subj, respInbox, data); err != nil { + if err := nc.publish(subj, respInbox, hdr, data); err != nil { return nil, token, err } return mch, token, nil } +// RequestMsg will send a request payload including optional headers and deliver +// the response message, or an error, including a timeout if no message was received properly. +func (nc *Conn) RequestMsg(msg *Msg, timeout time.Duration) (*Msg, error) { + var hdr []byte + var err error + + if len(msg.Header) > 0 { + if !nc.info.Headers { + return nil, ErrHeadersNotSupported + } + + hdr, err = msg.headerBytes() + if err != nil { + return nil, err + } + } + + return nc.request(msg.Subject, hdr, msg.Data, timeout) +} + // Request will send a request payload and deliver the response message, // or an error, including a timeout if no message was received properly. func (nc *Conn) Request(subj string, data []byte, timeout time.Duration) (*Msg, error) { + return nc.request(subj, nil, data, timeout) +} + +func (nc *Conn) request(subj string, hdr, data []byte, timeout time.Duration) (*Msg, error) { if nc == nil { return nil, ErrInvalidConnection } nc.mu.Lock() - // If user wants the old style. if nc.Opts.UseOldRequestStyle { nc.mu.Unlock() - return nc.oldRequest(subj, data, timeout) + return nc.oldRequest(subj, hdr, data, timeout) } - mch, token, err := nc.createNewRequestAndSend(subj, data) + return nc.newRequest(subj, hdr, data, timeout) +} + +func (nc *Conn) newRequest(subj string, hdr, data []byte, timeout time.Duration) (*Msg, error) { + mch, token, err := nc.createNewRequestAndSend(subj, hdr, data) if err != nil { return nil, err } @@ -2948,7 +3003,7 @@ func (nc *Conn) Request(subj string, data []byte, timeout time.Duration) (*Msg, // oldRequest will create an Inbox and perform a Request() call // with the Inbox reply and return the first reply received. // This is optimized for the case of multiple responses. -func (nc *Conn) oldRequest(subj string, data []byte, timeout time.Duration) (*Msg, error) { +func (nc *Conn) oldRequest(subj string, hdr, data []byte, timeout time.Duration) (*Msg, error) { inbox := NewInbox() ch := make(chan *Msg, RequestChanLen) @@ -2959,10 +3014,11 @@ func (nc *Conn) oldRequest(subj string, data []byte, timeout time.Duration) (*Ms s.AutoUnsubscribe(1) defer s.Unsubscribe() - err = nc.PublishRequest(subj, inbox, data) + err = nc.publish(subj, inbox, hdr, data) if err != nil { return nil, err } + return s.NextMsg(timeout) } @@ -3653,6 +3709,21 @@ func (m *Msg) Respond(data []byte) error { return nc.Publish(m.Reply, data) } +// RespondMsg allows a convenient way to respond to requests in service based subscriptions that might include headers +func (m *Msg) RespondMsg(msg *Msg) error { + if m == nil || m.Sub == nil { + return ErrMsgNotBound + } + if m.Reply == "" { + return ErrMsgNoReply + } + m.Sub.mu.Lock() + nc := m.Sub.conn + m.Sub.mu.Unlock() + // No need to check the connection here since the call to publish will do all the checking. + return nc.PublishMsg(msg) +} + // FIXME: This is a hack // removeFlushEntry is needed when we need to discard queued up responses // for our pings as part of a flush call. This happens when we have a flush diff --git a/test/context_test.go b/test/context_test.go index 6ef03a4f6..fdba58c69 100644 --- a/test/context_test.go +++ b/test/context_test.go @@ -49,6 +49,16 @@ func testContextRequestWithTimeout(t *testing.T, nc *nats.Conn) { nc.Subscribe("fast", func(m *nats.Msg) { nc.Publish(m.Reply, []byte("OK")) }) + nc.Subscribe("hdrs", func(m *nats.Msg) { + if m.Header.Get("Hdr-Test") != "1" { + m.Respond([]byte("-ERR")) + } + + r := nats.NewMsg(m.Reply) + r.Header = m.Header + r.Data = []byte("+OK") + m.RespondMsg(r) + }) ctx, cancelCB := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancelCB() // should always be called, not discarded, to prevent context leak @@ -90,6 +100,20 @@ func testContextRequestWithTimeout(t *testing.T, nc *nats.Conn) { if err == nil { t.Fatal("Expected request with context to fail") } + + // now test headers make it all the way back + msg := nats.NewMsg("hdrs") + msg.Header.Add("Hdr-Test", "1") + resp, err = nc.RequestMsgWithContext(context.Background(), msg) + if err != nil { + t.Fatalf("Expected request to be published: %v", err) + } + if string(resp.Data) != "+OK" { + t.Fatalf("Headers were not published to the requestor") + } + if resp.Header.Get("Hdr-Test") != "1" { + t.Fatalf("Did not receive header in response") + } } func TestContextRequestWithTimeout(t *testing.T) { diff --git a/test/headers_test.go b/test/headers_test.go index 2135883c9..105d75a5e 100644 --- a/test/headers_test.go +++ b/test/headers_test.go @@ -19,6 +19,7 @@ import ( "time" natsserver "github.com/nats-io/nats-server/v2/test" + "github.com/nats-io/nats.go" ) @@ -57,6 +58,46 @@ func TestBasicHeaders(t *testing.T) { } } +func TestRequestMsg(t *testing.T) { + s := RunServerOnPort(-1) + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Error connecting to server: %v", err) + } + defer nc.Close() + + subject := "headers.test" + sub, err := nc.Subscribe(subject, func(m *nats.Msg) { + if m.Header.Get("Hdr-Test") != "1" { + m.Respond([]byte("-ERR")) + } + + r := nats.NewMsg(m.Reply) + r.Header = m.Header + r.Data = []byte("+OK") + m.RespondMsg(r) + }) + if err != nil { + t.Fatalf("subscribe failed: %v", err) + } + defer sub.Unsubscribe() + + msg := nats.NewMsg(subject) + msg.Header.Add("Hdr-Test", "1") + resp, err := nc.RequestMsg(msg, time.Second) + if err != nil { + t.Fatalf("Expected request to be published: %v", err) + } + if string(resp.Data) != "+OK" { + t.Fatalf("Headers were not published to the requestor") + } + if resp.Header.Get("Hdr-Test") != "1" { + t.Fatalf("Did not receive header in response") + } +} + func TestNoHeaderSupport(t *testing.T) { opts := natsserver.DefaultTestOptions opts.Port = -1 @@ -77,4 +118,8 @@ func TestNoHeaderSupport(t *testing.T) { if err := nc.PublishMsg(m); err != nats.ErrHeadersNotSupported { t.Fatalf("Expected an error, got %v", err) } + + if _, err := nc.RequestMsg(m, time.Second); err != nats.ErrHeadersNotSupported { + t.Fatalf("Expected an error, got %v", err) + } }