diff --git a/pkg/repositories/gormimpl/common.go b/pkg/repositories/gormimpl/common.go index bd0a65f75..737b83664 100644 --- a/pkg/repositories/gormimpl/common.go +++ b/pkg/repositories/gormimpl/common.go @@ -19,6 +19,7 @@ const Description = "description" const ResourceType = "resource_type" const State = "state" const ID = "id" +const CreatedAt = "created_at" const executionTableName = "executions" const namedEntityMetadataTableName = "named_entity_metadata" @@ -30,7 +31,7 @@ const taskTableName = "tasks" const limit = "limit" const filters = "filters" -var identifierGroupBy = fmt.Sprintf("%s, %s, %s", Project, Domain, Name) +var identifierGroupBy = fmt.Sprintf("%s, %s, %s, %s", Project, Domain, Name, CreatedAt) var entityToTableName = map[common.Entity]string{ common.Execution: "executions", diff --git a/pkg/repositories/gormimpl/launch_plan_repo.go b/pkg/repositories/gormimpl/launch_plan_repo.go index dc379ed03..a1462feca 100644 --- a/pkg/repositories/gormimpl/launch_plan_repo.go +++ b/pkg/repositories/gormimpl/launch_plan_repo.go @@ -166,7 +166,7 @@ func (r *LaunchPlanRepo) ListLaunchPlanIdentifiers(ctx context.Context, input in // Scan the results into a list of launch plans var launchPlans []models.LaunchPlan timer := r.metrics.ListIdentifiersDuration.Start() - tx.Select([]string{Project, Domain, Name}).Group(identifierGroupBy).Scan(&launchPlans) + tx.Distinct([]string{Project, Domain, Name}).Group(identifierGroupBy).Scan(&launchPlans) timer.Stop() if tx.Error != nil { return interfaces.LaunchPlanCollectionOutput{}, r.errorTransformer.ToFlyteAdminError(tx.Error) diff --git a/pkg/repositories/gormimpl/named_entity_repo.go b/pkg/repositories/gormimpl/named_entity_repo.go index 8e02390dd..3e7813284 100644 --- a/pkg/repositories/gormimpl/named_entity_repo.go +++ b/pkg/repositories/gormimpl/named_entity_repo.go @@ -4,15 +4,15 @@ import ( "context" "fmt" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "google.golang.org/grpc/codes" - "github.com/flyteorg/flyteadmin/pkg/common" adminErrors "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" + + "google.golang.org/grpc/codes" "gorm.io/gorm" ) @@ -65,10 +65,10 @@ var resourceTypeToMetadataJoin = map[core.ResourceType]string{ core.ResourceType_TASK: leftJoinTaskNameToMetadata, } -var getGroupByForNamedEntity = fmt.Sprintf("%s.%s, %s.%s, %s.%s, %s.%s, %s.%s", +var getGroupByForNamedEntity = fmt.Sprintf("%s.%s, %s.%s, %s.%s, %s.%s, %s.%s, %s.%s", innerJoinTableAlias, Project, innerJoinTableAlias, Domain, innerJoinTableAlias, Name, namedEntityMetadataTableName, Description, - namedEntityMetadataTableName, State) + namedEntityMetadataTableName, State, namedEntityMetadataTableName, CreatedAt) func getSelectForNamedEntity(tableName string, resourceType core.ResourceType) []string { return []string{ @@ -198,7 +198,7 @@ func (r *NamedEntityRepo) List(ctx context.Context, input interfaces.ListNamedEn var entities []models.NamedEntity timer := r.metrics.ListDuration.Start() - tx.Select(getSelectForNamedEntity(innerJoinTableAlias, input.ResourceType)).Table(namedEntityMetadataTableName).Group(getGroupByForNamedEntity).Scan(&entities) + tx.Distinct(getSelectForNamedEntity(innerJoinTableAlias, input.ResourceType)).Table(namedEntityMetadataTableName).Group(getGroupByForNamedEntity).Scan(&entities) timer.Stop() diff --git a/pkg/repositories/gormimpl/named_entity_repo_test.go b/pkg/repositories/gormimpl/named_entity_repo_test.go index d1867eca0..7824b3db8 100644 --- a/pkg/repositories/gormimpl/named_entity_repo_test.go +++ b/pkg/repositories/gormimpl/named_entity_repo_test.go @@ -2,18 +2,20 @@ package gormimpl import ( "context" + "fmt" "testing" "github.com/flyteorg/flyteadmin/pkg/common" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - - mocket "github.com/Selvatico/go-mocket" + adminErrors "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" mockScope "github.com/flyteorg/flytestdlib/promutils" + + mocket "github.com/Selvatico/go-mocket" "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" ) func getMockNamedEntityResponseFromDb(expected models.NamedEntity) map[string]interface{} { @@ -155,7 +157,7 @@ func TestListNamedEntity(t *testing.T) { mockQuery := GlobalMock.NewMock() mockQuery.WithQuery( - `SELECT entities.project,entities.domain,entities.name,'2' AS resource_type,named_entity_metadata.description,named_entity_metadata.state FROM "named_entity_metadata" RIGHT JOIN (SELECT project,domain,name FROM "workflows" WHERE "domain" = $1 AND "project" = $2 GROUP BY project, domain, name ORDER BY name desc LIMIT 20) AS entities ON named_entity_metadata.resource_type = 2 AND named_entity_metadata.project = entities.project AND named_entity_metadata.domain = entities.domain AND named_entity_metadata.name = entities.name GROUP BY entities.project, entities.domain, entities.name, named_entity_metadata.description, named_entity_metadata.state ORDER BY name desc`).WithReply(results) + `SELECT entities.project,entities.domain,entities.name,'2' AS resource_type,named_entity_metadata.description,named_entity_metadata.state FROM "named_entity_metadata" RIGHT JOIN (SELECT project,domain,name FROM "workflows" WHERE "domain" = $1 AND "project" = $2 GROUP BY project, domain, name, created_at ORDER BY name desc LIMIT 20) AS entities ON named_entity_metadata.resource_type = 2 AND named_entity_metadata.project = entities.project AND named_entity_metadata.domain = entities.domain AND named_entity_metadata.name = entities.name GROUP BY entities.project, entities.domain, entities.name, named_entity_metadata.description, named_entity_metadata.state, named_entity_metadata.created_at ORDER BY name desc`).WithReply(results) sortParameter, _ := common.NewSortParameter(admin.Sort{ Direction: admin.Sort_DESCENDING, @@ -173,3 +175,105 @@ func TestListNamedEntity(t *testing.T) { assert.NoError(t, err) assert.Len(t, output.Entities, 1) } + +func TestListNamedEntityTxErrorCases(t *testing.T) { + metadataRepo := NewNamedEntityRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + mockQuery := GlobalMock.NewMock() + + mockQuery.WithQuery( + `SELECT entities.project,entities.domain,entities.name,'2' AS resource_type,named_entity_metadata.description,named_entity_metadata.state FROM "named_entity_metadata" RIGHT JOIN (SELECT project,domain,name FROM "workflows" WHERE "domain" = $1 AND "project" = $2 GROUP BY project, domain, name, created_at ORDER BY name desc LIMIT 20) AS entities ON named_entity_metadata.resource_type = 2 AND named_entity_metadata.project = entities.project AND named_entity_metadata.domain = entities.domain AND named_entity_metadata.name = entities.name GROUP BY entities.project, entities.domain, entities.name, named_entity_metadata.description, named_entity_metadata.state, named_entity_metadata.created_at ORDER BY name desc%`).WithError(fmt.Errorf("failed")) + + sortParameter, _ := common.NewSortParameter(admin.Sort{ + Direction: admin.Sort_DESCENDING, + Key: "name", + }) + output, err := metadataRepo.List(context.Background(), interfaces.ListNamedEntityInput{ + ResourceType: resourceType, + Project: "admintests", + Domain: "development", + ListResourceInput: interfaces.ListResourceInput{ + Limit: 20, + SortParameter: sortParameter, + }, + }) + assert.Equal(t, "Test transformer failed to find transformation to apply", err.Error()) + assert.Len(t, output.Entities, 0) +} + +func TestListNamedEntityInputErrorCases(t *testing.T) { + type test struct { + input interfaces.ListNamedEntityInput + wantedError error + wantedLength int + } + + sortParameter, _ := common.NewSortParameter(admin.Sort{ + Direction: admin.Sort_DESCENDING, + Key: "name", + }) + + tests := []test{ + { + input: interfaces.ListNamedEntityInput{ + ResourceType: resourceType, + Project: "", + Domain: "development", + ListResourceInput: interfaces.ListResourceInput{ + Limit: 20, + SortParameter: sortParameter, + }, + }, + wantedError: errors.GetInvalidInputError(Project), + wantedLength: 0, + }, + { + input: interfaces.ListNamedEntityInput{ + ResourceType: resourceType, + Project: "project", + Domain: "", + ListResourceInput: interfaces.ListResourceInput{ + Limit: 20, + SortParameter: sortParameter, + }, + }, + wantedError: errors.GetInvalidInputError(Domain), + wantedLength: 0, + }, + { + input: interfaces.ListNamedEntityInput{ + ResourceType: resourceType, + Project: "project", + Domain: "development", + ListResourceInput: interfaces.ListResourceInput{ + Limit: 0, + SortParameter: sortParameter, + }, + }, + wantedError: errors.GetInvalidInputError(limit), + wantedLength: 0, + }, + { + input: interfaces.ListNamedEntityInput{ + ResourceType: -1, + Project: "project", + Domain: "development", + ListResourceInput: interfaces.ListResourceInput{ + Limit: 20, + SortParameter: sortParameter, + }, + }, + wantedError: adminErrors.NewFlyteAdminErrorf(codes.InvalidArgument, + "Cannot list entity names for resource type: %v", -1), + wantedLength: 0, + }, + } + + metadataRepo := NewNamedEntityRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + for _, tc := range tests { + output, err := metadataRepo.List(context.Background(), tc.input) + assert.Len(t, output.Entities, tc.wantedLength) + assert.Equal(t, tc.wantedError, err) + } +} diff --git a/pkg/repositories/gormimpl/task_repo.go b/pkg/repositories/gormimpl/task_repo.go index f48c6ca11..ecfbc8a28 100644 --- a/pkg/repositories/gormimpl/task_repo.go +++ b/pkg/repositories/gormimpl/task_repo.go @@ -113,7 +113,7 @@ func (r *TaskRepo) ListTaskIdentifiers(ctx context.Context, input interfaces.Lis // Scan the results into a list of tasks var tasks []models.Task timer := r.metrics.ListIdentifiersDuration.Start() - tx.Select([]string{Project, Domain, Name}).Group(identifierGroupBy).Scan(&tasks) + tx.Distinct([]string{Project, Domain, Name}).Group(identifierGroupBy).Scan(&tasks) timer.Stop() if tx.Error != nil { return interfaces.TaskCollectionOutput{}, r.errorTransformer.ToFlyteAdminError(tx.Error) diff --git a/pkg/repositories/gormimpl/workflow_repo.go b/pkg/repositories/gormimpl/workflow_repo.go index 2c78cb2c4..69405b367 100644 --- a/pkg/repositories/gormimpl/workflow_repo.go +++ b/pkg/repositories/gormimpl/workflow_repo.go @@ -110,7 +110,7 @@ func (r *WorkflowRepo) ListIdentifiers(ctx context.Context, input interfaces.Lis // Scan the results into a list of workflows var workflows []models.Workflow timer := r.metrics.ListIdentifiersDuration.Start() - tx.Select([]string{Project, Domain, Name}).Group(identifierGroupBy).Scan(&workflows) + tx.Distinct([]string{Project, Domain, Name}).Group(identifierGroupBy).Scan(&workflows) timer.Stop() if tx.Error != nil { return interfaces.WorkflowCollectionOutput{}, r.errorTransformer.ToFlyteAdminError(tx.Error)