From 467e492c609b314126e880f26a0c116253d4a48c Mon Sep 17 00:00:00 2001 From: Liran BG Date: Mon, 25 Mar 2024 21:03:34 +0200 Subject: [PATCH] [API] Enhance log collector [1.6.x] (#5321) --- server/api/db/base.py | 1 + server/api/db/sqldb/db.py | 4 + server/api/main.py | 29 ++++- server/api/utils/clients/log_collector.py | 35 +++++- server/log-collector/cmd/logcollector/main.go | 4 +- server/log-collector/go.mod | 2 + server/log-collector/go.sum | 4 + server/log-collector/pkg/common/consts.go | 9 +- .../logcollector/logcollector_test.go | 98 +++++++++++++--- .../pkg/services/logcollector/server.go | 108 +++++++++++++++++- .../logcollector/test/logcollector_test.go | 3 +- .../pkg/services/logcollector/test/nop/nop.go | 15 +++ .../log-collector/proto/log_collector.proto | 9 ++ tests/api/test_collect_runs_logs.py | 36 ++++-- tests/api/utils/clients/test_log_collector.py | 55 ++++++--- 15 files changed, 360 insertions(+), 52 deletions(-) diff --git a/server/api/db/base.py b/server/api/db/base.py index af1d98dd805..2cae70f11dc 100644 --- a/server/api/db/base.py +++ b/server/api/db/base.py @@ -71,6 +71,7 @@ def list_distinct_runs_uids( only_uids: bool = False, last_update_time_from: datetime.datetime = None, states: List[str] = None, + specific_uids: List[str] = None, ): pass diff --git a/server/api/db/sqldb/db.py b/server/api/db/sqldb/db.py index 2b2e1c0b369..0817dbc5370 100644 --- a/server/api/db/sqldb/db.py +++ b/server/api/db/sqldb/db.py @@ -247,6 +247,7 @@ def list_distinct_runs_uids( only_uids=True, last_update_time_from: datetime = None, states: typing.List[str] = None, + specific_uids: List[str] = None, ) -> typing.Union[typing.List[str], RunList]: """ List all runs uids in the DB @@ -277,6 +278,9 @@ def list_distinct_runs_uids( if requested_logs_modes is not None: query = query.filter(Run.requested_logs.in_(requested_logs_modes)) + if specific_uids: + query = query.filter(Run.uid.in_(specific_uids)) + if not only_uids: # group_by allows us to have a row per uid with the whole record rather than just the uid (as distinct does) # note we cannot promise that the same row will be returned each time per uid as the order is not guaranteed diff --git a/server/api/main.py b/server/api/main.py index 95b15c6e826..797173c1deb 100644 --- a/server/api/main.py +++ b/server/api/main.py @@ -511,12 +511,36 @@ async def _start_periodic_stop_logs(): async def _verify_log_collection_stopped_on_startup(): """ - Pulls runs from DB that are in terminal state and have logs requested, and call stop logs for them. + First, list runs that are currently being collected in the log collector. + Second, query the DB for those runs that are also in terminal state and have logs requested. + Lastly, call stop logs for the runs that met all of the above conditions. This is done so that the log collector won't keep trying to collect logs for runs that are already in terminal state. """ + logger.debug("Listing runs currently being log collected") + log_collector_client = server.api.utils.clients.log_collector.LogCollectorClient() + run_uids_in_progress = [] + failed_listing = False + try: + runs_in_progress_response_stream = log_collector_client.list_runs_in_progress() + # collate the run uids from the response stream to a list + async for run_uids in runs_in_progress_response_stream: + run_uids_in_progress.extend(run_uids) + except Exception as exc: + failed_listing = True + logger.warning( + "Failed listing runs currently being log collected", + exc=err_to_str(exc), + traceback=traceback.format_exc(), + ) + + if len(run_uids_in_progress) == 0 and not failed_listing: + logger.debug("No runs currently being log collected") + return + logger.debug( - "Getting all runs which have reached terminal state and already have logs requested", + "Getting current log collected runs which have reached terminal state and already have logs requested", + run_uids_in_progress_count=len(run_uids_in_progress), ) db_session = await fastapi.concurrency.run_in_threadpool(create_session) try: @@ -531,6 +555,7 @@ async def _verify_log_collection_stopped_on_startup(): # usually it happens when run pods get preempted mlrun.runtimes.constants.RunStates.unknown, ], + specific_uids=run_uids_in_progress, ) if len(runs) > 0: diff --git a/server/api/utils/clients/log_collector.py b/server/api/utils/clients/log_collector.py index 4ee268fa96d..e92ebada5e7 100644 --- a/server/api/utils/clients/log_collector.py +++ b/server/api/utils/clients/log_collector.py @@ -313,7 +313,40 @@ async def delete_logs( if verbose: logger.warning(msg, error=response.errorMessage) - def _retryable_error(self, error_message, retryable_error_patterns) -> bool: + async def list_runs_in_progress( + self, + project: str = None, + verbose: bool = True, + raise_on_error: bool = True, + ) -> typing.AsyncIterable[str]: + """ + List runs in progress from the log collector service + :param project: A project name to filter the runs by. If not provided, all runs in progress will be listed + :param verbose: Whether to log errors + :param raise_on_error: Whether to raise an exception on error + :return: A list of run uids + """ + request = self._log_collector_pb2.ListRunsRequest( + project=project, + ) + + response_stream = self._call_stream("ListRunsInProgress", request) + try: + async for chunk in response_stream: + yield chunk.runUIDs + except Exception as exc: + msg = "Failed to list runs in progress" + if raise_on_error: + raise LogCollectorErrorCode.map_error_code_to_mlrun_error( + LogCollectorErrorCode.ErrCodeInternal.value, + mlrun.errors.err_to_str(exc), + msg, + ) + if verbose: + logger.warning(msg, error=mlrun.errors.err_to_str(exc)) + + @staticmethod + def _retryable_error(error_message, retryable_error_patterns) -> bool: """ Check if the error is retryable :param error_message: The error message diff --git a/server/log-collector/cmd/logcollector/main.go b/server/log-collector/cmd/logcollector/main.go index 7826f6b6245..2fd5384893a 100644 --- a/server/log-collector/cmd/logcollector/main.go +++ b/server/log-collector/cmd/logcollector/main.go @@ -45,6 +45,7 @@ func StartServer() error { getLogsBufferSizeBytes := flag.Int("get-logs-buffer-buffer-size-bytes", common.GetEnvOrDefaultInt("MLRUN_LOG_COLLECTOR__GET_LOGS_BUFFER_SIZE_BYTES", common.DefaultGetLogsBufferSize), "Size of buffers in the buffer pool for getting logs, in bytes (default: 3.75MB)") logTimeUpdateBytesInterval := flag.Int("log-time-update-bytes-interval", common.GetEnvOrDefaultInt("MLRUN_LOG_COLLECTOR__LOG_TIME_UPDATE_BYTES_INTERVAL", common.LogTimeUpdateBytesInterval), "Amount of logs to read between updates of the last log time in the 'in memory' state, in bytes (default: 4KB)") clusterizationRole := flag.String("clusterization-role", common.GetEnvOrDefaultString("MLRUN_HTTPDB__CLUSTERIZATION__ROLE", "chief"), "The role of the log collector in the cluster (chief, worker)") + listRunsChunkSize := flag.Int("list-runs-chunk-size", common.GetEnvOrDefaultInt("MLRUN_LOG_COLLECTOR__LIST_RUNS_CHUNK_SIZE", common.DefaultListRunsChunkSize), "The chunk size for listing runs in progress") // if namespace is not passed, it will be taken from env namespace := flag.String("namespace", "", "The namespace to collect logs from") @@ -80,7 +81,8 @@ func StartServer() error { *logCollectionBufferSizeBytes, *getLogsBufferSizeBytes, *logTimeUpdateBytesInterval, - *advancedLogLevel) + *advancedLogLevel, + *listRunsChunkSize) if err != nil { return errors.Wrap(err, "Failed to create log collector server") } diff --git a/server/log-collector/go.mod b/server/log-collector/go.mod index 6bf0293470d..dd5762b8605 100644 --- a/server/log-collector/go.mod +++ b/server/log-collector/go.mod @@ -8,6 +8,7 @@ require ( github.com/nuclio/errors v0.0.4 github.com/nuclio/logger v0.0.1 github.com/nuclio/loggerus v0.0.6 + github.com/samber/lo v1.39.0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 golang.org/x/sync v0.6.0 @@ -41,6 +42,7 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect golang.org/x/net v0.22.0 // indirect golang.org/x/oauth2 v0.18.0 // indirect golang.org/x/sys v0.18.0 // indirect diff --git a/server/log-collector/go.sum b/server/log-collector/go.sum index a2948ca6b94..9729fd468df 100644 --- a/server/log-collector/go.sum +++ b/server/log-collector/go.sum @@ -102,6 +102,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA= +github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.8.0/go.mod h1:4GuYW9TZmE769R5STWrRakJc4UqQ3+QQ95fyz7ENv1A= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -128,6 +130,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= diff --git a/server/log-collector/pkg/common/consts.go b/server/log-collector/pkg/common/consts.go index 987044e8b4a..a28cd173227 100644 --- a/server/log-collector/pkg/common/consts.go +++ b/server/log-collector/pkg/common/consts.go @@ -23,9 +23,6 @@ const ( ErrCodeBadRequest ) -const DefaultErrorStackDepth = 3 - -// Buffer sizes const ( // DefaultLogCollectionBufferSize is the default buffer size for collecting logs from pods DefaultLogCollectionBufferSize int = 10 * 1024 * 1024 // 10MB @@ -37,6 +34,12 @@ const ( // LogTimeUpdateBytesInterval is the bytes amount to read between updates of the // last log time in the in memory state LogTimeUpdateBytesInterval int = 4 * 1024 // 4KB + + // DefaultListRunsChunkSize is the default chunk size for listing runs + DefaultListRunsChunkSize int = 10 + + // DefaultErrorStackDepth is the default stack depth for errors + DefaultErrorStackDepth = 3 ) // Custom errors diff --git a/server/log-collector/pkg/services/logcollector/logcollector_test.go b/server/log-collector/pkg/services/logcollector/logcollector_test.go index de1415d1cf8..29b1db28faf 100644 --- a/server/log-collector/pkg/services/logcollector/logcollector_test.go +++ b/server/log-collector/pkg/services/logcollector/logcollector_test.go @@ -68,6 +68,7 @@ func (suite *LogCollectorTestSuite) SetupSuite() { bufferSizeBytes := 100 clusterizationRole := "chief" advancedLogLevel := 0 + listRunsChunkSize := 10 // create base dir suite.baseDir = path.Join(os.TempDir(), "/log_collector_test") @@ -88,7 +89,8 @@ func (suite *LogCollectorTestSuite) SetupSuite() { bufferSizeBytes, bufferSizeBytes, common.LogTimeUpdateBytesInterval, - advancedLogLevel) + advancedLogLevel, + listRunsChunkSize) suite.Require().NoError(err, "Failed to create log collector server") suite.logger.InfoWith("Setup complete") @@ -628,24 +630,7 @@ func (suite *LogCollectorTestSuite) TestStopLog() { projectNum := 2 var projectToRuns = map[string][]string{} - for i := 0; i < projectNum; i++ { - projectName := fmt.Sprintf("project-%d", i) - - // add log item to the server's states, so no error will be returned - for j := 0; j < logItemsNum; j++ { - runUID := uuid.New().String() - projectToRuns[projectName] = append(projectToRuns[projectName], runUID) - selector := fmt.Sprintf("run=%s", runUID) - - // Add state to the log collector's state manifest - err = suite.logCollectorServer.stateManifest.AddLogItem(suite.ctx, runUID, selector, projectName) - suite.Require().NoError(err, "Failed to add log item to the state manifest") - - // Add state to the log collector's current state - err = suite.logCollectorServer.currentState.AddLogItem(suite.ctx, runUID, selector, projectName) - suite.Require().NoError(err, "Failed to add log item to the current state") - } - } + suite.createLogItems(projectNum, logItemsNum, projectToRuns) // write state err = suite.logCollectorServer.stateManifest.WriteState(suite.logCollectorServer.stateManifest.GetState()) @@ -802,6 +787,81 @@ func (suite *LogCollectorTestSuite) TestGetLogFilePath() { suite.Require().Equal(runFilePath, logFilePath, "Expected log file path to be the same as the run file path") } +func (suite *LogCollectorTestSuite) TestListRunsInProgress() { + listRunsInProgress := func(request *log_collector.ListRunsRequest) []string { + nopStream := &nop.ListRunsResponseStreamNop{} + err := suite.logCollectorServer.ListRunsInProgress(request, nopStream) + suite.Require().NoError(err, "Failed to list runs in progress") + return nopStream.RunUIDs + + } + + verifyRuns := func(expectedRunUIDs []string, responseRunUIDs []string) { + suite.Require().Equal(len(expectedRunUIDs), len(responseRunUIDs)) + for _, runUID := range responseRunUIDs { + suite.Require().Contains(expectedRunUIDs, runUID, "Expected runUID to be in the expected list") + } + } + + // list runs without any runs in progress + runsInProgress := listRunsInProgress(&log_collector.ListRunsRequest{}) + suite.Require().Empty(runsInProgress, "Expected no runs in progress") + + // create log items in progress + projectNum := 5 + logItemsNum := 5 + var projectToRuns = map[string][]string{} + suite.createLogItems(projectNum, logItemsNum, projectToRuns) + defer func() { + // remove projects from state manifest and current state when test is done to avoid conflicts with other tests + for project := range projectToRuns { + err := suite.logCollectorServer.stateManifest.RemoveProject(project) + suite.Assert().NoError(err, "Failed to remove project from state manifest") + + err = suite.logCollectorServer.currentState.RemoveProject(project) + suite.Assert().NoError(err, "Failed to remove project from current state") + } + }() + + var expectedRunUIDs []string + for _, runs := range projectToRuns { + expectedRunUIDs = append(expectedRunUIDs, runs...) + } + + // list runs in progress for all projects + runsInProgress = listRunsInProgress(&log_collector.ListRunsRequest{}) + verifyRuns(expectedRunUIDs, runsInProgress) + + // list runs in progress for a specific project + projectName := "project-1" + expectedRunUIDs = projectToRuns[projectName] + runsInProgress = listRunsInProgress(&log_collector.ListRunsRequest{Project: projectName}) + verifyRuns(expectedRunUIDs, runsInProgress) +} + +// createLogItems creates `logItemsNum` log items for `projectNum` projects, and adds them to the server's states +func (suite *LogCollectorTestSuite) createLogItems(projectNum int, logItemsNum int, projectToRuns map[string][]string) { + var err error + for i := 0; i < projectNum; i++ { + projectName := fmt.Sprintf("project-%d", i) + + // add log item to the server's states, so no error will be returned + for j := 0; j < logItemsNum; j++ { + runUID := uuid.New().String() + projectToRuns[projectName] = append(projectToRuns[projectName], runUID) + selector := fmt.Sprintf("run=%s", runUID) + + // Add state to the log collector's state manifest + err = suite.logCollectorServer.stateManifest.AddLogItem(suite.ctx, runUID, selector, projectName) + suite.Require().NoError(err, "Failed to add log item to the state manifest") + + // Add state to the log collector's current state + err = suite.logCollectorServer.currentState.AddLogItem(suite.ctx, runUID, selector, projectName) + suite.Require().NoError(err, "Failed to add log item to the current state") + } + } +} + func TestLogCollectorTestSuite(t *testing.T) { suite.Run(t, new(LogCollectorTestSuite)) } diff --git a/server/log-collector/pkg/services/logcollector/server.go b/server/log-collector/pkg/services/logcollector/server.go index 376012901e7..50065b4cf47 100644 --- a/server/log-collector/pkg/services/logcollector/server.go +++ b/server/log-collector/pkg/services/logcollector/server.go @@ -37,6 +37,7 @@ import ( "github.com/nuclio/errors" "github.com/nuclio/logger" + "github.com/samber/lo" "golang.org/x/sync/errgroup" "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -71,6 +72,8 @@ type Server struct { // interval durations readLogWaitTime time.Duration monitoringInterval time.Duration + + listRunsChunkSize int } // NewLogCollectorServer creates a new log collector server @@ -87,7 +90,8 @@ func NewLogCollectorServer(logger logger.Logger, logCollectionBufferSizeBytes, getLogsBufferSizeBytes, logTimeUpdateBytesInterval, - advancedLogLevel int) (*Server, error) { + advancedLogLevel, + listRunsChunkSize int) (*Server, error) { abstractServer, err := framework.NewAbstractMlrunGRPCServer(logger, nil) if err != nil { return nil, errors.Wrap(err, "Failed to create abstract server") @@ -145,6 +149,10 @@ func NewLogCollectorServer(logger logger.Logger, logCollectionBufferPool := bufferpool.NewSizedBytePool(logCollectionBufferPoolSize, logCollectionBufferSizeBytes) getLogsBufferPool := bufferpool.NewSizedBytePool(getLogsBufferPoolSize, getLogsBufferSizeBytes) + if listRunsChunkSize <= 0 { + listRunsChunkSize = common.DefaultListRunsChunkSize + } + return &Server{ AbstractMlrunGRPCServer: abstractServer, namespace: namespace, @@ -163,6 +171,7 @@ func NewLogCollectorServer(logger logger.Logger, startLogsFindingPodsInterval: 3 * time.Second, startLogsFindingPodsTimeout: 15 * time.Second, advancedLogLevel: advancedLogLevel, + listRunsChunkSize: listRunsChunkSize, }, nil } @@ -634,6 +643,82 @@ func (s *Server) DeleteLogs(ctx context.Context, request *protologcollector.Stop return s.successfulBaseResponse(), nil } +// ListRunsInProgress returns a list of runs that are currently being collected +func (s *Server) ListRunsInProgress(request *protologcollector.ListRunsRequest, responseStream protologcollector.LogCollector_ListRunsInProgressServer) error { + ctx := responseStream.Context() + + s.Logger.DebugWithCtx(ctx, + "Received list runs in progress request", + "project", request.Project) + + // get all runs in progress from the state manifest + logItemsInProgress, err := s.stateManifest.GetItemsInProgress() + if err != nil { + message := "Failed to list runs in progress from state manifest" + s.Logger.ErrorWithCtx(ctx, message) + return errors.Wrap(err, message) + } + + runsInProgress, err := s.getRunUIDsInProgress(ctx, logItemsInProgress, request.Project) + if err != nil { + message := "Failed to list runs in progress" + s.Logger.ErrorWithCtx(ctx, message) + return errors.Wrap(err, message) + + } + + // get all runs in progress from the current state, and add merge them with the runs from the state manifest + // this can only happen if some voodoo occurred after the server restarted + logItemsInProgressCurrentState, err := s.currentState.GetItemsInProgress() + if err != nil { + message := "Failed to get ms in progress from current state" + s.Logger.ErrorWithCtx(ctx, message) + return errors.Wrap(err, message) + } + + runsInProgressCurrentState, err := s.getRunUIDsInProgress(ctx, logItemsInProgressCurrentState, request.Project) + if err != nil { + message := "Failed to list runs in progress from current state" + s.Logger.ErrorWithCtx(ctx, message) + return errors.Wrap(err, message) + + } + + // merge the two maps + for _, runUID := range runsInProgressCurrentState { + if !lo.Contains[string](runsInProgress, runUID) { + runsInProgress = append(runsInProgress, runUID) + } + } + + // send empty response if no runs are in progress + if len(runsInProgress) == 0 { + s.Logger.DebugWithCtx(ctx, "No runs in progress to list") + if err := responseStream.Send(&protologcollector.ListRunsResponse{ + RunUIDs: []string{}, + }); err != nil { + return errors.Wrapf(err, "Failed to send empty response to stream") + } + return nil + } + + // send each run in progress to the stream in chunks of 10 due to gRPC message size limit + for i := 0; i < len(runsInProgress); i += s.listRunsChunkSize { + endIndex := i + s.listRunsChunkSize + if endIndex > len(runsInProgress) { + endIndex = len(runsInProgress) + } + + if err := responseStream.Send(&protologcollector.ListRunsResponse{ + RunUIDs: runsInProgress[i:endIndex], + }); err != nil { + return errors.Wrapf(err, "Failed to send runs in progress to stream") + } + } + + return nil +} + // startLogStreaming streams logs from a pod and writes them into a file func (s *Server) startLogStreaming(ctx context.Context, runUID, @@ -1212,3 +1297,24 @@ func (s *Server) deleteProjectLogs(project string) error { } return nil } + +func (s *Server) getRunUIDsInProgress(ctx context.Context, inProgressMap *sync.Map, project string) ([]string, error) { + var runUIDs []string + + inProgressMap.Range(func(projectKey, runUIDsToLogItemsValue interface{}) bool { + // if a project was provided, only return runUIDs for that project + if project != "" && project != projectKey { + return true + } + + runUIDsToLogItems := runUIDsToLogItemsValue.(*sync.Map) + runUIDsToLogItems.Range(func(key, value interface{}) bool { + runUID := key.(string) + runUIDs = append(runUIDs, runUID) + return true + }) + return true + }) + + return runUIDs, nil +} diff --git a/server/log-collector/pkg/services/logcollector/test/logcollector_test.go b/server/log-collector/pkg/services/logcollector/test/logcollector_test.go index 73fb25ab57f..ea2fb4b29ff 100644 --- a/server/log-collector/pkg/services/logcollector/test/logcollector_test.go +++ b/server/log-collector/pkg/services/logcollector/test/logcollector_test.go @@ -98,7 +98,8 @@ func (suite *LogCollectorTestSuite) SetupSuite() { suite.bufferSizeBytes, /* logCollectionBufferSizeBytes */ suite.bufferSizeBytes, /* getLogsBufferSizeBytes */ common.LogTimeUpdateBytesInterval, - 0) /* advancedLogLevel */ + 0, /* advancedLogLevel */ + 10) /* listRunsChunkSize */ suite.Require().NoError(err, "Failed to create log collector server") // start log collector server in a goroutine, so it won't block the test diff --git a/server/log-collector/pkg/services/logcollector/test/nop/nop.go b/server/log-collector/pkg/services/logcollector/test/nop/nop.go index 2816f61109e..f1ef8143339 100644 --- a/server/log-collector/pkg/services/logcollector/test/nop/nop.go +++ b/server/log-collector/pkg/services/logcollector/test/nop/nop.go @@ -36,3 +36,18 @@ func (m *GetLogsResponseStreamNop) Send(response *log_collector.GetLogsResponse) func (m *GetLogsResponseStreamNop) Context() context.Context { return context.Background() } + +// ListRunsResponseStreamNop is a nop implementation of the protologcollector.LogCollector_ListRunsServer interface +type ListRunsResponseStreamNop struct { + grpc.ServerStream + RunUIDs []string +} + +func (m *ListRunsResponseStreamNop) Send(response *log_collector.ListRunsResponse) error { + m.RunUIDs = append(m.RunUIDs, response.RunUIDs...) + return nil +} + +func (m *ListRunsResponseStreamNop) Context() context.Context { + return context.Background() +} diff --git a/server/log-collector/proto/log_collector.proto b/server/log-collector/proto/log_collector.proto index 341bb423070..35726b49bfe 100644 --- a/server/log-collector/proto/log_collector.proto +++ b/server/log-collector/proto/log_collector.proto @@ -24,6 +24,7 @@ service LogCollector { rpc GetLogSize(GetLogSizeRequest) returns (GetLogSizeResponse) {} rpc StopLogs(StopLogsRequest) returns (BaseResponse) {} rpc DeleteLogs(StopLogsRequest) returns (BaseResponse) {} + rpc ListRunsInProgress(ListRunsRequest) returns (stream ListRunsResponse) {} } message BaseResponse { @@ -71,6 +72,14 @@ message StopLogsRequest { repeated string runUIDs = 2; } +message ListRunsRequest { + string project = 1; +} + +message ListRunsResponse { + repeated string runUIDs = 1; +} + // StringArray is a wrapper around a repeated string field, used in map values. message StringArray { repeated string values = 1; diff --git a/tests/api/test_collect_runs_logs.py b/tests/api/test_collect_runs_logs.py index d548e4b0638..8df59af501a 100644 --- a/tests/api/test_collect_runs_logs.py +++ b/tests/api/test_collect_runs_logs.py @@ -25,7 +25,10 @@ import server.api.main import server.api.utils.clients.log_collector import server.api.utils.singletons.db -from tests.api.utils.clients.test_log_collector import BaseLogCollectorResponse +from tests.api.utils.clients.test_log_collector import ( + BaseLogCollectorResponse, + ListRunsResponse, +) class TestCollectRunSLogs: @@ -482,6 +485,9 @@ async def test_verify_stop_logs_on_startup( run_uids = [run_uid for run_uid, _ in run_uids_to_state] + # the first run is not currently being log collected + run_uids_log_collected = run_uids[1:] + # update requested logs field to True server.api.utils.singletons.db.get_db().update_runs_requested_logs( db, run_uids, True @@ -492,24 +498,29 @@ async def test_verify_stop_logs_on_startup( requested_logs_modes=[True], only_uids=False, ) - assert len(runs) == 5 + assert len(runs) == len(run_uids) log_collector._call = unittest.mock.AsyncMock(return_value=None) + log_collector._call_stream = unittest.mock.MagicMock( + return_value=ListRunsResponse(run_uids=run_uids_log_collected) + ) await server.api.main._verify_log_collection_stopped_on_startup() + assert log_collector._call_stream.call_count == 1 + assert log_collector._call_stream.call_args[0][0] == "ListRunsInProgress" assert log_collector._call.call_count == 1 assert log_collector._call.call_args[0][0] == "StopLogs" stop_log_request = log_collector._call.call_args[0][1] assert stop_log_request.project == project_name # one of the runs is in running state - run_uids = run_uids[: len(run_uids) - 1] - assert len(stop_log_request.runUIDs) == len(run_uids) + expected_run_uids = run_uids_log_collected[:-1] + assert len(stop_log_request.runUIDs) == len(expected_run_uids) assert ( deepdiff.DeepDiff( list(stop_log_request.runUIDs), - run_uids, + expected_run_uids, ignore_order=True, ) == {} @@ -517,7 +528,7 @@ async def test_verify_stop_logs_on_startup( # update requested logs field to False for one run server.api.utils.singletons.db.get_db().update_runs_requested_logs( - db, [run_uids[0]], False + db, [run_uids[1]], False ) runs = server.api.utils.singletons.db.get_db().list_distinct_runs_uids( @@ -527,17 +538,26 @@ async def test_verify_stop_logs_on_startup( ) assert len(runs) == 4 + # mock it again so the stream will run again + log_collector._call_stream = unittest.mock.MagicMock( + return_value=ListRunsResponse(run_uids=run_uids_log_collected) + ) + await server.api.main._verify_log_collection_stopped_on_startup() assert log_collector._call.call_count == 2 assert log_collector._call.call_args[0][0] == "StopLogs" stop_log_request = log_collector._call.call_args[0][1] assert stop_log_request.project == project_name - assert len(stop_log_request.runUIDs) == 3 + assert len(stop_log_request.runUIDs) == 2 + + # the first run is not currently being log collected, second run has requested logs set to False + # and the last run is in running state + expected_run_uids = run_uids_log_collected[1:-1] assert ( deepdiff.DeepDiff( list(stop_log_request.runUIDs), - run_uids[1:], + expected_run_uids, ignore_order=True, ) == {} diff --git a/tests/api/utils/clients/test_log_collector.py b/tests/api/utils/clients/test_log_collector.py index 2499fadb284..04279f69a77 100644 --- a/tests/api/utils/clients/test_log_collector.py +++ b/tests/api/utils/clients/test_log_collector.py @@ -16,9 +16,7 @@ import unittest.mock import deepdiff -import fastapi.testclient import pytest -import sqlalchemy.orm.session import mlrun import mlrun.common.schemas @@ -66,6 +64,23 @@ def __init__(self, success, error, log_size=None): self.logSize = log_size +class ListRunsResponse: + def __init__(self, run_uids=None, total_calls=1): + self.runUIDs = run_uids or [] + self.total_calls = total_calls + self.current_calls = 0 + + # the following methods are required for the async iterator protocol + def __aiter__(self): + return self + + async def __anext__(self): + if self.current_calls < self.total_calls: + self.current_calls += 1 + return self + raise StopAsyncIteration + + mlrun.mlconf.log_collector.address = "http://localhost:8080" mlrun.mlconf.log_collector.mode = mlrun.common.schemas.LogsCollectorMode.sidecar @@ -74,8 +89,6 @@ class TestLogCollector: @pytest.mark.asyncio async def test_start_log( self, - db: sqlalchemy.orm.session.Session, - client: fastapi.testclient.TestClient, monkeypatch, ): run_uid = "123" @@ -108,9 +121,7 @@ async def test_start_log( assert success is False and error == "Failed to start logs" @pytest.mark.asyncio - async def test_get_logs( - self, db: sqlalchemy.orm.session.Session, client: fastapi.testclient.TestClient - ): + async def test_get_logs(self): run_uid = "123" project_name = "some-project" log_collector = server.api.utils.clients.log_collector.LogCollectorClient() @@ -151,9 +162,7 @@ async def test_get_logs( assert log == b"" @pytest.mark.asyncio - async def test_get_log_with_retryable_error( - self, db: sqlalchemy.orm.session.Session, client: fastapi.testclient.TestClient - ): + async def test_get_log_with_retryable_error(self): run_uid = "123" project_name = "some-project" log_collector = server.api.utils.clients.log_collector.LogCollectorClient() @@ -186,9 +195,7 @@ async def test_get_log_with_retryable_error( assert log == b"" # should not get here @pytest.mark.asyncio - async def test_stop_logs( - self, db: sqlalchemy.orm.session.Session, client: fastapi.testclient.TestClient - ): + async def test_stop_logs(self): run_uids = ["123"] project_name = "some-project" log_collector = server.api.utils.clients.log_collector.LogCollectorClient() @@ -213,9 +220,7 @@ async def test_stop_logs( await log_collector.stop_logs(run_uids=run_uids, project=project_name) @pytest.mark.asyncio - async def test_delete_logs( - self, db: sqlalchemy.orm.session.Session, client: fastapi.testclient.TestClient - ): + async def test_delete_logs(self): run_uids = None project_name = "some-project" log_collector = server.api.utils.clients.log_collector.LogCollectorClient() @@ -247,6 +252,24 @@ async def test_delete_logs( assert stop_log_request.project == project_name assert stop_log_request.runUIDs == run_uids + @pytest.mark.asyncio + async def test_list_runs_in_progress(self): + project_name = "some-project" + log_collector = server.api.utils.clients.log_collector.LogCollectorClient() + + async def _verify_runs(run_uids_stream): + async for run_uid_list in run_uids_stream: + for run_uid in run_uid_list: + assert run_uid in run_uids + + # mock a short response for ListRunsInProgress + run_uids = [f"{str(i)}" for i in range(10)] + log_collector._call_stream = unittest.mock.MagicMock( + return_value=ListRunsResponse(run_uids=run_uids) + ) + run_uids_stream = log_collector.list_runs_in_progress(project=project_name) + await _verify_runs(run_uids_stream) + @pytest.mark.parametrize( "error_code,expected_mlrun_error", [