Skip to content

Commit

Permalink
fix: add checkpoint when enqueue scan tasks for scan all (#18680)
Browse files Browse the repository at this point in the history
Fix the scanAll cannot be stopped in case of large number of artifacts,
add the checkpoint before submit scan tasks, mark the scanAll stopped
flag in the redis.

Fixes: #18044

Signed-off-by: chlins <chenyuzh@vmware.com>
  • Loading branch information
chlins committed Jun 5, 2023
1 parent 9d28d1f commit fbeeaa7
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 19 deletions.
8 changes: 7 additions & 1 deletion src/controller/artifact/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ func Iterator(ctx context.Context, chunkSize int, query *q.Query, option *Option
}

for _, artifact := range artifacts {
ch <- artifact
select {
case <-ctx.Done():
log.G(ctx).Errorf("context done, list artifacts exited, error: %v", ctx.Err())
return
case ch <- artifact:
continue
}
}

if len(artifacts) < chunkSize {
Expand Down
63 changes: 61 additions & 2 deletions src/controller/scan/base_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"reflect"
"strings"
"sync"
"time"

"github.com/google/uuid"

Expand All @@ -30,6 +31,7 @@ import (
sc "github.com/goharbor/harbor/src/controller/scanner"
"github.com/goharbor/harbor/src/controller/tag"
"github.com/goharbor/harbor/src/jobservice/job"
"github.com/goharbor/harbor/src/lib/cache"
"github.com/goharbor/harbor/src/lib/config"
"github.com/goharbor/harbor/src/lib/errors"
"github.com/goharbor/harbor/src/lib/log"
Expand All @@ -50,8 +52,12 @@ import (
"github.com/goharbor/harbor/src/pkg/task"
)

// DefaultController is a default singleton scan API controller.
var DefaultController = NewController()
var (
// DefaultController is a default singleton scan API controller.
DefaultController = NewController()

errScanAllStopped = errors.New("scanAll stopped")
)

// const definitions
const (
Expand All @@ -74,6 +80,9 @@ type uuidGenerator func() (string, error)
// utility methods.
type configGetter func(cfg string) (string, error)

// cacheGetter returns cache
type cacheGetter func() cache.Cache

// launchScanJobParam is a param to launch scan job.
type launchScanJobParam struct {
ExecutionID int64
Expand Down Expand Up @@ -109,6 +118,8 @@ type basicController struct {
taskMgr task.Manager
// Converter for V1 report to V2 report
reportConverter postprocessors.NativeScanReportConverter
// cache stores the stop scan all marks
cache cacheGetter
}

// NewController news a scan API controller
Expand Down Expand Up @@ -154,6 +165,9 @@ func NewController() Controller {
taskMgr: task.Mgr,
// Get the scan V1 to V2 report converters
reportConverter: postprocessors.Converter,
cache: func() cache.Cache {
return cache.Default()
},
}
}

Expand Down Expand Up @@ -368,6 +382,44 @@ func (bc *basicController) ScanAll(ctx context.Context, trigger string, async bo
return executionID, nil
}

func (bc *basicController) StopScanAll(ctx context.Context, executionID int64, async bool) error {
stopScanAll := func(ctx context.Context, executionID int64) error {
// mark scan all stopped
if err := bc.markScanAllStopped(ctx, executionID); err != nil {
return err
}
// stop the execution and sub tasks
return bc.execMgr.Stop(ctx, executionID)
}

if async {
go func() {
if err := stopScanAll(ctx, executionID); err != nil {
log.Errorf("failed to stop scan all, error: %v", err)
}
}()
return nil
}

return stopScanAll(ctx, executionID)
}

func scanAllStoppedKey(execID int64) string {
return fmt.Sprintf("scan_all:execution_id:%d:stopped", execID)
}

func (bc *basicController) markScanAllStopped(ctx context.Context, execID int64) error {
// set the expire time to 2 hours, the duration should be large enough
// for controller to capture the stop flag, leverage the key recycled
// by redis TTL, no need to clean by scan controller as the new scan all
// will have a new unique execution id, the old key has no effects to anything.
return bc.cache().Save(ctx, scanAllStoppedKey(execID), "", 2*time.Hour)
}

func (bc *basicController) isScanAllStopped(ctx context.Context, execID int64) bool {
return bc.cache().Contains(ctx, scanAllStoppedKey(execID))
}

func (bc *basicController) startScanAll(ctx context.Context, executionID int64) error {
batchSize := 50

Expand All @@ -379,8 +431,15 @@ func (bc *basicController) startScanAll(ctx context.Context, executionID int64)
UnsupportCount int `json:"unsupport_count"`
UnknowCount int `json:"unknow_count"`
}{}
// with cancel function to signal downstream worker
ctx, cancel := context.WithCancel(ctx)
defer cancel()

for artifact := range ar.Iterator(ctx, batchSize, nil, nil) {
if bc.isScanAllStopped(ctx, executionID) {
return errScanAllStopped
}

summary.TotalCount++

scan := func(ctx context.Context) error {
Expand Down
37 changes: 29 additions & 8 deletions src/controller/scan/base_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/goharbor/harbor/src/common/rbac"
"github.com/goharbor/harbor/src/controller/artifact"
"github.com/goharbor/harbor/src/controller/robot"
"github.com/goharbor/harbor/src/lib/cache"
"github.com/goharbor/harbor/src/lib/config"
"github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/lib/q"
Expand All @@ -49,6 +50,7 @@ import (
robottesting "github.com/goharbor/harbor/src/testing/controller/robot"
scannertesting "github.com/goharbor/harbor/src/testing/controller/scanner"
tagtesting "github.com/goharbor/harbor/src/testing/controller/tag"
mockcache "github.com/goharbor/harbor/src/testing/lib/cache"
ormtesting "github.com/goharbor/harbor/src/testing/lib/orm"
"github.com/goharbor/harbor/src/testing/mock"
accessorytesting "github.com/goharbor/harbor/src/testing/pkg/accessory"
Expand Down Expand Up @@ -77,6 +79,7 @@ type ControllerTestSuite struct {
ar artifact.Controller
c Controller
reportConverter *postprocessorstesting.ScanReportV1ToV2Converter
cache *mockcache.Cache
}

// TestController is the entry point of ControllerTestSuite.
Expand Down Expand Up @@ -271,6 +274,8 @@ func (suite *ControllerTestSuite) SetupSuite() {

suite.taskMgr = &tasktesting.Manager{}

suite.cache = &mockcache.Cache{}

suite.c = &basicController{
manager: mgr,
ar: suite.ar,
Expand Down Expand Up @@ -298,6 +303,7 @@ func (suite *ControllerTestSuite) SetupSuite() {
execMgr: suite.execMgr,
taskMgr: suite.taskMgr,
reportConverter: &postprocessorstesting.ScanReportV1ToV2Converter{},
cache: func() cache.Cache { return suite.cache },
}
}

Expand Down Expand Up @@ -522,25 +528,25 @@ func (suite *ControllerTestSuite) TestScanControllerGetMultiScanLog() {
func (suite *ControllerTestSuite) TestScanAll() {
{
// no artifacts found when scan all
ctx := context.TODO()

executionID := int64(1)

suite.execMgr.On(
"Create", ctx, "SCAN_ALL", int64(0), "SCHEDULE",
"Create", mock.Anything, "SCAN_ALL", int64(0), "SCHEDULE",
).Return(executionID, nil).Once()

mock.OnAnything(suite.accessoryMgr, "List").Return([]accessoryModel.Accessory{}, nil).Once()

mock.OnAnything(suite.artifactCtl, "List").Return([]*artifact.Artifact{}, nil).Once()

suite.taskMgr.On("Count", ctx, q.New(q.KeyWords{"execution_id": executionID})).Return(int64(0), nil).Once()
suite.taskMgr.On("Count", mock.Anything, q.New(q.KeyWords{"execution_id": executionID})).Return(int64(0), nil).Once()

mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once()

suite.execMgr.On("MarkDone", ctx, executionID, mock.Anything).Return(nil).Once()
suite.execMgr.On("MarkDone", mock.Anything, executionID, mock.Anything).Return(nil).Once()

_, err := suite.c.ScanAll(ctx, "SCHEDULE", false)
suite.cache.On("Contains", mock.Anything, scanAllStoppedKey(1)).Return(false).Once()

_, err := suite.c.ScanAll(context.TODO(), "SCHEDULE", false)
suite.NoError(err)
}

Expand All @@ -551,7 +557,7 @@ func (suite *ControllerTestSuite) TestScanAll() {
executionID := int64(1)

suite.execMgr.On(
"Create", ctx, "SCAN_ALL", int64(0), "SCHEDULE",
"Create", mock.Anything, "SCAN_ALL", int64(0), "SCHEDULE",
).Return(executionID, nil).Once()

mock.OnAnything(suite.accessoryMgr, "List").Return([]accessoryModel.Accessory{}, nil).Once()
Expand All @@ -568,13 +574,28 @@ func (suite *ControllerTestSuite) TestScanAll() {
mock.OnAnything(suite.reportMgr, "Create").Return("uuid", nil).Once()
mock.OnAnything(suite.taskMgr, "Create").Return(int64(0), fmt.Errorf("failed")).Once()
mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once()
suite.execMgr.On("MarkError", ctx, executionID, mock.Anything).Return(nil).Once()
suite.execMgr.On("MarkError", mock.Anything, executionID, mock.Anything).Return(nil).Once()

_, err := suite.c.ScanAll(ctx, "SCHEDULE", false)
suite.NoError(err)
}
}

func (suite *ControllerTestSuite) TestStopScanAll() {
mockExecID := int64(100)
// mock error case
mockErr := fmt.Errorf("stop scan all error")
suite.cache.On("Save", mock.Anything, scanAllStoppedKey(mockExecID), mock.Anything, mock.Anything).Return(mockErr).Once()
err := suite.c.StopScanAll(context.TODO(), mockExecID, false)
suite.EqualError(err, mockErr.Error())

// mock normal case
suite.cache.On("Save", mock.Anything, scanAllStoppedKey(mockExecID), mock.Anything, mock.Anything).Return(nil).Once()
suite.execMgr.On("Stop", mock.Anything, mockExecID).Return(nil).Once()
err = suite.c.StopScanAll(context.TODO(), mockExecID, false)
suite.NoError(err)
}

func (suite *ControllerTestSuite) TestDeleteReports() {
suite.reportMgr.On("DeleteByDigests", context.TODO(), "digest").Return(nil).Once()

Expand Down
2 changes: 1 addition & 1 deletion src/controller/scan/callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (suite *CallbackTestSuite) TestScanAllCallback() {

mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once()

suite.execMgr.On("MarkDone", context.TODO(), executionID, mock.Anything).Return(nil).Once()
suite.execMgr.On("MarkDone", mock.Anything, executionID, mock.Anything).Return(nil).Once()

suite.NoError(scanAllCallback(context.TODO(), ""))
}
Expand Down
10 changes: 10 additions & 0 deletions src/controller/scan/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,16 @@ type Controller interface {
// error : non nil error if any errors occurred
ScanAll(ctx context.Context, trigger string, async bool) (int64, error)

// StopScanAll stops the scanAll
//
// Arguments:
// ctx context.Context : the context for this method
// executionID int64 : the id of scan all execution
// async bool : stop scan all in background
// Returns:
// error : non nil error if any errors occurred
StopScanAll(ctx context.Context, executionID int64, async bool) error

// GetVulnerable returns the vulnerable of the artifact for the allowlist
//
// Arguments:
Expand Down
11 changes: 4 additions & 7 deletions src/server/v2.0/handler/scan_all.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/goharbor/harbor/src/controller/scanner"
"github.com/goharbor/harbor/src/jobservice/job"
"github.com/goharbor/harbor/src/lib/errors"
"github.com/goharbor/harbor/src/lib/log"
"github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/lib/q"
"github.com/goharbor/harbor/src/pkg/scheduler"
Expand Down Expand Up @@ -74,12 +73,10 @@ func (s *scanAllAPI) StopScanAll(ctx context.Context, params operation.StopScanA
if execution == nil {
return s.SendError(ctx, errors.BadRequestError(nil).WithMessage("no scan all job is found currently"))
}
go func(ctx context.Context, eid int64) {
err := s.execMgr.Stop(ctx, eid)
if err != nil {
log.Errorf("failed to stop the execution of executionID=%+v", execution.ID)
}
}(s.makeCtx(), execution.ID)

if err = s.scanCtl.StopScanAll(s.makeCtx(), execution.ID, true); err != nil {
return s.SendError(ctx, err)
}

return operation.NewStopScanAllAccepted()
}
Expand Down
1 change: 1 addition & 0 deletions src/server/v2.0/handler/scan_all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ func (suite *ScanAllTestSuite) TestStopScanAll() {
times := 3
suite.Security.On("IsAuthenticated").Return(true).Times(times)
suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Times(times)
mock.OnAnything(suite.scanCtl, "StopScanAll").Return(nil).Times(times)
mock.OnAnything(suite.scannerCtl, "ListRegistrations").Return([]*scanner.Registration{{ID: int64(1)}}, nil).Times(times)

{
Expand Down
14 changes: 14 additions & 0 deletions src/testing/controller/scan/controller.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit fbeeaa7

Please sign in to comment.