Skip to content

Commit

Permalink
Fix Assume Start Time Validation (#39008) (#39322)
Browse files Browse the repository at this point in the history
  • Loading branch information
kimlisa committed Mar 14, 2024
1 parent 740fa8e commit de8b1e5
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 25 deletions.
23 changes: 23 additions & 0 deletions api/types/access_request.go
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/utils"
)

Expand Down Expand Up @@ -826,3 +827,25 @@ func NewAccessRequestAllowedPromotions(promotions []*AccessRequestAllowedPromoti
Promotions: promotions,
}
}

// ValidateAssumeStartTime returns error if start time is in an invalid range.
func ValidateAssumeStartTime(assumeStartTime time.Time, accessExpiry time.Time, creationTime time.Time) error {
// Guard against requesting a start time before the request creation time.
if assumeStartTime.Before(creationTime) {
return trace.BadParameter("assume start time has to be after %v", creationTime.Format(time.RFC3339))
}
// Guard against requesting a start time after access expiry.
if assumeStartTime.After(accessExpiry) || assumeStartTime.Equal(accessExpiry) {
return trace.BadParameter("assume start time must be prior to access expiry time at %v",
accessExpiry.Format(time.RFC3339))
}
// Access expiry can be greater than constants.MaxAssumeStartDuration, but start time
// should be on or before constants.MaxAssumeStartDuration.
maxAssumableStartTime := creationTime.Add(constants.MaxAssumeStartDuration)
if maxAssumableStartTime.Before(accessExpiry) && assumeStartTime.After(maxAssumableStartTime) {
return trace.BadParameter("assume start time is too far in the future, latest time allowed is %v",
maxAssumableStartTime.Format(time.RFC3339))
}

return nil
}
55 changes: 55 additions & 0 deletions api/types/access_request_test.go
Expand Up @@ -18,12 +18,67 @@ package types

import (
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/constants"
)

func TestAssertAccessRequestImplementsResourceWithLabels(t *testing.T) {
ar, err := NewAccessRequest("test", "test", "test")
require.NoError(t, err)
require.Implements(t, (*ResourceWithLabels)(nil), ar)
}

func TestValidateAssumeStartTime(t *testing.T) {
creation := time.Now().UTC()
const day = 24 * time.Hour

expiry := creation.Add(12 * day)
maxAssumeStartDuration := creation.Add(constants.MaxAssumeStartDuration)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "start time too far in the future",
startTime: creation.Add(constants.MaxAssumeStartDuration + day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(tt, err, trace.BadParameter("assume start time is too far in the future, latest time allowed is %v",
maxAssumeStartDuration.Format(time.RFC3339)))
},
},
{
name: "expired start time",
startTime: creation.Add(100 * day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(t, err, trace.BadParameter("assume start time must be prior to access expiry time at %v",
expiry.Format(time.RFC3339)))
},
},
{
name: "before creation start time",
startTime: creation.Add(-10 * day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(t, err, trace.BadParameter("assume start time has to be after %v",
creation.Format(time.RFC3339)))
},
},
{
name: "valid start time",
startTime: creation.Add(6 * day),
errCheck: require.NoError,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateAssumeStartTime(tc.startTime, expiry, creation)
tc.errCheck(t, err)
})
}
}
201 changes: 201 additions & 0 deletions lib/auth/access_request_test.go
Expand Up @@ -118,6 +118,7 @@ func newAccessRequestTestPack(ctx context.Context, t *testing.T) *accessRequestT
Request: &types.AccessRequestConditions{
Roles: []string{"admins", "superadmins"},
SearchAsRoles: []string{"admins", "superadmins"},
MaxDuration: types.Duration(services.MaxAccessDuration),
},
},
},
Expand Down Expand Up @@ -1224,3 +1225,203 @@ func TestUpdateAccessRequestWithAdditionalReviewers(t *testing.T) {
})
}
}

func TestAssumeStartTime_CreateAccessRequestV2(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "too far in the future",
startTime: s.invalidMaxedAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time is too far in the future")
},
},
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
req, err := services.NewAccessRequest(s.requesterUserName, "admins")
require.NoError(t, err)
req.SetMaxDuration(s.maxDuration)
req.SetAssumeStartTime(tc.startTime)
_, err = s.requesterClient.CreateAccessRequestV2(ctx, req)
tc.errCheck(t, err)
})
}
}

func TestAssumeStartTime_SubmitAccessReview(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "too far in the future",
startTime: s.invalidMaxedAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time is too far in the future")
},
},
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
{
name: "valid submission",
startTime: s.validStartTime,
errCheck: require.NoError,
},
}
review := types.AccessReviewSubmission{
RequestID: s.createdRequest.GetName(),
Review: types.AccessReview{
Author: "admin",
ProposedState: types.RequestState_APPROVED,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
review.Review.AssumeStartTime = &tc.startTime
resp, err := s.testPack.tlsServer.AuthServer.AuthServer.SubmitAccessReview(ctx, review)
tc.errCheck(t, err)
if err == nil {
require.Equal(t, tc.startTime, *resp.GetAssumeStartTime())
}
})
}
}

func TestAssumeStartTime_SetAccessRequestState(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "too far in the future",
startTime: s.invalidMaxedAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time is too far in the future")
},
},
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
{
name: "valid set state",
startTime: s.validStartTime,
errCheck: require.NoError,
},
}
update := types.AccessRequestUpdate{
RequestID: s.createdRequest.GetName(),
State: types.RequestState_APPROVED,
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
update.AssumeStartTime = &tc.startTime
err := s.testPack.tlsServer.Auth().SetAccessRequestState(ctx, update)
tc.errCheck(t, err)
if err == nil {
resp, err := s.testPack.tlsServer.AuthServer.AuthServer.GetAccessRequests(ctx, types.AccessRequestFilter{})
require.NoError(t, err)
require.Len(t, resp, 1)
require.Equal(t, tc.startTime, *resp[0].GetAssumeStartTime())
}
})
}
}

type accessRequestWithStartTime struct {
testPack *accessRequestTestPack
requesterClient *Client
invalidMaxedAssumeStartTime time.Time
invalidExpiredAssumeStartTime time.Time
validStartTime time.Time
maxDuration time.Time
requesterUserName string
createdRequest types.AccessRequest
}

func createAccessRequestWithStartTime(t *testing.T) accessRequestWithStartTime {
t.Helper()

modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

testPack := newAccessRequestTestPack(ctx, t)

const requesterUserName = "requester"
requester := TestUser(requesterUserName)
requesterClient, err := testPack.tlsServer.NewClient(requester)
require.NoError(t, err)

t.Cleanup(func() { require.NoError(t, requesterClient.Close()) })

now := time.Now().UTC()
day := 24 * time.Hour

maxDuration := time.Now().UTC().Add(12 * day)

invalidMaxedAssumeStartTime := now.Add(constants.MaxAssumeStartDuration + (1 * day))
invalidExpiredAssumeStartTime := now.Add(100 * day)
validStartTime := now.Add(6 * day)

// create the access request object
req, err := services.NewAccessRequest(requesterUserName, "admins")
require.NoError(t, err)
req.SetMaxDuration(maxDuration)

req.SetAssumeStartTime(validStartTime)
createdReq, err := requesterClient.CreateAccessRequestV2(ctx, req)
require.NoError(t, err)
require.Equal(t, validStartTime, *createdReq.GetAssumeStartTime())

return accessRequestWithStartTime{
testPack: testPack,
requesterClient: requesterClient,
invalidMaxedAssumeStartTime: invalidMaxedAssumeStartTime,
invalidExpiredAssumeStartTime: invalidExpiredAssumeStartTime,
validStartTime: validStartTime,
maxDuration: maxDuration,
requesterUserName: requesterUserName,
createdRequest: createdReq,
}
}
23 changes: 15 additions & 8 deletions lib/services/access_request.go
Expand Up @@ -49,7 +49,7 @@ const day = 24 * time.Hour

// maxAccessDuration is the maximum duration that an access request can be
// granted for.
const maxAccessDuration = 14 * day
const MaxAccessDuration = 14 * day

// ValidateAccessRequest validates the AccessRequest and sets default values
func ValidateAccessRequest(ar types.AccessRequest) error {
Expand Down Expand Up @@ -368,8 +368,8 @@ func ValidateAccessPredicates(role types.Role) error {
}

if maxDuration := role.GetAccessRequestConditions(types.Allow).MaxDuration; maxDuration.Duration() != 0 &&
maxDuration.Duration() > maxAccessDuration {
return trace.BadParameter("max access duration must be less than or equal to %v", maxAccessDuration)
maxDuration.Duration() > MaxAccessDuration {
return trace.BadParameter("max access duration must be less than or equal to %v", MaxAccessDuration)
}

return nil
Expand Down Expand Up @@ -417,8 +417,8 @@ func ApplyAccessReview(req types.AccessRequest, rev types.AccessReview, author U
req.SetReviews(append(req.GetReviews(), rev))

if rev.AssumeStartTime != nil {
if rev.AssumeStartTime.After(req.GetAccessExpiry()) {
return trace.BadParameter("request start time is after expiry")
if err := types.ValidateAssumeStartTime(*rev.AssumeStartTime, req.GetAccessExpiry(), req.GetCreationTime()); err != nil {
return trace.Wrap(err)
}
req.SetAssumeStartTime(*rev.AssumeStartTime)
}
Expand Down Expand Up @@ -1213,6 +1213,13 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest
req.SetAccessExpiry(accessTTL)
// Adjusted max access duration is equal to the access expiry time.
req.SetMaxDuration(accessTTL)

if req.GetAssumeStartTime() != nil {
assumeStartTime := *req.GetAssumeStartTime()
if err := types.ValidateAssumeStartTime(assumeStartTime, accessTTL, req.GetCreationTime()); err != nil {
return trace.Wrap(err)
}
}
}

return nil
Expand All @@ -1233,13 +1240,13 @@ func (m *RequestValidator) calculateMaxAccessDuration(req types.AccessRequest) (
// For dry run requests, use the maximum possible duration.
// This prevents the time drift that can occur as the value is set on the client side.
if req.GetDryRun() {
maxDuration = maxAccessDuration
maxDuration = MaxAccessDuration
} else if maxDuration < 0 {
return 0, trace.BadParameter("invalid maxDuration: must be greater than creation time")
}

if maxDuration > maxAccessDuration {
return 0, trace.BadParameter("max_duration must be less than or equal to %v", maxAccessDuration)
if maxDuration > MaxAccessDuration {
return 0, trace.BadParameter("max_duration must be less than or equal to %v", MaxAccessDuration)
}

minAdjDuration := maxDuration
Expand Down
2 changes: 1 addition & 1 deletion lib/services/access_request_test.go
Expand Up @@ -644,7 +644,7 @@ func TestReviewThresholds(t *testing.T) {
propose: approve,
assumeStartTime: clock.Now().UTC().Add(10000 * time.Hour),
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.ErrorIs(tt, err, trace.BadParameter("request start time is after expiry"), i...)
require.ErrorContains(tt, err, "assume start time must be prior to access expiry time", i...)
},
},
},
Expand Down
7 changes: 7 additions & 0 deletions lib/services/local/dynamic_access.go
Expand Up @@ -109,6 +109,13 @@ func (s *DynamicAccessService) SetAccessRequestState(ctx context.Context, params
req.SetRoles(params.Roles)
}

if params.AssumeStartTime != nil {
if err := types.ValidateAssumeStartTime(*params.AssumeStartTime, req.GetAccessExpiry(), req.GetCreationTime()); err != nil {
return nil, trace.Wrap(err)
}
req.SetAssumeStartTime(*params.AssumeStartTime)
}

// approved requests should have a resource expiry which matches
// the underlying access expiry.
if params.State.IsApproved() {
Expand Down

0 comments on commit de8b1e5

Please sign in to comment.