Skip to content

Commit

Permalink
Validate Project Requests on Create/Update (flyteorg#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
narape committed Nov 12, 2020
1 parent 463bfb6 commit 0fcb7cb
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 19 deletions.
5 changes: 2 additions & 3 deletions pkg/manager/impl/project_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ func (m *ProjectManager) UpdateProject(ctx context.Context, projectUpdate admin.
return nil, err
}

// Run validation on the request, specifically checking for labels, and return err if validation does not succeed.
err = validation.ValidateProjectLabels(projectUpdate)
if err != nil {
// Run validation on the request and return err if validation does not succeed.
if err := validation.ValidateProject(projectUpdate); err != nil {
return nil, err
}

Expand Down
72 changes: 72 additions & 0 deletions pkg/manager/impl/project_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"testing"

"github.com/lyft/flyteadmin/pkg/manager/impl/shared"

"github.com/lyft/flyteadmin/pkg/common"
"github.com/lyft/flyteadmin/pkg/manager/impl/testutils"
repositoryMocks "github.com/lyft/flyteadmin/pkg/repositories/mocks"
Expand Down Expand Up @@ -160,3 +162,73 @@ func TestProjectManager_CreateProjectErrorDueToBadLabels(t *testing.T) {
})
assert.EqualError(t, err, "invalid label value [#badlabel]: [a DNS-1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]")
}

func TestProjectManager_UpdateProject(t *testing.T) {
mockRepository := repositoryMocks.NewMockRepository()
var updateFuncCalled bool
mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func(
ctx context.Context, projectID string) (models.Project, error) {
return models.Project{Identifier: "project-id", Name: "old-project-name", Description: "old-project-description"}, nil
}
mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).UpdateProjectFunction = func(
ctx context.Context, projectUpdate models.Project) error {
updateFuncCalled = true
assert.Equal(t, "project-id", projectUpdate.Identifier)
assert.Equal(t, "new-project-name", projectUpdate.Name)
assert.Equal(t, "new-project-description", projectUpdate.Description)
return nil
}
projectManager := NewProjectManager(mockRepository,
runtimeMocks.NewMockConfigurationProvider(
getMockApplicationConfigForProjectManagerTest(), nil, nil, nil, nil, nil))
_, err := projectManager.UpdateProject(context.Background(), admin.Project{
Id: "project-id",
Name: "new-project-name",
Description: "new-project-description",
})
assert.Nil(t, err)
assert.True(t, updateFuncCalled)
}

func TestProjectManager_UpdateProject_ErrorDueToProjectNotFound(t *testing.T) {
mockRepository := repositoryMocks.NewMockRepository()
mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func(
ctx context.Context, projectID string) (models.Project, error) {
return models.Project{}, errors.New(projectID + " not found")
}
mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).UpdateProjectFunction = func(
ctx context.Context, projectUpdate models.Project) error {
assert.Fail(t, "No calls to UpdateProject were expected")
return nil
}
projectManager := NewProjectManager(mockRepository,
runtimeMocks.NewMockConfigurationProvider(
getMockApplicationConfigForProjectManagerTest(), nil, nil, nil, nil, nil))
_, err := projectManager.UpdateProject(context.Background(), admin.Project{
Id: "not-found-project-id",
Name: "not-found-project-name",
Description: "not-found-project-description",
})
assert.Equal(t, errors.New("not-found-project-id not found"), err)
}

func TestProjectManager_UpdateProject_ErrorDueToInvalidProjectName(t *testing.T) {
mockRepository := repositoryMocks.NewMockRepository()
mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func(
ctx context.Context, projectID string) (models.Project, error) {
return models.Project{Identifier: "project-id", Name: "old-project-name", Description: "old-project-description"}, nil
}
mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).UpdateProjectFunction = func(
ctx context.Context, projectUpdate models.Project) error {
assert.Fail(t, "No calls to UpdateProject were expected")
return nil
}
projectManager := NewProjectManager(mockRepository,
runtimeMocks.NewMockConfigurationProvider(
getMockApplicationConfigForProjectManagerTest(), nil, nil, nil, nil, nil))
_, err := projectManager.UpdateProject(context.Background(), admin.Project{
Id: "project-id",
// No project name
})
assert.Equal(t, shared.GetMissingArgumentError("project_name"), err)
}
44 changes: 29 additions & 15 deletions pkg/manager/impl/validation/project_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package validation
import (
"context"

"github.com/lyft/flyteadmin/pkg/errors"
"github.com/lyft/flyteadmin/pkg/manager/impl/shared"

"github.com/lyft/flyteadmin/pkg/errors"
"github.com/lyft/flyteadmin/pkg/repositories"
runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces"
"github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin"
Expand All @@ -15,36 +16,52 @@ import (
const projectID = "project_id"
const projectName = "project_name"
const projectDescription = "project_description"
const labels = "labels"
const maxNameLength = 64
const maxDescriptionLength = 300
const maxLabelArrayLength = 16

func ValidateProjectRegisterRequest(request admin.ProjectRegisterRequest) error {
if request.Project == nil {
return shared.GetMissingArgumentError(shared.Project)
}
if err := ValidateEmptyStringField(request.Project.Id, projectID); err != nil {
return ValidateProject(*request.Project)
}

func ValidateProject(project admin.Project) error {
if err := ValidateEmptyStringField(project.Id, projectID); err != nil {
return err
}
if err := ValidateProjectLabels(*request.Project); err != nil {
if err := validateProjectLabels(project); err != nil {
return err
}
if errs := validation.IsDNS1123Label(request.Project.Id); len(errs) > 0 {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid project id [%s]: %v", request.Project.Id, errs)
if errs := validation.IsDNS1123Label(project.Id); len(errs) > 0 {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid project id [%s]: %v", project.Id, errs)
}
if err := ValidateEmptyStringField(project.Name, projectName); err != nil {
return err
}
if err := ValidateEmptyStringField(request.Project.Name, projectName); err != nil {
if err := ValidateMaxLengthStringField(project.Name, projectName, maxNameLength); err != nil {
return err
}
if err := ValidateMaxLengthStringField(request.Project.Description, projectDescription, maxDescriptionLength); err != nil {
if err := ValidateMaxLengthStringField(project.Description, projectDescription, maxDescriptionLength); err != nil {
return err
}
if request.Project.Domains != nil {
if project.Domains != nil {
return errors.NewFlyteAdminError(codes.InvalidArgument,
"Domains are currently only set system wide. Please retry without domains included in your request.")
}
return nil
}

func ValidateProjectLabels(request admin.Project) error {
if err := ValidateProjectLabelsAlphanumeric(request); err != nil {
func validateProjectLabels(project admin.Project) error {
if project.Labels == nil || len(project.Labels.Values) == 0 {
return nil
}
if err := ValidateMaxMapLengthField(project.Labels.Values, labels, maxLabelArrayLength); err != nil {
return err
}
if err := validateProjectLabelsAlphanumeric(project.Labels); err != nil {
return err
}
return nil
Expand Down Expand Up @@ -79,11 +96,8 @@ func ValidateProjectAndDomain(

// Given an admin.Project, checks if the project has labels and if it does, checks if the labels are K8s compliant,
// i.e. alphanumeric + - and _
func ValidateProjectLabelsAlphanumeric(request admin.Project) error {
if request.Labels == nil || len(request.Labels.Values) == 0 {
return nil
}
for key, value := range request.Labels.Values {
func validateProjectLabelsAlphanumeric(labels *admin.Labels) error {
for key, value := range labels.Values {
if errs := validation.IsDNS1123Label(key); len(errs) > 0 {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid label key [%s]: %v", key, errs)
}
Expand Down
158 changes: 158 additions & 0 deletions pkg/manager/impl/validation/project_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package validation
import (
"context"
"errors"
"strconv"
"testing"

"github.com/lyft/flyteadmin/pkg/manager/impl/testutils"
Expand Down Expand Up @@ -64,6 +65,15 @@ func TestValidateProjectRegisterRequest(t *testing.T) {
},
expectedError: "missing project_name",
},
{
request: admin.ProjectRegisterRequest{
Project: &admin.Project{
Id: "proj",
Name: "longnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamel",
},
},
expectedError: "project_name cannot exceed 64 characters",
},
{
request: admin.ProjectRegisterRequest{
Project: &admin.Project{
Expand Down Expand Up @@ -93,6 +103,36 @@ func TestValidateProjectRegisterRequest(t *testing.T) {
},
expectedError: "project_description cannot exceed 300 characters",
},
{
request: admin.ProjectRegisterRequest{
Project: &admin.Project{
Id: "proj",
Name: "name",
Labels: &admin.Labels{
Values: map[string]string{
"#badkey": "foo",
"bar": "baz",
},
},
},
},
expectedError: "invalid label key [#badkey]: [a DNS-1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]",
},
{
request: admin.ProjectRegisterRequest{
Project: &admin.Project{
Id: "proj",
Name: "name",
Labels: &admin.Labels{
Values: map[string]string{
"foo": ".bad-label-value",
"bar": "baz",
},
},
},
},
expectedError: "invalid label value [.bad-label-value]: [a DNS-1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]",
},
}

for _, val := range testValues {
Expand All @@ -102,6 +142,124 @@ func TestValidateProjectRegisterRequest(t *testing.T) {
}
}

func TestValidateProject_ValidProject(t *testing.T) {
assert.Nil(t, ValidateProject(admin.Project{
Id: "proj",
Name: "proj",
Description: "An amazing description for this project",
Labels: &admin.Labels{
Values: map[string]string{
"foo": "bar",
},
},
}))
}

func TestValidateProject(t *testing.T) {
type testValue struct {
project admin.Project
expectedError string
}
testValues := []testValue{
{
project: admin.Project{
Name: "proj",
Domains: []*admin.Domain{
{
Id: "foo",
Name: "foo",
},
},
},
expectedError: "missing project_id",
},
{
project: admin.Project{
Id: "%)(*&",
Name: "proj",
},
expectedError: "invalid project id [%)(*&]: [a DNS-1123 label must consist of lower case alphanumeric " +
"characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or " +
"'123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]",
},
{
project: admin.Project{
Id: "proj",
},
expectedError: "missing project_name",
},
{
project: admin.Project{
Id: "proj",
Name: "longnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamel",
},
expectedError: "project_name cannot exceed 64 characters",
},
{
project: admin.Project{
Id: "proj",
Name: "proj",
Domains: []*admin.Domain{
{
Id: "foo",
Name: "foo",
},
{
Id: "foo",
},
},
},
expectedError: "Domains are currently only set system wide. Please retry without domains included in your request.",
},
{
project: admin.Project{
Id: "proj",
Name: "name",
// 301 character string
Description: "longnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongn",
},
expectedError: "project_description cannot exceed 300 characters",
},
{
project: admin.Project{
Id: "proj",
Name: "name",
Labels: &admin.Labels{
Values: createLabelsMap(17),
},
},
expectedError: "labels map cannot exceed 16 entries",
},
{
project: admin.Project{
Id: "proj",
Name: "name",
Labels: &admin.Labels{
Values: map[string]string{
"#badkey": "foo",
"bar": "baz",
},
},
},
expectedError: "invalid label key [#badkey]: [a DNS-1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]",
},
}

for _, val := range testValues {
t.Run(val.expectedError, func(t *testing.T) {
assert.EqualError(t, ValidateProject(val.project), val.expectedError)
})
}
}

func createLabelsMap(size int) map[string]string {
result := make(map[string]string, size)
for i := 0; i < size; i++ {
result["key-"+strconv.Itoa(i)] = "value"
}
return result
}

func TestValidateProjectAndDomain(t *testing.T) {
mockRepo := repositoryMocks.NewMockRepository()
mockRepo.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func(
Expand Down
8 changes: 8 additions & 0 deletions pkg/manager/impl/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ func ValidateMaxLengthStringField(field string, fieldName string, limit int) err
return nil
}

// Validates that a map field does not exceed a certain amount of entries
func ValidateMaxMapLengthField(m map[string]string, fieldName string, limit int) error {
if len(m) > limit {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "%s map cannot exceed %d entries", fieldName, limit)
}
return nil
}

func ValidateIdentifierFieldsSet(id *core.Identifier) error {
if id == nil {
return shared.GetMissingArgumentError(shared.ID)
Expand Down
Loading

0 comments on commit 0fcb7cb

Please sign in to comment.