Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v15] Read the bearer token over websocket endpoints instead of query parameter #37915

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
218 changes: 210 additions & 8 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/gravitational/oxy/ratelimit"
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -151,6 +152,11 @@ type Handler struct {

// tracer is used to create spans.
tracer oteltrace.Tracer

// wsIODeadline is used to set a deadline for receiving a message from
// an authenticated websocket so unauthenticated sockets dont get left
// open.
wsIODeadline time.Duration
}

// HandlerOption is a functional argument - an option that can be passed
Expand Down Expand Up @@ -365,6 +371,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
ClusterFeatures: cfg.ClusterFeatures,
healthCheckAppServer: cfg.HealthCheckAppServer,
tracer: cfg.TracerProvider.Tracer(teleport.ComponentWeb),
wsIODeadline: wsIODeadline,
}

// Check for self-hosted vs Cloud.
Expand Down Expand Up @@ -720,7 +727,10 @@ func (h *Handler) bindDefaultEndpoints() {
h.DELETE("/webapi/sites/:site/locks/:uuid", h.WithClusterAuth(h.deleteClusterLock))

// active sessions handlers
h.GET("/webapi/sites/:site/connect", h.WithClusterAuth(h.siteNodeConnect)) // connect to an active session (via websocket)
// Deprecated: The connect/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/sites/:site/connect", h.WithClusterAuthWebSocket(false, h.siteNodeConnect)) // connect to an active session (via websocket)
h.GET("/webapi/sites/:site/connect/ws", h.WithClusterAuthWebSocket(true, h.siteNodeConnect)) // connect to an active session (via websocket, with auth over websocket)
h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions

// Audit events handlers.
Expand Down Expand Up @@ -828,9 +838,17 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET("/webapi/sites/:site/desktopservices", h.WithClusterAuth(h.clusterDesktopServicesGet))
h.GET("/webapi/sites/:site/desktops/:desktopName", h.WithClusterAuth(h.getDesktopHandle))
// GET /webapi/sites/:site/desktops/:desktopName/connect?access_token=<bearer_token>&username=<username>&width=<width>&height=<height>
h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuth(h.desktopConnectHandle))
// Deprecated: The connect/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuthWebSocket(false, h.desktopConnectHandle))
// GET /webapi/sites/:site/desktops/:desktopName/connect?username=<username>&width=<width>&height=<height>
h.GET("/webapi/sites/:site/desktops/:desktopName/connect/ws", h.WithClusterAuthWebSocket(true, h.desktopConnectHandle))
// GET /webapi/sites/:site/desktopplayback/:sid?access_token=<bearer_token>
h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuth(h.desktopPlaybackHandle))
// Deprecated: The desktopplayback/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuthWebSocket(false, h.desktopPlaybackHandle))
// // GET /webapi/sites/:site/desktopplayback/:sid/ws
h.GET("/webapi/sites/:site/desktopplayback/:sid/ws", h.WithClusterAuthWebSocket(true, h.desktopPlaybackHandle))
h.GET("/webapi/sites/:site/desktops/:desktopName/active", h.WithClusterAuth(h.desktopIsActive))

// GET a Connection Diagnostics by its name
Expand Down Expand Up @@ -889,7 +907,11 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET("/webapi/sites/:site/user-groups", h.WithClusterAuth(h.getUserGroups))

// WebSocket endpoint for the chat conversation
h.GET("/webapi/sites/:site/assistant", h.WithClusterAuth(h.assistant))
// Deprecated: The connect/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/sites/:site/assistant", h.WithClusterAuthWebSocket(false, h.assistant))
// WebSocket endpoint for the chat conversation, websocket auth
h.GET("/webapi/sites/:site/assistant/ws", h.WithClusterAuthWebSocket(true, h.assistant))

// Sets the title for the conversation.
h.POST("/webapi/assistant/conversations/:conversation_id/title", h.WithAuth(h.setAssistantTitle))
Expand All @@ -908,7 +930,11 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET("/webapi/assistant/conversations/:conversation_id", h.WithAuth(h.getAssistantConversationByID))

// Allows executing an arbitrary command on multiple nodes.
h.GET("/webapi/command/:site/execute", h.WithClusterAuth(h.executeCommand))
// Deprecated: The execute/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/command/:site/execute", h.WithClusterAuthWebSocket(false, h.executeCommand))
// Allows executing an arbitrary command on multiple nodes, websocket auth.
h.GET("/webapi/command/:site/execute/ws", h.WithClusterAuthWebSocket(true, h.executeCommand))

// Fetches the user's preferences
h.GET("/webapi/user/preferences", h.WithAuth(h.getUserPreferences))
Expand Down Expand Up @@ -2941,6 +2967,7 @@ func (h *Handler) siteNodeConnect(
p httprouter.Params,
sessionCtx *SessionContext,
site reversetunnelclient.RemoteSite,
ws *websocket.Conn,
) (interface{}, error) {
q := r.URL.Query()
params := q.Get("params")
Expand Down Expand Up @@ -3033,6 +3060,7 @@ func (h *Handler) siteNodeConnect(
PROXYSigner: h.cfg.PROXYSigner,
Tracker: tracker,
PresenceChecker: h.cfg.PresenceChecker,
WebsocketConn: ws,
}

term, err := NewTerminal(ctx, terminalConfig)
Expand Down Expand Up @@ -3731,6 +3759,9 @@ type ContextHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Pa
// ClusterHandler is a authenticated handler that is called for some existing remote cluster
type ClusterHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error)

// ClusterWebsocketHandler is a authenticated websocket handler that is called for some existing remote cluster
type ClusterWebsocketHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn) (interface{}, error)

// WithClusterAuth wraps a ClusterHandler to ensure that a request is authenticated to this proxy
// (the same as WithAuth), as well as to grab the remoteSite (which can represent this local cluster
// or a remote trusted cluster) as specified by the ":site" url parameter.
Expand All @@ -3745,12 +3776,108 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle {
})
}

func (h *Handler) writeErrToWebSocket(ws *websocket.Conn, err error) {
if err == nil {
return
}
errEnvelope := Envelope{
Type: defaults.WebsocketError,
Payload: trace.UserMessage(err),
}
env, err := errEnvelope.Marshal()
if err != nil {
h.log.WithError(err).Error("error marshaling proto")
return
}
if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil {
h.log.WithError(err).Error("error writing proto")
return
}
}

// authnWsUpgrader is an upgrader that allows any origin to connect to the websocket.
// This makes our lives easier in our automated tests. While ordinarily this would be
// used to enforce the same-origin policy, we don't need to worry about that for authenticated
// websockets, which also require a valid bearer token sent over the websocket after upgrade.
// Therefore even if an attacker were to connect to the websocket and trick the browser into
// sending the session cookie, they would still fail to send the bearer token needed to authenticate.
var authnWsUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return true },
}

// WithClusterAuthWebSocket wraps a ClusterWebsocketHandler to ensure that a request is authenticated
// to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as
// well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster)
// as specified by the ":site" url parameter.
//
// TODO(lxea): remove the 'websocketAuth' bool once the deprecated websocket handlers are removed
func (h *Handler) WithClusterAuthWebSocket(websocketAuth bool, fn ClusterWebsocketHandler) httprouter.Handle {
return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (any, error) {
var sctx *SessionContext
var ws *websocket.Conn
var site reversetunnelclient.RemoteSite
var err error

if websocketAuth {
sctx, ws, site, err = h.authenticateWSRequestWithCluster(w, r, p)
} else {
sctx, ws, site, err = h.authenticateWSRequestWithClusterDeprecated(w, r, p)
}

if err != nil {
return nil, trace.Wrap(err)
}
// WS protocol requires the server send a close message
// which should be done by downstream users
defer ws.Close()
if _, err := fn(w, r, p, sctx, site, ws); err != nil {
h.writeErrToWebSocket(ws, err)
}
return nil, nil
})
}

// authenticateWSRequestWithCluster ensures that a request is
// authenticated to this proxy via websocket, returning the
// *SessionContext (same as AuthenticateRequest), and also grabs the
// remoteSite (which can represent this local cluster or a remote
// trusted cluster) as specified by the ":site" url parameter.
func (h *Handler) authenticateWSRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) {
sctx, ws, err := h.AuthenticateRequestWS(w, r)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}

site, err := h.getSiteByParams(sctx, p)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}

return sctx, ws, site, nil
}

// TODO(lxea): remove once the deprecated websocket handlers are removed
func (h *Handler) authenticateWSRequestWithClusterDeprecated(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) {
sctx, site, err := h.authenticateRequestWithCluster(w, r, p)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}
ws, err := authnWsUpgrader.Upgrade(w, r, nil)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}
return sctx, ws, site, nil
}

// authenticateRequestWithCluster ensures that a request is authenticated
// to this proxy, returning the *SessionContext (same as AuthenticateRequest),
// and also grabs the remoteSite (which can represent this local cluster or a
// remote trusted cluster) as specified by the ":site" url parameter.
func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, reversetunnelclient.RemoteSite, error) {
sctx, err := h.AuthenticateRequest(w, r, true)

if err != nil {
return nil, nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -4068,9 +4195,7 @@ func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error {
return trace.Wrap(err)
}

// AuthenticateRequest authenticates request using combination of a session cookie
// and bearer token
func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) {
func (h *Handler) validateCookie(w http.ResponseWriter, r *http.Request) (*SessionContext, error) {
const missingCookieMsg = "missing session cookie"
cookie, err := r.Cookie(websession.CookieName)
if err != nil || (cookie != nil && cookie.Value == "") {
Expand All @@ -4085,6 +4210,17 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch
clearSessionCookies((w))
return nil, trace.AccessDenied("need auth")
}

return sctx, nil
}

// AuthenticateRequest authenticates request using combination of a session cookie
// and bearer token
func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) {
sctx, err := h.validateCookie(w, r)
if err != nil {
return nil, trace.Wrap(err)
}
if checkBearerToken {
creds, err := roundtrip.ParseAuthHeaders(r)
if err != nil {
Expand Down Expand Up @@ -4137,6 +4273,72 @@ func contextWithMFAResponseFromRequestHeader(ctx context.Context, requestHeader
return ctx, nil
}

type wsBearerToken struct {
Token string `json:"token"`
}

type wsStatus struct {
Type string `json:"type"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}

// wsIODeadline is used to set a deadline for receiving a message from
// an authenticated websocket so unauthenticated sockets dont get left
// open.
const wsIODeadline = time.Second * 4

// AuthenticateRequest authenticates request using combination of a session cookie
// and bearer token retrieved from a websocket
func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) (*SessionContext, *websocket.Conn, error) {
sctx, err := h.validateCookie(w, r)
if err != nil {
return nil, nil, trace.Wrap(err)
}
ws, err := authnWsUpgrader.Upgrade(w, r, nil)
if err != nil {
return nil, nil, trace.ConnectionProblem(err, "Error upgrading to websocket: %v", err)
}
if err := ws.SetReadDeadline(time.Now().Add(wsIODeadline)); err != nil {
return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err)
}

var t wsBearerToken
if err := ws.ReadJSON(&t); err != nil {
return nil, nil, trace.Wrap(err)
}
if err := sctx.validateBearerToken(r.Context(), t.Token); err != nil {
writeErr := ws.WriteJSON(wsStatus{
Type: "create_session_response",
Status: "error",
Message: "invalid token",
})
if writeErr != nil {
log.Errorf("Error while writing invalid token error to websocket: %s", writeErr)
}

return nil, nil, trace.Wrap(err)
}

if err := ws.WriteJSON(wsStatus{
Type: "create_session_response",
Status: "ok",
}); err != nil {
return nil, nil, trace.Wrap(err)
}

// unset the deadline as downstream consumers should handle this themselves.
if err := ws.SetReadDeadline(time.Time{}); err != nil {
return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err)
}

if err := parseMFAResponseFromRequest(r); err != nil {
return nil, nil, trace.Wrap(err)
}

return sctx, ws, nil
}

// ProxyWithRoles returns a reverse tunnel proxy verifying the permissions
// of the given user.
func (h *Handler) ProxyWithRoles(ctx *SessionContext) (reversetunnelclient.Tunnel, error) {
Expand Down