Skip to content

Commit

Permalink
Ent: IngestArtifacts optimized using concurrently (#1596)
Browse files Browse the repository at this point in the history
Signed-off-by: mrizzi <mrizzi@redhat.com>
  • Loading branch information
mrizzi committed Dec 21, 2023
1 parent a599888 commit 7a05b7e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 54 deletions.
94 changes: 40 additions & 54 deletions pkg/assembler/backends/ent/backend/artifact.go
Expand Up @@ -17,6 +17,7 @@ package backend

import (
"context"
stdsql "database/sql"
"strconv"
"strings"

Expand All @@ -25,7 +26,9 @@ import (
"github.com/guacsec/guac/pkg/assembler/backends/ent/artifact"
"github.com/guacsec/guac/pkg/assembler/backends/ent/predicate"
"github.com/guacsec/guac/pkg/assembler/graphql/model"
"github.com/pkg/errors"
"github.com/vektah/gqlparser/v2/gqlerror"
"golang.org/x/sync/errgroup"
)

func (b *EntBackend) Artifacts(ctx context.Context, artifactSpec *model.ArtifactSpec) ([]*model.Artifact, error) {
Expand Down Expand Up @@ -65,72 +68,55 @@ func toLowerPtr(s *string) *string {

func (b *EntBackend) IngestArtifacts(ctx context.Context, artifacts []*model.ArtifactInputSpec) ([]string, error) {
funcName := "IngestArtifacts"
records, err := WithinTX(ctx, b.client, func(ctx context.Context) (*[]string, error) {
client := ent.TxFromContext(ctx)
slc, err := ingestArtifacts(ctx, client, artifacts)
if err != nil {
return nil, err
}

return slc, nil
})

if err != nil {
artsID := make([]string, len(artifacts))
eg, ctx := errgroup.WithContext(ctx)
for i := range artifacts {
index := i
art := artifacts[index]
concurrently(eg, func() error {
a, err := b.IngestArtifact(ctx, art)
if err == nil {
artsID[index] = a
}
return err
})
}
if err := eg.Wait(); err != nil {
return nil, gqlerror.Errorf("%v :: %s", funcName, err)
}
return *records, nil
return artsID, nil
}

func (b *EntBackend) IngestArtifact(ctx context.Context, art *model.ArtifactInputSpec) (string, error) {
records, err := b.IngestArtifacts(ctx, []*model.ArtifactInputSpec{art})
id, err := WithinTX(ctx, b.client, func(ctx context.Context) (*int, error) {
client := ent.TxFromContext(ctx)
return upsertArtifact(ctx, client, art)
})
if err != nil {
return "", err
}

if len(records) == 0 {
return "", Errorf("no records returned")
}

return records[0], nil
return strconv.Itoa(*id), nil
}

func ingestArtifacts(ctx context.Context, client *ent.Tx, artifacts []*model.ArtifactInputSpec) (*[]string, error) {
batches := chunk(artifacts, 100)
ids := make([]int, 0)

for _, artifacts := range batches {
creates := make([]*ent.ArtifactCreate, len(artifacts))
predicates := make([]predicate.Artifact, len(artifacts))
for i, art := range artifacts {
creates[i] = client.Artifact.Create().
SetAlgorithm(strings.ToLower(art.Algorithm)).
SetDigest(strings.ToLower(art.Digest))
}

err := client.Artifact.CreateBulk(creates...).
OnConflict(
sql.ConflictColumns(artifact.FieldDigest),
).
UpdateNewValues().
Exec(ctx)
if err != nil {
return nil, err
}

for i, art := range artifacts {
predicates[i] = artifactQueryInputPredicates(*art)
func upsertArtifact(ctx context.Context, client *ent.Tx, art *model.ArtifactInputSpec) (*int, error) {
id, err := client.Artifact.Create().
SetAlgorithm(strings.ToLower(art.Algorithm)).
SetDigest(strings.ToLower(art.Digest)).
OnConflict(
sql.ConflictColumns(artifact.FieldDigest),
).
DoNothing().
ID(ctx)
if err != nil {
if err != stdsql.ErrNoRows {
return nil, errors.Wrap(err, "upsert artifact")
}

newRecords, err := client.Artifact.Query().Where(artifact.Or(predicates...)).IDs(ctx)
id, err = client.Artifact.Query().
Where(artifactQueryInputPredicates(*art)).
OnlyID(ctx)
if err != nil {
return nil, err
return nil, errors.Wrap(err, "get artifact")
}

ids = append(ids, newRecords...)
}
result := make([]string, len(ids))
for i := range ids {
result[i] = strconv.Itoa(ids[i])
}
return &result, nil
return &id, nil
}
3 changes: 3 additions & 0 deletions pkg/assembler/backends/ent/backend/artifact_test.go
Expand Up @@ -18,6 +18,8 @@
package backend

import (
"sort"

"github.com/google/go-cmp/cmp"
"github.com/guacsec/guac/pkg/assembler/graphql/model"
)
Expand Down Expand Up @@ -54,6 +56,7 @@ func (s *Suite) Test_IngestArtifacts() {
s.T().Errorf("demoClient.IngestArtifacts() error = %v, wantErr %v", err, tt.wantErr)
return
}
sort.Strings(got)
if diff := cmp.Diff(tt.want, got, ignoreID); diff != "" {
s.T().Errorf("Unexpected results. (-want +got):\n%s", diff)
}
Expand Down

0 comments on commit 7a05b7e

Please sign in to comment.