diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 4d921900c059..0ccde2fc5e46 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -454,6 +454,7 @@ def __init__( tokenizer, class_data_root=None, class_prompt=None, + class_num=None, size=512, center_crop=False, ): @@ -474,7 +475,10 @@ def __init__( self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) - self.num_class_images = len(self.class_images_path) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt else: @@ -814,6 +818,7 @@ def load_model_hook(models, input_dir): instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, + class_num=args.num_class_images, tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 9dcd20939c45..46edd5399e88 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -231,6 +231,7 @@ def __init__( tokenizer, class_data_root=None, class_prompt=None, + class_num=None, size=512, center_crop=False, ): @@ -251,7 +252,10 @@ def __init__( self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) - self.num_class_images = len(self.class_images_path) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt else: @@ -419,6 +423,7 @@ def main(): instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, + class_num=args.num_class_images, tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index c932198232d3..92d08b64b638 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -417,6 +417,7 @@ def __init__( tokenizer, class_data_root=None, class_prompt=None, + class_num=None, size=512, center_crop=False, ): @@ -437,7 +438,10 @@ def __init__( self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) - self.num_class_images = len(self.class_images_path) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt else: @@ -771,6 +775,7 @@ def main(args): instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, + class_num=args.num_class_images, tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop,