Skip to content

Commit

Permalink
Fixes to RespondMsg usage when crossing accounts
Browse files Browse the repository at this point in the history
Signed-off-by: Waldemar Quevedo <wally@synadia.com>
  • Loading branch information
wallyqs committed Jul 8, 2021
1 parent c7fc3c7 commit 373cf79
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 2 deletions.
18 changes: 16 additions & 2 deletions nats.go
Expand Up @@ -4190,12 +4190,26 @@ func (m *Msg) RespondMsg(msg *Msg) error {
if m.Reply == "" {
return ErrMsgNoReply
}
msg.Subject = m.Reply
m.Sub.mu.Lock()
nc := m.Sub.conn
m.Sub.mu.Unlock()

resp := &Msg{
Subject: m.Reply,
Header: msg.Header,
Data: msg.Data,
}
// Discard the reply inbox unless it is set to be a different one,
// for example when using it like this:
//
// m.Data = []byte("response")
// m.Respond(m)
//
if msg.Reply != m.Reply {
resp.Reply = msg.Reply
}
// No need to check the connection here since the call to publish will do all the checking.
return nc.PublishMsg(msg)
return nc.PublishMsg(resp)
}

// FIXME: This is a hack
Expand Down
236 changes: 236 additions & 0 deletions test/auth_test.go
Expand Up @@ -15,6 +15,7 @@ package test

import (
"fmt"
"os"
"strings"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -359,3 +360,238 @@ func TestPermViolation(t *testing.T) {
t.Fatal("Connection should be not be closed")
}
}

func TestCrossAccountRespond(t *testing.T) {
conf := createConfFile(t, []byte(`
listen: 127.0.0.1:-1
accounts: {
A: {
users: [ { user: a, password: a,
permissions = {
subscribe = ["foo", "bar", "quux", "_INBOX.>"]
publish = ["_INBOX.>", "_R_.>"]
}
}
]
exports [
{ service: "foo" }
{ service: "bar" }
# Multiple responses
{ service: "quux", response: "stream", threshold: "1s" }
]
},
B: {
users: [ { user: b, password: b } ]
imports [
{ service: { subject: "foo", account: A } }
{ service: { subject: "bar", account: A } }
{ service: { subject: "quux", account: A } }
]
},
}
`))
defer os.Remove(conf)

s, _ := RunServerWithConfig(conf)
defer s.Shutdown()

errCh := make(chan error, 1)
ncA, err := nats.Connect(s.ClientURL(), nats.UserInfo("a", "a"), nats.ErrorHandler(func(c *nats.Conn, sub *nats.Subscription, err error) {
t.Logf("Connection A: WARN: %s", err)
errCh <- err
}))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer ncA.Close()

ncB, err := nats.Connect(s.ClientURL(), nats.UserInfo("b", "b"), nats.ErrorHandler(func(c *nats.Conn, sub *nats.Subscription, err error) {
t.Logf("Connection B: WARN: %s", err)
errCh <- err
}))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer ncB.Close()

got := ""
expect := "ok"
t.Run("with NewMsg", func(t *testing.T) {
ncA.Subscribe("foo", func(m *nats.Msg) {
// NewMsg works with the side effect that it will reset the headers.
msg := nats.NewMsg(m.Reply)
msg.Data = []byte("ok")
msg.Header["X-NATS-Result"] = []string{"ok"}
err := m.RespondMsg(msg)
if err != nil {
errCh <- err
}
})
ncA.Flush()

msg := nats.NewMsg("foo")
msg.Header = nats.Header{
"X-NATS-ID": []string{"1"},
}
msg.Data = []byte("ping")
resp, err := ncB.RequestMsg(msg, 2*time.Second)
if err != nil {
t.Fatal(err)
}
got = string(resp.Data)
if got != expect {
t.Errorf("Expected %v, got: %v", expect, got)
}

if len(resp.Header) != 1 {
t.Errorf("Expected single header, got: %d", len(resp.Header))
}
_, ok := resp.Header["X-NATS-Result"]
if !ok {
t.Error("Missing header in response")
}
})

t.Run("single RespondMsg", func(t *testing.T) {
ncA.Subscribe("bar", func(m *nats.Msg) {
// RespondMsg will keep the original headers of the request.
m.Data = []byte("ok")
m.Header["X-NATS-Result"] = []string{"ok"}
err := m.RespondMsg(m)
if err != nil {
errCh <- err
}
})
ncA.Flush()

resp, err := ncB.Request("bar", []byte("ping"), 2*time.Second)
if err != nil {
t.Fatal(err)
}
got = string(resp.Data)
if got != expect {
t.Errorf("Expected %v, got: %v", expect, got)
}

if len(resp.Header) != 2 {
t.Errorf("Expected original headers as well, got: %d", len(resp.Header))
}
_, ok := resp.Header["X-NATS-Result"]
if !ok {
t.Error("Missing header in response")
}

// Server injects this header.
_, ok = resp.Header["Nats-Request-Info"]
if !ok {
t.Error("Missing header in response")
}
})

t.Run("multiple RespondMsg stream", func(t *testing.T) {
_, err := ncA.Subscribe("quux", func(m *nats.Msg) {
m.Data = []byte("start")
m.Header["Task-Progress"] = []string{"0%"}
err = m.RespondMsg(m)
if err != nil {
errCh <- err
return
}
m.Header["Task-Progress"] = []string{"50%"}
m.Data = []byte("wip")
err = m.RespondMsg(m)
if err != nil {
errCh <- err
return
}
m.Header["Task-Progress"] = []string{"100%"}
m.Data = []byte("done")
err = m.RespondMsg(m)
if err != nil {
errCh <- err
return
}
})
if err != nil {
t.Fatal(err)
}
ncA.Flush()

inbox := nats.NewInbox()
responses, err := ncB.SubscribeSync(inbox)
if err != nil {
errCh <- err
return
}
ncB.Flush()
err = ncB.PublishRequest("quux", inbox, []byte("start"))
if err != nil {
t.Error(err)
}

getNext := func(t *testing.T, sub *nats.Subscription) *nats.Msg {
resp, err := sub.NextMsg(2 * time.Second)
if err != nil {
t.Fatal(err)
}
if len(resp.Header) != 2 {
t.Errorf("Expected original headers as well, got: %d", len(resp.Header))
}
return resp
}

resp := getNext(t, responses)
got = string(resp.Data)
expect = "start"
if got != expect {
t.Errorf("Expected %v, got: %v", expect, got)
}
v, ok := resp.Header["Task-Progress"]
if !ok {
t.Error("Missing header in response")
}
got = v[0]
if got != "0%" {
t.Errorf("Unexpected value in header, got: %v", got)
}

resp = getNext(t, responses)
got = string(resp.Data)
expect = "wip"
if got != expect {
t.Errorf("Expected %v, got: %v", expect, got)
}
v, ok = resp.Header["Task-Progress"]
if !ok {
t.Error("Missing header in response")
}
got = v[0]
if got != "50%" {
t.Errorf("Unexpected value in header, got: %v", got)
}

resp = getNext(t, responses)
got = string(resp.Data)
expect = "done"
if got != expect {
t.Errorf("Expected %v, got: %v", expect, got)
}
v, ok = resp.Header["Task-Progress"]
if !ok {
t.Error("Missing header in response")
}
got = v[0]
if got != "100%" {
t.Errorf("Unexpected value in header, got: %v", got)
}
})

select {
case err := <-errCh:
if err != nil {
t.Error(err)
}
default:
}
}

0 comments on commit 373cf79

Please sign in to comment.