From 31a595358dc1ceeb39c3c476c2424d013a402bee Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 31 Mar 2023 15:38:38 +0000 Subject: [PATCH] use custom local datset --- examples/controlnet/train_controlnet.py | 13 +++++-------- examples/controlnet/train_controlnet_flax.py | 13 +++++-------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 6c14e8ca10db..33250f3aa244 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -543,16 +543,13 @@ def make_train_dataset(args, tokenizer, accelerator): cache_dir=args.cache_dir, ) else: - data_files = {} if args.train_data_dir is not None: - data_files["train"] = os.path.join(args.train_data_dir, "**") - dataset = load_dataset( - "imagefolder", - data_files=data_files, - cache_dir=args.cache_dir, - ) + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) # See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script # Preprocessing the datasets. # We need to tokenize inputs and targets. diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index f409a539667c..8a1e7fd4031a 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -478,16 +478,13 @@ def make_train_dataset(args, tokenizer, batch_size=None): streaming=args.streaming, ) else: - data_files = {} if args.train_data_dir is not None: - data_files["train"] = os.path.join(args.train_data_dir, "**") - dataset = load_dataset( - "imagefolder", - data_files=data_files, - cache_dir=args.cache_dir, - ) + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) # See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script # Preprocessing the datasets. # We need to tokenize inputs and targets.