Skip to content

Commit

Permalink
chore: Made it easier to request the latest framework version when ca…
Browse files Browse the repository at this point in the history
…lling Model.upload_*

Goal: make it easier for the calling functions to ask for the latest framework version.
It turns out that the current approach complicates calling the upload_* functions, because there is no values of the *_version parameter that mean "latest version".
This leads to the need to have different code paths depending on whether the user has given an explicit version value.
Example of the problem:

```python
def my_import_xgboost_model(
    model_path: str,
    xgboost_version: Optional[str] = None,
):
    if xgboost_version:
            aiplatfor.Model.upload_xgboost_model(model_path=model_path, xgboost_version=xgboost_version)
    else:
            aiplatfor.Model.upload_xgboost_model(model_path=model_path)
```

With this small change, the call is simplified:

```python
def my_import_xgboost_model(
    model_path: str,
    xgboost_version: Optional[str] = None,
):
    aiplatfor.Model.upload_xgboost_model(model_path=model_path, xgboost_version=xgboost_version)
```

P.S. It would be great if the `get_prebuilt_prediction_container_uri` function supported the "latest version" feature.
  • Loading branch information
Ark-kun committed Feb 3, 2022
1 parent 9ebc972 commit e69b58e
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2534,7 +2534,7 @@ def export_model(
def upload_xgboost_model_file(
cls,
model_file_path: str,
xgboost_version: str = "1.4",
xgboost_version: Optional[str] = None,
display_name: str = "XGBoost model",
description: Optional[str] = None,
instance_schema_uri: Optional[str] = None,
Expand Down Expand Up @@ -2674,7 +2674,7 @@ def upload_xgboost_model_file(
container_image_uri = aiplatform.helpers.get_prebuilt_prediction_container_uri(
region=location,
framework="xgboost",
framework_version=xgboost_version,
framework_version=xgboost_version or "1.4",
accelerator="cpu",
)

Expand Down Expand Up @@ -2729,7 +2729,7 @@ def upload_xgboost_model_file(
def upload_scikit_learn_model_file(
cls,
model_file_path: str,
sklearn_version: str = "1.0",
sklearn_version: Optional[str] = None,
display_name: str = "Scikit-learn model",
description: Optional[str] = None,
instance_schema_uri: Optional[str] = None,
Expand Down Expand Up @@ -2869,7 +2869,7 @@ def upload_scikit_learn_model_file(
container_image_uri = aiplatform.helpers.get_prebuilt_prediction_container_uri(
region=location,
framework="sklearn",
framework_version=sklearn_version,
framework_version=sklearn_version or "1.0",
accelerator="cpu",
)

Expand Down Expand Up @@ -2923,7 +2923,7 @@ def upload_scikit_learn_model_file(
def upload_tensorflow_saved_model(
cls,
saved_model_dir: str,
tensorflow_version: str = "2.7",
tensorflow_version: Optional[str] = None,
use_gpu: bool = False,
display_name: str = "Tensorflow model",
description: Optional[str] = None,
Expand Down Expand Up @@ -3061,7 +3061,7 @@ def upload_tensorflow_saved_model(
container_image_uri = aiplatform.helpers.get_prebuilt_prediction_container_uri(
region=location,
framework="tensorflow",
framework_version=tensorflow_version,
framework_version=tensorflow_version or "2.7",
accelerator="gpu" if use_gpu else "cpu",
)

Expand Down

0 comments on commit e69b58e

Please sign in to comment.