From 335fdb9d3aa7835eb3d21b25ab330059210ec82a Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Thu, 30 May 2024 13:38:17 +0200 Subject: [PATCH] Implement running OCI Marketplace images Part of the OCI backend implementation. Behind the OCI_BACKEND feature flag --- .../_internal/core/backends/oci/compute.py | 43 +++---------- .../_internal/core/backends/oci/region.py | 10 ++- .../_internal/core/backends/oci/resources.py | 63 ++++++++++++++++++- .../services/backends/configurators/oci.py | 44 ++++++++++--- .../_internal/server/routers/test_backends.py | 12 ++-- 5 files changed, 120 insertions(+), 52 deletions(-) diff --git a/src/dstack/_internal/core/backends/oci/compute.py b/src/dstack/_internal/core/backends/oci/compute.py index de713e0b1..1b5e0eca0 100644 --- a/src/dstack/_internal/core/backends/oci/compute.py +++ b/src/dstack/_internal/core/backends/oci/compute.py @@ -1,9 +1,6 @@ -import json -import os from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass from functools import cached_property -from typing import Dict, List, Optional, Set +from typing import List, Optional import oci @@ -33,34 +30,9 @@ ] -@dataclass -class PreConfiguredResources: - # TODO(#1194): remove this class and teach dstack to create or discover all - # necessary resources automatically - - standard_image_ids: Dict[str, str] - cuda_image_ids: Dict[str, str] - - @staticmethod - def load(required_regions: Set[str]) -> "PreConfiguredResources": - params = dict( - standard_image_ids=json.loads(os.getenv("DSTACK_OCI_STANDARD_IMAGE_IDS", "null")), - cuda_image_ids=json.loads(os.getenv("DSTACK_OCI_CUDA_IMAGE_IDS", "null")), - ) - for param, value in params.items(): - if not value or required_regions - set(value): - msg = ( - f"Invalid OCI parameter {param!r}. Make sure you set the corresponding" - " environment variable when running dstack server" - ) - raise ValueError(msg) - return PreConfiguredResources(**params) - - class OCICompute(Compute): def __init__(self, config: OCIConfig): self.config = config - self.pre_conf = PreConfiguredResources.load(set(config.regions or [])) self.regions = make_region_clients_map(config.regions or [], config.creds) @cached_property @@ -132,10 +104,13 @@ def create_instance( if availability_domain is None: raise NoCapacityError("Shape unavailable in all availability domains") - if len(instance_offer.instance.resources.gpus) > 0: - image_id = self.pre_conf.cuda_image_ids[instance_offer.region] - else: - image_id = self.pre_conf.standard_image_ids[instance_offer.region] + listing, package = resources.get_marketplace_listing_and_package( + cuda=len(instance_offer.instance.resources.gpus) > 0, + client=region.marketplace_client, + ) + resources.accept_marketplace_listing_agreements( + listing, self.config.compartment_id, region.marketplace_client + ) try: instance = resources.launch_instance( @@ -147,7 +122,7 @@ def create_instance( cloud_init_user_data=get_user_data(instance_config.get_public_keys()), shape=instance_offer.instance.name, disk_size_gb=round(instance_offer.instance.resources.disk.size_mib / 1024), - image_id=image_id, + image_id=package.image_id, ) except oci.exceptions.ServiceError as e: if e.code in ("LimitExceeded", "QuotaExceeded"): diff --git a/src/dstack/_internal/core/backends/oci/region.py b/src/dstack/_internal/core/backends/oci/region.py index 81873df7d..201ec11ab 100644 --- a/src/dstack/_internal/core/backends/oci/region.py +++ b/src/dstack/_internal/core/backends/oci/region.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from functools import cached_property -from typing import Dict, Iterable +from typing import Dict, Iterable, Set import oci from typing_extensions import Any, List, Mapping @@ -25,6 +25,10 @@ def compute_client(self) -> oci.core.ComputeClient: def identity_client(self) -> oci.identity.IdentityClient: return oci.identity.IdentityClient(self.client_config) + @cached_property + def marketplace_client(self) -> oci.marketplace.MarketplaceClient: + return oci.marketplace.MarketplaceClient(self.client_config) + @cached_property def virtual_network_client(self) -> oci.core.VirtualNetworkClient: return oci.core.VirtualNetworkClient(self.client_config) @@ -54,7 +58,7 @@ def make_region_clients_map( @dataclass class SubscribedRegions: - names: List[str] + names: Set[str] home_region_name: str @@ -65,7 +69,7 @@ def get_subscribed_regions(creds: AnyOCICreds) -> SubscribedRegions: subscriptions: List[oci.identity.models.RegionSubscription] = ( region.identity_client.list_region_subscriptions(config["tenancy"]).data ) - names = [s.region_name for s in subscriptions if s.status == s.STATUS_READY] + names = {s.region_name for s in subscriptions if s.status == s.STATUS_READY} home_region_name = next(s.region_name for s in subscriptions if s.is_home_region) return SubscribedRegions(names=names, home_region_name=home_region_name) diff --git a/src/dstack/_internal/core/backends/oci/resources.py b/src/dstack/_internal/core/backends/oci/resources.py index 2f0244d96..6f840ad5c 100644 --- a/src/dstack/_internal/core/backends/oci/resources.py +++ b/src/dstack/_internal/core/backends/oci/resources.py @@ -3,10 +3,11 @@ from concurrent.futures import Executor, ThreadPoolExecutor, as_completed from functools import reduce from itertools import islice -from typing import Dict, Iterable, List, Mapping, Optional, Set +from typing import Dict, Iterable, List, Mapping, Optional, Set, Tuple import oci +from dstack import version from dstack._internal.core.backends.oci.region import OCIRegionClient from dstack._internal.core.errors import BackendError from dstack._internal.core.models.instances import InstanceOffer @@ -279,6 +280,66 @@ def terminate_instance_if_exists(client: oci.core.ComputeClient, instance_id: st raise +def get_marketplace_listing_and_package( + cuda: bool, client: oci.marketplace.MarketplaceClient +) -> Tuple[oci.marketplace.models.Listing, oci.marketplace.models.ImageListingPackage]: + listing_name = f"dstack-{version.base_image}" + if cuda: + listing_name = f"dstack-cuda-{version.base_image}" + + listing_summaries: List[oci.marketplace.models.ListingSummary] = client.list_listings( + name=listing_name, + listing_types=[oci.marketplace.models.Listing.LISTING_TYPE_COMMUNITY], + limit=1000, + ).data + # filter by exact match, as list_listings seems to filter by substring + listing_summaries = [s for s in listing_summaries if s.name == listing_name] + + if len(listing_summaries) != 1: + msg = f"Expected to find 1 listing by name {listing_name}, found {len(listing_summaries)}" + raise BackendError(msg) + + listing: oci.marketplace.models.Listing = client.get_listing(listing_summaries[0].id).data + package = client.get_package(listing.id, listing.default_package_version).data + return listing, package + + +def accept_marketplace_listing_agreements( + listing: oci.marketplace.models.Listing, + compartment_id: str, + client: oci.marketplace.MarketplaceClient, +) -> None: + accepted_agreements: List[oci.marketplace.models.AcceptedAgreementSummary] = ( + client.list_accepted_agreements( + compartment_id=compartment_id, + listing_id=listing.id, + package_version=listing.default_package_version, + ).data + ) + accepted_agreement_ids = {a.agreement_id for a in accepted_agreements} + agreement_summaries: List[oci.marketplace.models.AgreementSummary] = client.list_agreements( + listing.id, listing.default_package_version + ).data + for agreement_summary in agreement_summaries: + if agreement_summary.id in accepted_agreement_ids: + continue + agreement: oci.marketplace.models.Agreement = client.get_agreement( + listing_id=listing.id, + package_version=listing.default_package_version, + agreement_id=agreement_summary.id, + compartment_id=compartment_id, + ).data + client.create_accepted_agreement( + oci.marketplace.models.CreateAcceptedAgreementDetails( + compartment_id=compartment_id, + listing_id=listing.id, + package_version=listing.default_package_version, + agreement_id=agreement_summary.id, + signature=agreement.signature, + ) + ) + + def get_or_create_compartment( name: str, parent_compartment_id: str, client: oci.identity.IdentityClient ) -> oci.identity.models.Compartment: diff --git a/src/dstack/_internal/server/services/backends/configurators/oci.py b/src/dstack/_internal/server/services/backends/configurators/oci.py index de9f7a745..42555dd5b 100644 --- a/src/dstack/_internal/server/services/backends/configurators/oci.py +++ b/src/dstack/_internal/server/services/backends/configurators/oci.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Tuple +from typing import Dict, Iterable, List, Set, Tuple from dstack._internal.core.backends.oci import OCIBackend, auth, resources from dstack._internal.core.backends.oci.config import OCIConfig @@ -34,6 +34,15 @@ ) from dstack._internal.settings import FeatureFlags +# where dstack images are published +SUPPORTED_REGIONS = frozenset( + [ + "eu-frankfurt-1", + "me-dubai-1", + "us-ashburn-1", + ] +) + class OCIConfigurator(Configurator): if FeatureFlags.OCI_BACKEND: @@ -42,12 +51,12 @@ class OCIConfigurator(Configurator): def get_default_configs(self) -> List[OCIConfigInfoWithCreds]: creds = OCIDefaultCreds() try: - regions = get_subscribed_regions(creds).names + subscribed_regions = get_subscribed_regions(creds).names except any_oci_exception: return [] return [ OCIConfigInfoWithCreds( - regions=regions, + regions=list(subscribed_regions & SUPPORTED_REGIONS), creds=creds, ) ] @@ -66,14 +75,14 @@ def get_config_values(self, config: OCIConfigInfoWithCredsPartial) -> OCIConfigV raise_invalid_credentials_error(fields=[["creds"]]) try: - available_regions = get_subscribed_regions(config.creds).names + available_regions = get_subscribed_regions(config.creds).names & SUPPORTED_REGIONS except any_oci_exception: raise_invalid_credentials_error(fields=[["creds"]]) if config.regions: selected_regions = [r for r in config.regions if r in available_regions] else: - selected_regions = available_regions + selected_regions = list(available_regions) config_values.regions = self._get_regions_element( available=available_regions, @@ -90,10 +99,9 @@ def create_backend( raise_invalid_credentials_error(fields=[["creds"]]) if config.regions is None: - config.regions = subscribed_regions.names - elif unsubscribed_regions := set(config.regions) - set(subscribed_regions.names): - msg = f"Regions {unsubscribed_regions} are configured but not subscribed to in OCI" - raise ServerClientError(msg, fields=[["regions"]]) + config.regions = list(subscribed_regions.names & SUPPORTED_REGIONS) + else: + _raise_if_regions_unavailable(config.regions, subscribed_regions.names) compartment_id, subnet_ids_per_region = _create_resources( project, config, subscribed_regions.home_region_name @@ -127,7 +135,7 @@ def _get_backend_config(self, model: BackendModel) -> OCIConfig: ) def _get_regions_element( - self, available: List[str], selected: List[str] + self, available: Iterable[str], selected: List[str] ) -> ConfigMultiElement: element = ConfigMultiElement(selected=selected) for region in available: @@ -135,6 +143,22 @@ def _get_regions_element( return element +def _raise_if_regions_unavailable( + region_names: Iterable[str], subscribed_region_names: Set[str] +) -> None: + region_names = set(region_names) + if unsupported_regions := region_names - SUPPORTED_REGIONS: + msg = ( + f"Regions {unsupported_regions} are configured but not supported by dstack yet. " + f"Only these regions are supported: {set(SUPPORTED_REGIONS)}. " + "Please contact dstack if a region you need is missing." + ) + raise ServerClientError(msg, fields=[["regions"]]) + if unsubscribed_regions := region_names - subscribed_region_names: + msg = f"Regions {unsubscribed_regions} are configured but not subscribed to in OCI" + raise ServerClientError(msg, fields=[["regions"]]) + + def _create_resources( project: ProjectModel, config: OCIConfigInfoWithCreds, home_region: str ) -> Tuple[str, Dict[str, str]]: diff --git a/src/tests/_internal/server/routers/test_backends.py b/src/tests/_internal/server/routers/test_backends.py index be09a8782..dad6f2a42 100644 --- a/src/tests/_internal/server/routers/test_backends.py +++ b/src/tests/_internal/server/routers/test_backends.py @@ -1,4 +1,5 @@ import json +from operator import itemgetter from unittest.mock import Mock, patch import pytest @@ -35,7 +36,7 @@ } SAMPLE_OCI_COMPARTMENT_ID = "ocid1.compartment.oc1..aaaaaaaa" SAMPLE_OCI_SUBSCRIBED_REGIONS = oci_region.SubscribedRegions( - names=["me-dubai-1", "eu-frankfurt-1"], home_region_name="eu-frankfurt-1" + names={"me-dubai-1", "eu-frankfurt-1"}, home_region_name="eu-frankfurt-1" ) SAMPLE_OCI_SUBNETS = { "me-dubai-1": "ocid1.subnet.oc1.me-dubai-1.aaaaaaaa", @@ -707,15 +708,18 @@ async def test_returns_config_on_valid_creds(self, test_db, session: AsyncSessio ) default_creds_available_mock.assert_called() get_regions_mock.assert_called() + body = response.json() + body["regions"]["selected"].sort() + body["regions"]["values"].sort(key=itemgetter("value")) assert response.status_code == 200, response.json() - assert response.json() == { + assert body == { "type": "oci", "default_creds": True, "regions": { - "selected": ["me-dubai-1", "eu-frankfurt-1"], + "selected": ["eu-frankfurt-1", "me-dubai-1"], "values": [ - {"label": "me-dubai-1", "value": "me-dubai-1"}, {"label": "eu-frankfurt-1", "value": "eu-frankfurt-1"}, + {"label": "me-dubai-1", "value": "me-dubai-1"}, ], }, "compartment_id": None,