Skip to content

Commit

Permalink
fix: use random string instead of incremental suffix, add concurrency…
Browse files Browse the repository at this point in the history
… testing, adjust transaction isolation level
  • Loading branch information
corban-beaird committed May 10, 2024
1 parent 6ef64b5 commit a165316
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 71 deletions.
50 changes: 32 additions & 18 deletions master/internal/api_project.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package internal
import (
"context"
"fmt"
"regexp"
"sort"
"strings"

"github.com/uptrace/bun"

Expand All @@ -29,6 +29,11 @@ import (
"github.com/determined-ai/determined/proto/pkg/workspacev1"
)

const (
// ProjectKeyRegex is the regex pattern for a project key.
ProjectKeyRegex = "^[A-Z0-9]{5}$"
)

func (a *apiServer) GetProjectByID(
ctx context.Context, id int32, curUser model.User,
) (*projectv1.Project, error) {
Expand Down Expand Up @@ -521,6 +526,19 @@ func (a *apiServer) getProjectNumericMetricsRange(
return metricsValues, searcherMetricsValue, nil
}

func validateProjectKey(key string) error {
switch {
case len(key) > project.MaxProjectKeyLength:
return errors.Errorf("project key cannot be longer than %d characters", project.MaxProjectKeyLength)
case len(key) < 1:
return errors.New("project key cannot be empty")
case !regexp.MustCompile(ProjectKeyRegex).MatchString(key):
return errors.Errorf("project key can only contain alphanumeric characters")
default:
return nil
}
}

func (a *apiServer) PostProject(
ctx context.Context, req *apiv1.PostProjectRequest,
) (*apiv1.PostProjectResponse, error) {
Expand All @@ -536,28 +554,24 @@ func (a *apiServer) PostProject(
return nil, status.Error(codes.PermissionDenied, err.Error())
}

var projectKey string
if req.Key == nil {
projectKey, err = project.GenerateProjectKey(ctx, req.Name)
if err != nil {
return nil, fmt.Errorf("error generating project key: %w", err)
if req.Key != nil {
if err = validateProjectKey(*req.Key); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
} else {
projectKey = *req.Key
}

p := &projectv1.Project{}
err = a.m.db.QueryProto("insert_project", p, req.Name, req.Description,
req.WorkspaceId, curUser.ID, projectKey)

if err != nil && strings.Contains(err.Error(), db.CodeUniqueViolation) {
if strings.Contains(err.Error(), "projects_key_key") {
return nil,
status.Errorf(codes.AlreadyExists, "project with key %s already exists", projectKey)
}
p := &model.Project{
Name: req.Name,
Description: req.Description,
WorkspaceID: int(req.WorkspaceId),
UserID: int(curUser.ID),
Username: curUser.Username,
}

return &apiv1.PostProjectResponse{Project: p},
if err = project.InsertProject(ctx, p, req.Key); err != nil {
return nil, err
}
return &apiv1.PostProjectResponse{Project: p.Proto()},
errors.Wrapf(err, "error creating project %s in database", req.Name)
}

Expand Down
51 changes: 35 additions & 16 deletions master/internal/api_project_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package internal
import (
"context"
"fmt"
"strings"
"testing"

"google.golang.org/grpc/codes"
Expand All @@ -24,6 +25,7 @@ import (
"github.com/determined-ai/determined/master/internal/mocks"
"github.com/determined-ai/determined/master/internal/project"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/syncx/errgroupx"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/projectv1"
"github.com/determined-ai/determined/proto/pkg/rbacv1"
Expand Down Expand Up @@ -378,21 +380,12 @@ func TestCreateProjectWithoutProjectKey(t *testing.T) {
require.NoError(t, werr)

projectName := "test-project" + uuid.New().String()
projectKeyPrefix := projectName[:3]
projectKeyPrefix := strings.ToUpper(projectName[:project.MaxProjectKeyPrefixLength])
resp, err := api.PostProject(ctx, &apiv1.PostProjectRequest{
Name: projectName, WorkspaceId: wresp.Workspace.Id,
})
require.NoError(t, err)

// Check that the project key is generated correctly.
countPostFix := 0
err = db.Bun().NewSelect().
ColumnExpr("COUNT(*)").
Table("projects").
Where("key ILIKE ?", (projectKeyPrefix+"%")).
Scan(ctx, &countPostFix)
require.NoError(t, err)
require.Equal(t, (projectKeyPrefix + fmt.Sprintf("%d", countPostFix)), resp.Project.Key)
require.Equal(t, projectKeyPrefix, resp.Project.Key[:project.MaxProjectKeyPrefixLength])
}

func TestCreateProjectWithProjectKey(t *testing.T) {
Expand All @@ -401,7 +394,7 @@ func TestCreateProjectWithProjectKey(t *testing.T) {
require.NoError(t, werr)

projectName := "test-project" + uuid.New().String()
projectKey := uuid.New().String()[:5]
projectKey := uuid.New().String()[:project.MaxProjectKeyLength]
resp, err := api.PostProject(ctx, &apiv1.PostProjectRequest{
Name: projectName, WorkspaceId: wresp.Workspace.Id, Key: &projectKey,
})
Expand All @@ -423,7 +416,7 @@ func TestCreateProjectWithDuplicateProjectKey(t *testing.T) {
require.NoError(t, werr)

projectName := "test-project" + uuid.New().String()
projectKey := uuid.New().String()[:5]
projectKey := uuid.New().String()[:project.MaxProjectKeyLength]
_, err := api.PostProject(ctx, &apiv1.PostProjectRequest{
Name: projectName, WorkspaceId: wresp.Workspace.Id, Key: &projectKey,
})
Expand All @@ -442,17 +435,43 @@ func TestCreateProjectWithDefaultKeyAndDuplicatePrefix(t *testing.T) {
require.NoError(t, werr)

projectName := uuid.New().String()
projectKeyPrefix := projectName[:3]
projectKeyPrefix := strings.ToUpper(projectName[:project.MaxProjectKeyPrefixLength])
resp1, err := api.PostProject(ctx, &apiv1.PostProjectRequest{
Name: projectName, WorkspaceId: wresp.Workspace.Id,
})
require.NoError(t, err)
require.Equal(t, (projectKeyPrefix + "1"), resp1.Project.Key)
require.Equal(t, projectKeyPrefix, resp1.Project.Key[:project.MaxProjectKeyPrefixLength])

resp2, err := api.PostProject(ctx, &apiv1.PostProjectRequest{
Name: projectName + "2", WorkspaceId: wresp.Workspace.Id,
})
require.NoError(t, err)
require.NoError(t, err)
require.Equal(t, (projectKeyPrefix + "2"), resp2.Project.Key)
require.Equal(t, projectKeyPrefix, resp2.Project.Key[:project.MaxProjectKeyPrefixLength])
}

func TestConcurrentProjectKeyGenerationAttempts(t *testing.T) {
api, _, ctx := setupAPITest(t, nil)
wresp, werr := api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{Name: uuid.New().String()})
require.NoError(t, werr)
for x := 0; x < 20; x++ {
errgrp := errgroupx.WithContext(ctx)
for i := 0; i < 20; i++ {
projectName := "test-project" + uuid.New().String()
errgrp.Go(func(context.Context) error {
_, err := api.PostProject(ctx, &apiv1.PostProjectRequest{
Name: projectName, WorkspaceId: wresp.Workspace.Id,
})
require.NoError(t, err)
return err
})
}
require.NoError(t, errgrp.Wait())
t.Cleanup(func() {
_, err := db.Bun().NewDelete().Table("projects").Where("workspace_id = ?", wresp.Workspace.Id).Exec(ctx)
require.NoError(t, err)
_, err = db.Bun().NewDelete().Table("workspaces").Where("id = ?", wresp.Workspace.Id).Exec(ctx)
require.NoError(t, err)
})
}
}
4 changes: 4 additions & 0 deletions master/internal/db/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ const (
// insert/update violates a foreign key constraint. Obtained from:
// https://www.postgresql.org/docs/10/errcodes-appendix.html
CodeForeignKeyViolation = "23503"
// CodeSerializationFailure is the error code that Postgres uses to indicate that a transaction
// failed due to a serialization failure. Obtained from:
// https://www.postgresql.org/docs/10/errcodes-appendix.html
CodeSerializationFailure = "40001"
)

// Close closes the underlying pq connection.
Expand Down
76 changes: 70 additions & 6 deletions master/internal/project/postgres_project.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,25 @@ import (
"context"
"database/sql"
"fmt"
"strings"

"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/uptrace/bun"

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/workspace"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/random"
)

const (
// MaxProjectKeyLength is the maximum length of a project key.
MaxProjectKeyLength = 5
// MaxProjectKeyPrefixLength is the maximum length of a project key prefix.
MaxProjectKeyPrefixLength = 3
// MaxRetries is the maximum number of retries for transaction conflicts.
MaxRetries = 5
)

// ProjectByName returns a project's ID if it exists in the given workspace and is not archived.
Expand Down Expand Up @@ -52,11 +68,59 @@ func ProjectIDByName(ctx context.Context, workspaceID int, projectName string) (
}

// GenerateProjectKey generates a unique project key for a project based on its name.
func GenerateProjectKey(ctx context.Context, projectName string) (string, error) {
generatedKey := ""
err := db.Bun().NewRaw("SELECT function_generate_project_key(?)", projectName).Scan(ctx, &generatedKey)
if err != nil {
return "", err
func generateProjectKey(ctx context.Context, tx bun.Tx, projectName string) (string, error) {
var key string
found := true
for i := 0; i < MaxRetries && found; i++ {
prefixLength := min(len(projectName), MaxProjectKeyPrefixLength)
prefix := projectName[:prefixLength]
suffix := random.String(MaxProjectKeyLength - prefixLength)
key = strings.ToUpper(prefix + suffix)
err := tx.NewSelect().Model(&model.Project{}).Where("key = ?", key).For("UPDATE").Scan(ctx)
found = err == nil
}
if found {
return "", fmt.Errorf("could not generate a unique project key")
}
return key, nil
}

// InsertProject inserts a new project into the database.
func InsertProject(
ctx context.Context,
p *model.Project,
requestedKey *string,
) (err error) {
RetryLoop:
for i := 0; i < MaxRetries; i++ {
err = db.Bun().RunInTx(ctx, &sql.TxOptions{Isolation: sql.LevelRepeatableRead},
func(ctx context.Context, tx bun.Tx) error {
var err error
if requestedKey == nil {
p.Key, err = generateProjectKey(ctx, tx, p.Name)
if err != nil {
return err
}
} else {
p.Key = *requestedKey
}
_, err = tx.NewInsert().Model(p).Exec(ctx)
if err != nil {
return err
}
return nil
},
)

switch {
case err == nil:
break RetryLoop
case requestedKey == nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint"):
log.Debugf("retrying project (%s) insertion due to generated key conflict (%s)", p.Name, p.Key)
continue // retry
default:
break RetryLoop
}
}
return generatedKey, nil
return errors.Wrapf(err, "error inserting project %s into database", p.Name)
}
14 changes: 8 additions & 6 deletions master/pkg/model/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ type Project struct {
CreatedAt time.Time `bun:"created_at,scanonly"`
Archived bool `bun:"archived"`
WorkspaceID int `bun:"workspace_id"`
WorkspaceName string `bun:"workspace_name"`
WorkspaceName string `bun:"workspace_name,scanonly"`
UserID int `bun:"user_id"`
Username string `bun:"username"`
Username string `bun:"username,scanonly"`
Immutable bool `bun:"immutable"`
Description string `bun:"description"`
Notes []*projectv1.Note `bun:"notes,type:jsonb"`
NumActiveExperiments int32 `bun:"num_active_experiments"`
NumExperiments int32 `bun:"num_experiments"`
State WorkspaceState `bun:"state"`
NumActiveExperiments int32 `bun:"num_active_experiments,scanonly"`
NumExperiments int32 `bun:"num_experiments,scanonly"`
State WorkspaceState `bun:"state,default:'UNSPECIFIED'::workspace_state"`
ErrorMessage string `bun:"error_message"`
LastExperimentStartedAt time.Time `bun:"last_experiment_started_at"`
LastExperimentStartedAt time.Time `bun:"last_experiment_started_at,scanonly"`
Key string `bun:"key"`
}

// Projects is an array of project instances.
Expand Down Expand Up @@ -55,6 +56,7 @@ func (p Project) Proto() *projectv1.Project {
NumActiveExperiments: p.NumActiveExperiments,
Notes: p.Notes,
LastExperimentStartedAt: lastExperimentStartedAt,
Key: p.Key,
}
}

Expand Down
31 changes: 31 additions & 0 deletions master/pkg/random/string.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package random

import (
"crypto/rand"
)

const (
// DefaultChars is the default character set used for generating random strings.
DefaultChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)

// String generates a random string of length n, using the characters in the charset string
// if provided, or using the default charset if not.
func String(n int, charset ...string) string {
var chars string
if len(charset) == 0 {
chars = DefaultChars
} else {
chars = charset[0]
}

bytes := make([]byte, n)
_, err := rand.Read(bytes)
if err != nil {
panic(err)
}
for i, b := range bytes {
bytes[i] = chars[b%byte(len(chars))]
}
return string(bytes)
}

This file was deleted.

This file was deleted.

Loading

0 comments on commit a165316

Please sign in to comment.