diff --git a/address/granter.go b/address/granter.go index afa17a7..187de52 100644 --- a/address/granter.go +++ b/address/granter.go @@ -12,6 +12,14 @@ import ( "gopkg.in/m-lab/pipe.v3" ) +// Manager manages access to a device by IP and port. +type Manager interface { + Start(port, device string) error + Grant(ip net.IP) error + Revoke(ip net.IP) error + Stop() ([]byte, error) +} + // IPManager supports granting IP subnet access using iptables or ip6tables. type IPManager struct { *semaphore.Weighted @@ -75,3 +83,26 @@ func cmdForIP(ip net.IP) (string, string) { } return iptables, "/64" } + +// NullManager implements the address.Manager interface while doing nothing. +type NullManager struct{} + +// Grant does nothing with the given ip. +func (r *NullManager) Grant(ip net.IP) error { + return nil +} + +// Revoke does nothing with the given ip. +func (r *NullManager) Revoke(ip net.IP) error { + return nil +} + +// Start does nothing to the given port or device. +func (r *NullManager) Start(port, device string) error { + return nil +} + +// Stop does nothing. +func (r *NullManager) Stop() ([]byte, error) { + return nil, nil +} diff --git a/address/granter_test.go b/address/granter_test.go index c3b22c5..2a9e5a6 100644 --- a/address/granter_test.go +++ b/address/granter_test.go @@ -101,3 +101,22 @@ func TestIPManager(t *testing.T) { } wg.Wait() } + +// TestNullManager verifies that the NullManager does nothing. +func TestNullManager(t *testing.T) { + t.Run("null-manager", func(t *testing.T) { + r := &NullManager{} + if err := r.Grant(net.ParseIP("127.0.0.1")); err != nil { + t.Errorf("NullManager.Grant() error = %v, want nil", err) + } + if err := r.Revoke(net.ParseIP("127.0.0.1")); err != nil { + t.Errorf("NullManager.Revoke() error = %v, want nil", err) + } + if err := r.Start("1234", "eth0"); err != nil { + t.Errorf("NullManager.Start() error = %v, want nil", err) + } + if _, err := r.Stop(); err != nil { + t.Errorf("NullManager.Stop() error = %v, want nil", err) + } + }) +} diff --git a/chanio/reader.go b/chanio/reader.go new file mode 100644 index 0000000..c83667c --- /dev/null +++ b/chanio/reader.go @@ -0,0 +1,16 @@ +package chanio + +import "io" + +// ReadOnce reads from the given reader once and closes the returned channel. All data is discarded. +func ReadOnce(r io.Reader) <-chan struct{} { + c := make(chan struct{}) + go func() { + b := make([]byte, 1) + // Block on read. Will return on EOF or when client sends data (which is discarded). + r.Read(b) + // Close channel to send reader a signal. + close(c) + }() + return c +} diff --git a/chanio/reader_test.go b/chanio/reader_test.go new file mode 100644 index 0000000..42d8b54 --- /dev/null +++ b/chanio/reader_test.go @@ -0,0 +1,25 @@ +package chanio + +import ( + "bytes" + "context" + "testing" + "time" +) + +func TestReadOnce(t *testing.T) { + t.Run("okay", func(t *testing.T) { + b := bytes.NewBufferString("message") + got := ReadOnce(b) + // Absolute timeout. Should never be reached. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + select { + case <-got: + // success + case <-ctx.Done(): + t.Errorf("ReadOnce() = context should never timeout") + } + }) +} diff --git a/cmd/envelope/README.md b/cmd/envelope/README.md index c1587f7..064e9cd 100644 --- a/cmd/envelope/README.md +++ b/cmd/envelope/README.md @@ -94,3 +94,26 @@ eyJhdWQiOlsibWxhYjEubGdhMDMiXSwiZXhwIjoxNTg0NTAyMjEyLCJpc3MiOiJsb2NhdGUubWVhc3Vy nRsYWIubmV0Iiwic3ViIjoiMTI3LjAuMC4yIn0.FZSjjDjWJVGSKzJKJP5Cbaacp8PNqGX5_zETe3SQsXvhlo hGlAlKLdhDkjBDIKttXkO3BL5xyQ09cVGfmbelDA ``` + +### Local development without access tokens + +Start the access envelope server, without requiring access tokens (and +without iptables management; by default these are both required). + +```sh +~/bin/envelope -envelope.token-required=false +``` + +Connect to the local access envelope using `curl`. When tokens are not +required, the default timeout is 60s. After this timeout, the server will +hangup automatically. + +```sh +curl --no-buffer \ + --header "Connection: Upgrade" \ + --header "Upgrade: websocket" \ + --header "Sec-WebSocket-Protocol: net.measurementlab.envelope" \ + --header "Sec-WebSocket-Version: 13" \ + --header "Sec-WebSocket-Key: aGVsbG8K" \ + http://localhost:8880/v0/envelope/access +``` diff --git a/cmd/envelope/main.go b/cmd/envelope/main.go index 5ed1297..6d054a9 100644 --- a/cmd/envelope/main.go +++ b/cmd/envelope/main.go @@ -12,9 +12,12 @@ import ( "time" "github.com/gorilla/handlers" + "github.com/gorilla/websocket" "github.com/justinas/alice" + "gopkg.in/square/go-jose.v2/jwt" "github.com/m-lab/access/address" + "github.com/m-lab/access/chanio" "github.com/m-lab/access/controller" "github.com/m-lab/access/token" "github.com/m-lab/go/flagx" @@ -82,44 +85,32 @@ func customFormat(w io.Writer, p handlers.LogFormatterParams) { } func (env *envelopeHandler) AllowRequest(rw http.ResponseWriter, req *http.Request) { - // AllowRequest is a state-changing POST method. - if req.Method != http.MethodPost { + // Websocket requests must be GET. Also note that AllowRequest is a + // state-changing operation. + if req.Method != http.MethodGet { rw.WriteHeader(http.StatusMethodNotAllowed) return } - cl := controller.GetClaim(req.Context()) - if cl == nil { - // This could happen if the TokenController is disabled. - logx.Debug.Println("missing claim") - rw.WriteHeader(http.StatusInternalServerError) - return - } - - if cl.Subject != env.subject { - logx.Debug.Println("wrong subject claim") - rw.WriteHeader(http.StatusBadRequest) - return - } - - // Tests may run (possibly repeatedly) until the claim expires. - deadline := cl.Expiry.Time() - if deadline.Before(time.Now()) { - logx.Debug.Println("already past expiration") + // Use client remote address as the basis of granting temporary subnet access. + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + logx.Debug.Println("failed to split remote addr:", err) rw.WriteHeader(http.StatusBadRequest) return } - // Use client remote address as the basis of granting temporary subnet access. - host, _, err := net.SplitHostPort(req.RemoteAddr) + // Get deadline based on token claim. + cl := controller.GetClaim(req.Context()) + deadline, err := env.getDeadline(cl) if err != nil { - logx.Debug.Println("failed to split remote addr") + logx.Debug.Println("failed to get deadline:", err) rw.WriteHeader(http.StatusBadRequest) return } - allow := net.ParseIP(host) - err = env.Grant(allow) + remote := net.ParseIP(host) + err = env.Grant(remote) switch { case err == address.ErrMaxConcurrent: logx.Debug.Println("grant limit reached") @@ -131,19 +122,90 @@ func (env *envelopeHandler) AllowRequest(rw http.ResponseWriter, req *http.Reque return } - ctx, cancel := context.WithDeadline(req.Context(), deadline) - defer cancel() - // Keep the lease until: - // * client disconnects. - // * timeout expires. - // * parent context is cancelled. - <-ctx.Done() + conn := setupConn(rw, req) + if conn == nil { + logx.Debug.Println("setup websocket conn failed") + rw.WriteHeader(http.StatusInternalServerError) + // TODO: handle panic. + rtx.PanicOnError(env.Revoke(remote), "Failed to remove rule for "+remote.String()) + return + } + + // At this point, we want to wait for either the deadline (when the envelope + // service closes the connection) or the client to close the websocket conn + // (to signal completion). + env.wait(req.Context(), conn, deadline) + // TODO: handle panic. - rtx.PanicOnError(env.Revoke(allow), "Failed to remove rule for "+allow.String()) + rtx.PanicOnError(env.Revoke(remote), "Failed to remove rule for "+remote.String()) +} + +func (env *envelopeHandler) getDeadline(cl *jwt.Claims) (time.Time, error) { + if cl == nil && requireTokens { + logx.Debug.Println("missing claim") + return time.Time{}, fmt.Errorf("missing claim when tokens required") + } + + if cl == nil { + // This could happen if tokens are not required. + return time.Now().Add(time.Minute), nil + } + + if cl.Subject != env.subject { + logx.Debug.Println("wrong subject claim") + return time.Time{}, fmt.Errorf("wrong claim subject") + } + + // Tests may run (possibly repeatedly) until the claim expires. + deadline := cl.Expiry.Time() + if deadline.Before(time.Now()) { + logx.Debug.Println("already past expiration") + return time.Time{}, fmt.Errorf("already past claim expiration") + } + return deadline, nil +} + +func setupConn(writer http.ResponseWriter, request *http.Request) *websocket.Conn { + headers := http.Header{} + headers.Add("Sec-WebSocket-Protocol", "net.measurementlab.envelope") + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + // Allow cross origin resource sharing + return true + }, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + conn, err := upgrader.Upgrade(writer, request, headers) + if err != nil { + logx.Debug.Println("failed to upgrade", err) + return nil + } + return conn +} + +func (env *envelopeHandler) wait(ctx context.Context, c *websocket.Conn, dl time.Time) { + // NOTE: we are explicitly ignoring the error value from SetDeadline. + // Any error there will show up on read below. + c.SetReadDeadline(dl) + c.SetWriteDeadline(dl) + ctxdl, cancel := context.WithDeadline(ctx, dl) + defer cancel() + // Clean up client connection upon return. + defer c.Close() + + // Keep the client connection open and the IP grant enabled until: + // * parent context expires. + // * context deadline expires. + // * client disconnects (or writes data that we don't expect). + select { + case <-ctxdl.Done(): + case <-chanio.ReadOnce(c.UnderlyingConn()): + } } var mainCtx, mainCancel = context.WithCancel(context.Background()) -var getEnvelopeHandler = func(subject string, mgr *address.IPManager) envelopeHandler { +var getEnvelopeHandler = func(subject string, mgr address.Manager) envelopeHandler { return envelopeHandler{ manager: mgr, subject: subject, @@ -161,7 +223,12 @@ func main() { verify, err := token.NewVerifier(verifyKeys.Get()...) rtx.Must(err, "Failed to create token verifier") - mgr := address.NewIPManager(maxIPs) + var mgr address.Manager + if requireTokens { + mgr = address.NewIPManager(maxIPs) + } else { + mgr = &address.NullManager{} + } env := getEnvelopeHandler(subject, mgr) ctl, _ := controller.Setup(mainCtx, verify, requireTokens, machine) // Handle all requests using the alice http handler chaining library. @@ -172,6 +239,12 @@ func main() { srv := &http.Server{ Addr: listenAddr, Handler: ac.Then(mux), + + // NOTE: prevent connections from staying open indefinitely. + // And, these timeouts are reset for individual clients that + // negotiate the websocket connection. + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, } _, port, err := net.SplitHostPort(listenAddr) rtx.Must(err, "failed to split listen address: %q", listenAddr) diff --git a/cmd/envelope/main_test.go b/cmd/envelope/main_test.go index dd088ee..8898032 100644 --- a/cmd/envelope/main_test.go +++ b/cmd/envelope/main_test.go @@ -12,10 +12,13 @@ import ( "net/http/httptest" "net/url" "os" + "strings" "testing" "time" "github.com/gorilla/handlers" + "github.com/gorilla/websocket" + "github.com/justinas/alice" "gopkg.in/square/go-jose.v2/jwt" "github.com/m-lab/access/address" @@ -60,6 +63,7 @@ func Test_main(t *testing.T) { mainCtx, mainCancel = context.WithCancel(context.Background()) certFile = "testdata/insecure-cert.pem" keyFile = "testdata/insecure-key.pem" + requireTokens = false // use NullManager. mainCancel() main() } @@ -75,31 +79,48 @@ func (f *fakeManager) Revoke(ip net.IP) error { return nil } -func Test_envelopeHandler_AllowRequest(t *testing.T) { +// Test_envelopeHandler_AllowRequest_Errors exercises error paths that cannot be +// reached using the websocket client package directly. +func Test_envelopeHandler_AllowRequest_Errors(t *testing.T) { subject := "envelope" tests := []struct { - name string - param string - method string - remote string - code int - claim *jwt.Claims - grantErr error + name string + method string + remote string + code int + allowEmptyClaim bool + claim *jwt.Claims + grantErr error }{ { name: "error-bad-method", - method: http.MethodGet, + method: http.MethodPost, code: http.StatusMethodNotAllowed, }, { name: "error-no-claim-found", - method: http.MethodPost, - code: http.StatusInternalServerError, + method: http.MethodGet, + code: http.StatusBadRequest, + remote: "127.0.0.2:1234", + }, + { + name: "error-allow-empty-claim", + method: http.MethodGet, + code: http.StatusBadRequest, + allowEmptyClaim: true, + remote: "127.0.0.2:1234", + }, + { + name: "error-remote-host-corrupt", + method: http.MethodGet, + code: http.StatusBadRequest, + remote: "thisisnotanip-1234", }, { name: "error-claim-subject-is-invalid", - method: http.MethodPost, + method: http.MethodGet, code: http.StatusBadRequest, + remote: "127.0.0.2:1234", claim: &jwt.Claims{ Issuer: "locate", Subject: "wrong-subject", @@ -107,7 +128,7 @@ func Test_envelopeHandler_AllowRequest(t *testing.T) { }, { name: "error-claim-is-already-expired", - method: http.MethodPost, + method: http.MethodGet, code: http.StatusBadRequest, remote: "127.0.0.2:1234", claim: &jwt.Claims{ @@ -118,7 +139,7 @@ func Test_envelopeHandler_AllowRequest(t *testing.T) { }, { name: "error-grant-ip-failure-max-concurrent", - method: http.MethodPost, + method: http.MethodGet, code: http.StatusServiceUnavailable, remote: "127.0.0.2:1234", claim: &jwt.Claims{ @@ -128,21 +149,9 @@ func Test_envelopeHandler_AllowRequest(t *testing.T) { }, grantErr: address.ErrMaxConcurrent, }, - { - name: "error-split-host-port-failure", - method: http.MethodPost, - code: http.StatusBadRequest, - remote: "corrupt-remote-ip", - claim: &jwt.Claims{ - Issuer: "locate", - Subject: subject, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)), - }, - grantErr: address.ErrMaxConcurrent, - }, { name: "error-grant-ip-failure-", - method: http.MethodPost, + method: http.MethodGet, code: http.StatusInternalServerError, remote: "127.0.0.2:1234", claim: &jwt.Claims{ @@ -152,31 +161,22 @@ func Test_envelopeHandler_AllowRequest(t *testing.T) { }, grantErr: errors.New("generic grant error"), }, - { - name: "success", - method: http.MethodPost, - code: http.StatusOK, - remote: "127.0.0.2:1234", - claim: &jwt.Claims{ - Issuer: "locate", - Subject: subject, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Second)), - }, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rw := httptest.NewRecorder() - req := httptest.NewRequest(tt.method, "/v0/envelope/access"+tt.param, nil) + req := httptest.NewRequest(tt.method, "/v0/envelope/access", nil) env := &envelopeHandler{ manager: &fakeManager{ grantErr: tt.grantErr, }, subject: "envelope", } + requireTokens = !tt.allowEmptyClaim if tt.claim != nil { req = req.Clone(controller.SetClaim(req.Context(), tt.claim)) } + req.RemoteAddr = tt.remote env.AllowRequest(rw, req) @@ -187,6 +187,77 @@ func Test_envelopeHandler_AllowRequest(t *testing.T) { } } +func Test_envelopeHandler_AllowRequest_Websocket(t *testing.T) { + subject := "envelope" + tests := []struct { + name string + code int + sleep time.Duration + claim *jwt.Claims + }{ + { + name: "success-exit-fast", + code: http.StatusSwitchingProtocols, + claim: &jwt.Claims{ + Issuer: "locate", + Subject: subject, + // Expiry: set below. + }, + }, + { + name: "success-wait-for-timeout", + code: http.StatusSwitchingProtocols, + claim: &jwt.Claims{ + Issuer: "locate", + Subject: subject, + // Expiry: set below. + }, + sleep: 2 * time.Second, // Force delay to create timeout. + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + env := &envelopeHandler{ + manager: &fakeManager{}, + subject: "envelope", + } + requireTokens = true + // Create a synthetic token claim handler that adds the unit test + // claim to the request context. It is simpler to inject the claim + // instead of invoking the PKI needed to sign and verify a real claim. + addClaims := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Unconditionally assign the unit test claim to the request context. + tt.claim.Expiry = jwt.NewNumericDate(time.Now().Add(time.Second)) + next.ServeHTTP(w, r.Clone(controller.SetClaim(r.Context(), tt.claim))) + }) + } + // Create a handler chain that adds a claim (above) and then handles the request. + ac := alice.New(addClaims).Then(http.HandlerFunc(env.AllowRequest)) + // Setup the fake server. + mux := http.NewServeMux() + mux.Handle("/v0/envelope/access", ac) + srv := httptest.NewServer(mux) + defer srv.Close() + + // Dial a websocket connection. + headers := http.Header{} + headers.Add("Sec-WebSocket-Protocol", "net.measurementlab.envelope") + c, resp, _ := websocket.DefaultDialer.Dial( + strings.Replace(srv.URL, "http", "ws", 1)+"/v0/envelope/access", headers) + + // Check the response code. + if tt.code != resp.StatusCode { + t.Errorf("AllowRequest() wrong status code; got %d, want %d", resp.StatusCode, tt.code) + } + if c != nil { + time.Sleep(tt.sleep) + c.Close() + } + }) + } +} + func Test_customFormat(t *testing.T) { tests := []struct { name string