Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions runner/internal/runner/api/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ package api

import (
"context"
"errors"
"net/http"
"time"

"github.com/dstackai/dstack/runner/internal/log"
"github.com/gorilla/websocket"
)

type logsWsRequestParams struct {
startTimestamp int64
}

var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
Expand All @@ -20,18 +25,54 @@ func (s *Server) logsWsGetHandler(w http.ResponseWriter, r *http.Request) (inter
if err != nil {
return nil, err
}
requestParams, err := parseRequestParams(r)
if err != nil {
_ = conn.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseUnsupportedData, err.Error()),
)
_ = conn.Close()
return nil, nil
}
// todo memorize clientId?
go s.streamJobLogs(conn)
go s.streamJobLogs(r.Context(), conn, requestParams)
return nil, nil
}

func (s *Server) streamJobLogs(conn *websocket.Conn) {
currentPos := 0
func parseRequestParams(r *http.Request) (logsWsRequestParams, error) {
query := r.URL.Query()
startTimeStr := query.Get("start_time")
var startTimestamp int64
if startTimeStr != "" {
t, err := time.Parse(time.RFC3339, startTimeStr)
if err != nil {
return logsWsRequestParams{}, errors.New("Failed to parse start_time value")
}
startTimestamp = t.Unix()
}
return logsWsRequestParams{startTimestamp: startTimestamp}, nil
}

func (s *Server) streamJobLogs(ctx context.Context, conn *websocket.Conn, params logsWsRequestParams) {
defer func() {
_ = conn.WriteMessage(websocket.CloseMessage, nil)
_ = conn.Close()
}()

currentPos := 0
startTimestampMs := params.startTimestamp * 1000
if startTimestampMs != 0 {
// TODO: Replace currentPos linear search with binary search
s.executor.RLock()
jobLogsWsHistory := s.executor.GetJobWsLogsHistory()
for _, logEntry := range jobLogsWsHistory {
if logEntry.Timestamp < startTimestampMs {
currentPos += 1
} else {
break
}
}
s.executor.RUnlock()
}
for {
s.executor.RLock()
jobLogsWsHistory := s.executor.GetJobWsLogsHistory()
Expand All @@ -52,7 +93,7 @@ func (s *Server) streamJobLogs(conn *websocket.Conn) {
for currentPos < len(jobLogsWsHistory) {
if err := conn.WriteMessage(websocket.BinaryMessage, jobLogsWsHistory[currentPos].Message); err != nil {
s.executor.RUnlock()
log.Error(context.TODO(), "Failed to write message", "err", err)
log.Error(ctx, "Failed to write message", "err", err)
return
}
currentPos++
Expand Down
2 changes: 1 addition & 1 deletion runner/internal/schemas/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type JobStateEvent struct {

type LogEvent struct {
Message []byte `json:"message"`
Timestamp int64 `json:"timestamp"`
Timestamp int64 `json:"timestamp"` // milliseconds
}

type SubmitBody struct {
Expand Down
12 changes: 11 additions & 1 deletion src/dstack/_internal/cli/commands/attach.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_run_exit_code,
print_finished_message,
)
from dstack._internal.cli.utils.common import console
from dstack._internal.cli.utils.common import console, get_start_time
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT
from dstack._internal.core.errors import CLIError
from dstack._internal.utils.common import get_or_error
Expand Down Expand Up @@ -61,6 +61,14 @@ def _register(self):
type=int,
default=0,
)
self._parser.add_argument(
"--since",
help=(
"Show only logs newer than the specified date."
" Can be a duration (e.g. 10s, 5m, 1d) or an RFC 3339 string (e.g. 2023-09-24T15:30:00Z)."
),
type=str,
)
self._parser.add_argument("run_name").completer = RunNameCompleter() # type: ignore[attr-defined]

def _command(self, args: argparse.Namespace):
Expand All @@ -86,7 +94,9 @@ def _command(self, args: argparse.Namespace):
job_num=args.job,
)
if args.logs:
start_time = get_start_time(args.since)
logs = run.logs(
start_time=start_time,
replica_num=args.replica,
job_num=args.job,
)
Expand Down
15 changes: 2 additions & 13 deletions src/dstack/_internal/cli/commands/logs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import argparse
import sys
from datetime import datetime
from typing import Optional

from dstack._internal.cli.commands import APIBaseCommand
from dstack._internal.cli.services.completion import RunNameCompleter
from dstack._internal.cli.utils.common import get_start_time
from dstack._internal.core.errors import CLIError
from dstack._internal.utils.common import parse_since
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -49,7 +47,7 @@ def _command(self, args: argparse.Namespace):
if run is None:
raise CLIError(f"Run {args.run_name} not found")

start_time = _get_start_time(args.since)
start_time = get_start_time(args.since)
logs = run.logs(
start_time=start_time,
diagnose=args.diagnose,
Expand All @@ -62,12 +60,3 @@ def _command(self, args: argparse.Namespace):
sys.stdout.buffer.flush()
except KeyboardInterrupt:
pass


def _get_start_time(since: Optional[str]) -> Optional[datetime]:
if since is None:
return None
try:
return parse_since(since)
except ValueError as e:
raise CLIError(e.args[0])
13 changes: 11 additions & 2 deletions src/dstack/_internal/cli/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union

from rich.console import Console
from rich.prompt import Confirm
Expand All @@ -11,7 +11,7 @@
from dstack._internal import settings
from dstack._internal.cli.utils.rich import DstackRichHandler
from dstack._internal.core.errors import CLIError, DstackError
from dstack._internal.utils.common import get_dstack_dir
from dstack._internal.utils.common import get_dstack_dir, parse_since

_colors = {
"secondary": "grey58",
Expand Down Expand Up @@ -110,3 +110,12 @@ def warn(message: str):
# Additional blank line for better visibility if there are more than one warning
message = f"{message}\n"
console.print(f"[warning][bold]{message}[/]")


def get_start_time(since: Optional[str]) -> Optional[datetime]:
if since is None:
return None
try:
return parse_since(since)
except ValueError as e:
raise CLIError(e.args[0])
26 changes: 4 additions & 22 deletions src/dstack/_internal/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from typing_extensions import ParamSpec

from dstack._internal.core.models.common import Duration

P = ParamSpec("P")
R = TypeVar("R")

Expand Down Expand Up @@ -150,20 +152,16 @@ def parse_since(value: str) -> datetime:
or a duration (e.g. 10s, 5m, 1d) between the timestamp and now.
"""
try:
seconds = parse_pretty_duration(value)
seconds = Duration.parse(value)
return get_current_datetime() - timedelta(seconds=seconds)
except ValueError:
pass
try:
res = datetime.fromisoformat(value)
except ValueError:
pass
raise ValueError("Invalid datetime format")
else:
return check_time_offset_aware(res)
try:
return datetime.fromtimestamp(int(value), tz=timezone.utc)
except Exception:
raise ValueError("Invalid datetime format")


def check_time_offset_aware(time: datetime) -> datetime:
Expand All @@ -172,22 +170,6 @@ def check_time_offset_aware(time: datetime) -> datetime:
return time


def parse_pretty_duration(duration: str) -> int:
regex = re.compile(r"(?P<amount>\d+)(?P<unit>s|m|h|d|w)$")
re_match = regex.match(duration)
if not re_match:
raise ValueError(f"Cannot parse the duration {duration}")
amount, unit = int(re_match.group("amount")), re_match.group("unit")
multiplier = {
"s": 1,
"m": 60,
"h": 3600,
"d": 24 * 3600,
"w": 7 * 24 * 3600,
}[unit]
return amount * multiplier


DURATION_UNITS_DESC = [
("w", 7 * 24 * 3600),
("d", 24 * 3600),
Expand Down
16 changes: 10 additions & 6 deletions src/dstack/api/_public/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datetime import datetime
from pathlib import Path
from typing import BinaryIO, Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse
from urllib.parse import urlencode, urlparse

from websocket import WebSocketApp

Expand Down Expand Up @@ -136,9 +136,7 @@ def service_model(self) -> Optional["ServiceModel"]:
),
)

def _attached_logs(
self,
) -> Iterable[bytes]:
def _attached_logs(self, start_time: Optional[datetime] = None) -> Iterable[bytes]:
q = queue.Queue()
_done = object()

Expand All @@ -150,8 +148,14 @@ def ws_thread():
logger.debug("WebSocket logs are done for %s", self.name)
q.put(_done)

url = f"ws://localhost:{self.ports[DSTACK_RUNNER_HTTP_PORT]}/logs_ws"
query_params = {}
if start_time is not None:
query_params["start_time"] = start_time.isoformat()
if query_params:
url = f"{url}?{urlencode(query_params)}"
ws = WebSocketApp(
f"ws://localhost:{self.ports[DSTACK_RUNNER_HTTP_PORT]}/logs_ws",
url=url,
on_open=lambda _: logger.debug("WebSocket logs are connected to %s", self.name),
on_close=lambda _, status_code, msg: logger.debug(
"WebSocket logs are disconnected. status_code: %s; message: %s",
Expand Down Expand Up @@ -215,7 +219,7 @@ def logs(
Log messages.
"""
if diagnose is False and self._ssh_attach is not None:
yield from self._attached_logs()
yield from self._attached_logs(start_time=start_time)
else:
job = self._find_job(replica_num=replica_num, job_num=job_num)
if job is None:
Expand Down