Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Assume Start Time Validation #39008

Merged
merged 5 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 23 additions & 0 deletions api/types/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/utils"
)

Expand Down Expand Up @@ -826,3 +827,25 @@ func NewAccessRequestAllowedPromotions(promotions []*AccessRequestAllowedPromoti
Promotions: promotions,
}
}

// ValidateAssumeStartTime returns error if start time is in an invalid range.
func ValidateAssumeStartTime(assumeStartTime time.Time, accessExpiry time.Time, creationTime time.Time) error {
// Guard against requesting a start time before the request creation time.
if assumeStartTime.Before(creationTime) {
return trace.BadParameter("assume start time has to be after %v", creationTime.Format(time.RFC3339))
}
// Guard against requesting a start time after access expiry.
if assumeStartTime.After(accessExpiry) || assumeStartTime.Equal(accessExpiry) {
return trace.BadParameter("assume start time must be prior to access expiry time at %v",
accessExpiry.Format(time.RFC3339))
}
// Access expiry can be greater than constants.MaxAssumeStartDuration, but start time
// should be on or before constants.MaxAssumeStartDuration.
maxAssumableStartTime := creationTime.Add(constants.MaxAssumeStartDuration)
if maxAssumableStartTime.Before(accessExpiry) && assumeStartTime.After(maxAssumableStartTime) {
return trace.BadParameter("assume start time is too far in the future, latest time allowed is %v",
maxAssumableStartTime.Format(time.RFC3339))
}

return nil
}
55 changes: 55 additions & 0 deletions api/types/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,67 @@ package types

import (
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/constants"
)

func TestAssertAccessRequestImplementsResourceWithLabels(t *testing.T) {
ar, err := NewAccessRequest("test", "test", "test")
require.NoError(t, err)
require.Implements(t, (*ResourceWithLabels)(nil), ar)
}

func TestValidateAssumeStartTime(t *testing.T) {
creation := time.Now().UTC()
const day = 24 * time.Hour

expiry := creation.Add(12 * day)
maxAssumeStartDuration := creation.Add(constants.MaxAssumeStartDuration)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "start time too far in the future",
startTime: creation.Add(constants.MaxAssumeStartDuration + day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(tt, err, trace.BadParameter("assume start time is too far in the future, latest time allowed is %v",
maxAssumeStartDuration.Format(time.RFC3339)))
},
},
{
name: "expired start time",
startTime: creation.Add(100 * day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(t, err, trace.BadParameter("assume start time must be prior to access expiry time at %v",
expiry.Format(time.RFC3339)))
},
},
{
name: "before creation start time",
startTime: creation.Add(-10 * day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(t, err, trace.BadParameter("assume start time has to be after %v",
creation.Format(time.RFC3339)))
},
},
{
name: "valid start time",
startTime: creation.Add(6 * day),
errCheck: require.NoError,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateAssumeStartTime(tc.startTime, expiry, creation)
tc.errCheck(t, err)
})
}
}
201 changes: 201 additions & 0 deletions lib/auth/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func newAccessRequestTestPack(ctx context.Context, t *testing.T) *accessRequestT
Request: &types.AccessRequestConditions{
Roles: []string{"admins", "superadmins"},
SearchAsRoles: []string{"admins", "superadmins"},
MaxDuration: types.Duration(services.MaxAccessDuration),
},
},
},
Expand Down Expand Up @@ -1518,3 +1519,203 @@ func TestUpdateAccessRequestWithAdditionalReviewers(t *testing.T) {
})
}
}

func TestAssumeStartTime_CreateAccessRequestV2(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "too far in the future",
startTime: s.invalidMaxedAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time is too far in the future")
},
},
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
req, err := services.NewAccessRequest(s.requesterUserName, "admins")
require.NoError(t, err)
req.SetMaxDuration(s.maxDuration)
req.SetAssumeStartTime(tc.startTime)
_, err = s.requesterClient.CreateAccessRequestV2(ctx, req)
tc.errCheck(t, err)
})
}
}

func TestAssumeStartTime_SubmitAccessReview(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "too far in the future",
startTime: s.invalidMaxedAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time is too far in the future")
},
},
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
{
name: "valid submission",
startTime: s.validStartTime,
errCheck: require.NoError,
},
}
review := types.AccessReviewSubmission{
RequestID: s.createdRequest.GetName(),
Review: types.AccessReview{
Author: "admin",
ProposedState: types.RequestState_APPROVED,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
review.Review.AssumeStartTime = &tc.startTime
resp, err := s.testPack.tlsServer.AuthServer.AuthServer.SubmitAccessReview(ctx, review)
tc.errCheck(t, err)
if err == nil {
require.Equal(t, tc.startTime, *resp.GetAssumeStartTime())
}
})
}
}

func TestAssumeStartTime_SetAccessRequestState(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "too far in the future",
startTime: s.invalidMaxedAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time is too far in the future")
},
},
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
{
name: "valid set state",
startTime: s.validStartTime,
errCheck: require.NoError,
},
}
update := types.AccessRequestUpdate{
RequestID: s.createdRequest.GetName(),
State: types.RequestState_APPROVED,
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
update.AssumeStartTime = &tc.startTime
err := s.testPack.tlsServer.Auth().SetAccessRequestState(ctx, update)
tc.errCheck(t, err)
if err == nil {
resp, err := s.testPack.tlsServer.AuthServer.AuthServer.GetAccessRequests(ctx, types.AccessRequestFilter{})
require.NoError(t, err)
require.Len(t, resp, 1)
require.Equal(t, tc.startTime, *resp[0].GetAssumeStartTime())
}
})
}
}

type accessRequestWithStartTime struct {
testPack *accessRequestTestPack
requesterClient *Client
invalidMaxedAssumeStartTime time.Time
invalidExpiredAssumeStartTime time.Time
validStartTime time.Time
maxDuration time.Time
requesterUserName string
createdRequest types.AccessRequest
}

func createAccessRequestWithStartTime(t *testing.T) accessRequestWithStartTime {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can add t.Helper() to mark this as a helper function.

t.Helper()

modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

testPack := newAccessRequestTestPack(ctx, t)

const requesterUserName = "requester"
requester := TestUser(requesterUserName)
requesterClient, err := testPack.tlsServer.NewClient(requester)
require.NoError(t, err)

t.Cleanup(func() { require.NoError(t, requesterClient.Close()) })

now := time.Now().UTC()
day := 24 * time.Hour

maxDuration := time.Now().UTC().Add(12 * day)

invalidMaxedAssumeStartTime := now.Add(constants.MaxAssumeStartDuration + (1 * day))
invalidExpiredAssumeStartTime := now.Add(100 * day)
validStartTime := now.Add(6 * day)

// create the access request object
req, err := services.NewAccessRequest(requesterUserName, "admins")
require.NoError(t, err)
req.SetMaxDuration(maxDuration)

req.SetAssumeStartTime(validStartTime)
createdReq, err := requesterClient.CreateAccessRequestV2(ctx, req)
require.NoError(t, err)
require.Equal(t, validStartTime, *createdReq.GetAssumeStartTime())

return accessRequestWithStartTime{
testPack: testPack,
requesterClient: requesterClient,
invalidMaxedAssumeStartTime: invalidMaxedAssumeStartTime,
invalidExpiredAssumeStartTime: invalidExpiredAssumeStartTime,
validStartTime: validStartTime,
maxDuration: maxDuration,
requesterUserName: requesterUserName,
createdRequest: createdReq,
}
}
25 changes: 16 additions & 9 deletions lib/services/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ const (
// A day is sometimes 23 hours, sometimes 25 hours, usually 24 hours.
day = 24 * time.Hour

// maxAccessDuration is the maximum duration that an access request can be
// MaxAccessDuration is the maximum duration that an access request can be
// granted for.
maxAccessDuration = 14 * day
MaxAccessDuration = 14 * day

// requestTTL is the the TTL for an access request, i.e. the amount of time that
// the access request can be reviewed. Defaults to 1 week.
Expand Down Expand Up @@ -377,8 +377,8 @@ func ValidateAccessPredicates(role types.Role) error {
}

if maxDuration := role.GetAccessRequestConditions(types.Allow).MaxDuration; maxDuration.Duration() != 0 &&
maxDuration.Duration() > maxAccessDuration {
return trace.BadParameter("max access duration must be less than or equal to %v", maxAccessDuration)
maxDuration.Duration() > MaxAccessDuration {
return trace.BadParameter("max access duration must be less than or equal to %v", MaxAccessDuration)
}

return nil
Expand Down Expand Up @@ -426,8 +426,8 @@ func ApplyAccessReview(req types.AccessRequest, rev types.AccessReview, author U
req.SetReviews(append(req.GetReviews(), rev))

if rev.AssumeStartTime != nil {
if rev.AssumeStartTime.After(req.GetAccessExpiry()) {
return trace.BadParameter("request start time is after expiry")
if err := types.ValidateAssumeStartTime(*rev.AssumeStartTime, req.GetAccessExpiry(), req.GetCreationTime()); err != nil {
return trace.Wrap(err)
}
req.SetAssumeStartTime(*rev.AssumeStartTime)
}
Expand Down Expand Up @@ -1222,6 +1222,13 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest
req.SetAccessExpiry(accessTTL)
// Adjusted max access duration is equal to the access expiry time.
req.SetMaxDuration(accessTTL)

if req.GetAssumeStartTime() != nil {
assumeStartTime := *req.GetAssumeStartTime()
if err := types.ValidateAssumeStartTime(assumeStartTime, accessTTL, req.GetCreationTime()); err != nil {
return trace.Wrap(err)
}
}
}

return nil
Expand All @@ -1242,13 +1249,13 @@ func (m *RequestValidator) calculateMaxAccessDuration(req types.AccessRequest) (
// For dry run requests, use the maximum possible duration.
// This prevents the time drift that can occur as the value is set on the client side.
if req.GetDryRun() {
maxDuration = maxAccessDuration
maxDuration = MaxAccessDuration
} else if maxDuration < 0 {
return 0, trace.BadParameter("invalid maxDuration: must be greater than creation time")
}

if maxDuration > maxAccessDuration {
return 0, trace.BadParameter("max_duration must be less than or equal to %v", maxAccessDuration)
if maxDuration > MaxAccessDuration {
return 0, trace.BadParameter("max_duration must be less than or equal to %v", MaxAccessDuration)
}

minAdjDuration := maxDuration
Expand Down
2 changes: 1 addition & 1 deletion lib/services/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ func TestReviewThresholds(t *testing.T) {
propose: approve,
assumeStartTime: clock.Now().UTC().Add(10000 * time.Hour),
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.ErrorIs(tt, err, trace.BadParameter("request start time is after expiry"), i...)
require.ErrorContains(tt, err, "assume start time must be prior to access expiry time", i...)
},
},
},
Expand Down