From 62286378baedb7141ca8b2cf9d6c2ec72d11813c Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 16 Oct 2024 11:01:00 +0200 Subject: [PATCH] Fix running NVIDIA NIM images Temporary fix to fetching image configs from NVIDIA NGC registry. Waiting for a permanent fix in upstream dependency (python-dxf). --- setup.py | 2 +- .../_internal/server/services/docker.py | 3 +- src/dstack/_internal/utils/dxf.py | 87 +++++++++++++++++++ 3 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 src/dstack/_internal/utils/dxf.py diff --git a/setup.py b/setup.py index 313608d96c..11f4f4db4a 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ def get_long_description(): "python-multipart", "filelock", "docker>=6.0.0", - "python-dxf>=11.0.0", + "python-dxf==12.1.0", "cachetools", "dnspython", "grpcio>=1.50", # indirect diff --git a/src/dstack/_internal/server/services/docker.py b/src/dstack/_internal/server/services/docker.py index a7fb70e3bc..ec9d42c62e 100644 --- a/src/dstack/_internal/server/services/docker.py +++ b/src/dstack/_internal/server/services/docker.py @@ -11,6 +11,7 @@ from dstack._internal.core.errors import DockerRegistryError from dstack._internal.core.models.common import CoreModel, RegistryAuth from dstack._internal.server.utils.common import join_byte_stream_checked +from dstack._internal.utils.dxf import PatchedDXF DEFAULT_PLATFORM = "linux/amd64" DEFAULT_REGISTRY = "index.docker.io" @@ -65,7 +66,7 @@ class ImageManifest(CoreModel): def get_image_config(image_name: str, registry_auth: Optional[RegistryAuth]) -> ImageConfigObject: image = parse_image_name(image_name) - registry_client = DXF( + registry_client = PatchedDXF( host=image.registry or DEFAULT_REGISTRY, repo=image.repo, auth=DXFAuthAdapter(registry_auth), diff --git a/src/dstack/_internal/utils/dxf.py b/src/dstack/_internal/utils/dxf.py new file mode 100644 index 0000000000..567d951718 --- /dev/null +++ b/src/dstack/_internal/utils/dxf.py @@ -0,0 +1,87 @@ +""" +Temporary patch to python-dxf. +TODO(#1828): remove once https://github.com/davedoesdev/dxf/issues/57 is resolved. +""" + +import base64 +import urllib.parse as urlparse +import warnings +from typing import List, Optional +from urllib.parse import urlencode + +import requests +import www_authenticate +from dxf import DXF, _ignore_warnings, _raise_for_status, _to_bytes_2and3, exceptions + + +class PatchedDXF(DXF): + # copied from python-dxf + this bugfix: https://github.com/davedoesdev/dxf/pull/58 + def authenticate( + self, + username: Optional[str] = None, + password: Optional[str] = None, + actions: Optional[List[str]] = None, + response: Optional[requests.Response] = None, + authorization: Optional[str] = None, + user_agent: str = "Docker-Client/19.03.2 (linux)", + ) -> Optional[str]: + if response is None: + with warnings.catch_warnings(): + _ignore_warnings(self) + response = self._sessions[0].get( + self._base_url, verify=self._tlsverify, timeout=self._timeout + ) + + if not self._response_needs_auth(response): + return None + + if self._insecure: + raise exceptions.DXFAuthInsecureError() + + parsed = www_authenticate.parse(response.headers["www-authenticate"]) + + if username is not None and password is not None: + headers = { + "Authorization": "Basic " + + base64.b64encode(_to_bytes_2and3(username + ":" + password)).decode("utf-8") + } + elif authorization is not None: + headers = {"Authorization": authorization} + else: + headers = {} + headers["User-Agent"] = user_agent + + if "bearer" in parsed: + info = parsed["bearer"] + if actions and self._repo: + scope = "repository:" + self._repo + ":" + ",".join(actions) + elif "scope" in info: + scope = info["scope"] + elif not self._repo: + # Issue #28: gcr.io doesn't return scope for non-repo requests + scope = "registry:catalog:*" + else: + scope = "" + url_parts = list(urlparse.urlparse(info["realm"])) + query = urlparse.parse_qsl(url_parts[4]) + if "service" in info: + query.append(("service", info["service"])) + query.extend(("scope", s) for s in scope.split()) + url_parts[4] = urlencode(query, True) + url_parts[0] = "https" + if self._auth_host: + url_parts[1] = self._auth_host + auth_url = urlparse.urlunparse(url_parts) + with warnings.catch_warnings(): + _ignore_warnings(self) + r = self._sessions[0].get( + auth_url, headers=headers, verify=self._tlsverify, timeout=self._timeout + ) + _raise_for_status(r) + rjson = r.json() + # Use 'access_token' value if present and not empty, else 'token' value. + self.token = rjson.get("access_token") or rjson["token"] + return self._token + + self._headers = headers + return None