Skip to content

Commit

Permalink
[v10] Validate token for node join script (#14944)
Browse files Browse the repository at this point in the history
The token value is provided via the HTTP request
and fed into the node join script. This could allow an attacker
to generate a node-join script with malicious code included.

Fix this by validating that tokens are valid and exist in the backend.

Additionally, we recently added the ability to specify labels via the
node-labels query parameter, which is also user-controlled. Since this
functionality was never integrated in the UI, we remove it here and
will add an alternative implementation in the future.

Also use single quotes in script to prevent expansion.
  • Loading branch information
zmb3 committed Jul 27, 2022
1 parent 0d782ee commit 055d531
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 110 deletions.
8 changes: 6 additions & 2 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -1537,8 +1537,12 @@ func (a *ServerWithRoles) GetTokens(ctx context.Context) ([]types.ProvisionToken
}

func (a *ServerWithRoles) GetToken(ctx context.Context, token string) (types.ProvisionToken, error) {
if err := a.action(apidefaults.Namespace, types.KindToken, types.VerbRead); err != nil {
return nil, trace.Wrap(err)
// The Proxy has permission to look up tokens by name in order to validate
// attempts to use the node join script.
if isProxy := a.hasBuiltinRole(types.RoleProxy); !isProxy {
if err := a.action(apidefaults.Namespace, types.KindToken, types.VerbRead); err != nil {
return nil, trace.Wrap(err)
}
}
return a.authServer.GetToken(ctx, token)
}
Expand Down
51 changes: 24 additions & 27 deletions lib/web/join_tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ type scriptSettings struct {
appName string
appURI string
joinMethod string
nodeLabels string
}

func (h *Handler) createTokenHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params, ctx *SessionContext) (interface{}, error) {
Expand Down Expand Up @@ -149,23 +148,14 @@ func (h *Handler) createNodeTokenHandle(w http.ResponseWriter, r *http.Request,

func (h *Handler) getNodeJoinScriptHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params) (interface{}, error) {
scripts.SetScriptHeaders(w.Header())
queryValues := r.URL.Query()

nodeLabels, err := url.QueryUnescape(queryValues.Get("node-labels"))
if err != nil {
log.WithField("query-param", "node-labels").WithError(err).Debug("Failed to return the app install script.")
w.Write(scripts.ErrorBashScript)
return nil, nil
}

settings := scriptSettings{
token: params.ByName("token"),
appInstallMode: false,
joinMethod: r.URL.Query().Get("method"),
nodeLabels: nodeLabels,
}

script, err := getJoinScript(settings, h.GetProxyClient())
script, err := getJoinScript(r.Context(), settings, h.GetProxyClient())
if err != nil {
log.WithError(err).Info("Failed to return the node install script.")
w.Write(scripts.ErrorBashScript)
Expand Down Expand Up @@ -206,7 +196,7 @@ func (h *Handler) getAppJoinScriptHandle(w http.ResponseWriter, r *http.Request,
appURI: uri,
}

script, err := getJoinScript(settings, h.GetProxyClient())
script, err := getJoinScript(r.Context(), settings, h.GetProxyClient())
if err != nil {
log.WithError(err).Info("Failed to return the app install script.")
w.Write(scripts.ErrorBashScript)
Expand Down Expand Up @@ -239,21 +229,27 @@ func createJoinToken(ctx context.Context, m nodeAPIGetter, roles types.SystemRol
}, nil
}

func getJoinScript(settings scriptSettings, m nodeAPIGetter) (string, error) {
// Skip decoding validation for IAM tokens since they are generated with a different method
if settings.joinMethod != string(types.JoinMethodIAM) {
// This token does not need to be validated against the backend because it's not used to
// reveal any sensitive information. However, we still need to perform a simple input
// validation check by verifying that the token was auto-generated.
// Auto-generated tokens must be encoded and must have an expected length.
func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter) (string, error) {
switch types.JoinMethod(settings.joinMethod) {
case types.JoinMethodUnspecified, types.JoinMethodToken:
decodedToken, err := hex.DecodeString(settings.token)
if err != nil {
return "", trace.Wrap(err)
}

if len(decodedToken) != auth.TokenLenBytes {
return "", trace.BadParameter("invalid token length")
return "", trace.BadParameter("invalid token %q", decodedToken)
}

case types.JoinMethodIAM:
default:
return "", trace.BadParameter("join method %q is not supported via script", settings.joinMethod)
}

// The provided token can be attacker controlled, so we must validate
// it with the backend before using it to generate the script.
_, err := m.GetToken(ctx, settings.token)
if err != nil {
return "", trace.BadParameter("invalid token")
}

// Get hostname and port from proxy server address.
Expand Down Expand Up @@ -285,9 +281,9 @@ func getJoinScript(settings scriptSettings, m nodeAPIGetter) (string, error) {
var buf bytes.Buffer
// If app install mode is requested but parameters are blank for some reason,
// we need to return an error.
if settings.appInstallMode == true {
if settings.appInstallMode {
if errs := validation.IsDNS1035Label(settings.appName); len(errs) > 0 {
return "", trace.BadParameter("appName %q must be a valid DNS subdomain: https://gravitational.com/teleport/docs/application-access/#application-name", settings.appName)
return "", trace.BadParameter("appName %q must be a valid DNS subdomain: https://goteleport.com/docs/application-access/guides/connecting-apps/#application-name", settings.appName)
}
if !appURIPattern.MatchString(settings.appURI) {
return "", trace.BadParameter("appURI %q contains invalid characters", settings.appURI)
Expand All @@ -311,7 +307,6 @@ func getJoinScript(settings scriptSettings, m nodeAPIGetter) (string, error) {
"appName": settings.appName,
"appURI": settings.appURI,
"joinMethod": settings.joinMethod,
"nodeLabels": settings.nodeLabels,
})
if err != nil {
return "", trace.Wrap(err)
Expand Down Expand Up @@ -364,16 +359,18 @@ func isSameRuleSet(r1 []*types.TokenRule, r2 []*types.TokenRule) bool {
}

type nodeAPIGetter interface {
// GenerateToken creates a special provisioning token for a new SSH server
// that is valid for ttl period seconds.
// GenerateToken creates a special provisioning token for a new SSH server.
//
// This token is used by SSH server to authenticate with Auth server
// and get signed certificate and private key from the auth server.
// and get a signed certificate.
//
// If token is not supplied, it will be auto generated and returned.
// If TTL is not supplied, token will be valid until removed.
GenerateToken(ctx context.Context, req *proto.GenerateTokenRequest) (string, error)

// GetToken looks up a provisioning token.
GetToken(ctx context.Context, token string) (types.ProvisionToken, error)

// GetClusterCACert returns the CAs for the local cluster without signing keys.
GetClusterCACert(ctx context.Context) (*proto.GetClusterCACertResponse, error)

Expand Down
189 changes: 123 additions & 66 deletions lib/web/join_tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package web

import (
"context"
"encoding/hex"
"testing"
"time"

Expand All @@ -31,6 +32,7 @@ import (
)

func TestCreateNodeJoinToken(t *testing.T) {
t.Parallel()
m := &mockedNodeAPIGetter{}
m.mockGenerateToken = func(ctx context.Context, req *proto.GenerateTokenRequest) (string, error) {
return "some-token-id", nil
Expand All @@ -47,6 +49,7 @@ func TestCreateNodeJoinToken(t *testing.T) {
}

func TestGenerateIAMTokenName(t *testing.T) {
t.Parallel()
rule1 := types.TokenRule{
AWSAccount: "100000000000",
AWSARN: "arn:aws:iam:1",
Expand Down Expand Up @@ -84,6 +87,7 @@ func TestGenerateIAMTokenName(t *testing.T) {
}

func TestSortRules(t *testing.T) {
t.Parallel()
tt := []struct {
name string
rules []*types.TokenRule
Expand Down Expand Up @@ -277,81 +281,126 @@ func TestSortRules(t *testing.T) {
}
}

func TestGetNodeJoinScript(t *testing.T) {
m := &mockedNodeAPIGetter{}
m.mockGetProxyServers = func() ([]types.Server, error) {
var s types.ServerV2
s.SetPublicAddr("test-host:12345678")

return []types.Server{&s}, nil
}
m.mockGetClusterCACert = func(ctx context.Context) (*proto.GetClusterCACertResponse, error) {
fakeBytes := []byte(fixtures.SigningCertPEM)
return &proto.GetClusterCACertResponse{TLSCA: fakeBytes}, nil
}
func toHex(s string) string { return hex.EncodeToString([]byte(s)) }

nilTokenLength := scriptSettings{
token: "",
}
func TestGetNodeJoinScript(t *testing.T) {
validToken := "f18da1c9f6630a51e8daf121e7451daa"
validIAMToken := "valid-iam-token"

shortTokenLength := scriptSettings{
token: "f18da1c9f6630a51e8daf121e7451d",
}
m := &mockedNodeAPIGetter{
mockGetProxyServers: func() ([]types.Server, error) {
var s types.ServerV2
s.SetPublicAddr("test-host:12345678")

testTokenID := "f18da1c9f6630a51e8daf121e7451daa"
validTokenLength := scriptSettings{
token: testTokenID,
return []types.Server{&s}, nil
},
mockGetClusterCACert: func(context.Context) (*proto.GetClusterCACertResponse, error) {
fakeBytes := []byte(fixtures.SigningCertPEM)
return &proto.GetClusterCACertResponse{TLSCA: fakeBytes}, nil
},
mockGetToken: func(_ context.Context, token string) (types.ProvisionToken, error) {
if token == validToken || token == validIAMToken {
return &types.ProvisionTokenV2{
Metadata: types.Metadata{
Name: token,
},
}, nil
}
return nil, trace.NotFound("token does not exist")
},
}

// Test zero-value initialization.
script, err := getJoinScript(scriptSettings{}, m)
require.Empty(t, script)
require.True(t, trace.IsBadParameter(err))

// Test bad token lengths.
script, err = getJoinScript(nilTokenLength, m)
require.Empty(t, script)
require.True(t, trace.IsBadParameter(err))

script, err = getJoinScript(shortTokenLength, m)
require.Empty(t, script)
require.True(t, trace.IsBadParameter(err))

// Test valid token format.
script, err = getJoinScript(validTokenLength, m)
require.NoError(t, err)

require.Contains(t, script, testTokenID)
require.Contains(t, script, "test-host")
require.Contains(t, script, "12345678")
require.Contains(t, script, "sha256:")
require.NotContains(t, script, "JOIN_METHOD=\"iam\"")
for _, test := range []struct {
desc string
settings scriptSettings
errAssert require.ErrorAssertionFunc
extraAssertions func(script string)
}{
{
desc: "zero value",
settings: scriptSettings{},
errAssert: require.Error,
},
{
desc: "short token length",
settings: scriptSettings{token: toHex("f18da1c9f6630a51e8daf121e7451d")},
errAssert: require.Error,
},
{
desc: "valid length but does not exist",
settings: scriptSettings{token: toHex("xxxxxxx9f6630a51e8daf121exxxxxxx")},
errAssert: require.Error,
},
{
desc: "valid",
settings: scriptSettings{token: validToken},
errAssert: require.NoError,
extraAssertions: func(script string) {
require.Contains(t, script, validToken)
require.Contains(t, script, "test-host")
require.Contains(t, script, "12345678")
require.Contains(t, script, "sha256:")
require.NotContains(t, script, "JOIN_METHOD='iam'")
},
},
{
desc: "invalid IAM",
settings: scriptSettings{
token: toHex("invalid-iam-token"),
joinMethod: string(types.JoinMethodIAM),
},
errAssert: require.Error,
},
{
desc: "valid iam",
settings: scriptSettings{
token: validIAMToken,
joinMethod: string(types.JoinMethodIAM),
},
errAssert: require.NoError,
extraAssertions: func(script string) {
require.Contains(t, script, "JOIN_METHOD='iam'")
},
},
} {
t.Run(test.desc, func(t *testing.T) {
script, err := getJoinScript(context.Background(), test.settings, m)
test.errAssert(t, err)
if err != nil {
require.Empty(t, script)
}

// Test iam method script
iamToken := scriptSettings{
token: "token length doesnt matter in this case",
joinMethod: string(types.JoinMethodIAM),
if test.extraAssertions != nil {
test.extraAssertions(script)
}
})
}

script, err = getJoinScript(iamToken, m)
require.NoError(t, err)
require.Contains(t, script, "JOIN_METHOD=\"iam\"")
}

func TestGetAppJoinScript(t *testing.T) {
m := &mockedNodeAPIGetter{}
m.mockGetProxyServers = func() ([]types.Server, error) {
var s types.ServerV2
s.SetPublicAddr("test-host:12345678")
testTokenID := "f18da1c9f6630a51e8daf121e7451daa"
m := &mockedNodeAPIGetter{
mockGetToken: func(_ context.Context, token string) (types.ProvisionToken, error) {
if token == testTokenID {
return &types.ProvisionTokenV2{
Metadata: types.Metadata{
Name: token,
},
}, nil
}
return nil, trace.NotFound("token does not exist")
},
mockGetProxyServers: func() ([]types.Server, error) {
var s types.ServerV2
s.SetPublicAddr("test-host:12345678")

return []types.Server{&s}, nil
}
m.mockGetClusterCACert = func(ctx context.Context) (*proto.GetClusterCACertResponse, error) {
fakeBytes := []byte(fixtures.SigningCertPEM)
return &proto.GetClusterCACertResponse{TLSCA: fakeBytes}, nil
return []types.Server{&s}, nil
},
mockGetClusterCACert: func(context.Context) (*proto.GetClusterCACertResponse, error) {
fakeBytes := []byte(fixtures.SigningCertPEM)
return &proto.GetClusterCACertResponse{TLSCA: fakeBytes}, nil
},
}

testTokenID := "f18da1c9f6630a51e8daf121e7451daa"
badAppName := scriptSettings{
token: testTokenID,
appInstallMode: true,
Expand All @@ -367,11 +416,11 @@ func TestGetAppJoinScript(t *testing.T) {
}

// Test invalid app data.
script, err := getJoinScript(badAppName, m)
script, err := getJoinScript(context.Background(), badAppName, m)
require.Empty(t, script)
require.True(t, trace.IsBadParameter(err))

script, err = getJoinScript(badAppURI, m)
script, err = getJoinScript(context.Background(), badAppURI, m)
require.Empty(t, script)
require.True(t, trace.IsBadParameter(err))

Expand Down Expand Up @@ -500,7 +549,7 @@ func TestGetAppJoinScript(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
script, err = getJoinScript(tc.settings, m)
script, err = getJoinScript(context.Background(), tc.settings, m)
if tc.shouldError {
require.NotNil(t, err)
require.Equal(t, script, "")
Expand Down Expand Up @@ -622,6 +671,7 @@ type mockedNodeAPIGetter struct {
mockGenerateToken func(ctx context.Context, req *proto.GenerateTokenRequest) (string, error)
mockGetProxyServers func() ([]types.Server, error)
mockGetClusterCACert func(ctx context.Context) (*proto.GetClusterCACertResponse, error)
mockGetToken func(ctx context.Context, token string) (types.ProvisionToken, error)
}

func (m *mockedNodeAPIGetter) GenerateToken(ctx context.Context, req *proto.GenerateTokenRequest) (string, error) {
Expand All @@ -647,3 +697,10 @@ func (m *mockedNodeAPIGetter) GetClusterCACert(ctx context.Context) (*proto.GetC

return nil, trace.NotImplemented("mockGetClusterCACert not implemented")
}

func (m *mockedNodeAPIGetter) GetToken(ctx context.Context, token string) (types.ProvisionToken, error) {
if m.mockGetToken != nil {
return m.mockGetToken(ctx, token)
}
return nil, trace.NotImplemented("mockGetToken not implemented")
}
Loading

0 comments on commit 055d531

Please sign in to comment.