From b3096f5c842b280b91ef0b9127b0e968137c77d8 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Thu, 11 May 2023 09:54:30 +0400 Subject: [PATCH 1/2] Replace backend with hub_client in providers --- cli/dstack/providers/_torchrun/main.py | 6 +++--- cli/dstack/providers/bash/main.py | 6 +++--- cli/dstack/providers/code/main.py | 6 +++--- cli/dstack/providers/docker/main.py | 6 +++--- cli/dstack/providers/lab/main.py | 6 +++--- cli/dstack/providers/notebook/main.py | 6 +++--- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/cli/dstack/providers/_torchrun/main.py b/cli/dstack/providers/_torchrun/main.py index 6409bf54fc..54de88d275 100644 --- a/cli/dstack/providers/_torchrun/main.py +++ b/cli/dstack/providers/_torchrun/main.py @@ -4,7 +4,7 @@ from rich_argparse import RichHelpFormatter -from dstack.backend.base import Backend +import dstack.api.hub as hub from dstack.core.job import GpusRequirements, JobSpec, Requirements from dstack.providers import Provider @@ -24,14 +24,14 @@ def __init__(self): def load( self, - backend: Backend, + hub_client: "hub.HubClient", args: Optional[Namespace], workflow_name: Optional[str], provider_data: Dict[str, Any], run_name: str, ssh_key_pub: Optional[str] = None, ): - super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub) + super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) self.script = self.provider_data.get("script") or self.provider_data.get("file") self.setup = self._get_list_data("setup") or self._get_list_data("before_run") self.python = self._safe_python_version("python") diff --git a/cli/dstack/providers/bash/main.py b/cli/dstack/providers/bash/main.py index af7c4fa431..4c054ae6c4 100644 --- a/cli/dstack/providers/bash/main.py +++ b/cli/dstack/providers/bash/main.py @@ -3,8 +3,8 @@ from rich_argparse import RichHelpFormatter +import dstack.api.hub as hub from dstack import version -from dstack.backend.base import Backend from dstack.core.app import AppSpec from dstack.core.job import JobSpec from dstack.providers import Provider @@ -26,14 +26,14 @@ def __init__(self): def load( self, - backend: Backend, + hub_client: "hub.HubClient", args: Optional[Namespace], workflow_name: Optional[str], provider_data: Dict[str, Any], run_name: str, ssh_key_pub: Optional[str] = None, ): - super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub) + super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) self.python = self._safe_python_version("python") self.commands = self._get_list_data("commands") self.env = self._env() diff --git a/cli/dstack/providers/code/main.py b/cli/dstack/providers/code/main.py index d2a4df323a..d5d68cb8ce 100644 --- a/cli/dstack/providers/code/main.py +++ b/cli/dstack/providers/code/main.py @@ -4,8 +4,8 @@ from rich_argparse import RichHelpFormatter +import dstack.api.hub as hub from dstack import version -from dstack.backend.base import Backend from dstack.core.app import AppSpec from dstack.core.job import JobSpec from dstack.providers import Provider @@ -28,14 +28,14 @@ def __init__(self): def load( self, - backend: Backend, + hub_client: "hub.HubClient", args: Optional[Namespace], workflow_name: Optional[str], provider_data: Dict[str, Any], run_name: str, ssh_key_pub: Optional[str] = None, ): - super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub) + super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) self.setup = self._get_list_data("setup") or self._get_list_data("before_run") self.ports = self.provider_data.get("ports") self.python = self._safe_python_version("python") diff --git a/cli/dstack/providers/docker/main.py b/cli/dstack/providers/docker/main.py index 8ee6f04eb4..2ed4934282 100644 --- a/cli/dstack/providers/docker/main.py +++ b/cli/dstack/providers/docker/main.py @@ -3,7 +3,7 @@ from rich_argparse import RichHelpFormatter -from dstack.backend.base import Backend +import dstack.api.hub as hub from dstack.core.app import AppSpec from dstack.core.job import JobSpec from dstack.providers import Provider @@ -24,14 +24,14 @@ def __init__(self): def load( self, - backend: Backend, + hub_client: "hub.HubClient", args: Optional[Namespace], workflow_name: Optional[str], provider_data: Dict[str, Any], run_name: str, ssh_key_pub: Optional[str] = None, ): - super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub) + super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) self.image_name = self.provider_data["image"] self.registry_auth = self.provider_data.get("registry_auth") self.commands = self._get_list_data("commands") diff --git a/cli/dstack/providers/lab/main.py b/cli/dstack/providers/lab/main.py index 128e9fbf1a..aafe11aa0f 100644 --- a/cli/dstack/providers/lab/main.py +++ b/cli/dstack/providers/lab/main.py @@ -4,8 +4,8 @@ from rich_argparse import RichHelpFormatter +import dstack.api.hub as hub from dstack import version -from dstack.backend.base import Backend from dstack.core.app import AppSpec from dstack.core.job import JobSpec from dstack.providers import Provider @@ -26,14 +26,14 @@ def __init__(self): def load( self, - backend: Backend, + hub_client: "hub.HubClient", args: Optional[Namespace], workflow_name: Optional[str], provider_data: Dict[str, Any], run_name: str, ssh_key_pub: Optional[str] = None, ): - super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub) + super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) self.setup = self._get_list_data("setup") or self._get_list_data("before_run") self.python = self._safe_python_version("python") self.version = self.provider_data.get("version") diff --git a/cli/dstack/providers/notebook/main.py b/cli/dstack/providers/notebook/main.py index d26222bba0..de27fa4699 100644 --- a/cli/dstack/providers/notebook/main.py +++ b/cli/dstack/providers/notebook/main.py @@ -4,8 +4,8 @@ from rich_argparse import RichHelpFormatter +import dstack.api.hub as hub from dstack import version -from dstack.backend.base import Backend from dstack.core.app import AppSpec from dstack.core.job import JobSpec from dstack.providers import Provider @@ -26,14 +26,14 @@ def __init__(self): def load( self, - backend: Backend, + hub_client: "hub.HubClient", args: Optional[Namespace], workflow_name: Optional[str], provider_data: Dict[str, Any], run_name: str, ssh_key_pub: Optional[str] = None, ): - super().load(backend, args, workflow_name, provider_data, run_name, ssh_key_pub) + super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) self.setup = self._get_list_data("setup") or self._get_list_data("before_run") self.python = self._safe_python_version("python") self.version = self.provider_data.get("version") From 18586ea41a979dc2b2eba41f8ab6789c05fd5b32 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Thu, 11 May 2023 12:22:46 +0400 Subject: [PATCH 2/2] Pass home_dir to runner, mount git credentials into the container --- cli/dstack/core/job.py | 3 +++ cli/dstack/providers/__init__.py | 1 + cli/dstack/providers/docker/main.py | 1 + cli/dstack/schemas/workflows.json | 8 +++++++ runner/internal/executor/executor.go | 35 ++++++++++++++++++++++++++++ runner/internal/models/backend.go | 1 + 6 files changed, 49 insertions(+) diff --git a/cli/dstack/core/job.py b/cli/dstack/core/job.py index 5e6ade7645..ce1b028f8f 100644 --- a/cli/dstack/core/job.py +++ b/cli/dstack/core/job.py @@ -186,6 +186,7 @@ class Job(JobHead): commands: Optional[List[str]] entrypoint: Optional[List[str]] env: Optional[Dict[str, str]] + home_dir: Optional[str] working_dir: Optional[str] artifact_specs: Optional[List[ArtifactSpec]] cache_specs: List[CacheSpec] @@ -252,6 +253,7 @@ def serialize(self) -> dict: "commands": self.commands or [], "entrypoint": self.entrypoint, "env": self.env or {}, + "home_dir": self.home_dir or "", "working_dir": self.working_dir or "", "artifacts": artifacts, "cache": [item.dict() for item in self.cache_specs], @@ -394,6 +396,7 @@ def unserialize(job_data: dict): commands=job_data.get("commands") or None, entrypoint=job_data.get("entrypoint") or None, env=job_data["env"] or None, + home_dir=job_data.get("home_dir") or None, working_dir=job_data.get("working_dir") or None, artifact_specs=artifact_specs, cache_specs=[CacheSpec(**item) for item in job_data.get("cache", [])], diff --git a/cli/dstack/providers/__init__.py b/cli/dstack/providers/__init__.py index d4d661137f..13074579ab 100644 --- a/cli/dstack/providers/__init__.py +++ b/cli/dstack/providers/__init__.py @@ -256,6 +256,7 @@ def submit_jobs(self, hub_client: "hub.HubClient", tag_name: str) -> List[Job]: commands=job_spec.commands, entrypoint=job_spec.entrypoint, env=job_spec.env, + home_dir=self.home_dir, working_dir=job_spec.working_dir, artifact_specs=job_spec.artifact_specs, cache_specs=self.cache_specs, diff --git a/cli/dstack/providers/docker/main.py b/cli/dstack/providers/docker/main.py index 2ed4934282..e498f7b000 100644 --- a/cli/dstack/providers/docker/main.py +++ b/cli/dstack/providers/docker/main.py @@ -40,6 +40,7 @@ def load( self.entrypoint = ["/bin/sh", "-i", "-c"] self.artifact_specs = self._artifact_specs() self.env = self.provider_data.get("env") + self.home_dir = self.provider_data.get("home_dir") self.working_dir = self.provider_data.get("working_dir") self.ports = self.provider_data.get("ports") self.resources = self._resources() diff --git a/cli/dstack/schemas/workflows.json b/cli/dstack/schemas/workflows.json index d9b517c22c..b025ffdf58 100644 --- a/cli/dstack/schemas/workflows.json +++ b/cli/dstack/schemas/workflows.json @@ -168,6 +168,11 @@ "type": "integer", "minimum": 1 }, + "home_dir": { + "description": "The absolute path to the home directory inside the container", + "type": "string", + "minLength": 1 + }, "working_dir": { "description": "The absolute or relative path to the working directory where to run the workflow", "type": "string", @@ -283,6 +288,9 @@ "registry_auth": { "$ref": "#/definitions/registry_auth" }, + "home_dir": { + "$ref": "#/definitions/home_dir" + }, "working_dir": { "$ref": "#/definitions/working_dir" }, diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index e7de8e4ee2..1d0cf1898a 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -525,6 +525,41 @@ func (ex *Executor) processJob(ctx context.Context, stoppedCh chan struct{}) err } bindings = append(bindings, art...) } + if job.RepoType == "remote" && job.HomeDir != "" { + cred := ex.backend.GitCredentials(ctx) + if cred != nil { + log.Trace(ctx, "Trying to mount git credentials") + credPath := path.Join(ex.backend.GetTMPDir(ctx), consts.RUNS_DIR, job.RunName, "credentials") + credMountPath := "" + switch cred.Protocol { + case "ssh": + if cred.PrivateKey != nil { + credMountPath = path.Join(job.HomeDir, ".ssh/id_rsa") + if err := os.WriteFile(credPath, []byte(*cred.PrivateKey), 0600); err != nil { + log.Error(ctx, "Failed writing credentials", "err", err) + } + } + case "https": + if cred.OAuthToken != nil { + credMountPath = path.Join(job.HomeDir, ".config/gh/hosts.yml") + ghHost := fmt.Sprintf("%s:\n oauth_token: \"%s\"\n", job.RepoHostName, *cred.OAuthToken) + if err := os.WriteFile(credPath, []byte(ghHost), 0644); err != nil { + log.Error(ctx, "Failed writing credentials", "err", err) + } + } + default: + } + if credMountPath != "" { + defer os.Remove(credPath) + log.Trace(ctx, "Mounting git credentials", "target", credMountPath) + bindings = append(bindings, mount.Mount{ + Type: mount.TypeBind, + Source: credPath, + Target: credMountPath, + }) + } + } + } logger := ex.backend.CreateLogger(ctx, fmt.Sprintf("/dstack/jobs/%s/%s", ex.backend.Bucket(ctx), job.RepoId), job.RunName) secrets, err := ex.backend.Secrets(ctx) if err != nil { diff --git a/runner/internal/models/backend.go b/runner/internal/models/backend.go index b5b6675579..4bdf5a14b5 100644 --- a/runner/internal/models/backend.go +++ b/runner/internal/models/backend.go @@ -55,6 +55,7 @@ type Job struct { InstanceType string `yaml:"instance_type"` //Variables map[string]interface{} `yaml:"variables"` WorkflowName string `yaml:"workflow_name"` + HomeDir string `yaml:"home_dir"` WorkingDir string `yaml:"working_dir"` RegistryAuth RegistryAuth `yaml:"registry_auth"`