Skip to content

Commit

Permalink
[v9] Validate token for node join script (#14946)
Browse files Browse the repository at this point in the history
Validate token for node join script (#14726)

The token value is provided via the HTTP request
and fed into the node join script. This could allow an attacker
to generate a node-join script with malicious code included.

Fix this by validating that tokens are valid and exist in the backend.

Additionally, we recently added the ability to specify labels via the
node-labels query parameter, which is also user-controlled. Since this
functionality was never integrated in the UI, we remove it here and
will add an alternative implementation in the future.

Also use single quotes in script to prevent expansion
  • Loading branch information
zmb3 committed Jul 27, 2022
1 parent 61322ce commit 5d87b7f
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 98 deletions.
8 changes: 6 additions & 2 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -1439,8 +1439,12 @@ func (a *ServerWithRoles) GetTokens(ctx context.Context, opts ...services.Marsha
}

func (a *ServerWithRoles) GetToken(ctx context.Context, token string) (types.ProvisionToken, error) {
if err := a.action(apidefaults.Namespace, types.KindToken, types.VerbRead); err != nil {
return nil, trace.Wrap(err)
// The Proxy has permission to look up tokens by name in order to validate
// attempts to use the node join script.
if isProxy := a.hasBuiltinRole(types.RoleProxy); !isProxy {
if err := a.action(apidefaults.Namespace, types.KindToken, types.VerbRead); err != nil {
return nil, trace.Wrap(err)
}
}
return a.authServer.GetToken(ctx, token)
}
Expand Down
40 changes: 24 additions & 16 deletions lib/web/join_tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (h *Handler) getNodeJoinScriptHandle(w http.ResponseWriter, r *http.Request
joinMethod: r.URL.Query().Get("method"),
}

script, err := getJoinScript(settings, h.GetProxyClient())
script, err := getJoinScript(r.Context(), settings, h.GetProxyClient())
if err != nil {
log.WithError(err).Info("Failed to return the node install script.")
w.Write(scripts.ErrorBashScript)
Expand Down Expand Up @@ -195,7 +195,7 @@ func (h *Handler) getAppJoinScriptHandle(w http.ResponseWriter, r *http.Request,
appURI: uri,
}

script, err := getJoinScript(settings, h.GetProxyClient())
script, err := getJoinScript(r.Context(), settings, h.GetProxyClient())
if err != nil {
log.WithError(err).Info("Failed to return the app install script.")
w.Write(scripts.ErrorBashScript)
Expand Down Expand Up @@ -228,21 +228,27 @@ func createJoinToken(ctx context.Context, m nodeAPIGetter, roles types.SystemRol
}, nil
}

func getJoinScript(settings scriptSettings, m nodeAPIGetter) (string, error) {
// Skip decoding validation for IAM tokens since they are generated with a different method
if settings.joinMethod != string(types.JoinMethodIAM) {
// This token does not need to be validated against the backend because it's not used to
// reveal any sensitive information. However, we still need to perform a simple input
// validation check by verifying that the token was auto-generated.
// Auto-generated tokens must be encoded and must have an expected length.
func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter) (string, error) {
switch types.JoinMethod(settings.joinMethod) {
case types.JoinMethodUnspecified, types.JoinMethodToken:
decodedToken, err := hex.DecodeString(settings.token)
if err != nil {
return "", trace.Wrap(err)
}

if len(decodedToken) != auth.TokenLenBytes {
return "", trace.BadParameter("invalid token length")
return "", trace.BadParameter("invalid token %q", decodedToken)
}

case types.JoinMethodIAM:
default:
return "", trace.BadParameter("join method %q is not supported via script", settings.joinMethod)
}

// The provided token can be attacker controlled, so we must validate
// it with the backend before using it to generate the script.
_, err := m.GetToken(ctx, settings.token)
if err != nil {
return "", trace.BadParameter("invalid token")
}

// Get hostname and port from proxy server address.
Expand Down Expand Up @@ -274,9 +280,9 @@ func getJoinScript(settings scriptSettings, m nodeAPIGetter) (string, error) {
var buf bytes.Buffer
// If app install mode is requested but parameters are blank for some reason,
// we need to return an error.
if settings.appInstallMode == true {
if settings.appInstallMode {
if errs := validation.IsDNS1035Label(settings.appName); len(errs) > 0 {
return "", trace.BadParameter("appName %q must be a valid DNS subdomain: https://gravitational.com/teleport/docs/application-access/#application-name", settings.appName)
return "", trace.BadParameter("appName %q must be a valid DNS subdomain: https://goteleport.com/docs/application-access/guides/connecting-apps/#application-name", settings.appName)
}
if !appURIPattern.MatchString(settings.appURI) {
return "", trace.BadParameter("appURI %q contains invalid characters", settings.appURI)
Expand Down Expand Up @@ -346,16 +352,18 @@ func isSameRuleSet(r1 []*types.TokenRule, r2 []*types.TokenRule) bool {
}

type nodeAPIGetter interface {
// GenerateToken creates a special provisioning token for a new SSH server
// that is valid for ttl period seconds.
// GenerateToken creates a special provisioning token for a new SSH server.
//
// This token is used by SSH server to authenticate with Auth server
// and get signed certificate and private key from the auth server.
// and get a signed certificate.
//
// If token is not supplied, it will be auto generated and returned.
// If TTL is not supplied, token will be valid until removed.
GenerateToken(ctx context.Context, req auth.GenerateTokenRequest) (string, error)

// GetToken looks up a provisioning token.
GetToken(ctx context.Context, token string) (types.ProvisionToken, error)

// GetClusterCACert returns the CAs for the local cluster without signing keys.
GetClusterCACert() (*auth.LocalCAResponse, error)

Expand Down
189 changes: 123 additions & 66 deletions lib/web/join_tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package web

import (
"context"
"encoding/hex"
"testing"
"time"

Expand All @@ -31,6 +32,7 @@ import (
)

func TestCreateNodeJoinToken(t *testing.T) {
t.Parallel()
m := &mockedNodeAPIGetter{}
m.mockGenerateToken = func(ctx context.Context, req auth.GenerateTokenRequest) (string, error) {
return "some-token-id", nil
Expand All @@ -47,6 +49,7 @@ func TestCreateNodeJoinToken(t *testing.T) {
}

func TestGenerateIAMTokenName(t *testing.T) {
t.Parallel()
rule1 := types.TokenRule{
AWSAccount: "100000000000",
AWSARN: "arn:aws:iam:1",
Expand Down Expand Up @@ -84,6 +87,7 @@ func TestGenerateIAMTokenName(t *testing.T) {
}

func TestSortRules(t *testing.T) {
t.Parallel()
tt := []struct {
name string
rules []*types.TokenRule
Expand Down Expand Up @@ -277,81 +281,126 @@ func TestSortRules(t *testing.T) {
}
}

func TestGetNodeJoinScript(t *testing.T) {
m := &mockedNodeAPIGetter{}
m.mockGetProxyServers = func() ([]types.Server, error) {
var s types.ServerV2
s.SetPublicAddr("test-host:12345678")

return []types.Server{&s}, nil
}
m.mockGetClusterCACert = func() (*auth.LocalCAResponse, error) {
fakeBytes := []byte(fixtures.SigningCertPEM)
return &auth.LocalCAResponse{TLSCA: fakeBytes}, nil
}
func toHex(s string) string { return hex.EncodeToString([]byte(s)) }

nilTokenLength := scriptSettings{
token: "",
}
func TestGetNodeJoinScript(t *testing.T) {
validToken := "f18da1c9f6630a51e8daf121e7451daa"
validIAMToken := "valid-iam-token"

shortTokenLength := scriptSettings{
token: "f18da1c9f6630a51e8daf121e7451d",
}
m := &mockedNodeAPIGetter{
mockGetProxyServers: func() ([]types.Server, error) {
var s types.ServerV2
s.SetPublicAddr("test-host:12345678")

testTokenID := "f18da1c9f6630a51e8daf121e7451daa"
validTokenLength := scriptSettings{
token: testTokenID,
return []types.Server{&s}, nil
},
mockGetClusterCACert: func() (*auth.LocalCAResponse, error) {
fakeBytes := []byte(fixtures.SigningCertPEM)
return &auth.LocalCAResponse{TLSCA: fakeBytes}, nil
},
mockGetToken: func(_ context.Context, token string) (types.ProvisionToken, error) {
if token == validToken || token == validIAMToken {
return &types.ProvisionTokenV2{
Metadata: types.Metadata{
Name: token,
},
}, nil
}
return nil, trace.NotFound("token does not exist")
},
}

// Test zero-value initialization.
script, err := getJoinScript(scriptSettings{}, m)
require.Empty(t, script)
require.True(t, trace.IsBadParameter(err))

// Test bad token lengths.
script, err = getJoinScript(nilTokenLength, m)
require.Empty(t, script)
require.True(t, trace.IsBadParameter(err))

script, err = getJoinScript(shortTokenLength, m)
require.Empty(t, script)
require.True(t, trace.IsBadParameter(err))

// Test valid token format.
script, err = getJoinScript(validTokenLength, m)
require.NoError(t, err)

require.Contains(t, script, testTokenID)
require.Contains(t, script, "test-host")
require.Contains(t, script, "12345678")
require.Contains(t, script, "sha256:")
require.NotContains(t, script, "JOIN_METHOD=\"iam\"")
for _, test := range []struct {
desc string
settings scriptSettings
errAssert require.ErrorAssertionFunc
extraAssertions func(script string)
}{
{
desc: "zero value",
settings: scriptSettings{},
errAssert: require.Error,
},
{
desc: "short token length",
settings: scriptSettings{token: toHex("f18da1c9f6630a51e8daf121e7451d")},
errAssert: require.Error,
},
{
desc: "valid length but does not exist",
settings: scriptSettings{token: toHex("xxxxxxx9f6630a51e8daf121exxxxxxx")},
errAssert: require.Error,
},
{
desc: "valid",
settings: scriptSettings{token: validToken},
errAssert: require.NoError,
extraAssertions: func(script string) {
require.Contains(t, script, validToken)
require.Contains(t, script, "test-host")
require.Contains(t, script, "12345678")
require.Contains(t, script, "sha256:")
require.NotContains(t, script, "JOIN_METHOD='iam'")
},
},
{
desc: "invalid IAM",
settings: scriptSettings{
token: toHex("invalid-iam-token"),
joinMethod: string(types.JoinMethodIAM),
},
errAssert: require.Error,
},
{
desc: "valid iam",
settings: scriptSettings{
token: validIAMToken,
joinMethod: string(types.JoinMethodIAM),
},
errAssert: require.NoError,
extraAssertions: func(script string) {
require.Contains(t, script, "JOIN_METHOD='iam'")
},
},
} {
t.Run(test.desc, func(t *testing.T) {
script, err := getJoinScript(context.Background(), test.settings, m)
test.errAssert(t, err)
if err != nil {
require.Empty(t, script)
}

// Test iam method script
iamToken := scriptSettings{
token: "token length doesnt matter in this case",
joinMethod: string(types.JoinMethodIAM),
if test.extraAssertions != nil {
test.extraAssertions(script)
}
})
}

script, err = getJoinScript(iamToken, m)
require.NoError(t, err)
require.Contains(t, script, "JOIN_METHOD=\"iam\"")
}

func TestGetAppJoinScript(t *testing.T) {
m := &mockedNodeAPIGetter{}
m.mockGetProxyServers = func() ([]types.Server, error) {
var s types.ServerV2
s.SetPublicAddr("test-host:12345678")
testTokenID := "f18da1c9f6630a51e8daf121e7451daa"
m := &mockedNodeAPIGetter{
mockGetToken: func(_ context.Context, token string) (types.ProvisionToken, error) {
if token == testTokenID {
return &types.ProvisionTokenV2{
Metadata: types.Metadata{
Name: token,
},
}, nil
}
return nil, trace.NotFound("token does not exist")
},
mockGetProxyServers: func() ([]types.Server, error) {
var s types.ServerV2
s.SetPublicAddr("test-host:12345678")

return []types.Server{&s}, nil
}
m.mockGetClusterCACert = func() (*auth.LocalCAResponse, error) {
fakeBytes := []byte(fixtures.SigningCertPEM)
return &auth.LocalCAResponse{TLSCA: fakeBytes}, nil
return []types.Server{&s}, nil
},
mockGetClusterCACert: func() (*auth.LocalCAResponse, error) {
fakeBytes := []byte(fixtures.SigningCertPEM)
return &auth.LocalCAResponse{TLSCA: fakeBytes}, nil
},
}

testTokenID := "f18da1c9f6630a51e8daf121e7451daa"
badAppName := scriptSettings{
token: testTokenID,
appInstallMode: true,
Expand All @@ -367,11 +416,11 @@ func TestGetAppJoinScript(t *testing.T) {
}

// Test invalid app data.
script, err := getJoinScript(badAppName, m)
script, err := getJoinScript(context.Background(), badAppName, m)
require.Empty(t, script)
require.True(t, trace.IsBadParameter(err))

script, err = getJoinScript(badAppURI, m)
script, err = getJoinScript(context.Background(), badAppURI, m)
require.Empty(t, script)
require.True(t, trace.IsBadParameter(err))

Expand Down Expand Up @@ -500,7 +549,7 @@ func TestGetAppJoinScript(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
script, err = getJoinScript(tc.settings, m)
script, err = getJoinScript(context.Background(), tc.settings, m)
if tc.shouldError {
require.NotNil(t, err)
require.Equal(t, script, "")
Expand Down Expand Up @@ -622,6 +671,7 @@ type mockedNodeAPIGetter struct {
mockGenerateToken func(ctx context.Context, req auth.GenerateTokenRequest) (string, error)
mockGetProxyServers func() ([]types.Server, error)
mockGetClusterCACert func() (*auth.LocalCAResponse, error)
mockGetToken func(ctx context.Context, token string) (types.ProvisionToken, error)
}

func (m *mockedNodeAPIGetter) GenerateToken(ctx context.Context, req auth.GenerateTokenRequest) (string, error) {
Expand All @@ -647,3 +697,10 @@ func (m *mockedNodeAPIGetter) GetClusterCACert() (*auth.LocalCAResponse, error)

return nil, trace.NotImplemented("mockGetClusterCACert not implemented")
}

func (m *mockedNodeAPIGetter) GetToken(ctx context.Context, token string) (types.ProvisionToken, error) {
if m.mockGetToken != nil {
return m.mockGetToken(ctx, token)
}
return nil, trace.NotImplemented("mockGetToken not implemented")
}
Loading

0 comments on commit 5d87b7f

Please sign in to comment.