Skip to content

Commit

Permalink
Implement running OCI Marketplace images (#1288)
Browse files Browse the repository at this point in the history
Part of the OCI backend implementation. Behind the
OCI_BACKEND feature flag
  • Loading branch information
jvstme committed May 31, 2024
1 parent 4d0d88f commit 22ffcdb
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 52 deletions.
43 changes: 9 additions & 34 deletions src/dstack/_internal/core/backends/oci/compute.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import json
import os
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import cached_property
from typing import Dict, List, Optional, Set
from typing import List, Optional

import oci

Expand Down Expand Up @@ -33,34 +30,9 @@
]


@dataclass
class PreConfiguredResources:
# TODO(#1194): remove this class and teach dstack to create or discover all
# necessary resources automatically

standard_image_ids: Dict[str, str]
cuda_image_ids: Dict[str, str]

@staticmethod
def load(required_regions: Set[str]) -> "PreConfiguredResources":
params = dict(
standard_image_ids=json.loads(os.getenv("DSTACK_OCI_STANDARD_IMAGE_IDS", "null")),
cuda_image_ids=json.loads(os.getenv("DSTACK_OCI_CUDA_IMAGE_IDS", "null")),
)
for param, value in params.items():
if not value or required_regions - set(value):
msg = (
f"Invalid OCI parameter {param!r}. Make sure you set the corresponding"
" environment variable when running dstack server"
)
raise ValueError(msg)
return PreConfiguredResources(**params)


class OCICompute(Compute):
def __init__(self, config: OCIConfig):
self.config = config
self.pre_conf = PreConfiguredResources.load(set(config.regions or []))
self.regions = make_region_clients_map(config.regions or [], config.creds)

@cached_property
Expand Down Expand Up @@ -132,10 +104,13 @@ def create_instance(
if availability_domain is None:
raise NoCapacityError("Shape unavailable in all availability domains")

if len(instance_offer.instance.resources.gpus) > 0:
image_id = self.pre_conf.cuda_image_ids[instance_offer.region]
else:
image_id = self.pre_conf.standard_image_ids[instance_offer.region]
listing, package = resources.get_marketplace_listing_and_package(
cuda=len(instance_offer.instance.resources.gpus) > 0,
client=region.marketplace_client,
)
resources.accept_marketplace_listing_agreements(
listing, self.config.compartment_id, region.marketplace_client
)

try:
instance = resources.launch_instance(
Expand All @@ -147,7 +122,7 @@ def create_instance(
cloud_init_user_data=get_user_data(instance_config.get_public_keys()),
shape=instance_offer.instance.name,
disk_size_gb=round(instance_offer.instance.resources.disk.size_mib / 1024),
image_id=image_id,
image_id=package.image_id,
)
except oci.exceptions.ServiceError as e:
if e.code in ("LimitExceeded", "QuotaExceeded"):
Expand Down
10 changes: 7 additions & 3 deletions src/dstack/_internal/core/backends/oci/region.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from functools import cached_property
from typing import Dict, Iterable
from typing import Dict, Iterable, Set

import oci
from typing_extensions import Any, List, Mapping
Expand All @@ -25,6 +25,10 @@ def compute_client(self) -> oci.core.ComputeClient:
def identity_client(self) -> oci.identity.IdentityClient:
return oci.identity.IdentityClient(self.client_config)

@cached_property
def marketplace_client(self) -> oci.marketplace.MarketplaceClient:
return oci.marketplace.MarketplaceClient(self.client_config)

@cached_property
def virtual_network_client(self) -> oci.core.VirtualNetworkClient:
return oci.core.VirtualNetworkClient(self.client_config)
Expand Down Expand Up @@ -54,7 +58,7 @@ def make_region_clients_map(

@dataclass
class SubscribedRegions:
names: List[str]
names: Set[str]
home_region_name: str


Expand All @@ -65,7 +69,7 @@ def get_subscribed_regions(creds: AnyOCICreds) -> SubscribedRegions:
subscriptions: List[oci.identity.models.RegionSubscription] = (
region.identity_client.list_region_subscriptions(config["tenancy"]).data
)
names = [s.region_name for s in subscriptions if s.status == s.STATUS_READY]
names = {s.region_name for s in subscriptions if s.status == s.STATUS_READY}
home_region_name = next(s.region_name for s in subscriptions if s.is_home_region)

return SubscribedRegions(names=names, home_region_name=home_region_name)
63 changes: 62 additions & 1 deletion src/dstack/_internal/core/backends/oci/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from concurrent.futures import Executor, ThreadPoolExecutor, as_completed
from functools import reduce
from itertools import islice
from typing import Dict, Iterable, List, Mapping, Optional, Set
from typing import Dict, Iterable, List, Mapping, Optional, Set, Tuple

import oci

from dstack import version
from dstack._internal.core.backends.oci.region import OCIRegionClient
from dstack._internal.core.errors import BackendError
from dstack._internal.core.models.instances import InstanceOffer
Expand Down Expand Up @@ -279,6 +280,66 @@ def terminate_instance_if_exists(client: oci.core.ComputeClient, instance_id: st
raise


def get_marketplace_listing_and_package(
cuda: bool, client: oci.marketplace.MarketplaceClient
) -> Tuple[oci.marketplace.models.Listing, oci.marketplace.models.ImageListingPackage]:
listing_name = f"dstack-{version.base_image}"
if cuda:
listing_name = f"dstack-cuda-{version.base_image}"

listing_summaries: List[oci.marketplace.models.ListingSummary] = client.list_listings(
name=listing_name,
listing_types=[oci.marketplace.models.Listing.LISTING_TYPE_COMMUNITY],
limit=1000,
).data
# filter by exact match, as list_listings seems to filter by substring
listing_summaries = [s for s in listing_summaries if s.name == listing_name]

if len(listing_summaries) != 1:
msg = f"Expected to find 1 listing by name {listing_name}, found {len(listing_summaries)}"
raise BackendError(msg)

listing: oci.marketplace.models.Listing = client.get_listing(listing_summaries[0].id).data
package = client.get_package(listing.id, listing.default_package_version).data
return listing, package


def accept_marketplace_listing_agreements(
listing: oci.marketplace.models.Listing,
compartment_id: str,
client: oci.marketplace.MarketplaceClient,
) -> None:
accepted_agreements: List[oci.marketplace.models.AcceptedAgreementSummary] = (
client.list_accepted_agreements(
compartment_id=compartment_id,
listing_id=listing.id,
package_version=listing.default_package_version,
).data
)
accepted_agreement_ids = {a.agreement_id for a in accepted_agreements}
agreement_summaries: List[oci.marketplace.models.AgreementSummary] = client.list_agreements(
listing.id, listing.default_package_version
).data
for agreement_summary in agreement_summaries:
if agreement_summary.id in accepted_agreement_ids:
continue
agreement: oci.marketplace.models.Agreement = client.get_agreement(
listing_id=listing.id,
package_version=listing.default_package_version,
agreement_id=agreement_summary.id,
compartment_id=compartment_id,
).data
client.create_accepted_agreement(
oci.marketplace.models.CreateAcceptedAgreementDetails(
compartment_id=compartment_id,
listing_id=listing.id,
package_version=listing.default_package_version,
agreement_id=agreement_summary.id,
signature=agreement.signature,
)
)


def get_or_create_compartment(
name: str, parent_compartment_id: str, client: oci.identity.IdentityClient
) -> oci.identity.models.Compartment:
Expand Down
44 changes: 34 additions & 10 deletions src/dstack/_internal/server/services/backends/configurators/oci.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Dict, List, Tuple
from typing import Dict, Iterable, List, Set, Tuple

from dstack._internal.core.backends.oci import OCIBackend, auth, resources
from dstack._internal.core.backends.oci.config import OCIConfig
Expand Down Expand Up @@ -34,6 +34,15 @@
)
from dstack._internal.settings import FeatureFlags

# where dstack images are published
SUPPORTED_REGIONS = frozenset(
[
"eu-frankfurt-1",
"me-dubai-1",
"us-ashburn-1",
]
)


class OCIConfigurator(Configurator):
if FeatureFlags.OCI_BACKEND:
Expand All @@ -42,12 +51,12 @@ class OCIConfigurator(Configurator):
def get_default_configs(self) -> List[OCIConfigInfoWithCreds]:
creds = OCIDefaultCreds()
try:
regions = get_subscribed_regions(creds).names
subscribed_regions = get_subscribed_regions(creds).names
except any_oci_exception:
return []
return [
OCIConfigInfoWithCreds(
regions=regions,
regions=list(subscribed_regions & SUPPORTED_REGIONS),
creds=creds,
)
]
Expand All @@ -66,14 +75,14 @@ def get_config_values(self, config: OCIConfigInfoWithCredsPartial) -> OCIConfigV
raise_invalid_credentials_error(fields=[["creds"]])

try:
available_regions = get_subscribed_regions(config.creds).names
available_regions = get_subscribed_regions(config.creds).names & SUPPORTED_REGIONS
except any_oci_exception:
raise_invalid_credentials_error(fields=[["creds"]])

if config.regions:
selected_regions = [r for r in config.regions if r in available_regions]
else:
selected_regions = available_regions
selected_regions = list(available_regions)

config_values.regions = self._get_regions_element(
available=available_regions,
Expand All @@ -90,10 +99,9 @@ def create_backend(
raise_invalid_credentials_error(fields=[["creds"]])

if config.regions is None:
config.regions = subscribed_regions.names
elif unsubscribed_regions := set(config.regions) - set(subscribed_regions.names):
msg = f"Regions {unsubscribed_regions} are configured but not subscribed to in OCI"
raise ServerClientError(msg, fields=[["regions"]])
config.regions = list(subscribed_regions.names & SUPPORTED_REGIONS)
else:
_raise_if_regions_unavailable(config.regions, subscribed_regions.names)

compartment_id, subnet_ids_per_region = _create_resources(
project, config, subscribed_regions.home_region_name
Expand Down Expand Up @@ -127,14 +135,30 @@ def _get_backend_config(self, model: BackendModel) -> OCIConfig:
)

def _get_regions_element(
self, available: List[str], selected: List[str]
self, available: Iterable[str], selected: List[str]
) -> ConfigMultiElement:
element = ConfigMultiElement(selected=selected)
for region in available:
element.values.append(ConfigElementValue(value=region, label=region))
return element


def _raise_if_regions_unavailable(
region_names: Iterable[str], subscribed_region_names: Set[str]
) -> None:
region_names = set(region_names)
if unsupported_regions := region_names - SUPPORTED_REGIONS:
msg = (
f"Regions {unsupported_regions} are configured but not supported by dstack yet. "
f"Only these regions are supported: {set(SUPPORTED_REGIONS)}. "
"Please contact dstack if a region you need is missing."
)
raise ServerClientError(msg, fields=[["regions"]])
if unsubscribed_regions := region_names - subscribed_region_names:
msg = f"Regions {unsubscribed_regions} are configured but not subscribed to in OCI"
raise ServerClientError(msg, fields=[["regions"]])


def _create_resources(
project: ProjectModel, config: OCIConfigInfoWithCreds, home_region: str
) -> Tuple[str, Dict[str, str]]:
Expand Down
12 changes: 8 additions & 4 deletions src/tests/_internal/server/routers/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from operator import itemgetter
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -35,7 +36,7 @@
}
SAMPLE_OCI_COMPARTMENT_ID = "ocid1.compartment.oc1..aaaaaaaa"
SAMPLE_OCI_SUBSCRIBED_REGIONS = oci_region.SubscribedRegions(
names=["me-dubai-1", "eu-frankfurt-1"], home_region_name="eu-frankfurt-1"
names={"me-dubai-1", "eu-frankfurt-1"}, home_region_name="eu-frankfurt-1"
)
SAMPLE_OCI_SUBNETS = {
"me-dubai-1": "ocid1.subnet.oc1.me-dubai-1.aaaaaaaa",
Expand Down Expand Up @@ -707,15 +708,18 @@ async def test_returns_config_on_valid_creds(self, test_db, session: AsyncSessio
)
default_creds_available_mock.assert_called()
get_regions_mock.assert_called()
body = response.json()
body["regions"]["selected"].sort()
body["regions"]["values"].sort(key=itemgetter("value"))
assert response.status_code == 200, response.json()
assert response.json() == {
assert body == {
"type": "oci",
"default_creds": True,
"regions": {
"selected": ["me-dubai-1", "eu-frankfurt-1"],
"selected": ["eu-frankfurt-1", "me-dubai-1"],
"values": [
{"label": "me-dubai-1", "value": "me-dubai-1"},
{"label": "eu-frankfurt-1", "value": "eu-frankfurt-1"},
{"label": "me-dubai-1", "value": "me-dubai-1"},
],
},
"compartment_id": None,
Expand Down

0 comments on commit 22ffcdb

Please sign in to comment.