Skip to content

Commit

Permalink
Always force org in database where clauses (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Feb 9, 2024
1 parent 47d71ac commit bdbfb5e
Show file tree
Hide file tree
Showing 26 changed files with 132 additions and 109 deletions.
29 changes: 21 additions & 8 deletions flyteadmin/pkg/manager/impl/project_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,23 @@ func getMockApplicationConfigForProjectManagerTest() runtimeInterfaces.Applicati
return &mockApplicationConfig
}

func testListProjects(request admin.ProjectListRequest, token string, orderExpr string, queryExpr *common.GormQueryExpr, t *testing.T) {
func expectedOrgQueryExpr() *common.GormQueryExpr {
return &common.GormQueryExpr{
Query: "org = ?",
Args: "",
}
}

func testListProjects(request admin.ProjectListRequest, token string, orderExpr string, queryExprs []*common.GormQueryExpr, t *testing.T) {
repository := repositoryMocks.NewMockRepository()
repository.ProjectRepo().(*repositoryMocks.MockProjectRepo).ListProjectsFunction = func(
ctx context.Context, input interfaces.ListResourceInput) ([]models.Project, error) {
if len(input.InlineFilters) != 0 {
q, _ := input.InlineFilters[0].GetGormQueryExpr()
assert.Equal(t, *queryExpr, q)
for idx, inlineFilter := range input.InlineFilters {
q, _ := inlineFilter.GetGormQueryExpr()
assert.Equal(t, *queryExprs[idx], q)
}

}
assert.Equal(t, orderExpr, input.SortParameter.GetGormOrderExpr())
activeState := int32(admin.Project_ACTIVE)
Expand Down Expand Up @@ -82,7 +92,7 @@ func TestListProjects_NoFilters_LimitOne(t *testing.T) {
testListProjects(admin.ProjectListRequest{
Token: "1",
Limit: 1,
}, "2", "identifier asc", nil, t)
}, "2", "identifier asc", []*common.GormQueryExpr{expectedOrgQueryExpr()}, t)
}

func TestListProjects_HighLimit_SortBy_Filter(t *testing.T) {
Expand All @@ -94,14 +104,17 @@ func TestListProjects_HighLimit_SortBy_Filter(t *testing.T) {
Key: "name",
Direction: admin.Sort_DESCENDING,
},
}, "", "name desc", &common.GormQueryExpr{
Query: "name = ?",
Args: "foo",
}, "", "name desc", []*common.GormQueryExpr{
expectedOrgQueryExpr(),
{
Query: "name = ?",
Args: "foo",
},
}, t)
}

func TestListProjects_NoToken_NoLimit(t *testing.T) {
testListProjects(admin.ProjectListRequest{}, "", "identifier asc", nil, t)
testListProjects(admin.ProjectListRequest{}, "", "identifier asc", []*common.GormQueryExpr{expectedOrgQueryExpr()}, t)
}

func TestProjectManager_CreateProject(t *testing.T) {
Expand Down
11 changes: 5 additions & 6 deletions flyteadmin/pkg/manager/impl/util/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,12 @@ type FilterSpec struct {
func getIdentifierFilters(entity common.Entity, spec FilterSpec) ([]common.InlineFilter, error) {
filters := make([]common.InlineFilter, 0)

if spec.Org != "" {
orgFilter, err := GetSingleValueEqualityFilter(entity, shared.Org, spec.Org)
if err != nil {
return nil, err
}
filters = append(filters, orgFilter)
// Always apply the org filter even when it's omitted
orgFilter, err := GetSingleValueEqualityFilter(entity, shared.Org, spec.Org)
if err != nil {
return nil, err
}
filters = append(filters, orgFilter)

if spec.Project != "" {
projectFilter, err := GetSingleValueEqualityFilter(entity, shared.Project, spec.Project)
Expand Down
2 changes: 2 additions & 0 deletions flyteadmin/pkg/manager/impl/util/filters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,14 @@ func TestGetDbFilters(t *testing.T) {
assert.NoError(t, err)

// Init expected values for filters.
orgFilter, _ := GetSingleValueEqualityFilter(common.LaunchPlan, shared.Org, "")
projectFilter, _ := GetSingleValueEqualityFilter(common.LaunchPlan, shared.Project, "project")
domainFilter, _ := GetSingleValueEqualityFilter(common.LaunchPlan, shared.Domain, "domain")
nameFilter, _ := GetSingleValueEqualityFilter(common.LaunchPlan, shared.Name, "name")
versionFilter, _ := common.NewSingleValueFilter(common.LaunchPlan, common.NotEqual, shared.Version, "TheWorst")
workflowNameFilter, _ := common.NewSingleValueFilter(common.Workflow, common.Equal, shared.Name, "workflow")
expectedFilters := []common.InlineFilter{
orgFilter,
projectFilter,
domainFilter,
nameFilter,
Expand Down
13 changes: 13 additions & 0 deletions flyteadmin/pkg/repositories/gormimpl/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,16 @@ func applyScopedFilters(tx *gorm.DB, inlineFilters []common.InlineFilter, mapFil
}
return tx, nil
}

const (
orgColumn = "org"
executionOrgColumn = "execution_org"
)

func getOrgFilter(org string) map[string]interface{} {
return map[string]interface{}{orgColumn: org}
}

func getExecutionOrgFilter(executionOrg string) map[string]interface{} {
return map[string]interface{}{executionOrgColumn: executionOrg}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,33 @@ func TestGetDescriptionEntity(t *testing.T) {
assert.Equal(t, shortDescription, output.ShortDescription)
}

func TestGetDescriptionEntityNoOrg(t *testing.T) {
descriptionEntityRepo := NewDescriptionEntityRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope())

descriptionEntities := make([]map[string]interface{}, 0)
descriptionEntity := getMockDescriptionEntityResponseFromDb(version, []byte{1, 2})
descriptionEntities = append(descriptionEntities, descriptionEntity)

GlobalMock := mocket.Catcher.Reset()
GlobalMock.Logging = true
// Only match on queries that append expected filters
GlobalMock.NewMock().WithQuery(`SELECT * FROM "description_entities" WHERE project = $1 AND domain = $2 AND name = $3 AND version = $4 AND org = $5 LIMIT 1`).
WithReply(descriptionEntities)
output, err := descriptionEntityRepo.Get(context.Background(), interfaces.GetDescriptionEntityInput{
ResourceType: resourceType,
Project: project,
Domain: domain,
Name: name,
Version: version,
})
assert.Empty(t, err)
assert.Equal(t, project, output.Project)
assert.Equal(t, domain, output.Domain)
assert.Equal(t, name, output.Name)
assert.Equal(t, version, output.Version)
assert.Equal(t, shortDescription, output.ShortDescription)
}

func TestListDescriptionEntities(t *testing.T) {
descriptionEntityRepo := NewDescriptionEntityRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope())

Expand All @@ -65,7 +92,8 @@ func TestListDescriptionEntities(t *testing.T) {
}

GlobalMock := mocket.Catcher.Reset()
GlobalMock.NewMock().WithReply(descriptionEntities)
GlobalMock.Logging = true
GlobalMock.NewMock().WithQuery("SELECT * FROM \"description_entities\" WHERE project = $1 AND domain = $2 AND name = $3 AND org = $4").WithReply(descriptionEntities)

collection, err := descriptionEntityRepo.List(context.Background(), interfaces.ListResourceInput{})
assert.Equal(t, 0, len(collection.Entities))
Expand All @@ -76,6 +104,7 @@ func TestListDescriptionEntities(t *testing.T) {
getEqualityFilter(common.Workflow, "project", project),
getEqualityFilter(common.Workflow, "domain", domain),
getEqualityFilter(common.Workflow, "name", name),
getEqualityFilter(common.Workflow, "org", ""),
},
Limit: 20,
})
Expand Down
5 changes: 2 additions & 3 deletions flyteadmin/pkg/repositories/gormimpl/execution_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ func (r *ExecutionRepo) Get(ctx context.Context, input interfaces.Identifier) (m
Project: input.Project,
Domain: input.Domain,
Name: input.Name,
Org: input.Org,
},
}).Take(&execution)
}).Where(getExecutionOrgFilter(input.Org)).Take(&execution)
timer.Stop()

if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) {
Expand All @@ -60,7 +59,7 @@ func (r *ExecutionRepo) Get(ctx context.Context, input interfaces.Identifier) (m

func (r *ExecutionRepo) Update(ctx context.Context, execution models.Execution) error {
timer := r.metrics.UpdateDuration.Start()
tx := r.db.WithContext(ctx).Model(&execution).Updates(execution)
tx := r.db.WithContext(ctx).Model(&execution).Where(getExecutionOrgFilter(execution.Org)).Updates(execution)
timer.Stop()
if err := tx.Error; err != nil {
return r.errorTransformer.ToFlyteAdminError(err)
Expand Down
8 changes: 2 additions & 6 deletions flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ func TestUpdateExecution(t *testing.T) {
updated := false

// Only match on queries that append expected filters
GlobalMock.NewMock().WithQuery(`UPDATE "executions" SET "updated_at"=$1,"execution_project"=$2,` +
`"execution_domain"=$3,"execution_name"=$4,"launch_plan_id"=$5,"workflow_id"=$6,"phase"=$7,"closure"=$8,` +
`"spec"=$9,"started_at"=$10,"execution_created_at"=$11,"execution_updated_at"=$12,"duration"=$13 WHERE "` +
`execution_project" = $14 AND "execution_domain" = $15 AND "execution_name" = $16`).WithCallback(
GlobalMock.NewMock().WithQuery(`UPDATE "executions" SET "updated_at"=$1,"execution_project"=$2,"execution_domain"=$3,"execution_name"=$4,"launch_plan_id"=$5,"workflow_id"=$6,"phase"=$7,"closure"=$8,"spec"=$9,"started_at"=$10,"execution_created_at"=$11,"execution_updated_at"=$12,"duration"=$13 WHERE "execution_org" = $14 AND "execution_project" = $15 AND "execution_domain" = $16 AND "execution_name" = $17`).WithCallback(
func(s string, values []driver.NamedValue) {
updated = true
},
Expand Down Expand Up @@ -129,13 +126,12 @@ func TestGetExecution(t *testing.T) {
GlobalMock.Logging = true

// Only match on queries that append expected filters
GlobalMock.NewMock().WithQuery(`SELECT * FROM "executions" WHERE "executions"."execution_project" = $1 AND "executions"."execution_domain" = $2 AND "executions"."execution_name" = $3 AND "executions"."execution_org" = $4 LIMIT 1`).WithReply(executions)
GlobalMock.NewMock().WithQuery(`SELECT * FROM "executions" WHERE "executions"."execution_project" = $1 AND "executions"."execution_domain" = $2 AND "executions"."execution_name" = $3 AND "execution_org" = $4 LIMIT 1`).WithReply(executions)

output, err := executionRepo.Get(context.Background(), interfaces.Identifier{
Project: "project",
Domain: "domain",
Name: "1",
Org: testOrg,
})
assert.NoError(t, err)
assert.EqualValues(t, expectedExecution, output)
Expand Down
9 changes: 4 additions & 5 deletions flyteadmin/pkg/repositories/gormimpl/launch_plan_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (r *LaunchPlanRepo) Create(ctx context.Context, input models.LaunchPlan) er

func (r *LaunchPlanRepo) Update(ctx context.Context, input models.LaunchPlan) error {
timer := r.metrics.UpdateDuration.Start()
tx := r.db.WithContext(ctx).Model(&input).Updates(input)
tx := r.db.WithContext(ctx).Model(&input).Where(getOrgFilter(input.Org)).Updates(input)
timer.Stop()
if err := tx.Error; err != nil {
return r.errorTransformer.ToFlyteAdminError(err)
Expand All @@ -58,9 +58,8 @@ func (r *LaunchPlanRepo) Get(ctx context.Context, input interfaces.Identifier) (
Domain: input.Domain,
Name: input.Name,
Version: input.Version,
Org: input.Org,
},
}).Take(&launchPlan)
}).Where(getOrgFilter(input.Org)).Take(&launchPlan)
timer.Stop()

if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) {
Expand Down Expand Up @@ -90,15 +89,15 @@ func (r *LaunchPlanRepo) SetActive(

// There is a launch plan to disable as part of this transaction
if toDisable != nil {
tx.Model(&toDisable).UpdateColumns(toDisable)
tx.Model(&toDisable).Where(getOrgFilter(toDisable.Org)).UpdateColumns(toDisable)
if err := tx.Error; err != nil {
tx.Rollback()
return r.errorTransformer.ToFlyteAdminError(err)
}
}

// And update the desired version.
tx.Model(&toEnable).UpdateColumns(toEnable)
tx.Model(&toEnable).Where(getOrgFilter(toEnable.Org)).UpdateColumns(toEnable)
if err := tx.Error; err != nil {
tx.Rollback()
return r.errorTransformer.ToFlyteAdminError(err)
Expand Down
11 changes: 4 additions & 7 deletions flyteadmin/pkg/repositories/gormimpl/launch_plan_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ func TestGetLaunchPlan(t *testing.T) {
Domain: domain,
Name: name,
Version: version,
Org: testOrg,
},
Spec: launchPlanSpec,
WorkflowID: workflowID,
Expand All @@ -77,20 +76,18 @@ func TestGetLaunchPlan(t *testing.T) {
GlobalMock.Logging = true
// Only match on queries that append expected filters
GlobalMock.NewMock().WithQuery(
`SELECT * FROM "launch_plans" WHERE "launch_plans"."project" = $1 AND "launch_plans"."domain" = $2 AND "launch_plans"."name" = $3 AND "launch_plans"."version" = $4 AND "launch_plans"."org" = $5 LIMIT 1`).WithReply(launchPlans)
`SELECT * FROM "launch_plans" WHERE "launch_plans"."project" = $1 AND "launch_plans"."domain" = $2 AND "launch_plans"."name" = $3 AND "launch_plans"."version" = $4 AND "org" = $5 LIMIT 1`).WithReply(launchPlans)
output, err := launchPlanRepo.Get(context.Background(), interfaces.Identifier{
Project: project,
Domain: domain,
Name: name,
Version: version,
Org: testOrg,
})
assert.NoError(t, err)
assert.Equal(t, project, output.Project)
assert.Equal(t, domain, output.Domain)
assert.Equal(t, name, output.Name)
assert.Equal(t, version, output.Version)
assert.Equal(t, testOrg, output.Org)
assert.Equal(t, launchPlanSpec, output.Spec)
}

Expand All @@ -102,7 +99,7 @@ func TestSetInactiveLaunchPlan(t *testing.T) {
mockDb := GlobalMock.NewMock()
updated := false
mockDb.WithQuery(
`UPDATE "launch_plans" SET "id"=$1,"updated_at"=$2,"project"=$3,"domain"=$4,"name"=$5,"version"=$6,"closure"=$7,"state"=$8 WHERE "project" = $9 AND "domain" = $10 AND "name" = $11 AND "version" = $12`).WithCallback(
`UPDATE "launch_plans" SET "id"=$1,"updated_at"=$2,"project"=$3,"domain"=$4,"name"=$5,"version"=$6,"closure"=$7,"state"=$8 WHERE "org" = $9 AND "project" = $10 AND "domain" = $11 AND "name" = $12 AND "version" = $13`).WithCallback(
func(s string, values []driver.NamedValue) {
updated = true
},
Expand Down Expand Up @@ -133,7 +130,7 @@ func TestSetActiveLaunchPlan(t *testing.T) {
mockQuery := GlobalMock.NewMock()
updated := false
mockQuery.WithQuery(
`UPDATE "launch_plans" SET "id"=$1,"project"=$2,"domain"=$3,"name"=$4,"version"=$5,"closure"=$6,"state"=$7 WHERE "project" = $8 AND "domain" = $9 AND "name" = $10 AND "version" = $11`).WithCallback(
`UPDATE "launch_plans" SET "id"=$1,"project"=$2,"domain"=$3,"name"=$4,"version"=$5,"closure"=$6,"state"=$7 WHERE "org" = $8 AND "project" = $9 AND "domain" = $10 AND "name" = $11 AND "version" = $12`).WithCallback(
func(s string, values []driver.NamedValue) {
updated = true
},
Expand Down Expand Up @@ -176,7 +173,7 @@ func TestSetActiveLaunchPlan_NoCurrentlyActiveLaunchPlan(t *testing.T) {
mockQuery := GlobalMock.NewMock()
updated := false
mockQuery.WithQuery(
`UPDATE "launch_plans" SET "id"=$1,"project"=$2,"domain"=$3,"name"=$4,"version"=$5,"closure"=$6,"state"=$7 WHERE "project" = $8 AND "domain" = $9 AND "name" = $10 AND "version" = $11`).WithCallback(
`UPDATE "launch_plans" SET "id"=$1,"project"=$2,"domain"=$3,"name"=$4,"version"=$5,"closure"=$6,"state"=$7 WHERE "org" = $8 AND "project" = $9 AND "domain" = $10 AND "name" = $11 AND "version" = $12`).WithCallback(
func(s string, values []driver.NamedValue) {
updated = true
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ func TestGetNamedEntity(t *testing.T) {
Project: project,
Domain: domain,
Name: name,
Org: testOrg,
},
NamedEntityMetadataFields: models.NamedEntityMetadataFields{
Description: description,
Expand All @@ -62,7 +61,6 @@ func TestGetNamedEntity(t *testing.T) {
assert.Equal(t, name, output.Name)
assert.Equal(t, resourceType, output.ResourceType)
assert.Equal(t, description, output.Description)
assert.Equal(t, testOrg, output.Org)
}

func TestUpdateNamedEntity_WithExisting(t *testing.T) {
Expand Down
32 changes: 6 additions & 26 deletions flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
adminErrors "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/errors"
"github.com/flyteorg/flyte/flyteadmin/pkg/repositories/interfaces"
"github.com/flyteorg/flyte/flyteadmin/pkg/repositories/models"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytestdlib/promutils"
)

Expand Down Expand Up @@ -41,23 +40,14 @@ func (r *NodeExecutionRepo) Get(ctx context.Context, input interfaces.NodeExecut
Project: input.NodeExecutionIdentifier.ExecutionId.Project,
Domain: input.NodeExecutionIdentifier.ExecutionId.Domain,
Name: input.NodeExecutionIdentifier.ExecutionId.Name,
Org: input.NodeExecutionIdentifier.ExecutionId.Org,
},
},
}).Take(&nodeExecution)
}).Where(getExecutionOrgFilter(input.NodeExecutionIdentifier.ExecutionId.Org)).Take(&nodeExecution)
timer.Stop()

if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return models.NodeExecution{},
adminErrors.GetMissingEntityError("node execution", &core.NodeExecutionIdentifier{
NodeId: input.NodeExecutionIdentifier.NodeId,
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: input.NodeExecutionIdentifier.ExecutionId.Project,
Domain: input.NodeExecutionIdentifier.ExecutionId.Domain,
Name: input.NodeExecutionIdentifier.ExecutionId.Name,
Org: input.NodeExecutionIdentifier.ExecutionId.Org,
},
})
adminErrors.GetMissingEntityError("node execution", &input.NodeExecutionIdentifier)
} else if tx.Error != nil {
return models.NodeExecution{}, r.errorTransformer.ToFlyteAdminError(tx.Error)
}
Expand All @@ -75,23 +65,14 @@ func (r *NodeExecutionRepo) GetWithChildren(ctx context.Context, input interface
Project: input.NodeExecutionIdentifier.ExecutionId.Project,
Domain: input.NodeExecutionIdentifier.ExecutionId.Domain,
Name: input.NodeExecutionIdentifier.ExecutionId.Name,
Org: input.NodeExecutionIdentifier.ExecutionId.Org,
},
},
}).Preload("ChildNodeExecutions").Take(&nodeExecution)
}).Where(getExecutionOrgFilter(input.NodeExecutionIdentifier.ExecutionId.Org)).Preload("ChildNodeExecutions").Take(&nodeExecution)
timer.Stop()

if tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return models.NodeExecution{},
adminErrors.GetMissingEntityError("node execution", &core.NodeExecutionIdentifier{
NodeId: input.NodeExecutionIdentifier.NodeId,
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: input.NodeExecutionIdentifier.ExecutionId.Project,
Domain: input.NodeExecutionIdentifier.ExecutionId.Domain,
Name: input.NodeExecutionIdentifier.ExecutionId.Name,
Org: input.NodeExecutionIdentifier.ExecutionId.Org,
},
})
adminErrors.GetMissingEntityError("node execution", &input.NodeExecutionIdentifier)
} else if tx.Error != nil {
return models.NodeExecution{}, r.errorTransformer.ToFlyteAdminError(tx.Error)
}
Expand All @@ -101,7 +82,7 @@ func (r *NodeExecutionRepo) GetWithChildren(ctx context.Context, input interface

func (r *NodeExecutionRepo) Update(ctx context.Context, nodeExecution *models.NodeExecution) error {
timer := r.metrics.UpdateDuration.Start()
tx := r.db.WithContext(ctx).Model(&nodeExecution).Updates(nodeExecution)
tx := r.db.WithContext(ctx).Model(&nodeExecution).Where(getExecutionOrgFilter(nodeExecution.Org)).Updates(nodeExecution)
timer.Stop()
if err := tx.Error; err != nil {
return r.errorTransformer.ToFlyteAdminError(err)
Expand Down Expand Up @@ -155,10 +136,9 @@ func (r *NodeExecutionRepo) Exists(ctx context.Context, input interfaces.NodeExe
Project: input.NodeExecutionIdentifier.ExecutionId.Project,
Domain: input.NodeExecutionIdentifier.ExecutionId.Domain,
Name: input.NodeExecutionIdentifier.ExecutionId.Name,
Org: input.NodeExecutionIdentifier.ExecutionId.Org,
},
},
}).Take(&nodeExecution)
}).Where(getExecutionOrgFilter(input.NodeExecutionIdentifier.ExecutionId.Org)).Take(&nodeExecution)
timer.Stop()
if tx.Error != nil {
return false, r.errorTransformer.ToFlyteAdminError(tx.Error)
Expand Down
Loading

0 comments on commit bdbfb5e

Please sign in to comment.