Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Msg variants of the Request() functions for headers #574

Merged
merged 1 commit into from
Jun 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 28 additions & 6 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// +build go1.7

// A Go client for the NATS messaging system (https://nats.io).
package nats

Expand All @@ -21,9 +19,33 @@ import (
"reflect"
)

// RequestMsgWithContext takes a context, a subject and payload
// in bytes and request expecting a single response.
func (nc *Conn) RequestMsgWithContext(ctx context.Context, msg *Msg) (*Msg, error) {
var hdr []byte
var err error

if len(msg.Header) > 0 {
if !nc.info.Headers {
return nil, ErrHeadersNotSupported
}

hdr, err = msg.headerBytes()
if err != nil {
return nil, err
}
}

return nc.requestWithContext(ctx, msg.Subject, hdr, msg.Data)
}

// RequestWithContext takes a context, a subject and payload
// in bytes and request expecting a single response.
func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte) (*Msg, error) {
return nc.requestWithContext(ctx, subj, nil, data)
}

func (nc *Conn) requestWithContext(ctx context.Context, subj string, hdr, data []byte) (*Msg, error) {
if ctx == nil {
return nil, ErrInvalidContext
}
Expand All @@ -40,10 +62,10 @@ func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte
// If user wants the old style.
if nc.Opts.UseOldRequestStyle {
nc.mu.Unlock()
return nc.oldRequestWithContext(ctx, subj, data)
return nc.oldRequestWithContext(ctx, subj, hdr, data)
}

mch, token, err := nc.createNewRequestAndSend(subj, data)
mch, token, err := nc.createNewRequestAndSend(subj, hdr, data)
if err != nil {
return nil, err
}
Expand All @@ -67,7 +89,7 @@ func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte
}

// oldRequestWithContext utilizes inbox and subscription per request.
func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, data []byte) (*Msg, error) {
func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, hdr, data []byte) (*Msg, error) {
inbox := NewInbox()
ch := make(chan *Msg, RequestChanLen)

Expand All @@ -78,7 +100,7 @@ func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, data []b
s.AutoUnsubscribe(1)
defer s.Unsubscribe()

err = nc.PublishRequest(subj, inbox, data)
err = nc.publish(subj, inbox, hdr, data)
if err != nil {
return nil, err
}
Expand Down
97 changes: 84 additions & 13 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,31 @@ type Msg struct {
barrier *barrierInfo
}

func (m *Msg) headerBytes() ([]byte, error) {
var hdr []byte
if len(m.Header) == 0 {
return hdr, nil
}

var b bytes.Buffer
_, err := b.WriteString(hdrLine)
if err != nil {
return nil, ErrBadHeaderMsg
}

err = m.Header.Write(&b)
if err != nil {
return nil, ErrBadHeaderMsg
}

_, err = b.WriteString(crlf)
if err != nil {
return nil, ErrBadHeaderMsg
}

return b.Bytes(), nil
}

type barrierInfo struct {
refs int64
f func()
Expand Down Expand Up @@ -2687,18 +2712,21 @@ func (nc *Conn) PublishMsg(m *Msg) error {
if m == nil {
return ErrInvalidMsg
}

var hdr []byte
var err error

if len(m.Header) > 0 {
if !nc.info.Headers {
return ErrHeadersNotSupported
}
// FIXME(dlc) - Optimize
var b bytes.Buffer
b.WriteString(hdrLine)
m.Header.Write(&b)
b.WriteString(crlf)
hdr = b.Bytes()

hdr, err = m.headerBytes()
if err != nil {
return err
}
}

return nc.publish(m.Subject, m.Reply, hdr, m.Data)
}

Expand Down Expand Up @@ -2874,7 +2902,7 @@ func (nc *Conn) respHandler(m *Msg) {
}

// Helper to setup and send new request style requests. Return the chan to receive the response.
func (nc *Conn) createNewRequestAndSend(subj string, data []byte) (chan *Msg, string, error) {
func (nc *Conn) createNewRequestAndSend(subj string, hdr, data []byte) (chan *Msg, string, error) {
// Do setup for the new style if needed.
if nc.respMap == nil {
nc.initNewResp()
Expand All @@ -2898,28 +2926,55 @@ func (nc *Conn) createNewRequestAndSend(subj string, data []byte) (chan *Msg, st
}
nc.mu.Unlock()

if err := nc.PublishRequest(subj, respInbox, data); err != nil {
if err := nc.publish(subj, respInbox, hdr, data); err != nil {
return nil, token, err
}

return mch, token, nil
}

// RequestMsg will send a request payload including optional headers and deliver
// the response message, or an error, including a timeout if no message was received properly.
func (nc *Conn) RequestMsg(msg *Msg, timeout time.Duration) (*Msg, error) {
ripienaar marked this conversation as resolved.
Show resolved Hide resolved
var hdr []byte
var err error

if len(msg.Header) > 0 {
if !nc.info.Headers {
return nil, ErrHeadersNotSupported
}

hdr, err = msg.headerBytes()
if err != nil {
return nil, err
}
}

return nc.request(msg.Subject, hdr, msg.Data, timeout)
}

// Request will send a request payload and deliver the response message,
// or an error, including a timeout if no message was received properly.
func (nc *Conn) Request(subj string, data []byte, timeout time.Duration) (*Msg, error) {
return nc.request(subj, nil, data, timeout)
}

func (nc *Conn) request(subj string, hdr, data []byte, timeout time.Duration) (*Msg, error) {
if nc == nil {
return nil, ErrInvalidConnection
}

nc.mu.Lock()
// If user wants the old style.
if nc.Opts.UseOldRequestStyle {
nc.mu.Unlock()
return nc.oldRequest(subj, data, timeout)
return nc.oldRequest(subj, hdr, data, timeout)
}

mch, token, err := nc.createNewRequestAndSend(subj, data)
return nc.newRequest(subj, hdr, data, timeout)
}

func (nc *Conn) newRequest(subj string, hdr, data []byte, timeout time.Duration) (*Msg, error) {
mch, token, err := nc.createNewRequestAndSend(subj, hdr, data)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -2948,7 +3003,7 @@ func (nc *Conn) Request(subj string, data []byte, timeout time.Duration) (*Msg,
// oldRequest will create an Inbox and perform a Request() call
// with the Inbox reply and return the first reply received.
// This is optimized for the case of multiple responses.
func (nc *Conn) oldRequest(subj string, data []byte, timeout time.Duration) (*Msg, error) {
func (nc *Conn) oldRequest(subj string, hdr, data []byte, timeout time.Duration) (*Msg, error) {
inbox := NewInbox()
ch := make(chan *Msg, RequestChanLen)

Expand All @@ -2959,10 +3014,11 @@ func (nc *Conn) oldRequest(subj string, data []byte, timeout time.Duration) (*Ms
s.AutoUnsubscribe(1)
defer s.Unsubscribe()

err = nc.PublishRequest(subj, inbox, data)
err = nc.publish(subj, inbox, hdr, data)
if err != nil {
return nil, err
}

return s.NextMsg(timeout)
}

Expand Down Expand Up @@ -3653,6 +3709,21 @@ func (m *Msg) Respond(data []byte) error {
return nc.Publish(m.Reply, data)
}

// RespondMsg allows a convenient way to respond to requests in service based subscriptions that might include headers
func (m *Msg) RespondMsg(msg *Msg) error {
if m == nil || m.Sub == nil {
return ErrMsgNotBound
}
if m.Reply == "" {
return ErrMsgNoReply
}
m.Sub.mu.Lock()
nc := m.Sub.conn
m.Sub.mu.Unlock()
// No need to check the connection here since the call to publish will do all the checking.
return nc.PublishMsg(msg)
}

// FIXME: This is a hack
// removeFlushEntry is needed when we need to discard queued up responses
// for our pings as part of a flush call. This happens when we have a flush
Expand Down
24 changes: 24 additions & 0 deletions test/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ func testContextRequestWithTimeout(t *testing.T, nc *nats.Conn) {
nc.Subscribe("fast", func(m *nats.Msg) {
nc.Publish(m.Reply, []byte("OK"))
})
nc.Subscribe("hdrs", func(m *nats.Msg) {
if m.Header.Get("Hdr-Test") != "1" {
m.Respond([]byte("-ERR"))
}

r := nats.NewMsg(m.Reply)
r.Header = m.Header
r.Data = []byte("+OK")
m.RespondMsg(r)
})

ctx, cancelCB := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancelCB() // should always be called, not discarded, to prevent context leak
Expand Down Expand Up @@ -90,6 +100,20 @@ func testContextRequestWithTimeout(t *testing.T, nc *nats.Conn) {
if err == nil {
t.Fatal("Expected request with context to fail")
}

// now test headers make it all the way back
msg := nats.NewMsg("hdrs")
msg.Header.Add("Hdr-Test", "1")
resp, err = nc.RequestMsgWithContext(context.Background(), msg)
if err != nil {
t.Fatalf("Expected request to be published: %v", err)
}
if string(resp.Data) != "+OK" {
t.Fatalf("Headers were not published to the requestor")
}
if resp.Header.Get("Hdr-Test") != "1" {
t.Fatalf("Did not receive header in response")
}
}

func TestContextRequestWithTimeout(t *testing.T) {
Expand Down
45 changes: 45 additions & 0 deletions test/headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"time"

natsserver "github.com/nats-io/nats-server/v2/test"

"github.com/nats-io/nats.go"
)

Expand Down Expand Up @@ -57,6 +58,46 @@ func TestBasicHeaders(t *testing.T) {
}
}

func TestRequestMsg(t *testing.T) {
s := RunServerOnPort(-1)
defer s.Shutdown()

nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Error connecting to server: %v", err)
}
defer nc.Close()

subject := "headers.test"
sub, err := nc.Subscribe(subject, func(m *nats.Msg) {
if m.Header.Get("Hdr-Test") != "1" {
m.Respond([]byte("-ERR"))
}

r := nats.NewMsg(m.Reply)
r.Header = m.Header
r.Data = []byte("+OK")
m.RespondMsg(r)
})
if err != nil {
t.Fatalf("subscribe failed: %v", err)
}
defer sub.Unsubscribe()

msg := nats.NewMsg(subject)
msg.Header.Add("Hdr-Test", "1")
resp, err := nc.RequestMsg(msg, time.Second)
if err != nil {
t.Fatalf("Expected request to be published: %v", err)
}
if string(resp.Data) != "+OK" {
t.Fatalf("Headers were not published to the requestor")
}
if resp.Header.Get("Hdr-Test") != "1" {
t.Fatalf("Did not receive header in response")
}
}

func TestNoHeaderSupport(t *testing.T) {
opts := natsserver.DefaultTestOptions
opts.Port = -1
Expand All @@ -77,4 +118,8 @@ func TestNoHeaderSupport(t *testing.T) {
if err := nc.PublishMsg(m); err != nats.ErrHeadersNotSupported {
t.Fatalf("Expected an error, got %v", err)
}

if _, err := nc.RequestMsg(m, time.Second); err != nats.ErrHeadersNotSupported {
t.Fatalf("Expected an error, got %v", err)
}
}