Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cli/dstack/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class Job(JobHead):
commands: Optional[List[str]]
entrypoint: Optional[List[str]]
env: Optional[Dict[str, str]]
home_dir: Optional[str]
working_dir: Optional[str]
artifact_specs: Optional[List[ArtifactSpec]]
cache_specs: List[CacheSpec]
Expand Down Expand Up @@ -252,6 +253,7 @@ def serialize(self) -> dict:
"commands": self.commands or [],
"entrypoint": self.entrypoint,
"env": self.env or {},
"home_dir": self.home_dir or "",
"working_dir": self.working_dir or "",
"artifacts": artifacts,
"cache": [item.dict() for item in self.cache_specs],
Expand Down Expand Up @@ -394,6 +396,7 @@ def unserialize(job_data: dict):
commands=job_data.get("commands") or None,
entrypoint=job_data.get("entrypoint") or None,
env=job_data["env"] or None,
home_dir=job_data.get("home_dir") or None,
working_dir=job_data.get("working_dir") or None,
artifact_specs=artifact_specs,
cache_specs=[CacheSpec(**item) for item in job_data.get("cache", [])],
Expand Down
1 change: 1 addition & 0 deletions cli/dstack/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def submit_jobs(self, hub_client: "hub.HubClient", tag_name: str) -> List[Job]:
commands=job_spec.commands,
entrypoint=job_spec.entrypoint,
env=job_spec.env,
home_dir=self.home_dir,
working_dir=job_spec.working_dir,
artifact_specs=job_spec.artifact_specs,
cache_specs=self.cache_specs,
Expand Down
6 changes: 3 additions & 3 deletions cli/dstack/providers/_torchrun/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from rich_argparse import RichHelpFormatter

from dstack.backend.base import Backend
import dstack.api.hub as hub
from dstack.core.job import GpusRequirements, JobSpec, Requirements
from dstack.providers import Provider

Expand All @@ -24,14 +24,14 @@ def __init__(self):

def load(
self,
backend: Backend,
hub_client: "hub.HubClient",
args: Optional[Namespace],
workflow_name: Optional[str],
provider_data: Dict[str, Any],
run_name: str,
ssh_key_pub: Optional[str] = None,
):
super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub)
super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub)
self.script = self.provider_data.get("script") or self.provider_data.get("file")
self.setup = self._get_list_data("setup") or self._get_list_data("before_run")
self.python = self._safe_python_version("python")
Expand Down
6 changes: 3 additions & 3 deletions cli/dstack/providers/bash/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from rich_argparse import RichHelpFormatter

import dstack.api.hub as hub
from dstack import version
from dstack.backend.base import Backend
from dstack.core.app import AppSpec
from dstack.core.job import JobSpec
from dstack.providers import Provider
Expand All @@ -26,14 +26,14 @@ def __init__(self):

def load(
self,
backend: Backend,
hub_client: "hub.HubClient",
args: Optional[Namespace],
workflow_name: Optional[str],
provider_data: Dict[str, Any],
run_name: str,
ssh_key_pub: Optional[str] = None,
):
super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub)
super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub)
self.python = self._safe_python_version("python")
self.commands = self._get_list_data("commands")
self.env = self._env()
Expand Down
6 changes: 3 additions & 3 deletions cli/dstack/providers/code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from rich_argparse import RichHelpFormatter

import dstack.api.hub as hub
from dstack import version
from dstack.backend.base import Backend
from dstack.core.app import AppSpec
from dstack.core.job import JobSpec
from dstack.providers import Provider
Expand All @@ -28,14 +28,14 @@ def __init__(self):

def load(
self,
backend: Backend,
hub_client: "hub.HubClient",
args: Optional[Namespace],
workflow_name: Optional[str],
provider_data: Dict[str, Any],
run_name: str,
ssh_key_pub: Optional[str] = None,
):
super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub)
super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub)
self.setup = self._get_list_data("setup") or self._get_list_data("before_run")
self.ports = self.provider_data.get("ports")
self.python = self._safe_python_version("python")
Expand Down
7 changes: 4 additions & 3 deletions cli/dstack/providers/docker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from rich_argparse import RichHelpFormatter

from dstack.backend.base import Backend
import dstack.api.hub as hub
from dstack.core.app import AppSpec
from dstack.core.job import JobSpec
from dstack.providers import Provider
Expand All @@ -24,14 +24,14 @@ def __init__(self):

def load(
self,
backend: Backend,
hub_client: "hub.HubClient",
args: Optional[Namespace],
workflow_name: Optional[str],
provider_data: Dict[str, Any],
run_name: str,
ssh_key_pub: Optional[str] = None,
):
super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub)
super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub)
self.image_name = self.provider_data["image"]
self.registry_auth = self.provider_data.get("registry_auth")
self.commands = self._get_list_data("commands")
Expand All @@ -40,6 +40,7 @@ def load(
self.entrypoint = ["/bin/sh", "-i", "-c"]
self.artifact_specs = self._artifact_specs()
self.env = self.provider_data.get("env")
self.home_dir = self.provider_data.get("home_dir")
self.working_dir = self.provider_data.get("working_dir")
self.ports = self.provider_data.get("ports")
self.resources = self._resources()
Expand Down
6 changes: 3 additions & 3 deletions cli/dstack/providers/lab/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from rich_argparse import RichHelpFormatter

import dstack.api.hub as hub
from dstack import version
from dstack.backend.base import Backend
from dstack.core.app import AppSpec
from dstack.core.job import JobSpec
from dstack.providers import Provider
Expand All @@ -26,14 +26,14 @@ def __init__(self):

def load(
self,
backend: Backend,
hub_client: "hub.HubClient",
args: Optional[Namespace],
workflow_name: Optional[str],
provider_data: Dict[str, Any],
run_name: str,
ssh_key_pub: Optional[str] = None,
):
super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub)
super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub)
self.setup = self._get_list_data("setup") or self._get_list_data("before_run")
self.python = self._safe_python_version("python")
self.version = self.provider_data.get("version")
Expand Down
6 changes: 3 additions & 3 deletions cli/dstack/providers/notebook/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from rich_argparse import RichHelpFormatter

import dstack.api.hub as hub
from dstack import version
from dstack.backend.base import Backend
from dstack.core.app import AppSpec
from dstack.core.job import JobSpec
from dstack.providers import Provider
Expand All @@ -26,14 +26,14 @@ def __init__(self):

def load(
self,
backend: Backend,
hub_client: "hub.HubClient",
args: Optional[Namespace],
workflow_name: Optional[str],
provider_data: Dict[str, Any],
run_name: str,
ssh_key_pub: Optional[str] = None,
):
super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub)
super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub)
self.setup = self._get_list_data("setup") or self._get_list_data("before_run")
self.python = self._safe_python_version("python")
self.version = self.provider_data.get("version")
Expand Down
8 changes: 8 additions & 0 deletions cli/dstack/schemas/workflows.json
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@
"type": "integer",
"minimum": 1
},
"home_dir": {
"description": "The absolute path to the home directory inside the container",
"type": "string",
"minLength": 1
},
"working_dir": {
"description": "The absolute or relative path to the working directory where to run the workflow",
"type": "string",
Expand Down Expand Up @@ -283,6 +288,9 @@
"registry_auth": {
"$ref": "#/definitions/registry_auth"
},
"home_dir": {
"$ref": "#/definitions/home_dir"
},
"working_dir": {
"$ref": "#/definitions/working_dir"
},
Expand Down
35 changes: 35 additions & 0 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,41 @@ func (ex *Executor) processJob(ctx context.Context, stoppedCh chan struct{}) err
}
bindings = append(bindings, art...)
}
if job.RepoType == "remote" && job.HomeDir != "" {
cred := ex.backend.GitCredentials(ctx)
if cred != nil {
log.Trace(ctx, "Trying to mount git credentials")
credPath := path.Join(ex.backend.GetTMPDir(ctx), consts.RUNS_DIR, job.RunName, "credentials")
credMountPath := ""
switch cred.Protocol {
case "ssh":
if cred.PrivateKey != nil {
credMountPath = path.Join(job.HomeDir, ".ssh/id_rsa")
if err := os.WriteFile(credPath, []byte(*cred.PrivateKey), 0600); err != nil {
log.Error(ctx, "Failed writing credentials", "err", err)
}
}
case "https":
if cred.OAuthToken != nil {
credMountPath = path.Join(job.HomeDir, ".config/gh/hosts.yml")
ghHost := fmt.Sprintf("%s:\n oauth_token: \"%s\"\n", job.RepoHostName, *cred.OAuthToken)
if err := os.WriteFile(credPath, []byte(ghHost), 0644); err != nil {
log.Error(ctx, "Failed writing credentials", "err", err)
}
}
default:
}
if credMountPath != "" {
defer os.Remove(credPath)
log.Trace(ctx, "Mounting git credentials", "target", credMountPath)
bindings = append(bindings, mount.Mount{
Type: mount.TypeBind,
Source: credPath,
Target: credMountPath,
})
}
}
}
logger := ex.backend.CreateLogger(ctx, fmt.Sprintf("/dstack/jobs/%s/%s", ex.backend.Bucket(ctx), job.RepoId), job.RunName)
secrets, err := ex.backend.Secrets(ctx)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions runner/internal/models/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type Job struct {
InstanceType string `yaml:"instance_type"`
//Variables map[string]interface{} `yaml:"variables"`
WorkflowName string `yaml:"workflow_name"`
HomeDir string `yaml:"home_dir"`
WorkingDir string `yaml:"working_dir"`

RegistryAuth RegistryAuth `yaml:"registry_auth"`
Expand Down