diff --git a/README.md b/README.md index aa7af29..31097e0 100644 --- a/README.md +++ b/README.md @@ -3093,9 +3093,11 @@ Get training jobs. ```python training_job = client.execute_training_job( dataset_name="dataset_name", - base_model_name="fastlabel_object_detection_light", // "fastlabel_object_detection_light" or "fastlabel_object_detection_high_accuracy" + base_model_name="fastlabel_object_detection_light", // "fastlabel_object_detection_light" or "fastlabel_object_detection_high_accuracy" or "fastlabel_u_net_general" epoch=300, - use_dataset_train_val=True + use_dataset_train_val=True, + resize_option="fixed", // optional, "fixed" or "none" + resize_dimension=1024, // optional, 512 or 1024 ) ``` diff --git a/fastlabel/__init__.py b/fastlabel/__init__.py index 995ed36..340ed13 100644 --- a/fastlabel/__init__.py +++ b/fastlabel/__init__.py @@ -6,7 +6,7 @@ import re from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import List, Optional, Union +from typing import List, Literal, Optional, Union import aiohttp import cv2 @@ -4248,6 +4248,8 @@ def execute_training_job( instance_type: str = "ml.p3.2xlarge", batch_size: int = None, learning_rate: float = None, + resize_option: Optional[Literal["fixed", "none"]] = None, + resize_dimension: Optional[int] = None, ) -> list: """ Returns a list of training jobs. @@ -4266,6 +4268,8 @@ def execute_training_job( "instanceType": instance_type, "batchSize": batch_size, "learningRate": learning_rate, + "resizeOption": resize_option, + "resizeDimension": resize_dimension, } return self.api.post_request(