diff --git a/pkg/assembler/backends/ent/backend/certifyVuln_test.go b/pkg/assembler/backends/ent/backend/certifyVuln_test.go index c4398e8446..d493047a13 100644 --- a/pkg/assembler/backends/ent/backend/certifyVuln_test.go +++ b/pkg/assembler/backends/ent/backend/certifyVuln_test.go @@ -1023,9 +1023,6 @@ func (s *Suite) TestIngestCertifyVulns() { }, }, } - ignoreID := cmp.FilterPath(func(p cmp.Path) bool { - return strings.Compare(".ID", p[len(p)-1].String()) == 0 - }, cmp.Ignore()) ctx := context.Background() for _, test := range tests { s.Run(test.Name, func() { @@ -1072,7 +1069,7 @@ func (s *Suite) TestIngestCertifyVulns() { if err != nil { return } - if diff := cmp.Diff(test.ExpVuln, got, ignoreID); diff != "" { + if diff := cmp.Diff(test.ExpVuln, got, IngestPredicatesCmpOpts...); diff != "" { t.Errorf("Unexpected results. (-want +got):\n%s", diff) } }) diff --git a/pkg/assembler/backends/ent/backend/concurrently.go b/pkg/assembler/backends/ent/backend/concurrently.go new file mode 100644 index 0000000000..608f1d12d5 --- /dev/null +++ b/pkg/assembler/backends/ent/backend/concurrently.go @@ -0,0 +1,55 @@ +// +// Copyright 2023 The GUAC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backend + +import ( + "context" + "os" + "strconv" + + "github.com/guacsec/guac/pkg/logging" + "golang.org/x/sync/errgroup" +) + +var concurrent chan struct{} + +const MaxConcurrentBulkIngestionString string = "MAX_CONCURRENT_BULK_INGESTION" +const defaultMaxConcurrentBulkIngestion int = 50 + +func init() { + logger := logging.FromContext(context.Background()) + size := defaultMaxConcurrentBulkIngestion + maxConcurrentBulkIngestionEnv, found := os.LookupEnv(MaxConcurrentBulkIngestionString) + if found { + maxConcurrentBulkIngestion, err := strconv.Atoi(maxConcurrentBulkIngestionEnv) + if err != nil { + logger.Warnf("failed to convert %v value %v to integer. Default value %v will be applied", MaxConcurrentBulkIngestionString, maxConcurrentBulkIngestionEnv, defaultMaxConcurrentBulkIngestion) + size = defaultMaxConcurrentBulkIngestion + } else { + size = maxConcurrentBulkIngestion + } + } + concurrent = make(chan struct{}, size) +} + +func concurrently(eg *errgroup.Group, fn func() error) { + eg.Go(func() error { + concurrent <- struct{}{} + err := fn() + <-concurrent + return err + }) +} diff --git a/pkg/assembler/backends/ent/backend/dependency.go b/pkg/assembler/backends/ent/backend/dependency.go index 27c5d0b871..a613a7587f 100644 --- a/pkg/assembler/backends/ent/backend/dependency.go +++ b/pkg/assembler/backends/ent/backend/dependency.go @@ -23,6 +23,7 @@ import ( "github.com/guacsec/guac/pkg/assembler/backends/ent/dependency" "github.com/guacsec/guac/pkg/assembler/graphql/model" "github.com/pkg/errors" + "golang.org/x/sync/errgroup" ) func (b *EntBackend) IsDependency(ctx context.Context, spec *model.IsDependencySpec) ([]*model.IsDependency, error) { @@ -74,13 +75,24 @@ func (b *EntBackend) IsDependency(ctx context.Context, spec *model.IsDependencyS func (b *EntBackend) IngestDependencies(ctx context.Context, pkgs []*model.PkgInputSpec, depPkgs []*model.PkgInputSpec, depPkgMatchType model.MatchFlags, dependencies []*model.IsDependencyInputSpec) ([]*model.IsDependency, error) { // TODO: This looks like a good candidate for using BulkCreate() - var modelIsDependencies []*model.IsDependency + var modelIsDependencies = make([]*model.IsDependency, len(dependencies)) + eg, ctx := errgroup.WithContext(ctx) for i := range dependencies { - isDependency, err := b.IngestDependency(ctx, *pkgs[i], *depPkgs[i], depPkgMatchType, *dependencies[i]) - if err != nil { - return nil, Errorf("IngestDependency failed with err: %v", err) - } - modelIsDependencies = append(modelIsDependencies, isDependency) + index := i + pkg := *pkgs[index] + depPkg := *depPkgs[index] + dpmt := depPkgMatchType + dep := *dependencies[index] + concurrently(eg, func() error { + p, err := b.IngestDependency(ctx, pkg, depPkg, dpmt, dep) + if err == nil { + modelIsDependencies[index] = &model.IsDependency{ID: p.ID} + } + return err + }) + } + if err := eg.Wait(); err != nil { + return nil, err } return modelIsDependencies, nil } diff --git a/pkg/assembler/backends/ent/backend/dependency_test.go b/pkg/assembler/backends/ent/backend/dependency_test.go index 232a60ecdf..f3b346e780 100644 --- a/pkg/assembler/backends/ent/backend/dependency_test.go +++ b/pkg/assembler/backends/ent/backend/dependency_test.go @@ -726,9 +726,6 @@ func (s *Suite) TestIngestDependencies() { }, }, } - ignoreID := cmp.FilterPath(func(p cmp.Path) bool { - return strings.Compare(".ID", p[len(p)-1].String()) == 0 - }, cmp.Ignore()) ctx := s.Ctx for _, test := range tests { s.Run(test.Name, func() { @@ -760,7 +757,7 @@ func (s *Suite) TestIngestDependencies() { if err != nil { return } - if diff := cmp.Diff(test.ExpID, got, ignoreID); diff != "" { + if diff := cmp.Diff(test.ExpID, got, IngestPredicatesCmpOpts...); diff != "" { t.Errorf("Unexpected results. (-want +got):\n%s", diff) } }) diff --git a/pkg/assembler/backends/ent/backend/helpers_test.go b/pkg/assembler/backends/ent/backend/helpers_test.go index 34e4f1eb19..d2b74f6447 100644 --- a/pkg/assembler/backends/ent/backend/helpers_test.go +++ b/pkg/assembler/backends/ent/backend/helpers_test.go @@ -22,6 +22,9 @@ import ( "strings" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/guacsec/guac/pkg/assembler/graphql/model" + "github.com/guacsec/guac/pkg/assembler/helpers" ) func ptr[T any](s T) *T { @@ -39,3 +42,25 @@ var ignoreEmptySlices = cmp.FilterValues(func(x, y interface{}) bool { } return false }, cmp.Ignore()) + +var IngestPredicatesCmpOpts = []cmp.Option{ + ignoreID, + cmpopts.EquateEmpty(), + cmpopts.SortSlices(isDependencyLess), + cmpopts.SortSlices(packageLess), + cmpopts.SortSlices(certifyVulnLess), +} + +func isDependencyLess(e1, e2 *model.IsDependency) bool { + return packageLess(e1.Package, e2.Package) +} + +func packageLess(e1, e2 *model.Package) bool { + purl1 := helpers.PkgToPurl(e1.Type, e1.Namespaces[0].Namespace, e1.Namespaces[0].Names[0].Name, e1.Namespaces[0].Names[0].Versions[0].Version, e1.Namespaces[0].Names[0].Versions[0].Subpath, nil) + purl2 := helpers.PkgToPurl(e2.Type, e2.Namespaces[0].Namespace, e2.Namespaces[0].Names[0].Name, e2.Namespaces[0].Names[0].Versions[0].Version, e2.Namespaces[0].Names[0].Versions[0].Subpath, nil) + return purl1 < purl2 +} + +func certifyVulnLess(e1, e2 *model.CertifyVuln) bool { + return packageLess(e1.Package, e2.Package) +} diff --git a/pkg/assembler/backends/ent/backend/package.go b/pkg/assembler/backends/ent/backend/package.go index bdceb73f70..f80a60a620 100644 --- a/pkg/assembler/backends/ent/backend/package.go +++ b/pkg/assembler/backends/ent/backend/package.go @@ -34,6 +34,7 @@ import ( "github.com/guacsec/guac/pkg/assembler/backends/helper" "github.com/guacsec/guac/pkg/assembler/graphql/model" "github.com/pkg/errors" + "golang.org/x/sync/errgroup" ) func (b *EntBackend) Packages(ctx context.Context, pkgSpec *model.PkgSpec) ([]*model.Package, error) { @@ -103,12 +104,20 @@ func (b *EntBackend) Packages(ctx context.Context, pkgSpec *model.PkgSpec) ([]*m func (b *EntBackend) IngestPackages(ctx context.Context, pkgs []*model.PkgInputSpec) ([]*model.Package, error) { // FIXME: (ivanvanderbyl) This will be suboptimal because we can't batch insert relations with upserts. See Readme. models := make([]*model.Package, len(pkgs)) - for i, pkg := range pkgs { - p, err := b.IngestPackage(ctx, *pkg) - if err != nil { - return nil, err - } - models[i] = p + eg, ctx := errgroup.WithContext(ctx) + for i := range pkgs { + index := i + pkg := pkgs[index] + concurrently(eg, func() error { + p, err := b.IngestPackage(ctx, *pkg) + if err == nil { + models[index] = p + } + return err + }) + } + if err := eg.Wait(); err != nil { + return nil, err } return models, nil } diff --git a/pkg/assembler/backends/ent/backend/package_test.go b/pkg/assembler/backends/ent/backend/package_test.go index 86d7494384..9ed49e8fb6 100644 --- a/pkg/assembler/backends/ent/backend/package_test.go +++ b/pkg/assembler/backends/ent/backend/package_test.go @@ -227,7 +227,7 @@ func (s *Suite) Test_IngestPackages() { s.T().Errorf("demoClient.IngestPackages() error = %v, wantErr %v", err, tt.wantErr) return } - if diff := cmp.Diff(tt.want, got, ignoreID); diff != "" { + if diff := cmp.Diff(tt.want, got, IngestPredicatesCmpOpts...); diff != "" { s.T().Errorf("Unexpected results. (-want +got):\n%s", diff) } }) diff --git a/pkg/assembler/backends/ent/backend/pkgequal_test.go b/pkg/assembler/backends/ent/backend/pkgequal_test.go index 6202d1ad61..9094243d7c 100644 --- a/pkg/assembler/backends/ent/backend/pkgequal_test.go +++ b/pkg/assembler/backends/ent/backend/pkgequal_test.go @@ -742,7 +742,7 @@ func (s *Suite) TestIngestPkgEquals() { if err != nil { return } - if diff := cmp.Diff(test.ExpHE, got, ignoreID); diff != "" { + if diff := cmp.Diff(test.ExpHE, got, IngestPredicatesCmpOpts...); diff != "" { t.Errorf("Unexpected results. (-want +got):\n%s", diff) } }) diff --git a/pkg/assembler/backends/ent/testutils/suite.go b/pkg/assembler/backends/ent/testutils/suite.go index 4b55d4b2f2..06348eecf9 100644 --- a/pkg/assembler/backends/ent/testutils/suite.go +++ b/pkg/assembler/backends/ent/testutils/suite.go @@ -41,6 +41,10 @@ func init() { } txdb.Register("txdb", "postgres", db) + err := os.Setenv("MAX_CONCURRENT_BULK_INGESTION", "1") + if err != nil { + log.Fatal(err) + } } type Suite struct {