Skip to content

Commit

Permalink
nsqd: support for multiple Auth HTTP Methods
Browse files Browse the repository at this point in the history
Adds simple config option and flag to allow for auth to occur via POST
request in addition to GET. Rationale: Errors from net/http requests are
bubbled to nsqd when there is an error during authentication, such as if
the nsq authentication server is unavailable. These errors include the
full path, including any GET parameter, thus causing the authentication
secret to be logged. This does not occur by default for the POST body
thus helping protect secrets in transit between nsqd and the
authentication server.
  • Loading branch information
danrjohnson authored and mreiferson committed May 12, 2024
1 parent 62fa868 commit 0db445c
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 28 deletions.
1 change: 1 addition & 0 deletions apps/nsqd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func nsqdFlagSet(opts *nsqd.Options) *flag.FlagSet {

authHTTPAddresses := app.StringArray{}
flagSet.Var(&authHTTPAddresses, "auth-http-address", "<addr>:<port> or a full url to query auth server (may be given multiple times)")
flagSet.String("auth-http-request-method", opts.AuthHTTPRequestMethod, "HTTP method to use for auth server requests")
flagSet.String("broadcast-address", opts.BroadcastAddress, "address that will be registered with lookupd (defaults to the OS hostname)")
flagSet.Int("broadcast-tcp-port", opts.BroadcastTCPPort, "TCP port that will be registered with lookupd (defaults to the TCP port that this nsqd is listening on)")
flagSet.Int("broadcast-http-port", opts.BroadcastHTTPPort, "HTTP port that will be registered with lookupd (defaults to the HTTP port that this nsqd is listening on)")
Expand Down
23 changes: 15 additions & 8 deletions internal/auth/authorizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ func (a *State) IsExpired() bool {
}

func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
var retErr error
start := rand.Int()
n := len(authd)
for i := 0; i < n; i++ {
a := authd[(i+start)%n]
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout)
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout, httpRequestMethod)
if err != nil {
es := fmt.Sprintf("failed to auth against %s - %s", a, err)
if retErr != nil {
Expand All @@ -97,7 +97,8 @@ func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName
}

func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
var authState State
v := url.Values{}
v.Set("remote_ip", remoteIP)
if tlsEnabled {
Expand All @@ -110,15 +111,21 @@ func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName strin

var endpoint string
if strings.Contains(authd, "://") {
endpoint = fmt.Sprintf("%s?%s", authd, v.Encode())
endpoint = authd
} else {
endpoint = fmt.Sprintf("http://%s/auth?%s", authd, v.Encode())
endpoint = fmt.Sprintf("http://%s/auth", authd)
}

var authState State
client := http_api.NewClient(clientTLSConfig, connectTimeout, requestTimeout)
if err := client.GETV1(endpoint, &authState); err != nil {
return nil, err
if httpRequestMethod == "post" {
if err := client.POSTV1(endpoint, v, &authState); err != nil {
return nil, err
}
} else {
endpoint = fmt.Sprintf("%s?%s", endpoint, v.Encode())
if err := client.GETV1(endpoint, &authState); err != nil {
return nil, err
}
}

// validation on response
Expand Down
4 changes: 2 additions & 2 deletions internal/clusterinfo/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ func (c *ClusterInfo) nsqlookupdPOST(addrs []string, uri string, qs string) erro
for _, addr := range addrs {
endpoint := fmt.Sprintf("http://%s/%s?%s", addr, uri, qs)
c.logf("CI: querying nsqlookupd %s", endpoint)
err := c.client.POSTV1(endpoint)
err := c.client.POSTV1(endpoint, nil, nil)
if err != nil {
errs = append(errs, err)
}
Expand All @@ -894,7 +894,7 @@ func (c *ClusterInfo) producersPOST(pl Producers, uri string, qs string) error {
for _, p := range pl {
endpoint := fmt.Sprintf("http://%s/%s?%s", p.HTTPAddress(), uri, qs)
c.logf("CI: querying nsqd %s", endpoint)
err := c.client.POSTV1(endpoint)
err := c.client.POSTV1(endpoint, nil, nil)
if err != nil {
errs = append(errs, err)
}
Expand Down
21 changes: 19 additions & 2 deletions internal/http_api/api_request.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http_api

import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -86,14 +87,26 @@ retry:

// PostV1 is a helper function to perform a V1 HTTP request
// and parse our NSQ daemon's expected response format, with deadlines.
func (c *Client) POSTV1(endpoint string) error {
func (c *Client) POSTV1(endpoint string, data url.Values, v interface{}) error {
retry:
req, err := http.NewRequest("POST", endpoint, nil)
var reqBody io.Reader
if data != nil {
js, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal POST data to endpoint: %v", endpoint)
}
reqBody = bytes.NewBuffer(js)
}

req, err := http.NewRequest("POST", endpoint, reqBody)
if err != nil {
return err
}

req.Header.Add("Accept", "application/vnd.nsq; version=1.0")
if reqBody != nil {
req.Header.Add("Content-Type", "application/json")
}

resp, err := c.c.Do(req)
if err != nil {
Expand All @@ -116,6 +129,10 @@ retry:
return fmt.Errorf("got response %s %q", resp.Status, body)
}

if v != nil {
return json.Unmarshal(body, &v)
}

return nil
}

Expand Down
4 changes: 3 additions & 1 deletion nsqd/client_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,9 @@ func (c *clientV2) QueryAuthd() error {
remoteIP, tlsEnabled, commonName, c.AuthSecret,
c.nsqd.clientTLSConfig,
c.nsqd.getOpts().HTTPClientConnectTimeout,
c.nsqd.getOpts().HTTPClientRequestTimeout)
c.nsqd.getOpts().HTTPClientRequestTimeout,
c.nsqd.getOpts().AuthHTTPRequestMethod,
)
if err != nil {
return err
}
Expand Down
4 changes: 4 additions & 0 deletions nsqd/nsqd.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ func New(opts *Options) (*NSQD, error) {
}
n.clientTLSConfig = clientTLSConfig

if opts.AuthHTTPRequestMethod != "post" && opts.AuthHTTPRequestMethod != "get" {
return nil, errors.New("--auth-http-request-method must be post or get")
}

for _, v := range opts.E2EProcessingLatencyPercentiles {
if v <= 0 || v > 1 {
return nil, fmt.Errorf("invalid E2E processing latency percentile: %v", v)
Expand Down
6 changes: 3 additions & 3 deletions nsqd/nsqd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,11 @@ func TestCluster(t *testing.T) {
test.Nil(t, err)

url := fmt.Sprintf("http://%s/topic/create?topic=%s", nsqd.RealHTTPAddr(), topicName)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
test.Nil(t, err)

url = fmt.Sprintf("http://%s/channel/create?topic=%s&channel=ch", nsqd.RealHTTPAddr(), topicName)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
test.Nil(t, err)

// allow some time for nsqd to push info to nsqlookupd
Expand Down Expand Up @@ -394,7 +394,7 @@ func TestCluster(t *testing.T) {
test.Equal(t, "ch", lr.Channels[0])

url = fmt.Sprintf("http://%s/topic/delete?topic=%s", nsqd.RealHTTPAddr(), topicName)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
test.Nil(t, err)

// allow some time for nsqd to push info to nsqlookupd
Expand Down
2 changes: 2 additions & 0 deletions nsqd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type Options struct {
BroadcastHTTPPort int `flag:"broadcast-http-port"`
NSQLookupdTCPAddresses []string `flag:"lookupd-tcp-address" cfg:"nsqlookupd_tcp_addresses"`
AuthHTTPAddresses []string `flag:"auth-http-address" cfg:"auth_http_addresses"`
AuthHTTPRequestMethod string `flag:"auth-http-request-method" cfg:"auth_http_request_method"`
HTTPClientConnectTimeout time.Duration `flag:"http-client-connect-timeout" cfg:"http_client_connect_timeout"`
HTTPClientRequestTimeout time.Duration `flag:"http-client-request-timeout" cfg:"http_client_request_timeout"`

Expand Down Expand Up @@ -110,6 +111,7 @@ func NewOptions() *Options {

NSQLookupdTCPAddresses: make([]string, 0),
AuthHTTPAddresses: make([]string, 0),
AuthHTTPRequestMethod: "get",

HTTPClientConnectTimeout: 2 * time.Second,
HTTPClientRequestTimeout: 5 * time.Second,
Expand Down
38 changes: 29 additions & 9 deletions nsqd/protocol_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"os"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -1476,24 +1477,30 @@ func TestClientAuth(t *testing.T) {
authSuccess := ""
tlsEnabled := false
commonName := ""
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
httpAuthRequestMethod := "get"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// now one that will succeed
authResponse = `{"ttl":10, "authorizations":
[{"topic":"test", "channels":[".*"], "permissions":["subscribe","publish"]}]
}`
authError = ""
authSuccess = `{"identity":"","identity_url":"","permission_count":1}`
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// one with TLS enabled
tlsEnabled = true
commonName = "test.local"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// test POST based authentication
httpAuthRequestMethod = "post"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

}

func runAuthTest(t *testing.T, authResponse string, authSecret string, authError string,
authSuccess string, tlsEnabled bool, commonName string) {
authSuccess string, tlsEnabled bool, commonName string, httpAuthRequestMethod string) {
var err error
var expectedRemoteIP string
expectedTLS := "false"
Expand All @@ -1503,11 +1510,23 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError

authd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("in test auth handler %s", r.RequestURI)
r.ParseForm()
test.Equal(t, expectedRemoteIP, r.Form.Get("remote_ip"))
test.Equal(t, expectedTLS, r.Form.Get("tls"))
test.Equal(t, commonName, r.Form.Get("common_name"))
test.Equal(t, authSecret, r.Form.Get("secret"))
test.Equal(t, httpAuthRequestMethod, strings.ToLower(r.Method))

var values url.Values

if r.Method == "POST" {
err = json.NewDecoder(r.Body).Decode(&values)
if err != nil {
t.Error(err)
}
} else {
r.ParseForm()
values = r.Form
}
test.Equal(t, expectedRemoteIP, values.Get("remote_ip"))
test.Equal(t, expectedTLS, values.Get("tls"))
test.Equal(t, commonName, values.Get("common_name"))
test.Equal(t, authSecret, values.Get("secret"))
fmt.Fprint(w, authResponse)
}))
defer authd.Close()
Expand All @@ -1519,6 +1538,7 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError
opts.Logger = test.NewTestLogger(t)
opts.LogLevel = LOG_DEBUG
opts.AuthHTTPAddresses = []string{addr.Host}
opts.AuthHTTPRequestMethod = httpAuthRequestMethod
if tlsEnabled {
opts.TLSCert = "./test/certs/server.pem"
opts.TLSKey = "./test/certs/server.key"
Expand Down
6 changes: 3 additions & 3 deletions nsqlookupd/nsqlookupd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func TestTombstoneRecover(t *testing.T) {

endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
httpAddr, topicName, HostAddr, HTTPPort)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
test.Nil(t, err)

pr := ProducersDoc{}
Expand Down Expand Up @@ -263,7 +263,7 @@ func TestTombstoneUnregister(t *testing.T) {

endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
httpAddr, topicName, HostAddr, HTTPPort)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
test.Nil(t, err)

pr := ProducersDoc{}
Expand Down Expand Up @@ -348,7 +348,7 @@ func TestTombstonedNodes(t *testing.T) {

endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
httpAddr, topicName, HostAddr, HTTPPort)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
test.Nil(t, err)

producers, _ = ci.GetLookupdProducers(lookupdHTTPAddrs)
Expand Down

0 comments on commit 0db445c

Please sign in to comment.