Skip to content

Commit

Permalink
Merge pull request #574 from ripienaar/request_headers
Browse files Browse the repository at this point in the history
add Msg variants of the Request() functions for headers
  • Loading branch information
derekcollison committed Jun 3, 2020
2 parents e93e18d + c4ab892 commit fdbf7d6
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 19 deletions.
34 changes: 28 additions & 6 deletions context.go
Expand Up @@ -11,8 +11,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// +build go1.7

// A Go client for the NATS messaging system (https://nats.io).
package nats

Expand All @@ -21,9 +19,33 @@ import (
"reflect"
)

// RequestMsgWithContext takes a context, a subject and payload
// in bytes and request expecting a single response.
func (nc *Conn) RequestMsgWithContext(ctx context.Context, msg *Msg) (*Msg, error) {
var hdr []byte
var err error

if len(msg.Header) > 0 {
if !nc.info.Headers {
return nil, ErrHeadersNotSupported
}

hdr, err = msg.headerBytes()
if err != nil {
return nil, err
}
}

return nc.requestWithContext(ctx, msg.Subject, hdr, msg.Data)
}

// RequestWithContext takes a context, a subject and payload
// in bytes and request expecting a single response.
func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte) (*Msg, error) {
return nc.requestWithContext(ctx, subj, nil, data)
}

func (nc *Conn) requestWithContext(ctx context.Context, subj string, hdr, data []byte) (*Msg, error) {
if ctx == nil {
return nil, ErrInvalidContext
}
Expand All @@ -40,10 +62,10 @@ func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte
// If user wants the old style.
if nc.Opts.UseOldRequestStyle {
nc.mu.Unlock()
return nc.oldRequestWithContext(ctx, subj, data)
return nc.oldRequestWithContext(ctx, subj, hdr, data)
}

mch, token, err := nc.createNewRequestAndSend(subj, data)
mch, token, err := nc.createNewRequestAndSend(subj, hdr, data)
if err != nil {
return nil, err
}
Expand All @@ -67,7 +89,7 @@ func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte
}

// oldRequestWithContext utilizes inbox and subscription per request.
func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, data []byte) (*Msg, error) {
func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, hdr, data []byte) (*Msg, error) {
inbox := NewInbox()
ch := make(chan *Msg, RequestChanLen)

Expand All @@ -78,7 +100,7 @@ func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, data []b
s.AutoUnsubscribe(1)
defer s.Unsubscribe()

err = nc.PublishRequest(subj, inbox, data)
err = nc.publish(subj, inbox, hdr, data)
if err != nil {
return nil, err
}
Expand Down
97 changes: 84 additions & 13 deletions nats.go
Expand Up @@ -510,6 +510,31 @@ type Msg struct {
barrier *barrierInfo
}

func (m *Msg) headerBytes() ([]byte, error) {
var hdr []byte
if len(m.Header) == 0 {
return hdr, nil
}

var b bytes.Buffer
_, err := b.WriteString(hdrLine)
if err != nil {
return nil, ErrBadHeaderMsg
}

err = m.Header.Write(&b)
if err != nil {
return nil, ErrBadHeaderMsg
}

_, err = b.WriteString(crlf)
if err != nil {
return nil, ErrBadHeaderMsg
}

return b.Bytes(), nil
}

type barrierInfo struct {
refs int64
f func()
Expand Down Expand Up @@ -2687,18 +2712,21 @@ func (nc *Conn) PublishMsg(m *Msg) error {
if m == nil {
return ErrInvalidMsg
}

var hdr []byte
var err error

if len(m.Header) > 0 {
if !nc.info.Headers {
return ErrHeadersNotSupported
}
// FIXME(dlc) - Optimize
var b bytes.Buffer
b.WriteString(hdrLine)
m.Header.Write(&b)
b.WriteString(crlf)
hdr = b.Bytes()

hdr, err = m.headerBytes()
if err != nil {
return err
}
}

return nc.publish(m.Subject, m.Reply, hdr, m.Data)
}

Expand Down Expand Up @@ -2874,7 +2902,7 @@ func (nc *Conn) respHandler(m *Msg) {
}

// Helper to setup and send new request style requests. Return the chan to receive the response.
func (nc *Conn) createNewRequestAndSend(subj string, data []byte) (chan *Msg, string, error) {
func (nc *Conn) createNewRequestAndSend(subj string, hdr, data []byte) (chan *Msg, string, error) {
// Do setup for the new style if needed.
if nc.respMap == nil {
nc.initNewResp()
Expand All @@ -2898,28 +2926,55 @@ func (nc *Conn) createNewRequestAndSend(subj string, data []byte) (chan *Msg, st
}
nc.mu.Unlock()

if err := nc.PublishRequest(subj, respInbox, data); err != nil {
if err := nc.publish(subj, respInbox, hdr, data); err != nil {
return nil, token, err
}

return mch, token, nil
}

// RequestMsg will send a request payload including optional headers and deliver
// the response message, or an error, including a timeout if no message was received properly.
func (nc *Conn) RequestMsg(msg *Msg, timeout time.Duration) (*Msg, error) {
var hdr []byte
var err error

if len(msg.Header) > 0 {
if !nc.info.Headers {
return nil, ErrHeadersNotSupported
}

hdr, err = msg.headerBytes()
if err != nil {
return nil, err
}
}

return nc.request(msg.Subject, hdr, msg.Data, timeout)
}

// Request will send a request payload and deliver the response message,
// or an error, including a timeout if no message was received properly.
func (nc *Conn) Request(subj string, data []byte, timeout time.Duration) (*Msg, error) {
return nc.request(subj, nil, data, timeout)
}

func (nc *Conn) request(subj string, hdr, data []byte, timeout time.Duration) (*Msg, error) {
if nc == nil {
return nil, ErrInvalidConnection
}

nc.mu.Lock()
// If user wants the old style.
if nc.Opts.UseOldRequestStyle {
nc.mu.Unlock()
return nc.oldRequest(subj, data, timeout)
return nc.oldRequest(subj, hdr, data, timeout)
}

mch, token, err := nc.createNewRequestAndSend(subj, data)
return nc.newRequest(subj, hdr, data, timeout)
}

func (nc *Conn) newRequest(subj string, hdr, data []byte, timeout time.Duration) (*Msg, error) {
mch, token, err := nc.createNewRequestAndSend(subj, hdr, data)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -2948,7 +3003,7 @@ func (nc *Conn) Request(subj string, data []byte, timeout time.Duration) (*Msg,
// oldRequest will create an Inbox and perform a Request() call
// with the Inbox reply and return the first reply received.
// This is optimized for the case of multiple responses.
func (nc *Conn) oldRequest(subj string, data []byte, timeout time.Duration) (*Msg, error) {
func (nc *Conn) oldRequest(subj string, hdr, data []byte, timeout time.Duration) (*Msg, error) {
inbox := NewInbox()
ch := make(chan *Msg, RequestChanLen)

Expand All @@ -2959,10 +3014,11 @@ func (nc *Conn) oldRequest(subj string, data []byte, timeout time.Duration) (*Ms
s.AutoUnsubscribe(1)
defer s.Unsubscribe()

err = nc.PublishRequest(subj, inbox, data)
err = nc.publish(subj, inbox, hdr, data)
if err != nil {
return nil, err
}

return s.NextMsg(timeout)
}

Expand Down Expand Up @@ -3653,6 +3709,21 @@ func (m *Msg) Respond(data []byte) error {
return nc.Publish(m.Reply, data)
}

// RespondMsg allows a convenient way to respond to requests in service based subscriptions that might include headers
func (m *Msg) RespondMsg(msg *Msg) error {
if m == nil || m.Sub == nil {
return ErrMsgNotBound
}
if m.Reply == "" {
return ErrMsgNoReply
}
m.Sub.mu.Lock()
nc := m.Sub.conn
m.Sub.mu.Unlock()
// No need to check the connection here since the call to publish will do all the checking.
return nc.PublishMsg(msg)
}

// FIXME: This is a hack
// removeFlushEntry is needed when we need to discard queued up responses
// for our pings as part of a flush call. This happens when we have a flush
Expand Down
24 changes: 24 additions & 0 deletions test/context_test.go
Expand Up @@ -49,6 +49,16 @@ func testContextRequestWithTimeout(t *testing.T, nc *nats.Conn) {
nc.Subscribe("fast", func(m *nats.Msg) {
nc.Publish(m.Reply, []byte("OK"))
})
nc.Subscribe("hdrs", func(m *nats.Msg) {
if m.Header.Get("Hdr-Test") != "1" {
m.Respond([]byte("-ERR"))
}

r := nats.NewMsg(m.Reply)
r.Header = m.Header
r.Data = []byte("+OK")
m.RespondMsg(r)
})

ctx, cancelCB := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancelCB() // should always be called, not discarded, to prevent context leak
Expand Down Expand Up @@ -90,6 +100,20 @@ func testContextRequestWithTimeout(t *testing.T, nc *nats.Conn) {
if err == nil {
t.Fatal("Expected request with context to fail")
}

// now test headers make it all the way back
msg := nats.NewMsg("hdrs")
msg.Header.Add("Hdr-Test", "1")
resp, err = nc.RequestMsgWithContext(context.Background(), msg)
if err != nil {
t.Fatalf("Expected request to be published: %v", err)
}
if string(resp.Data) != "+OK" {
t.Fatalf("Headers were not published to the requestor")
}
if resp.Header.Get("Hdr-Test") != "1" {
t.Fatalf("Did not receive header in response")
}
}

func TestContextRequestWithTimeout(t *testing.T) {
Expand Down
45 changes: 45 additions & 0 deletions test/headers_test.go
Expand Up @@ -19,6 +19,7 @@ import (
"time"

natsserver "github.com/nats-io/nats-server/v2/test"

"github.com/nats-io/nats.go"
)

Expand Down Expand Up @@ -57,6 +58,46 @@ func TestBasicHeaders(t *testing.T) {
}
}

func TestRequestMsg(t *testing.T) {
s := RunServerOnPort(-1)
defer s.Shutdown()

nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Error connecting to server: %v", err)
}
defer nc.Close()

subject := "headers.test"
sub, err := nc.Subscribe(subject, func(m *nats.Msg) {
if m.Header.Get("Hdr-Test") != "1" {
m.Respond([]byte("-ERR"))
}

r := nats.NewMsg(m.Reply)
r.Header = m.Header
r.Data = []byte("+OK")
m.RespondMsg(r)
})
if err != nil {
t.Fatalf("subscribe failed: %v", err)
}
defer sub.Unsubscribe()

msg := nats.NewMsg(subject)
msg.Header.Add("Hdr-Test", "1")
resp, err := nc.RequestMsg(msg, time.Second)
if err != nil {
t.Fatalf("Expected request to be published: %v", err)
}
if string(resp.Data) != "+OK" {
t.Fatalf("Headers were not published to the requestor")
}
if resp.Header.Get("Hdr-Test") != "1" {
t.Fatalf("Did not receive header in response")
}
}

func TestNoHeaderSupport(t *testing.T) {
opts := natsserver.DefaultTestOptions
opts.Port = -1
Expand All @@ -77,4 +118,8 @@ func TestNoHeaderSupport(t *testing.T) {
if err := nc.PublishMsg(m); err != nats.ErrHeadersNotSupported {
t.Fatalf("Expected an error, got %v", err)
}

if _, err := nc.RequestMsg(m, time.Second); err != nats.ErrHeadersNotSupported {
t.Fatalf("Expected an error, got %v", err)
}
}

0 comments on commit fdbf7d6

Please sign in to comment.