Skip to content

Commit

Permalink
Read the bearer token over WS endpoints
Browse files Browse the repository at this point in the history
use the request context, not session

Dont pass websocket by context

lint

resolve some comments

Add TestWSAuthenticateRequest

Close ws in handler

deprecation notices, doc

resolve comments

resolve comments

give a longer read/write deadline

dont set write deadline, ws endpoints never did before and it breaks things

convert frontend to use ws access token

Resolove comments, move to using an explicit state

fix ci

reset read deadline

prettier

update connectToHost

linter

read errors from websocket

missing /ws on ttyWsAddr and fix wrong onmessage

fix race in test

lint

skip TestTerminal as it takes 11 seconds to run

dont skip the test

resolve apiserver comments

Add an AuthenticatedWebSocket class

convert other clients to use AuthenticatedWebSocket

Converts `AuthenticatedWebSocket` into drop-in replacement for `WebSocket` (#37699)

* Converts `AuthenticatedWebSocket` into drop-in replacement for `WebSocket`
that automatically goes through Teleport's custom authentication process
before facilitating any caller-defined communication.

This also reverts previous-`WebSocket` users to their original state
(sans the code for passing the bearer token in the query string),
swapping in `AuthenticatedWebSocket` in place of `WebSocket`.

Create a single authnWsUpgrader with a comment justifying why we turn off CORS

recieving to receiving

resolve comments
  • Loading branch information
lxea committed Feb 8, 2024
1 parent 502026f commit 1c5b5b6
Show file tree
Hide file tree
Showing 16 changed files with 672 additions and 105 deletions.
209 changes: 201 additions & 8 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/gravitational/oxy/ratelimit"
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -146,6 +147,11 @@ type Handler struct {

// tracer is used to create spans.
tracer oteltrace.Tracer

// wsIODeadline is used to set a deadline for receiving a message from
// an authenticated websocket so unauthenticated sockets dont get left
// open.
wsIODeadline time.Duration
}

// HandlerOption is a functional argument - an option that can be passed
Expand Down Expand Up @@ -348,6 +354,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
ClusterFeatures: cfg.ClusterFeatures,
healthCheckAppServer: cfg.HealthCheckAppServer,
tracer: cfg.TracerProvider.Tracer(teleport.ComponentWeb),
wsIODeadline: wsIODeadline,
}

// Check for self-hosted vs Cloud.
Expand Down Expand Up @@ -695,7 +702,10 @@ func (h *Handler) bindDefaultEndpoints() {
h.DELETE("/webapi/sites/:site/locks/:uuid", h.WithClusterAuth(h.deleteClusterLock))

// active sessions handlers
h.GET("/webapi/sites/:site/connect", h.WithClusterAuth(h.siteNodeConnect)) // connect to an active session (via websocket)
// Deprecated: The connect/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/sites/:site/connect", h.WithClusterAuthWebSocket(false, h.siteNodeConnect)) // connect to an active session (via websocket)
h.GET("/webapi/sites/:site/connect/ws", h.WithClusterAuthWebSocket(true, h.siteNodeConnect)) // connect to an active session (via websocket, with auth over websocket)
h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions

// Audit events handlers.
Expand Down Expand Up @@ -800,7 +810,11 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET("/webapi/sites/:site/desktopservices", h.WithClusterAuth(h.clusterDesktopServicesGet))
h.GET("/webapi/sites/:site/desktops/:desktopName", h.WithClusterAuth(h.getDesktopHandle))
// GET /webapi/sites/:site/desktops/:desktopName/connect?access_token=<bearer_token>&username=<username>&width=<width>&height=<height>
h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuth(h.desktopConnectHandle))
// Deprecated: The connect/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuthWebSocket(false, h.desktopConnectHandle))
// GET /webapi/sites/:site/desktops/:desktopName/connect?username=<username>&width=<width>&height=<height>
h.GET("/webapi/sites/:site/desktops/:desktopName/connect/ws", h.WithClusterAuthWebSocket(true, h.desktopConnectHandle))
// GET /webapi/sites/:site/desktopplayback/:sid?access_token=<bearer_token>
h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuth(h.desktopPlaybackHandle))
h.GET("/webapi/sites/:site/desktops/:desktopName/active", h.WithClusterAuth(h.desktopIsActive))
Expand Down Expand Up @@ -858,7 +872,11 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET("/webapi/sites/:site/user-groups", h.WithClusterAuth(h.getUserGroups))

// WebSocket endpoint for the chat conversation
h.GET("/webapi/sites/:site/assistant", h.WithClusterAuth(h.assistant))
// Deprecated: The connect/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/sites/:site/assistant", h.WithClusterAuthWebSocket(false, h.assistant))
// WebSocket endpoint for the chat conversation, websocket auth
h.GET("/webapi/sites/:site/assistant/ws", h.WithClusterAuthWebSocket(true, h.assistant))

// Sets the title for the conversation.
h.POST("/webapi/assistant/conversations/:conversation_id/title", h.WithAuth(h.setAssistantTitle))
Expand All @@ -877,7 +895,11 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET("/webapi/assistant/conversations/:conversation_id", h.WithAuth(h.getAssistantConversationByID))

// Allows executing an arbitrary command on multiple nodes.
h.GET("/webapi/command/:site/execute", h.WithClusterAuth(h.executeCommand))
// Deprecated: The execute/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/command/:site/execute", h.WithClusterAuthWebSocket(false, h.executeCommand))
// Allows executing an arbitrary command on multiple nodes, websocket auth.
h.GET("/webapi/command/:site/execute/ws", h.WithClusterAuthWebSocket(true, h.executeCommand))

// Fetches the user's preferences
h.GET("/webapi/user/preferences", h.WithAuth(h.getUserPreferences))
Expand Down Expand Up @@ -2942,6 +2964,7 @@ func (h *Handler) siteNodeConnect(
p httprouter.Params,
sessionCtx *SessionContext,
site reversetunnelclient.RemoteSite,
ws *websocket.Conn,
) (interface{}, error) {
q := r.URL.Query()
params := q.Get("params")
Expand Down Expand Up @@ -3034,6 +3057,7 @@ func (h *Handler) siteNodeConnect(
PROXYSigner: h.cfg.PROXYSigner,
Tracker: tracker,
Clock: h.clock,
WebsocketConn: ws,
}

term, err := NewTerminal(ctx, terminalConfig)
Expand Down Expand Up @@ -3752,6 +3776,9 @@ type ContextHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Pa
// ClusterHandler is a authenticated handler that is called for some existing remote cluster
type ClusterHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error)

// ClusterWebsocketHandler is a authenticated websocket handler that is called for some existing remote cluster
type ClusterWebsocketHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn) (interface{}, error)

// WithClusterAuth wraps a ClusterHandler to ensure that a request is authenticated to this proxy
// (the same as WithAuth), as well as to grab the remoteSite (which can represent this local cluster
// or a remote trusted cluster) as specified by the ":site" url parameter.
Expand All @@ -3766,12 +3793,107 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle {
})
}

func (h *Handler) writeErrToWebSocket(ws *websocket.Conn, err error) {
if err == nil {
return
}
errEnvelope := Envelope{
Type: defaults.WebsocketError,
Payload: trace.UserMessage(err),
}
env, err := errEnvelope.Marshal()
if err != nil {
h.log.WithError(err).Error("error marshaling proto")
return
}
if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil {
h.log.WithError(err).Error("error writing proto")
return
}
}

// authnWsUpgrader is an upgrader that allows any origin to connect to the websocket.
// This makes our lives easier in our automated tests. While ordinarily this would be
// used to enforce the same-origin policy, we don't need to worry about that for authenticated
// websockets, which also require a valid bearer token sent over the websocket after upgrade.
// Therefore even if an attacker were to connect to the websocket and trick the browser into
// sending the session cookie, they would still fail to send the bearer token needed to authenticate.
var authnWsUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return true },
}

// WithClusterAuthWebSocket wraps a ClusterWebsocketHandler to ensure that a request is authenticated
// to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as
// well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster)
// as specified by the ":site" url parameter.
//
// TODO(lxea): remove the 'websocketAuth' bool once the deprecated websocket handlers are removed
func (h *Handler) WithClusterAuthWebSocket(websocketAuth bool, fn ClusterWebsocketHandler) httprouter.Handle {
return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (any, error) {
if websocketAuth {
sctx, ws, site, err := h.authenticateWSRequestWithCluster(w, r, p)
if err != nil {
return nil, trace.Wrap(err)
}
// WS protocol requires the server send a close message
// which should be done by downstream users
defer ws.Close()

if _, err := fn(w, r, p, sctx, site, ws); err != nil {
h.writeErrToWebSocket(ws, err)
}
return nil, nil
}

sctx, site, err := h.authenticateRequestWithCluster(w, r, p)
if err != nil {
return nil, trace.Wrap(err)
}
ws, err := authnWsUpgrader.Upgrade(w, r, nil)
if err != nil {
const errMsg = "Error upgrading to websocket"
h.log.WithError(err).Error(errMsg)
http.Error(w, errMsg, http.StatusInternalServerError)
return nil, nil
}
// WS protocol requires the server send a close message
// which should be done by downstream users
defer ws.Close()
if _, err := fn(w, r, p, sctx, site, ws); err != nil {
h.writeErrToWebSocket(ws, err)
}
return nil, nil
})
}

// authenticateWSRequestWithCluster ensures that a request is
// authenticated to this proxy via websocket, returning the
// *SessionContext (same as AuthenticateRequest), and also grabs the
// remoteSite (which can represent this local cluster or a remote
// trusted cluster) as specified by the ":site" url parameter.
func (h *Handler) authenticateWSRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) {
sctx, ws, err := h.AuthenticateRequestWS(w, r)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}

site, err := h.getSiteByParams(sctx, p)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}

return sctx, ws, site, nil
}

// authenticateRequestWithCluster ensures that a request is authenticated
// to this proxy, returning the *SessionContext (same as AuthenticateRequest),
// and also grabs the remoteSite (which can represent this local cluster or a
// remote trusted cluster) as specified by the ":site" url parameter.
func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, reversetunnelclient.RemoteSite, error) {
sctx, err := h.AuthenticateRequest(w, r, true)

if err != nil {
return nil, nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -4089,9 +4211,7 @@ func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error {
return trace.Wrap(err)
}

// AuthenticateRequest authenticates request using combination of a session cookie
// and bearer token
func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) {
func (h *Handler) validateCookie(w http.ResponseWriter, r *http.Request) (*SessionContext, error) {
const missingCookieMsg = "missing session cookie"
cookie, err := r.Cookie(websession.CookieName)
if err != nil || (cookie != nil && cookie.Value == "") {
Expand All @@ -4101,11 +4221,22 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch
if err != nil {
return nil, trace.AccessDenied("failed to decode cookie")
}
ctx, err := h.auth.getOrCreateSession(r.Context(), decodedCookie.User, decodedCookie.SID)
sctx, err := h.auth.getOrCreateSession(r.Context(), decodedCookie.User, decodedCookie.SID)
if err != nil {
websession.ClearCookie(w)
return nil, trace.AccessDenied("need auth")
}

return sctx, nil
}

// AuthenticateRequest authenticates request using combination of a session cookie
// and bearer token
func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) {
ctx, err := h.validateCookie(w, r)
if err != nil {
return nil, trace.Wrap(err)
}
if checkBearerToken {
creds, err := roundtrip.ParseAuthHeaders(r)
if err != nil {
Expand All @@ -4118,6 +4249,68 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch
return ctx, nil
}

type wsBearerToken struct {
Token string `json:"token"`
}

type wsStatus struct {
Type string `json:"type"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}

// wsIODeadline is used to set a deadline for receiving a message from
// an authenticated websocket so unauthenticated sockets dont get left
// open.
const wsIODeadline = time.Second * 4

// AuthenticateRequest authenticates request using combination of a session cookie
// and bearer token retrieved from a websocket
func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) (*SessionContext, *websocket.Conn, error) {
sctx, err := h.validateCookie(w, r)
if err != nil {
return nil, nil, trace.Wrap(err)
}
ws, err := authnWsUpgrader.Upgrade(w, r, nil)
if err != nil {
return nil, nil, trace.ConnectionProblem(err, "Error upgrading to websocket: %v", err)
}
if err := ws.SetReadDeadline(time.Now().Add(wsIODeadline)); err != nil {
return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err)
}

var t wsBearerToken
if err := ws.ReadJSON(&t); err != nil {
return nil, nil, trace.Wrap(err)
}
if err := sctx.validateBearerToken(r.Context(), t.Token); err != nil {
writeErr := ws.WriteJSON(wsStatus{
Type: "create_session_response",
Status: "error",
Message: "invalid token",
})
if writeErr != nil {
log.Errorf("Error while writing invalid token error to websocket: %s", writeErr)
}

return nil, nil, trace.Wrap(err)
}

if err := ws.WriteJSON(wsStatus{
Type: "create_session_response",
Status: "ok",
}); err != nil {
return nil, nil, trace.Wrap(err)
}

// unset the deadline as downstream consumers should handle this themselves.
if err := ws.SetReadDeadline(time.Time{}); err != nil {
return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err)
}

return sctx, ws, nil
}

// ProxyWithRoles returns a reverse tunnel proxy verifying the permissions
// of the given user.
func (h *Handler) ProxyWithRoles(ctx *SessionContext) (reversetunnelclient.Tunnel, error) {
Expand Down

0 comments on commit 1c5b5b6

Please sign in to comment.