Skip to content

Commit

Permalink
[v14] Fix Assume Start Time Validation (#39324)
Browse files Browse the repository at this point in the history
* Fix Assume Start Time Validation (#39008)

* Remove checking for max assume start time test (same as max duration)
  • Loading branch information
kimlisa committed Mar 23, 2024
1 parent 6056cab commit 880ae37
Show file tree
Hide file tree
Showing 9 changed files with 281 additions and 25 deletions.
23 changes: 23 additions & 0 deletions api/types/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/utils"
)

Expand Down Expand Up @@ -826,3 +827,25 @@ func NewAccessRequestAllowedPromotions(promotions []*AccessRequestAllowedPromoti
Promotions: promotions,
}
}

// ValidateAssumeStartTime returns error if start time is in an invalid range.
func ValidateAssumeStartTime(assumeStartTime time.Time, accessExpiry time.Time, creationTime time.Time) error {
// Guard against requesting a start time before the request creation time.
if assumeStartTime.Before(creationTime) {
return trace.BadParameter("assume start time has to be after %v", creationTime.Format(time.RFC3339))
}
// Guard against requesting a start time after access expiry.
if assumeStartTime.After(accessExpiry) || assumeStartTime.Equal(accessExpiry) {
return trace.BadParameter("assume start time must be prior to access expiry time at %v",
accessExpiry.Format(time.RFC3339))
}
// Access expiry can be greater than constants.MaxAssumeStartDuration, but start time
// should be on or before constants.MaxAssumeStartDuration.
maxAssumableStartTime := creationTime.Add(constants.MaxAssumeStartDuration)
if maxAssumableStartTime.Before(accessExpiry) && assumeStartTime.After(maxAssumableStartTime) {
return trace.BadParameter("assume start time is too far in the future, latest time allowed is %v",
maxAssumableStartTime.Format(time.RFC3339))
}

return nil
}
55 changes: 55 additions & 0 deletions api/types/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,67 @@ package types

import (
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/constants"
)

func TestAssertAccessRequestImplementsResourceWithLabels(t *testing.T) {
ar, err := NewAccessRequest("test", "test", "test")
require.NoError(t, err)
require.Implements(t, (*ResourceWithLabels)(nil), ar)
}

func TestValidateAssumeStartTime(t *testing.T) {
creation := time.Now().UTC()
const day = 24 * time.Hour

expiry := creation.Add(12 * day)
maxAssumeStartDuration := creation.Add(constants.MaxAssumeStartDuration)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "start time too far in the future",
startTime: creation.Add(constants.MaxAssumeStartDuration + day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(tt, err, trace.BadParameter("assume start time is too far in the future, latest time allowed is %v",
maxAssumeStartDuration.Format(time.RFC3339)))
},
},
{
name: "expired start time",
startTime: creation.Add(100 * day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(t, err, trace.BadParameter("assume start time must be prior to access expiry time at %v",
expiry.Format(time.RFC3339)))
},
},
{
name: "before creation start time",
startTime: creation.Add(-10 * day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(t, err, trace.BadParameter("assume start time has to be after %v",
creation.Format(time.RFC3339)))
},
},
{
name: "valid start time",
startTime: creation.Add(6 * day),
errCheck: require.NoError,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateAssumeStartTime(tc.startTime, expiry, creation)
tc.errCheck(t, err)
})
}
}
178 changes: 178 additions & 0 deletions lib/auth/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ func newAccessRequestTestPack(ctx context.Context, t *testing.T) *accessRequestT
Request: &types.AccessRequestConditions{
Roles: []string{"admins", "superadmins"},
SearchAsRoles: []string{"admins", "superadmins"},
MaxDuration: types.Duration(services.MaxAccessDuration),
},
},
},
Expand Down Expand Up @@ -1215,3 +1216,180 @@ func TestUpdateAccessRequestWithAdditionalReviewers(t *testing.T) {
})
}
}

func TestAssumeStartTime_CreateAccessRequestV2(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
req, err := services.NewAccessRequest(s.requesterUserName, "admins")
require.NoError(t, err)
req.SetMaxDuration(s.maxDuration)
req.SetAssumeStartTime(tc.startTime)
_, err = s.requesterClient.CreateAccessRequestV2(ctx, req)
tc.errCheck(t, err)
})
}
}

func TestAssumeStartTime_SubmitAccessReview(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
{
name: "valid submission",
startTime: s.validStartTime,
errCheck: require.NoError,
},
}
review := types.AccessReviewSubmission{
RequestID: s.createdRequest.GetName(),
Review: types.AccessReview{
Author: "admin",
ProposedState: types.RequestState_APPROVED,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
review.Review.AssumeStartTime = &tc.startTime
resp, err := s.testPack.tlsServer.AuthServer.AuthServer.SubmitAccessReview(ctx, review)
tc.errCheck(t, err)
if err == nil {
require.Equal(t, tc.startTime, *resp.GetAssumeStartTime())
}
})
}
}

func TestAssumeStartTime_SetAccessRequestState(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
{
name: "valid set state",
startTime: s.validStartTime,
errCheck: require.NoError,
},
}
update := types.AccessRequestUpdate{
RequestID: s.createdRequest.GetName(),
State: types.RequestState_APPROVED,
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
update.AssumeStartTime = &tc.startTime
err := s.testPack.tlsServer.Auth().SetAccessRequestState(ctx, update)
tc.errCheck(t, err)
if err == nil {
resp, err := s.testPack.tlsServer.AuthServer.AuthServer.GetAccessRequests(ctx, types.AccessRequestFilter{})
require.NoError(t, err)
require.Len(t, resp, 1)
require.Equal(t, tc.startTime, *resp[0].GetAssumeStartTime())
}
})
}
}

type accessRequestWithStartTime struct {
testPack *accessRequestTestPack
requesterClient *Client
invalidMaxedAssumeStartTime time.Time
invalidExpiredAssumeStartTime time.Time
validStartTime time.Time
maxDuration time.Time
requesterUserName string
createdRequest types.AccessRequest
}

func createAccessRequestWithStartTime(t *testing.T) accessRequestWithStartTime {
t.Helper()
clock := clockwork.NewFakeClock()
now := clock.Now().UTC()

modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

testPack := newAccessRequestTestPack(ctx, t)

const requesterUserName = "requester"
requester := TestUser(requesterUserName)
requesterClient, err := testPack.tlsServer.NewClient(requester)
require.NoError(t, err)

t.Cleanup(func() { require.NoError(t, requesterClient.Close()) })

day := 24 * time.Hour

maxDuration := now.Add(services.MaxAccessDuration)

invalidMaxedAssumeStartTime := now.Add(constants.MaxAssumeStartDuration + (1 * day))
invalidExpiredAssumeStartTime := now.Add(100 * day)
validStartTime := now.Add(2 * day)

// create the access request object
req, err := services.NewAccessRequest(requesterUserName, "admins")
require.NoError(t, err)
req.SetMaxDuration(maxDuration)

req.SetAssumeStartTime(validStartTime)
createdReq, err := requesterClient.CreateAccessRequestV2(ctx, req)
require.NoError(t, err)
require.Equal(t, validStartTime, *createdReq.GetAssumeStartTime())

return accessRequestWithStartTime{
testPack: testPack,
requesterClient: requesterClient,
invalidMaxedAssumeStartTime: invalidMaxedAssumeStartTime,
invalidExpiredAssumeStartTime: invalidExpiredAssumeStartTime,
validStartTime: validStartTime,
maxDuration: maxDuration,
requesterUserName: requesterUserName,
createdRequest: createdReq,
}
}
25 changes: 16 additions & 9 deletions lib/services/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ const maxAccessRequestReasonSize = 4096
// A day is sometimes 23 hours, sometimes 25 hours, usually 24 hours.
const day = 24 * time.Hour

// maxAccessDuration is the maximum duration that an access request can be
// MaxAccessDuration is the maximum duration that an access request can be
// granted for.
const maxAccessDuration = 7 * day
const MaxAccessDuration = 7 * day

// ValidateAccessRequest validates the AccessRequest and sets default values
func ValidateAccessRequest(ar types.AccessRequest) error {
Expand Down Expand Up @@ -365,8 +365,8 @@ func ValidateAccessPredicates(role types.Role) error {
}

if maxDuration := role.GetAccessRequestConditions(types.Allow).MaxDuration; maxDuration.Duration() != 0 &&
maxDuration.Duration() > maxAccessDuration {
return trace.BadParameter("max access duration must be less or equal 7 days")
maxDuration.Duration() > MaxAccessDuration {
return trace.BadParameter("max access duration must be less than or equal to %v", MaxAccessDuration)
}

return nil
Expand Down Expand Up @@ -414,8 +414,8 @@ func ApplyAccessReview(req types.AccessRequest, rev types.AccessReview, author U
req.SetReviews(append(req.GetReviews(), rev))

if rev.AssumeStartTime != nil {
if rev.AssumeStartTime.After(req.GetAccessExpiry()) {
return trace.BadParameter("request start time is after expiry")
if err := types.ValidateAssumeStartTime(*rev.AssumeStartTime, req.GetAccessExpiry(), req.GetCreationTime()); err != nil {
return trace.Wrap(err)
}
req.SetAssumeStartTime(*rev.AssumeStartTime)
}
Expand Down Expand Up @@ -1210,6 +1210,13 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest
req.SetAccessExpiry(accessTTL)
// Adjusted max access duration is equal to the access expiry time.
req.SetMaxDuration(accessTTL)

if req.GetAssumeStartTime() != nil {
assumeStartTime := *req.GetAssumeStartTime()
if err := types.ValidateAssumeStartTime(assumeStartTime, accessTTL, req.GetCreationTime()); err != nil {
return trace.Wrap(err)
}
}
}

return nil
Expand Down Expand Up @@ -1240,13 +1247,13 @@ func (m *RequestValidator) calculateMaxAccessDuration(req types.AccessRequest) (
// This prevents the time drift that can occur as the value is set on the client side.
// TODO(jakule): Replace with MaxAccessDuration that is a duration (5h, 4d etc), and not a point in time.
if req.GetDryRun() {
maxDuration = maxAccessDuration
maxDuration = MaxAccessDuration
} else if maxDuration < 0 {
return 0, trace.BadParameter("invalid maxDuration: must be greater than creation time")
}

if maxDuration > maxAccessDuration {
return 0, trace.BadParameter("max_duration must be less or equal 7 days")
if maxDuration > MaxAccessDuration {
return 0, trace.BadParameter("max_duration must be less than or equal to %v", MaxAccessDuration)
}

minAdjDuration := maxDuration
Expand Down
2 changes: 1 addition & 1 deletion lib/services/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ func TestReviewThresholds(t *testing.T) {
propose: approve,
assumeStartTime: clock.Now().UTC().Add(10000 * time.Hour),
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.ErrorIs(tt, err, trace.BadParameter("request start time is after expiry"), i...)
require.ErrorContains(tt, err, "assume start time must be prior to access expiry time", i...)
},
},
},
Expand Down
7 changes: 7 additions & 0 deletions lib/services/local/dynamic_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ func (s *DynamicAccessService) SetAccessRequestState(ctx context.Context, params
req.SetRoles(params.Roles)
}

if params.AssumeStartTime != nil {
if err := types.ValidateAssumeStartTime(*params.AssumeStartTime, req.GetAccessExpiry(), req.GetCreationTime()); err != nil {
return nil, trace.Wrap(err)
}
req.SetAssumeStartTime(*params.AssumeStartTime)
}

// approved requests should have a resource expiry which matches
// the underlying access expiry.
if params.State.IsApproved() {
Expand Down
5 changes: 0 additions & 5 deletions tool/tctl/common/access_request_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/asciitable"
"github.com/gravitational/teleport/lib/auth"
Expand Down Expand Up @@ -236,10 +235,6 @@ func (c *AccessRequestCommand) Approve(ctx context.Context, client *auth.Client)
if err != nil {
return trace.BadParameter("parsing assume-start-time (required format RFC3339 e.g 2023-12-12T23:20:50.52Z): %v", err)
}
if time.Until(parsedAssumeStartTime) > constants.MaxAssumeStartDuration {
return trace.BadParameter("assume-start-time too far in future: latest date %q",
parsedAssumeStartTime.Add(constants.MaxAssumeStartDuration).Format(time.RFC3339))
}
assumeStartTime = &parsedAssumeStartTime
}
for _, reqID := range strings.Split(c.reqIDs, ",") {
Expand Down

0 comments on commit 880ae37

Please sign in to comment.