Skip to content

Commit

Permalink
Address CRs
Browse files Browse the repository at this point in the history
  • Loading branch information
kimlisa committed Mar 6, 2024
1 parent 002199c commit 9852ae8
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 134 deletions.
4 changes: 2 additions & 2 deletions api/types/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -832,11 +832,11 @@ func NewAccessRequestAllowedPromotions(promotions []*AccessRequestAllowedPromoti
func ValidateAssumeStartTime(assumeStartTime time.Time, accessExpiry time.Time, creationTime time.Time) error {
// Guard against requesting a start time before the request creation time.
if assumeStartTime.Before(creationTime) {
return trace.BadParameter("assume start time has to be greater than: %q", creationTime.Format(time.RFC3339))
return trace.BadParameter("assume start time has to be after: %q", creationTime.Format(time.RFC3339))
}
// Guard against requesting a start time after access expiry.
if assumeStartTime.After(accessExpiry) || assumeStartTime.Equal(accessExpiry) {
return trace.BadParameter("assume start time cannot equal or exceed access expiry time at: %q",
return trace.BadParameter("assume start time must be prior to access expiry time at: %q",
accessExpiry.Format(time.RFC3339))
}
// Access expiry can be greater than constants.MaxAssumeStartDuration, but start time
Expand Down
72 changes: 45 additions & 27 deletions api/types/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/constants"
Expand All @@ -34,36 +33,55 @@ func TestAssertAccessRequestImplementsResourceWithLabels(t *testing.T) {
}

func TestValidateAssumeStartTime(t *testing.T) {
clock := clockwork.NewFakeClock()
creation := clock.Now().UTC()
creation := time.Now().UTC()
day := 24 * time.Hour

expiry := creation.Add(12 * day)
maxAssumeStartDuration := creation.Add(constants.MaxAssumeStartDuration)

// Start time too far in the future.
invalidMaxedAssumeStartTime := creation.Add(constants.MaxAssumeStartDuration + (1 * day))
err := ValidateAssumeStartTime(invalidMaxedAssumeStartTime, expiry, creation)
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorIs(t, err, trace.BadParameter("assume start time is too far in the future, latest time allowed %q",
maxAssumeStartDuration.Format(time.RFC3339)))
testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "start time too far in the future",
startTime: creation.Add(constants.MaxAssumeStartDuration + (1 * day)),
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorIs(tt, err, trace.BadParameter("assume start time is too far in the future, latest time allowed %q",
maxAssumeStartDuration.Format(time.RFC3339)))
},
},
{
name: "expired start time",
startTime: creation.Add(100 * day),
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorIs(t, err, trace.BadParameter("assume start time must be prior to access expiry time at: %q",
expiry.Format(time.RFC3339)))
},
},
{
name: "before creation start time",
startTime: creation.Add(-10 * day),
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorIs(t, err, trace.BadParameter("assume start time has to be after: %q",
creation.Format(time.RFC3339)))
},
},
{
name: "valid start time",
startTime: creation.Add(6 * day),
errCheck: require.NoError,
},
}

// Expired start time.
invalidExpiredAssumeStartTime := creation.Add(100 * day)
err = ValidateAssumeStartTime(invalidExpiredAssumeStartTime, expiry, creation)
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorIs(t, err, trace.BadParameter("assume start time cannot equal or exceed access expiry time at: %q",
expiry.Format(time.RFC3339)))

// Before creation start time.
invalidBeforeCreationStartTime := creation.Add(-10 * day)
err = ValidateAssumeStartTime(invalidBeforeCreationStartTime, expiry, creation)
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorIs(t, err, trace.BadParameter("assume start time has to be greater than: %q",
creation.Format(time.RFC3339)))

// Valid start time.
validStartTime := creation.Add(6 * day)
err = ValidateAssumeStartTime(validStartTime, expiry, creation)
require.NoError(t, err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateAssumeStartTime(tc.startTime, expiry, creation)
tc.errCheck(t, err)
})
}
}
264 changes: 168 additions & 96 deletions lib/auth/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,153 @@ func TestUpdateAccessRequestWithAdditionalReviewers(t *testing.T) {
}
}

func TestAccessRequest_AssumeStartTime(t *testing.T) {
func TestAssumeStartTime_CreateAccessRequestV2(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "too far in the future",
startTime: s.invalidMaxedAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time is too far in the future")
},
},
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time must be prior to access expiry time")
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, err := services.NewAccessRequest(s.requesterUserName, "admins")
require.NoError(t, err)
req.SetMaxDuration(s.maxDuration)
req.SetAssumeStartTime(tc.startTime)
_, err = s.requesterClient.CreateAccessRequestV2(ctx, req)
tc.errCheck(t, err)
})
}
}

func TestAssumeStartTime_SubmitAccessReview(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "too far in the future",
startTime: s.invalidMaxedAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time is too far in the future")
},
},
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time must be prior to access expiry time")
},
},
{
name: "valid submission",
startTime: s.validStartTime,
errCheck: require.NoError,
},
}
review := types.AccessReviewSubmission{
RequestID: s.createdRequest.GetName(),
Review: types.AccessReview{
Author: "admin",
ProposedState: types.RequestState_APPROVED,
},
}
for _, tc := range testCases {
review.Review.AssumeStartTime = &tc.startTime
resp, err := s.testPack.tlsServer.AuthServer.AuthServer.SubmitAccessReview(ctx, review)
tc.errCheck(t, err)
if err == nil {
require.Equal(t, tc.startTime, *resp.GetAssumeStartTime())
}
}
}

func TestAssumeStartTime_SetAccessRequestState(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "too far in the future",
startTime: s.invalidMaxedAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time is too far in the future")
},
},
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time must be prior to access expiry time")
},
},
{
name: "valid set state",
startTime: s.validStartTime,
errCheck: require.NoError,
},
}
update := types.AccessRequestUpdate{
RequestID: s.createdRequest.GetName(),
State: types.RequestState_APPROVED,
}
for _, tc := range testCases {
update.AssumeStartTime = &tc.startTime
err := s.testPack.tlsServer.Auth().SetAccessRequestState(ctx, update)
tc.errCheck(t, err)
if err == nil {
resp, err := s.testPack.tlsServer.AuthServer.AuthServer.GetAccessRequests(ctx, types.AccessRequestFilter{})
require.NoError(t, err)
require.Len(t, resp, 1)
require.Equal(t, tc.startTime, *resp[0].GetAssumeStartTime())
}
}
}

type accessRequestWithStartTime struct {
testPack *accessRequestTestPack
requesterClient *Client
invalidMaxedAssumeStartTime time.Time
invalidExpiredAssumeStartTime time.Time
validStartTime time.Time
maxDuration time.Time
requesterUserName string
createdRequest types.AccessRequest
}

func createAccessRequestWithStartTime(t *testing.T) accessRequestWithStartTime {
modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
Expand All @@ -1534,107 +1680,33 @@ func TestAccessRequest_AssumeStartTime(t *testing.T) {

t.Cleanup(func() { require.NoError(t, requesterClient.Close()) })

clock := clockwork.NewFakeClock()
now := clock.Now().UTC()
now := time.Now().UTC()
day := 24 * time.Hour

maxDuration := clock.Now().UTC().Add(12 * day)
maxDuration := time.Now().UTC().Add(12 * day)

invalidMaxedAssumeStartTime := now.Add(constants.MaxAssumeStartDuration + (1 * day))
invalidExpiredAssumeStartTime := now.Add(100 * day)
validStartTime := now.Add(6 * day)

var createdReq types.AccessRequest

t.Run("CreateAccessRequest, request a specific start time", func(t *testing.T) {
// create the access request object
req, err := services.NewAccessRequest(requesterUserName, "admins")
require.NoError(t, err)
req.SetMaxDuration(maxDuration)

// invalid, greater than constants.MaxAssumeStartDuration
req.SetAssumeStartTime(invalidMaxedAssumeStartTime)
_, err = requesterClient.CreateAccessRequestV2(ctx, req)
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time is too far in the future")

// invalid, after access expiry time
req.SetAssumeStartTime(invalidExpiredAssumeStartTime)
_, err = requesterClient.CreateAccessRequestV2(ctx, req)
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time cannot equal or exceed access expiry time")

// valid start time
req.SetAssumeStartTime(validStartTime)
createdReq, err = requesterClient.CreateAccessRequestV2(ctx, req)
require.NoError(t, err)
require.Equal(t, validStartTime, *createdReq.GetAssumeStartTime())
})

var changedStartTime time.Time
t.Run("SubmitAccessReview, initial change requested start time", func(t *testing.T) {
review := types.AccessReviewSubmission{
RequestID: createdReq.GetName(),
Review: types.AccessReview{
Author: "admin",
ProposedState: types.RequestState_APPROVED,
},
}

// invalid, greater than constants.MaxAssumeStartDuration
review.Review.AssumeStartTime = &invalidMaxedAssumeStartTime
_, err := testPack.tlsServer.AuthServer.AuthServer.SubmitAccessReview(ctx, review)
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time is too far in the future")

// invalid, after access expiry time
review.Review.AssumeStartTime = &invalidExpiredAssumeStartTime
_, err = testPack.tlsServer.AuthServer.AuthServer.SubmitAccessReview(ctx, review)
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time cannot equal or exceed access expiry time")

// valid, changed start time
changedStartTime = validStartTime.Add(-day * 2)
review.Review.AssumeStartTime = &changedStartTime
resp, err := testPack.tlsServer.AuthServer.AuthServer.SubmitAccessReview(ctx, review)
require.NoError(t, err)
require.Equal(t, changedStartTime, *resp.GetAssumeStartTime())
})

t.Run("SetAccessRequestState, subsequent change changed start time", func(t *testing.T) {
// double check current assume start time was from previous results.
resp, err := testPack.tlsServer.AuthServer.AuthServer.GetAccessRequests(ctx, types.AccessRequestFilter{})
require.NoError(t, err)
require.Len(t, resp, 1)
require.Equal(t, changedStartTime, *resp[0].GetAssumeStartTime())

update := types.AccessRequestUpdate{
RequestID: createdReq.GetName(),
State: types.RequestState_APPROVED,
}

// invalid, greater than constants.MaxAssumeStartDuration
update.AssumeStartTime = &invalidMaxedAssumeStartTime
err = testPack.tlsServer.Auth().SetAccessRequestState(ctx, update)
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time is too far in the future")

// invalid, after access expiry time
update.AssumeStartTime = &invalidExpiredAssumeStartTime
err = testPack.tlsServer.Auth().SetAccessRequestState(ctx, update)
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.Contains(t, err.Error(), "assume start time cannot equal or exceed access expiry time")

// valid, changed again start time
changedAgainStartTime := changedStartTime.Add(-day * 2)
update.AssumeStartTime = &changedAgainStartTime
err = testPack.tlsServer.Auth().SetAccessRequestState(ctx, update)
require.NoError(t, err)
// create the access request object
req, err := services.NewAccessRequest(requesterUserName, "admins")
require.NoError(t, err)
req.SetMaxDuration(maxDuration)

// double check access request was updated.
resp, err = testPack.tlsServer.AuthServer.AuthServer.GetAccessRequests(ctx, types.AccessRequestFilter{})
require.NoError(t, err)
require.Len(t, resp, 1)
require.Equal(t, changedAgainStartTime, *resp[0].GetAssumeStartTime())
})
req.SetAssumeStartTime(validStartTime)
createdReq, err := requesterClient.CreateAccessRequestV2(ctx, req)
require.NoError(t, err)
require.Equal(t, validStartTime, *createdReq.GetAssumeStartTime())

return accessRequestWithStartTime{
testPack: testPack,
requesterClient: requesterClient,
invalidMaxedAssumeStartTime: invalidMaxedAssumeStartTime,
invalidExpiredAssumeStartTime: invalidExpiredAssumeStartTime,
validStartTime: validStartTime,
maxDuration: maxDuration,
requesterUserName: requesterUserName,
createdRequest: createdReq,
}
}
2 changes: 1 addition & 1 deletion lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4867,7 +4867,7 @@ func (a *Server) submitAccessReview(

// final permission checks and review application must be done by the local backend
// service, as their validity depends upon optimistic locking.
req, err := a.ApplyAccessReview(ctx, params, checker, a.clock)
req, err := a.ApplyAccessReview(ctx, params, checker)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down

0 comments on commit 9852ae8

Please sign in to comment.