Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions runner/internal/runner/api/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ func (s *Server) metricsGetHandler(w http.ResponseWriter, r *http.Request) (inte
return metrics, nil
}

// submitPostHandler must be called first
// It's safe to call it more than once
func (s *Server) submitPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
s.executor.Lock()
defer s.executor.Unlock()
state := s.executor.GetRunnerState()
if state != executor.WaitSubmit {
if state == executor.WaitRun {
log.Warning(r.Context(), "Job already submitted, submitting again", "current_state", state)
} else if state != executor.WaitSubmit {
log.Warning(r.Context(), "Executor doesn't wait submit", "current_state", state)
return nil, &api.Error{Status: http.StatusConflict}
}
Expand All @@ -52,20 +56,19 @@ func (s *Server) submitPostHandler(w http.ResponseWriter, r *http.Request) (inte
log.Error(r.Context(), "Failed to decode submit body", "err", err)
return nil, err
}
// todo go-playground/validator

s.executor.SetJob(body)
s.jobBarrierCh <- nil // notify server that job submitted
s.executor.SetRunnerState(executor.WaitRun)

return nil, nil
}

// uploadArchivePostHandler may be called 0 or more times, and must be called after submitPostHandler
// and before uploadCodePostHandler
// If uploadArchivePostHandler is called, it must be called after submitPostHandler and before runPostHandler
// It's safe to call it more than once with the same archive
func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
s.executor.Lock()
defer s.executor.Unlock()
if s.executor.GetRunnerState() != executor.WaitCode {
if s.executor.GetRunnerState() != executor.WaitRun {
return nil, &api.Error{Status: http.StatusConflict}
}

Expand Down Expand Up @@ -123,10 +126,12 @@ func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request
return nil, nil
}

// If uploadCodePostHandler is called, it must be called after submitPostHandler and before runPostHandler
// It's safe to call it more than once
func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
s.executor.Lock()
defer s.executor.Unlock()
if s.executor.GetRunnerState() != executor.WaitCode {
if s.executor.GetRunnerState() != executor.WaitRun {
return nil, &api.Error{Status: http.StatusConflict}
}

Expand All @@ -139,8 +144,6 @@ func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) (
return nil, fmt.Errorf("copy request body: %w", err)
}

s.executor.SetRunnerState(executor.WaitRun)

return nil, nil
}

Expand All @@ -151,6 +154,7 @@ func (s *Server) runPostHandler(w http.ResponseWriter, r *http.Request) (interfa
return nil, &api.Error{Status: http.StatusConflict}
}
s.executor.SetRunnerState(executor.ServeLogs)
s.jobBarrierCh <- nil // notify server that job started
s.executor.Unlock()

var runCtx context.Context
Expand Down
10 changes: 5 additions & 5 deletions runner/internal/runner/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ type Server struct {
pullDoneCh chan interface{} // Closed then /api/pull gave everything
wsDoneCh chan interface{} // Closed then /logs_ws gave everything

submitWaitDuration time.Duration
logsWaitDuration time.Duration
startWaitDuration time.Duration
logsWaitDuration time.Duration

executor executor.Executor
cancelRun context.CancelFunc
Expand Down Expand Up @@ -51,8 +51,8 @@ func NewServer(ctx context.Context, address string, version string, ex executor.
pullDoneCh: make(chan interface{}),
wsDoneCh: make(chan interface{}),

submitWaitDuration: 5 * time.Minute,
logsWaitDuration: 5 * time.Minute,
startWaitDuration: 5 * time.Minute,
logsWaitDuration: 5 * time.Minute,

executor: ex,

Expand Down Expand Up @@ -82,7 +82,7 @@ func (s *Server) Run(ctx context.Context) error {

select {
case <-s.jobBarrierCh: // job started
case <-time.After(s.submitWaitDuration):
case <-time.After(s.startWaitDuration):
log.Error(ctx, "Job didn't start in time, shutting down")
return errors.New("no job submitted")
case <-ctx.Done():
Expand Down
17 changes: 12 additions & 5 deletions runner/internal/runner/executor/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,29 @@ import (
)

type Executor interface {
// It must be safe to call SetJob more than once
SetJob(job schemas.SubmitBody)
// It must be safe to call WriteFileArchive more than once with the same archive
WriteFileArchive(id string, src io.Reader) error
// It must be safe to call WriteRepoBlob more than once
WriteRepoBlob(src io.Reader) error
Run(ctx context.Context) error

GetHistory(timestamp int64) *schemas.PullResponse
GetJobWsLogsHistory() []schemas.LogEvent

GetRunnerState() string
SetRunnerState(state string)

GetJobInfo(ctx context.Context) (username string, workingDir string, err error)
Run(ctx context.Context) error
SetJob(job schemas.SubmitBody)
SetJobState(ctx context.Context, state schemas.JobState)
SetJobStateWithTerminationReason(
ctx context.Context,
state schemas.JobState,
terminationReason types.TerminationReason,
terminationMessage string,
)
SetRunnerState(state string)
WriteFileArchive(id string, src io.Reader) error
WriteRepoBlob(src io.Reader) error

Lock()
RLock()
RUnlock()
Expand Down
1 change: 0 additions & 1 deletion runner/internal/runner/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ func (ex *RunExecutor) SetJob(body schemas.SubmitBody) {
ex.secrets = body.Secrets
ex.repoCredentials = body.RepoCredentials
ex.jobLogs.SetQuota(body.LogQuotaHour)
ex.state = WaitCode
}

func (ex *RunExecutor) SetJobState(ctx context.Context, state schemas.JobState) {
Expand Down
23 changes: 14 additions & 9 deletions runner/internal/runner/executor/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,8 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error {
return fmt.Errorf("prepare git repo: %w", err)
}
case "local", "virtual":
log.Trace(ctx, "Extracting tar archive")
if err := ex.prepareArchive(ctx); err != nil {
return fmt.Errorf("prepare archive: %w", err)
if err := ex.extractCodeArchive(ctx); err != nil {
return fmt.Errorf("extract code archive: %w", err)
}
default:
return fmt.Errorf("unknown RepoType: %s", ex.getRepoData().RepoType)
Expand Down Expand Up @@ -164,26 +163,32 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error {
return fmt.Errorf("set repo config: %w", err)
}

if ex.repoBlobPath == "" {
log.Trace(ctx, "No diff to apply")
return nil
}
log.Trace(ctx, "Applying diff")
repoDiff, err := os.ReadFile(ex.repoBlobPath)
if err != nil {
return fmt.Errorf("read repo diff: %w", err)
}
if len(repoDiff) > 0 {
if err := repo.ApplyDiff(ctx, ex.repoDir, string(repoDiff)); err != nil {
return fmt.Errorf("apply diff: %w", err)
}
if err := repo.ApplyDiff(ctx, ex.repoDir, string(repoDiff)); err != nil {
return fmt.Errorf("apply diff: %w", err)
}
return nil
}

func (ex *RunExecutor) prepareArchive(ctx context.Context) error {
func (ex *RunExecutor) extractCodeArchive(ctx context.Context) error {
if ex.repoBlobPath == "" {
log.Trace(ctx, "No code archive to extract")
return nil
}
log.Trace(ctx, "Extracting code archive", "src", ex.repoBlobPath, "dst", ex.repoDir)
file, err := os.Open(ex.repoBlobPath)
if err != nil {
return fmt.Errorf("open code archive: %w", err)
}
defer func() { _ = file.Close() }()
log.Trace(ctx, "Extracting code archive", "src", ex.repoBlobPath, "dst", ex.repoDir)
if err := extract.Tar(ctx, file, ex.repoDir, nil); err != nil {
return fmt.Errorf("extract tar archive: %w", err)
}
Expand Down
1 change: 0 additions & 1 deletion runner/internal/runner/executor/states.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package executor

const (
WaitSubmit = "wait_submit"
WaitCode = "wait_code"
WaitRun = "wait_run"
ServeLogs = "serve_logs"
WaitLogsFinished = "wait_logs_finished"
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/core/models/repos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class Repo(ABC):
repo_dir: Optional[str]
run_repo_data: "repos.AnyRunRepoData"

@abstractmethod
def has_code_to_write(self) -> bool:
pass

@abstractmethod
def write_code_file(self, fp: BinaryIO) -> str:
pass
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/core/models/repos/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __init__(
self.repo_id = repo_id
self.run_repo_data = repo_data

def has_code_to_write(self) -> bool:
# LocalRepo is deprecated, no need for real implementation
return True

def write_code_file(self, fp: BinaryIO) -> str:
repo_path = Path(self.run_repo_data.repo_dir)
with tarfile.TarFile(mode="w", fileobj=fp) as t:
Expand Down
7 changes: 7 additions & 0 deletions src/dstack/_internal/core/models/repos/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ def __init__(
self.repo_id = repo_id
self.run_repo_data = repo_data

def has_code_to_write(self) -> bool:
# repo_diff is:
# * None for RemoteRepo.from_url()
# * an empty string for RemoteRepo.from_dir() if there are no changes ("clean" state)
# * a non-empty string for RemoteRepo.from_dir() if there are changes ("dirty" state)
return bool(self.run_repo_data.repo_diff)

def write_code_file(self, fp: BinaryIO) -> str:
if self.run_repo_data.repo_diff is not None:
fp.write(self.run_repo_data.repo_diff.encode())
Expand Down
3 changes: 3 additions & 0 deletions src/dstack/_internal/core/models/repos/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def add_file(self, path: str, content: bytes):

self.files[resolve_relative_path(path).as_posix()] = content

def has_code_to_write(self) -> bool:
return len(self.files) > 0

def write_code_file(self, fp: BinaryIO) -> str:
with tarfile.TarFile(mode="w", fileobj=fp) as t:
for path, content in sorted(self.files.items()):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,7 @@ def _submit_job_to_runner(
job: Job,
jrd: Optional[JobRuntimeData],
cluster_info: ClusterInfo,
code: bytes,
code: Optional[bytes],
file_archives: Iterable[tuple[uuid.UUID, bytes]],
secrets: Dict[str, str],
repo_credentials: Optional[RemoteRepoCreds],
Expand Down Expand Up @@ -1352,11 +1352,15 @@ def _submit_job_to_runner(
repo_credentials=repo_credentials,
instance_env=instance_env,
)
logger.debug("%s: uploading file archive(s)", fmt(job_model))
for archive_id, archive in file_archives:
logger.debug("%s: uploading file archive: %s", fmt(job_model), archive_id)
runner_client.upload_archive(archive_id, archive)
logger.debug("%s: uploading code", fmt(job_model))
runner_client.upload_code(code)
if code is None and not runner_client.is_code_upload_optional():
# Old runner, we must call `/api/upload_code` to proceed
code = b""
if code is not None:
logger.debug("%s: uploading code", fmt(job_model))
runner_client.upload_code(code)
logger.debug("%s: starting job", fmt(job_model))
job_info = runner_client.run_job()
if job_info is not None:
Expand Down Expand Up @@ -1520,18 +1524,20 @@ def _get_repo_code_hash(run: Run, job: Job) -> Optional[str]:
return job.job_spec.repo_code_hash


async def _get_job_code(project: ProjectModel, repo: RepoModel, code_hash: Optional[str]) -> bytes:
async def _get_job_code(
project: ProjectModel, repo: RepoModel, code_hash: Optional[str]
) -> Optional[bytes]:
if code_hash is None:
return b""
return None
async with get_session_ctx() as session:
code_model = await get_code_model(session=session, repo=repo, code_hash=code_hash)
if code_model is None:
return b""
return None
if code_model.blob is not None:
return code_model.blob
storage = get_default_storage()
if storage is None:
return b""
return None
blob = await run_async(
storage.get_code,
project.name,
Expand All @@ -1542,7 +1548,7 @@ async def _get_job_code(project: ProjectModel, repo: RepoModel, code_hash: Optio
logger.error(
"Failed to get repo code hash %s from storage for repo %s", code_hash, repo.name
)
return b""
return None
return blob


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1173,17 +1173,17 @@ def _get_repo_code_hash(run: Run, job: Job) -> Optional[str]:

async def _get_job_code(
session: AsyncSession, project: ProjectModel, repo: RepoModel, code_hash: Optional[str]
) -> bytes:
) -> Optional[bytes]:
if code_hash is None:
return b""
return None
code_model = await get_code_model(session=session, repo=repo, code_hash=code_hash)
if code_model is None:
return b""
return None
if code_model.blob is not None:
return code_model.blob
storage = get_default_storage()
if storage is None:
return b""
return None
blob = await common_utils.run_async(
storage.get_code,
project.name,
Expand All @@ -1194,7 +1194,7 @@ async def _get_job_code(
logger.error(
"Failed to get repo code hash %s from storage for repo %s", code_hash, repo.name
)
return b""
return None
return blob


Expand Down Expand Up @@ -1243,7 +1243,7 @@ def _submit_job_to_runner(
job_model: JobModel,
job: Job,
cluster_info: ClusterInfo,
code: bytes,
code: Optional[bytes],
file_archives: Iterable[tuple[uuid.UUID, bytes]],
secrets: Dict[str, str],
repo_credentials: Optional[RemoteRepoCreds],
Expand Down Expand Up @@ -1285,11 +1285,15 @@ def _submit_job_to_runner(
repo_credentials=repo_credentials,
instance_env=instance_env,
)
logger.debug("%s: uploading file archive(s)", fmt(job_model))
for archive_id, archive in file_archives:
logger.debug("%s: uploading file archive: %s", fmt(job_model), archive_id)
runner_client.upload_archive(archive_id, archive)
logger.debug("%s: uploading code", fmt(job_model))
runner_client.upload_code(code)
if code is None and not runner_client.is_code_upload_optional():
# Old runner, we must call `/api/upload_code` to proceed
code = b""
if code is not None:
logger.debug("%s: uploading code", fmt(job_model))
runner_client.upload_code(code)
logger.debug("%s: starting job", fmt(job_model))
job_info = runner_client.run_job()
if job_info is not None:
Expand Down
Loading
Loading