From eaf6ffa89c51f8f970e41929ab9b9a03ed3c0c2d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 7 Nov 2025 08:29:22 +0500 Subject: [PATCH] Support --since arg for dstack attach --logs command --- runner/internal/runner/api/ws.go | 51 +++++++++++++++++++-- runner/internal/schemas/schemas.go | 2 +- src/dstack/_internal/cli/commands/attach.py | 12 ++++- src/dstack/_internal/cli/commands/logs.py | 15 +----- src/dstack/_internal/cli/utils/common.py | 13 +++++- src/dstack/_internal/utils/common.py | 26 ++--------- src/dstack/api/_public/runs.py | 16 ++++--- 7 files changed, 85 insertions(+), 50 deletions(-) diff --git a/runner/internal/runner/api/ws.go b/runner/internal/runner/api/ws.go index ebb0caea2..2dbd81f07 100644 --- a/runner/internal/runner/api/ws.go +++ b/runner/internal/runner/api/ws.go @@ -2,6 +2,7 @@ package api import ( "context" + "errors" "net/http" "time" @@ -9,6 +10,10 @@ import ( "github.com/gorilla/websocket" ) +type logsWsRequestParams struct { + startTimestamp int64 +} + var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true @@ -20,18 +25,54 @@ func (s *Server) logsWsGetHandler(w http.ResponseWriter, r *http.Request) (inter if err != nil { return nil, err } + requestParams, err := parseRequestParams(r) + if err != nil { + _ = conn.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseUnsupportedData, err.Error()), + ) + _ = conn.Close() + return nil, nil + } // todo memorize clientId? - go s.streamJobLogs(conn) + go s.streamJobLogs(r.Context(), conn, requestParams) return nil, nil } -func (s *Server) streamJobLogs(conn *websocket.Conn) { - currentPos := 0 +func parseRequestParams(r *http.Request) (logsWsRequestParams, error) { + query := r.URL.Query() + startTimeStr := query.Get("start_time") + var startTimestamp int64 + if startTimeStr != "" { + t, err := time.Parse(time.RFC3339, startTimeStr) + if err != nil { + return logsWsRequestParams{}, errors.New("Failed to parse start_time value") + } + startTimestamp = t.Unix() + } + return logsWsRequestParams{startTimestamp: startTimestamp}, nil +} + +func (s *Server) streamJobLogs(ctx context.Context, conn *websocket.Conn, params logsWsRequestParams) { defer func() { _ = conn.WriteMessage(websocket.CloseMessage, nil) _ = conn.Close() }() - + currentPos := 0 + startTimestampMs := params.startTimestamp * 1000 + if startTimestampMs != 0 { + // TODO: Replace currentPos linear search with binary search + s.executor.RLock() + jobLogsWsHistory := s.executor.GetJobWsLogsHistory() + for _, logEntry := range jobLogsWsHistory { + if logEntry.Timestamp < startTimestampMs { + currentPos += 1 + } else { + break + } + } + s.executor.RUnlock() + } for { s.executor.RLock() jobLogsWsHistory := s.executor.GetJobWsLogsHistory() @@ -52,7 +93,7 @@ func (s *Server) streamJobLogs(conn *websocket.Conn) { for currentPos < len(jobLogsWsHistory) { if err := conn.WriteMessage(websocket.BinaryMessage, jobLogsWsHistory[currentPos].Message); err != nil { s.executor.RUnlock() - log.Error(context.TODO(), "Failed to write message", "err", err) + log.Error(ctx, "Failed to write message", "err", err) return } currentPos++ diff --git a/runner/internal/schemas/schemas.go b/runner/internal/schemas/schemas.go index 389ed4c9e..4a92702ee 100644 --- a/runner/internal/schemas/schemas.go +++ b/runner/internal/schemas/schemas.go @@ -16,7 +16,7 @@ type JobStateEvent struct { type LogEvent struct { Message []byte `json:"message"` - Timestamp int64 `json:"timestamp"` + Timestamp int64 `json:"timestamp"` // milliseconds } type SubmitBody struct { diff --git a/src/dstack/_internal/cli/commands/attach.py b/src/dstack/_internal/cli/commands/attach.py index 367e65232..e005723e1 100644 --- a/src/dstack/_internal/cli/commands/attach.py +++ b/src/dstack/_internal/cli/commands/attach.py @@ -11,7 +11,7 @@ get_run_exit_code, print_finished_message, ) -from dstack._internal.cli.utils.common import console +from dstack._internal.cli.utils.common import console, get_start_time from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT from dstack._internal.core.errors import CLIError from dstack._internal.utils.common import get_or_error @@ -61,6 +61,14 @@ def _register(self): type=int, default=0, ) + self._parser.add_argument( + "--since", + help=( + "Show only logs newer than the specified date." + " Can be a duration (e.g. 10s, 5m, 1d) or an RFC 3339 string (e.g. 2023-09-24T15:30:00Z)." + ), + type=str, + ) self._parser.add_argument("run_name").completer = RunNameCompleter() # type: ignore[attr-defined] def _command(self, args: argparse.Namespace): @@ -86,7 +94,9 @@ def _command(self, args: argparse.Namespace): job_num=args.job, ) if args.logs: + start_time = get_start_time(args.since) logs = run.logs( + start_time=start_time, replica_num=args.replica, job_num=args.job, ) diff --git a/src/dstack/_internal/cli/commands/logs.py b/src/dstack/_internal/cli/commands/logs.py index 1dd1316df..78cde52f4 100644 --- a/src/dstack/_internal/cli/commands/logs.py +++ b/src/dstack/_internal/cli/commands/logs.py @@ -1,12 +1,10 @@ import argparse import sys -from datetime import datetime -from typing import Optional from dstack._internal.cli.commands import APIBaseCommand from dstack._internal.cli.services.completion import RunNameCompleter +from dstack._internal.cli.utils.common import get_start_time from dstack._internal.core.errors import CLIError -from dstack._internal.utils.common import parse_since from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -49,7 +47,7 @@ def _command(self, args: argparse.Namespace): if run is None: raise CLIError(f"Run {args.run_name} not found") - start_time = _get_start_time(args.since) + start_time = get_start_time(args.since) logs = run.logs( start_time=start_time, diagnose=args.diagnose, @@ -62,12 +60,3 @@ def _command(self, args: argparse.Namespace): sys.stdout.buffer.flush() except KeyboardInterrupt: pass - - -def _get_start_time(since: Optional[str]) -> Optional[datetime]: - if since is None: - return None - try: - return parse_since(since) - except ValueError as e: - raise CLIError(e.args[0]) diff --git a/src/dstack/_internal/cli/utils/common.py b/src/dstack/_internal/cli/utils/common.py index b319b837d..c75f08b81 100644 --- a/src/dstack/_internal/cli/utils/common.py +++ b/src/dstack/_internal/cli/utils/common.py @@ -1,7 +1,7 @@ import logging from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union from rich.console import Console from rich.prompt import Confirm @@ -11,7 +11,7 @@ from dstack._internal import settings from dstack._internal.cli.utils.rich import DstackRichHandler from dstack._internal.core.errors import CLIError, DstackError -from dstack._internal.utils.common import get_dstack_dir +from dstack._internal.utils.common import get_dstack_dir, parse_since _colors = { "secondary": "grey58", @@ -110,3 +110,12 @@ def warn(message: str): # Additional blank line for better visibility if there are more than one warning message = f"{message}\n" console.print(f"[warning][bold]{message}[/]") + + +def get_start_time(since: Optional[str]) -> Optional[datetime]: + if since is None: + return None + try: + return parse_since(since) + except ValueError as e: + raise CLIError(e.args[0]) diff --git a/src/dstack/_internal/utils/common.py b/src/dstack/_internal/utils/common.py index fec6a6d48..28becc936 100644 --- a/src/dstack/_internal/utils/common.py +++ b/src/dstack/_internal/utils/common.py @@ -12,6 +12,8 @@ from typing_extensions import ParamSpec +from dstack._internal.core.models.common import Duration + P = ParamSpec("P") R = TypeVar("R") @@ -150,20 +152,16 @@ def parse_since(value: str) -> datetime: or a duration (e.g. 10s, 5m, 1d) between the timestamp and now. """ try: - seconds = parse_pretty_duration(value) + seconds = Duration.parse(value) return get_current_datetime() - timedelta(seconds=seconds) except ValueError: pass try: res = datetime.fromisoformat(value) except ValueError: - pass + raise ValueError("Invalid datetime format") else: return check_time_offset_aware(res) - try: - return datetime.fromtimestamp(int(value), tz=timezone.utc) - except Exception: - raise ValueError("Invalid datetime format") def check_time_offset_aware(time: datetime) -> datetime: @@ -172,22 +170,6 @@ def check_time_offset_aware(time: datetime) -> datetime: return time -def parse_pretty_duration(duration: str) -> int: - regex = re.compile(r"(?P\d+)(?Ps|m|h|d|w)$") - re_match = regex.match(duration) - if not re_match: - raise ValueError(f"Cannot parse the duration {duration}") - amount, unit = int(re_match.group("amount")), re_match.group("unit") - multiplier = { - "s": 1, - "m": 60, - "h": 3600, - "d": 24 * 3600, - "w": 7 * 24 * 3600, - }[unit] - return amount * multiplier - - DURATION_UNITS_DESC = [ ("w", 7 * 24 * 3600), ("d", 24 * 3600), diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index b583e676e..0823aa087 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -10,7 +10,7 @@ from datetime import datetime from pathlib import Path from typing import BinaryIO, Dict, Iterable, List, Optional, Union -from urllib.parse import urlparse +from urllib.parse import urlencode, urlparse from websocket import WebSocketApp @@ -136,9 +136,7 @@ def service_model(self) -> Optional["ServiceModel"]: ), ) - def _attached_logs( - self, - ) -> Iterable[bytes]: + def _attached_logs(self, start_time: Optional[datetime] = None) -> Iterable[bytes]: q = queue.Queue() _done = object() @@ -150,8 +148,14 @@ def ws_thread(): logger.debug("WebSocket logs are done for %s", self.name) q.put(_done) + url = f"ws://localhost:{self.ports[DSTACK_RUNNER_HTTP_PORT]}/logs_ws" + query_params = {} + if start_time is not None: + query_params["start_time"] = start_time.isoformat() + if query_params: + url = f"{url}?{urlencode(query_params)}" ws = WebSocketApp( - f"ws://localhost:{self.ports[DSTACK_RUNNER_HTTP_PORT]}/logs_ws", + url=url, on_open=lambda _: logger.debug("WebSocket logs are connected to %s", self.name), on_close=lambda _, status_code, msg: logger.debug( "WebSocket logs are disconnected. status_code: %s; message: %s", @@ -215,7 +219,7 @@ def logs( Log messages. """ if diagnose is False and self._ssh_attach is not None: - yield from self._attached_logs() + yield from self._attached_logs(start_time=start_time) else: job = self._find_job(replica_num=replica_num, job_num=job_num) if job is None: