diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index 64b71b4f83ae..e5410a0a9e90 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -275,6 +275,7 @@ def main(): args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, + data_dir=args.train_data_dir ) else: data_files = {} diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index f0d83d55e9bf..d621858e3d2c 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -768,6 +768,7 @@ def load_model_hook(models, input_dir): args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, + data_dir=args.train_data_dir ) else: data_files = {}