From a009f1d1fe03fe622b57de5e53cbe283257f91ec Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 25 Mar 2023 09:37:05 +0530 Subject: [PATCH 1/9] improve stable unclip doc. --- .../source/en/api/pipelines/stable_unclip.mdx | 58 +++++++++++++++---- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_unclip.mdx b/docs/source/en/api/pipelines/stable_unclip.mdx index c8b5d58705ba..372242ae2dff 100644 --- a/docs/source/en/api/pipelines/stable_unclip.mdx +++ b/docs/source/en/api/pipelines/stable_unclip.mdx @@ -42,12 +42,9 @@ Coming soon! ### Text guided Image-to-Image Variation ```python -import requests -import torch -from PIL import Image -from io import BytesIO - from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" @@ -55,12 +52,10 @@ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( pipe = pipe.to("cuda") url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" - -response = requests.get(url) -init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = load_image(url) images = pipe(init_image).images -images[0].save("fantasy_landscape.png") +images[0].save("variation_image.png") ``` Optionally, you can also pass a prompt to `pipe` such as: @@ -69,7 +64,50 @@ Optionally, you can also pass a prompt to `pipe` such as: prompt = "A fantasy landscape, trending on artstation" images = pipe(init_image, prompt=prompt).images -images[0].save("fantasy_landscape.png") +images[0].save("variation_image_two.png") +``` + +### Memory optimization + +If you are short on GPU memory, you can enable smart CPU offloading so that models that are not needed +immediately for a computation can be offloaded to CPU: + +```python +from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch + +pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" +) +# Offload to CPU. +pipe.enable_model_cpu_offload() + +url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" +init_image = load_image(url) + +images = pipe(init_image).images +images[0] +``` + +Further memory optimizations are possible by enabling VAE slicing on the pipeline: + +```python +from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch + +pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" +) +pipe.enable_model_cpu_offload() +pipe.enable_vae_slicing() + +url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" +init_image = load_image(url) + +images = pipe(init_image).images +images[0] ``` ### StableUnCLIPPipeline From 6550c88856d97e8bfdefb236424237980ad7585a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 28 Mar 2023 13:11:18 +0530 Subject: [PATCH 2/9] feat: add streaming support to controlnet flax training script. --- examples/controlnet/train_controlnet_flax.py | 36 ++++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index c6c95170da2d..d20bccff2b0b 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -206,7 +206,7 @@ def parse_args(): parser.add_argument( "--from_pt", action="store_true", - help="Load the pretrained model from a pytorch checkpoint.", + help="Load the pretrained model from a PyTorch checkpoint.", ) parser.add_argument( "--tokenizer_name", @@ -332,6 +332,7 @@ def parse_args(): " or to a folder containing files that 🤗 Datasets can understand." ), ) + parser.add_argument("--streaming", store_action=True, help="To stream a large dataset from Hub.") parser.add_argument( "--dataset_config_name", type=str, @@ -369,7 +370,7 @@ def parse_args(): default=None, help=( "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." + "value if set. Needed if `streaming` is set to True." ), ) parser.add_argument( @@ -453,10 +454,15 @@ def parse_args(): " or the same number of `--validation_prompt`s and `--validation_image`s" ) + # This idea comes from + # https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370 + if args.streaming and args.max_train_samples is None: + raise ValueError("You must specify `max_train_samples` when using dataset streaming.") + return args -def make_train_dataset(args, tokenizer): +def make_train_dataset(args, tokenizer, batch_size=None): # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). @@ -468,6 +474,7 @@ def make_train_dataset(args, tokenizer): args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, + streaming=args.streaming, ) else: data_files = {} @@ -565,9 +572,20 @@ def preprocess_train(examples): if jax.process_index() == 0: if args.max_train_samples is not None: - dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + if args.streaming: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).take(args.max_train_samples) + else: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) # Set the training transforms - train_dataset = dataset["train"].with_transform(preprocess_train) + if args.streaming: + train_dataset = dataset["train"].map( + preprocess_train, + batched=True, + batch_size=batch_size, + remove_columns=list(dataset["train"].features.keys()), + ) + else: + train_dataset = dataset["train"].with_transform(preprocess_train) return train_dataset @@ -661,8 +679,8 @@ def main(): raise NotImplementedError("No tokenizer specified!") # Get the datasets: you can either provide your own training and evaluation files (see below) - train_dataset = make_train_dataset(args, tokenizer) total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps + train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size) train_dataloader = torch.utils.data.DataLoader( train_dataset, @@ -897,7 +915,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng): vae_params = jax_utils.replicate(vae_params) # Train! - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.streaming: + dataset_length = args.max_train_samples + else: + dataset_length = len(train_dataloader) + num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps) # Scheduler and math around the number of training steps. if args.max_train_steps is None: From def534a58309531cbee954cb418675171b7cb7fd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 28 Mar 2023 13:19:38 +0530 Subject: [PATCH 3/9] fix: CLI arg. --- examples/controlnet/train_controlnet_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index d20bccff2b0b..32ac65d3eff9 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -332,7 +332,7 @@ def parse_args(): " or to a folder containing files that 🤗 Datasets can understand." ), ) - parser.add_argument("--streaming", store_action=True, help="To stream a large dataset from Hub.") + parser.add_argument("--streaming", action="store_true", help="To stream a large dataset from Hub.") parser.add_argument( "--dataset_config_name", type=str, From 4d0ab45566309d1e783ef4eb41b4331cee7d7cbd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 28 Mar 2023 13:21:25 +0530 Subject: [PATCH 4/9] fix: torch dataloader shuffle setting. --- examples/controlnet/train_controlnet_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 32ac65d3eff9..3a34b9e3a1cb 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -684,7 +684,7 @@ def main(): train_dataloader = torch.utils.data.DataLoader( train_dataset, - shuffle=True, + shuffle=not args.streaming, collate_fn=collate_fn, batch_size=total_train_batch_size, num_workers=args.dataloader_num_workers, From fc8e1ee0ffbaf10c8fca0336d6e18953b6b1632a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 28 Mar 2023 13:24:51 +0530 Subject: [PATCH 5/9] fix: dataset length. --- examples/controlnet/train_controlnet_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 3a34b9e3a1cb..fbde72d994b4 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -928,7 +928,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") From 6c82633ed8ebb3ffdb884ebdbc6df1d3ff8405d9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 28 Mar 2023 13:35:16 +0530 Subject: [PATCH 6/9] fix: wandb config. --- examples/controlnet/train_controlnet_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index fbde72d994b4..791c2e43e823 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -938,7 +938,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): wandb.define_metric("*", step_metric="train/step") wandb.config.update( { - "num_train_examples": len(train_dataset), + "num_train_examples": args.max_train_samples if args.streaming else len(train_dataset), "total_train_batch_size": total_train_batch_size, "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch, "num_devices": jax.device_count(), From 297f2b43b959d8d51d82cd9767531b2dad20d1e5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 28 Mar 2023 13:38:45 +0530 Subject: [PATCH 7/9] fix: steps_per_epoch in the training loop. --- examples/controlnet/train_controlnet_flax.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 791c2e43e823..ab5a32b7c5c9 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -957,7 +957,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng): train_metrics = [] - steps_per_epoch = len(train_dataset) // total_train_batch_size + steps_per_epoch = ( + args.max_train_samples // total_train_batch_size + if args.streaming + else len(train_dataset) // total_train_batch_size + ) train_step_progress_bar = tqdm( total=steps_per_epoch, desc="Training...", From acc950907b243d37fb94490562ce7e9c7ef432c0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 28 Mar 2023 14:05:05 +0530 Subject: [PATCH 8/9] add: entry about streaming in the readme --- examples/controlnet/README.md | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 0650c2230b71..4e6856560bde 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -335,7 +335,7 @@ huggingface-cli login Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub: -``` +```bash export MODEL_DIR="runwayml/stable-diffusion-v1-5" export OUTPUT_DIR="control_out" export HUB_MODEL_ID="fill-circle-controlnet" @@ -343,7 +343,7 @@ export HUB_MODEL_ID="fill-circle-controlnet" And finally start the training -``` +```bash python3 train_controlnet_flax.py \ --pretrained_model_name_or_path=$MODEL_DIR \ --output_dir=$OUTPUT_DIR \ @@ -363,3 +363,30 @@ python3 train_controlnet_flax.py \ ``` Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet). + +Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command: + +```bash +python3 train_controlnet_flax.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=multimodalart/facesyntheticsspigacaptioned \ + --streaming \ + --conditioning_image_column=spiga_seg \ + --image_column=image \ + --caption_column=image_caption \ + --resolution=512 \ + --max_train_samples 50 \ + --max_train_steps 5 \ + --learning_rate=1e-5 \ + --validation_steps=2 \ + --train_batch_size=1 \ + --revision="flax" \ + --report_to="wandb" +``` + +Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options: + +* [Webdataset](https://webdataset.github.io/webdataset/) +* [TorchData](https://github.com/pytorch/data) +* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds) \ No newline at end of file From 142c9ffe4f8f2ee49deeb9361e9dd52bb328d025 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 28 Mar 2023 19:37:16 +0000 Subject: [PATCH 9/9] get column names from iterable dataset + fix final logging --- examples/controlnet/train_controlnet_flax.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index ab5a32b7c5c9..f409a539667c 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -35,6 +35,7 @@ from flax.training.common_utils import shard from huggingface_hub import HfFolder, Repository, create_repo, whoami from PIL import Image +from torch.utils.data import IterableDataset from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed @@ -490,7 +491,10 @@ def make_train_dataset(args, tokenizer, batch_size=None): # Preprocessing the datasets. # We need to tokenize inputs and targets. - column_names = dataset["train"].column_names + if isinstance(dataset["train"], IterableDataset): + column_names = next(iter(dataset["train"])).keys() + else: + column_names = dataset["train"].column_names # 6. Get the column names for input/target. if args.image_column is None: @@ -1006,7 +1010,8 @@ def cumul_grad_step(grad_idx, loss_grad_rng): # Create the pipeline using using the trained modules and save it. if jax.process_index() == 0: - image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) + if args.validation_prompt is not None: + image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) controlnet.save_pretrained( args.output_dir,