Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions ads/opctl/backend/ads_ml_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
import shutil
import tempfile
import time
import re
from distutils import dir_util
from typing import Dict, Tuple, Union

from ads.common.auth import AuthContext, AuthType, create_signer
from ads.common.oci_client import OCIClientFactory
from ads.config import (
CONDA_BUCKET_NAME,
CONDA_BUCKET_NS,
)
from ads.jobs import (
ContainerRuntime,
DataScienceJob,
Expand Down Expand Up @@ -65,6 +70,32 @@ def __init__(self, config: Dict) -> None:
self.auth_type = config["execution"].get("auth")
self.profile = config["execution"].get("oci_profile", None)
self.client = OCIClientFactory(**self.oci_auth).data_science
self.object_storage = OCIClientFactory(**self.oci_auth).object_storage

def _get_latest_conda_pack(self,
prefix,
python_version,
base_conda) -> str:
"""
get the latest conda pack.
"""
try:
objects = self.object_storage.list_objects(namespace_name=CONDA_BUCKET_NS,
bucket_name=CONDA_BUCKET_NAME,
prefix=prefix).data.objects
py_str = python_version.replace(".", "")
py_filter = [obj for obj in objects if f"p{py_str}" in obj.name]

def extract_version(obj_name):
match = re.search(rf"{prefix}([\d.]+)/", obj_name)
return tuple(map(int, match.group(1).split("."))) if match else (0,)

latest_obj = max(py_filter, key=lambda obj: extract_version(obj.name))
return latest_obj.name.split("/")[-1]
except Exception as e:
logger.warning(f"Error while fetching latest conda pack: {e}")
return base_conda


def init(
self,
Expand Down Expand Up @@ -100,6 +131,16 @@ def init(
or ""
).lower()

# If a tag is present
if ":" in conda_slug:
base_conda = conda_slug.split(":")[0]
conda_slug = self._get_latest_conda_pack(
self.config["prefix"],
self.config["python_version"],
base_conda
)
logger.info(f"Proceeding with the {conda_slug} conda pack.")

# if conda slug contains '/' then the assumption is that it is a custom conda pack
# the conda prefix needs to be added
if "/" in conda_slug:
Expand Down
2 changes: 2 additions & 0 deletions ads/opctl/operator/common/backend_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ def _init_backend_config(
if operator_info.conda_type == PACK_TYPE.SERVICE
else operator_info.conda_prefix,
"freeform_tags": freeform_tags,
"python_version": operator_info.python_version,
"prefix": operator_info.prefix,
}
},
{
Expand Down
2 changes: 2 additions & 0 deletions ads/opctl/operator/common/operator_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ class OperatorInfo(DataClassSerializable):
description: str = ""
version: str = ""
conda: str = ""
prefix: str = ""
python_version: str = "3.11"
conda_type: str = ""
path: str = ""
keywords: List[str] = None
Expand Down
4 changes: 3 additions & 1 deletion ads/opctl/operator/lowcode/anomaly/MLoperator
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ type: anomaly
version: v1
conda_type: service
name: Anomaly Detection Operator
conda: anomaly_p310_cpu_x86_64_v1
conda: anomaly_p311_cpu_x86_64_v2:latest
prefix: service_pack/cpu/AI_Anomaly_Detection_Operator/
python_version: "3.11"
gpu: no
keywords:
- Anomaly Detection
Expand Down
4 changes: 3 additions & 1 deletion ads/opctl/operator/lowcode/forecast/MLoperator
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ type: forecast
version: v1
name: Forecasting Operator
conda_type: service
conda: forecast_p310_cpu_x86_64_v4
conda: forecast_p311_cpu_x86_64_v10:latest
prefix: service_pack/cpu/AI_Forecasting_Operator/
python_version: "3.11"
gpu: no
jobs_default_params:
shape_name: VM.Standard.E4.Flex
Expand Down