diff --git a/js.go b/js.go index 4b8bc322a..6996ca8ae 100644 --- a/js.go +++ b/js.go @@ -20,7 +20,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "strconv" "strings" "sync" @@ -283,7 +282,7 @@ func (js *js) PublishMsg(m *Msg, opts ...PubOpt) (*PubAck, error) { var o pubOpts if len(opts) > 0 { if m.Header == nil { - m.Header = http.Header{} + m.Header = Header{} } for _, opt := range opts { if err := opt.configurePublish(&o); err != nil { @@ -584,7 +583,7 @@ func (js *js) PublishMsgAsync(m *Msg, opts ...PubOpt) (PubAckFuture, error) { var o pubOpts if len(opts) > 0 { if m.Header == nil { - m.Header = http.Header{} + m.Header = Header{} } for _, opt := range opts { if err := opt.configurePublish(&o); err != nil { diff --git a/jsm.go b/jsm.go index 85ee3817c..b6a3f8b7a 100644 --- a/jsm.go +++ b/jsm.go @@ -18,7 +18,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "strings" "time" ) @@ -709,7 +708,7 @@ type apiMsgGetRequest struct { type RawStreamMsg struct { Subject string Sequence uint64 - Header http.Header + Header Header Data []byte Time time.Time } @@ -765,7 +764,7 @@ func (js *js) GetMsg(name string, seq uint64, opts ...JSOpt) (*RawStreamMsg, err msg := resp.Message - var hdr http.Header + var hdr Header if msg.Header != nil { hdr, err = decodeHeadersMsg(msg.Header) if err != nil { diff --git a/nats.go b/nats.go index 158c29ae7..8eb1f508a 100644 --- a/nats.go +++ b/nats.go @@ -582,7 +582,7 @@ type Subscription struct { type Msg struct { Subject string Reply string - Header http.Header + Header Header Data []byte Sub *Subscription next *Msg @@ -602,7 +602,7 @@ func (m *Msg) headerBytes() ([]byte, error) { return nil, ErrBadHeaderMsg } - err = m.Header.Write(&b) + err = http.Header(m.Header).Write(&b) if err != nil { return nil, ErrBadHeaderMsg } @@ -2605,7 +2605,7 @@ func (nc *Conn) processMsg(data []byte) { copy(msgPayload, data) // Check if we have headers encoded here. - var h http.Header + var h Header var err error var ctrl bool var hasFC bool @@ -3001,11 +3001,52 @@ func (nc *Conn) Publish(subj string, data []byte) error { return nc.publish(subj, _EMPTY_, nil, data) } +// Header represents the optional Header for a NATS message, +// based on the implementation of http.Header. +type Header map[string][]string + +// Add adds the key, value pair to the header. It is case-sensitive +// and appends to any existing values associated with key. +func (h Header) Add(key, value string) { + h[key] = append(h[key], value) +} + +// Set sets the header entries associated with key to the single +// element value. It is case-sensitive and replaces any existing +// values associated with key. +func (h Header) Set(key, value string) { + h[key] = []string{value} +} + +// Get gets the first value associated with the given key. +// It is case-sensitive. +func (h Header) Get(key string) string { + if h == nil { + return _EMPTY_ + } + if v := h[key]; v != nil { + return v[0] + } + return _EMPTY_ +} + +// Values returns all values associated with the given key. +// It is case-sensitive. +func (h Header) Values(key string) []string { + return h[key] +} + +// Del deletes the values associated with a key. +// It is case-sensitive. +func (h Header) Del(key string) { + delete(h, key) +} + // NewMsg creates a message for publishing that will use headers. func NewMsg(subject string) *Msg { return &Msg{ Subject: subject, - Header: make(http.Header), + Header: make(Header), } } @@ -3024,7 +3065,7 @@ const ( ) // decodeHeadersMsg will decode and headers. -func decodeHeadersMsg(data []byte) (http.Header, error) { +func decodeHeadersMsg(data []byte) (Header, error) { tp := textproto.NewReader(bufio.NewReader(bytes.NewReader(data))) l, err := tp.ReadLine() if err != nil || len(l) < hdrPreEnd || l[:hdrPreEnd] != hdrLine[:hdrPreEnd] { @@ -3049,7 +3090,7 @@ func decodeHeadersMsg(data []byte) (http.Header, error) { mh.Add(descrHdr, description) } } - return http.Header(mh), nil + return Header(mh), nil } // readMIMEHeader returns a MIMEHeader that preserves the diff --git a/nats_test.go b/nats_test.go index efa26c807..7f8c840bc 100644 --- a/nats_test.go +++ b/nats_test.go @@ -25,6 +25,7 @@ import ( "fmt" "io/ioutil" "net" + "net/http" "net/url" "os" "reflect" @@ -2482,6 +2483,52 @@ func TestHeaderParser(t *testing.T) { checkStatus("NATS/1.0 404 No Messages", 404, "No Messages") } +func TestHeaderMultiLine(t *testing.T) { + m := NewMsg("foo") + m.Header = Header{ + "CorrelationID": []string{"123"}, + "Msg-ID": []string{"456"}, + "X-NATS-Keys": []string{"A", "B", "C"}, + "X-Test-Keys": []string{"D", "E", "F"}, + } + // Users can opt-in to canonicalize like http.Header does + // by using http.Header#Set or http.Header#Add. + http.Header(m.Header).Set("accept-encoding", "json") + http.Header(m.Header).Add("AUTHORIZATION", "s3cr3t") + + // Multi Value Header becomes represented as multi-lines in the wire + // since internally using same Write from http stdlib. + m.Header.Set("X-Test", "First") + m.Header.Add("X-Test", "Second") + m.Header.Add("X-Test", "Third") + + b, err := m.headerBytes() + if err != nil { + t.Fatal(err) + } + result := string(b) + + expectedHeader := `NATS/1.0 +Accept-Encoding: json +Authorization: s3cr3t +CorrelationID: 123 +Msg-ID: 456 +X-NATS-Keys: A +X-NATS-Keys: B +X-NATS-Keys: C +X-Test: First +X-Test: Second +X-Test: Third +X-Test-Keys: D +X-Test-Keys: E +X-Test-Keys: F + +` + if strings.Replace(expectedHeader, "\n", "\r\n", -1) != result { + t.Fatalf("Expected: %q, got: %q", expectedHeader, result) + } +} + func TestLameDuckMode(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { diff --git a/test/headers_test.go b/test/headers_test.go index e0e88cc9d..24b962326 100644 --- a/test/headers_test.go +++ b/test/headers_test.go @@ -17,6 +17,7 @@ import ( "fmt" "net/http" "reflect" + "sort" "testing" "time" @@ -146,29 +147,40 @@ func TestMsgHeadersCasePreserving(t *testing.T) { m := nats.NewMsg(subject) - // Avoid canonicalizing headers by creating headers manually. - // - // To not use canonical keys, Go recommends accessing the map directly. - // https://golang.org/pkg/net/http/#Header.Set - m.Header = http.Header{ + // http.Header preserves the original keys and allows case-sensitive + // lookup by accessing the map directly. + hdr := http.Header{ "CorrelationID": []string{"123"}, "Msg-ID": []string{"456"}, "X-NATS-Keys": []string{"A", "B", "C"}, "X-Test-Keys": []string{"D", "E", "F"}, } - // Users can opt-in to canonicalize an http.Header - // by using http.Header.Add() - m.Header.Add("Accept-Encoding", "json") - m.Header.Add("Authorization", "s3cr3t") + // Validate that can be used interchangeably with http.Header + type HeaderInterface interface { + Add(key, value string) + Del(key string) + Get(key string) string + Set(key, value string) + Values(key string) []string + } + var _ HeaderInterface = http.Header{} + var _ HeaderInterface = nats.Header{} + + // A NATS Header is the same type as http.Header so simple casting + // works to use canonical form used in Go HTTP servers if needed, + // and it also preserves the same original keys like Go HTTP requests. + m.Header = nats.Header(hdr) + http.Header(m.Header).Set("accept-encoding", "json") + http.Header(m.Header).Add("AUTHORIZATION", "s3cr3t") - // Multi Value Header + // Multi Value using the same matching key. m.Header.Set("X-Test", "First") m.Header.Add("X-Test", "Second") m.Header.Add("X-Test", "Third") + m.Data = []byte("Simple Headers") nc.PublishMsg(m) - msg, err := sub.NextMsg(time.Second) if err != nil { t.Fatalf("Did not receive response: %v", err) @@ -176,30 +188,40 @@ func TestMsgHeadersCasePreserving(t *testing.T) { // Blank out the sub since its not present in the original. msg.Sub = nil + + // Confirm that received message is just like the one originally sent. if !reflect.DeepEqual(m, msg) { t.Fatalf("Messages did not match! \n%+v\n%+v\n", m, msg) } for _, test := range []struct { - Header string - Values []string - Canonical bool + Header string + Values []string }{ - {"Accept-Encoding", []string{"json"}, true}, - {"Authorization", []string{"s3cr3t"}, true}, - {"X-Test", []string{"First", "Second", "Third"}, true}, - {"CorrelationID", []string{"123"}, false}, - {"Msg-ID", []string{"456"}, false}, - {"X-NATS-Keys", []string{"A", "B", "C"}, false}, - {"X-Test-Keys", []string{"D", "E", "F"}, true}, + {"Accept-Encoding", []string{"json"}}, + {"Authorization", []string{"s3cr3t"}}, + {"X-Test", []string{"First", "Second", "Third"}}, + {"CorrelationID", []string{"123"}}, + {"Msg-ID", []string{"456"}}, + {"X-NATS-Keys", []string{"A", "B", "C"}}, + {"X-Test-Keys", []string{"D", "E", "F"}}, } { // Accessing directly will always work. - v, ok := msg.Header[test.Header] + v1, ok := msg.Header[test.Header] if !ok { t.Errorf("Expected %v to be present", test.Header) } - if len(v) != len(test.Values) { - t.Errorf("Expected %v values in header, got: %v", len(test.Values), len(v)) + if len(v1) != len(test.Values) { + t.Errorf("Expected %v values in header, got: %v", len(test.Values), len(v1)) + } + + // Exact match is preferred and fastest for Get. + v2 := msg.Header.Get(test.Header) + if v2 == "" { + t.Errorf("Expected %v to be present", test.Header) + } + if v1[0] != v2 { + t.Errorf("Expected: %s, got: %v", v1, v2) } for k, val := range test.Values { @@ -209,14 +231,6 @@ func TestMsgHeadersCasePreserving(t *testing.T) { t.Errorf("Expected %v values in header, got: %v", val, vv) } } - - // Only canonical version of headers can be fetched with Add/Get/Values. - // Need to access the map directly to get the non canonicalized version - // as per the Go docs of textproto package. - if !test.Canonical { - continue - } - if len(test.Values) > 1 { if !reflect.DeepEqual(test.Values, msg.Header.Values(test.Header)) { t.Fatalf("Headers did not match! \n%+v\n%+v\n", test.Values, msg.Header.Values(test.Header)) @@ -234,7 +248,6 @@ func TestMsgHeadersCasePreserving(t *testing.T) { errCh := make(chan error, 2) msgCh := make(chan *nats.Msg, 1) sub, err = nc.Subscribe("nats.svc.A", func(msg *nats.Msg) { - //lint:ignore SA1008 non canonical form test hdr := msg.Header["x-trace-id"] hdr = append(hdr, "A") msg.Header["x-trace-id"] = hdr @@ -260,7 +273,6 @@ func TestMsgHeadersCasePreserving(t *testing.T) { defer sub.Unsubscribe() sub, err = nc.Subscribe("nats.svc.B", func(msg *nats.Msg) { - //lint:ignore SA1008 non canonical form test hdr := msg.Header["x-trace-id"] hdr = append(hdr, "B") msg.Header["x-trace-id"] = hdr @@ -281,7 +293,7 @@ func TestMsgHeadersCasePreserving(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { msg := nats.NewMsg("nats.svc.A") - msg.Header = r.Header.Clone() + msg.Header = nats.Header(r.Header.Clone()) msg.Header["x-trace-id"] = []string{"S"} msg.Header["Result-ID"] = []string{"OK"} resp, err := nc.RequestMsg(msg, 2*time.Second) @@ -295,7 +307,7 @@ func TestMsgHeadersCasePreserving(t *testing.T) { w.Header()[k] = v } - // Remove Date for testing. + // Remove Date from response header for testing. w.Header()["Date"] = nil w.WriteHeader(200) @@ -308,7 +320,7 @@ func TestMsgHeadersCasePreserving(t *testing.T) { t.Fatal(err) } - client := &http.Client{Timeout: 3 * time.Second} + client := &http.Client{Timeout: 2 * time.Second} resp, err := client.Do(req) if err != nil { t.Fatal(err) @@ -335,7 +347,6 @@ func TestMsgHeadersCasePreserving(t *testing.T) { t.Errorf("Wrong number of headers in NATS message, got: %v", len(msg.Header)) } - //lint:ignore SA1008 non canonical form test v, ok := msg.Header["x-trace-id"] if !ok { t.Fatal("Missing headers in message") @@ -343,4 +354,87 @@ func TestMsgHeadersCasePreserving(t *testing.T) { if !reflect.DeepEqual(v, []string{"S", "A", "B"}) { t.Fatal("Missing headers in message") } + for _, key := range []string{"x-trace-id"} { + v = msg.Header.Values(key) + if v == nil { + t.Fatal("Missing headers in message") + } + if !reflect.DeepEqual(v, []string{"S", "A", "B"}) { + t.Fatal("Missing headers in message") + } + } + + t.Run("multi value header", func(t *testing.T) { + getHeader := func() nats.Header { + return nats.Header{ + "foo": []string{"A"}, + "Foo": []string{"B"}, + "FOO": []string{"C"}, + } + } + + hdr := getHeader() + got := hdr.Get("foo") + expected := "A" + if got != expected { + t.Errorf("Expected: %v, got: %v", expected, got) + } + got = hdr.Get("Foo") + expected = "B" + if got != expected { + t.Errorf("Expected: %v, got: %v", expected, got) + } + got = hdr.Get("FOO") + expected = "C" + if got != expected { + t.Errorf("Expected: %v, got: %v", expected, got) + } + + // No match. + got = hdr.Get("fOo") + if got != "" { + t.Errorf("Unexpected result, got: %v", got) + } + + // Only match explicitly. + for _, test := range []struct { + key string + expectedValues []string + }{ + {"foo", []string{"A"}}, + {"Foo", []string{"B"}}, + {"FOO", []string{"C"}}, + {"fOO", nil}, + {"foO", nil}, + } { + t.Run("", func(t *testing.T) { + hdr := getHeader() + result := hdr.Values(test.key) + sort.Strings(result) + + if !reflect.DeepEqual(result, test.expectedValues) { + t.Errorf("Expected: %+v, got: %+v", test.expectedValues, result) + } + if hdr.Get(test.key) == "" { + return + } + + // Cleanup all the matching keys. + hdr.Del(test.key) + + got := len(hdr) + expected := 2 + if got != expected { + t.Errorf("Expected: %v, got: %v", expected, got) + } + result = hdr.Values(test.key) + if result != nil { + t.Errorf("Expected to cleanup all matching keys, got: %+v", result) + } + if v := hdr.Get(test.key); v != "" { + t.Errorf("Expected to cleanup all matching keys, got: %v", v) + } + }) + } + }) } diff --git a/test/js_test.go b/test/js_test.go index 1282714c2..00ca85d9d 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io/ioutil" "net" - "net/http" "os" "reflect" "strings" @@ -1780,7 +1779,7 @@ func testJetStreamManagement_GetMsg(t *testing.T, srvs ...*jsServer) { msg := nats.NewMsg("foo.A") data := fmt.Sprintf("A:%d", i) msg.Data = []byte(data) - msg.Header = http.Header{ + msg.Header = nats.Header{ "X-NATS-Key": []string{"123"}, } msg.Header.Add("X-Nats-Test-Data", data) @@ -1898,7 +1897,7 @@ func testJetStreamManagement_GetMsg(t *testing.T, srvs ...*jsServer) { "X-Nats-Test-Data": {"A:1"}, "X-NATS-Key": {"123"}, } - if !reflect.DeepEqual(streamMsg.Header, http.Header(expectedMap)) { + if !reflect.DeepEqual(streamMsg.Header, nats.Header(expectedMap)) { t.Errorf("Expected %v, got: %v", expectedMap, streamMsg.Header) } @@ -1910,7 +1909,7 @@ func testJetStreamManagement_GetMsg(t *testing.T, srvs ...*jsServer) { if err != nil { t.Fatal(err) } - if !reflect.DeepEqual(msg.Header, http.Header(expectedMap)) { + if !reflect.DeepEqual(msg.Header, nats.Header(expectedMap)) { t.Errorf("Expected %v, got: %v", expectedMap, msg.Header) } })