From a857655af5beb4273cf1ab2724077f88307f429e Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Mon, 12 Feb 2024 16:22:35 +0530 Subject: [PATCH] fixed type annotations --- examples/text_to_image/train_text_to_image.py | 4 ++-- examples/text_to_image/train_text_to_image_lora.py | 4 +++- examples/text_to_image/train_text_to_image_sdxl.py | 12 ++++++------ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 6cc8db6fb215..6fb8b17944eb 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -67,8 +67,8 @@ def save_model_card( args, repo_id: str, - images=None, - repo_folder=None, + images: list = None, + repo_folder: str = None, ): img_str = "" if len(images) > 0: diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 73d0470522fa..47e67f695b08 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -56,7 +56,9 @@ logger = get_logger(__name__, log_level="INFO") -def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None): +def save_model_card( + repo_id: str, images: list = None, base_model: str = None, dataset_name: str = None, repo_folder: str = None +): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 8a5948350d50..292e52bca0f8 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -66,12 +66,12 @@ def save_model_card( repo_id: str, - images=None, - validation_prompt=None, - base_model=str, - dataset_name=str, - repo_folder=None, - vae_path=None, + images: list = None, + validation_prompt: str = None, + base_model: str = None, + dataset_name: str = None, + repo_folder: str = None, + vae_path: str = None, ): img_str = "" for i, image in enumerate(images):