Skip to content
31 changes: 29 additions & 2 deletions examples/controlnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,15 @@ huggingface-cli login

Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:

```
```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="control_out"
export HUB_MODEL_ID="fill-circle-controlnet"
```

And finally start the training

```
```bash
python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
Expand All @@ -363,3 +363,30 @@ python3 train_controlnet_flax.py \
```

Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).

Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:

```bash
python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
--streaming \
--conditioning_image_column=spiga_seg \
--image_column=image \
--caption_column=image_caption \
--resolution=512 \
--max_train_samples 50 \
--max_train_steps 5 \
--learning_rate=1e-5 \
--validation_steps=2 \
--train_batch_size=1 \
--revision="flax" \
--report_to="wandb"
```

Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:

* [Webdataset](https://webdataset.github.io/webdataset/)
* [TorchData](https://github.com/pytorch/data)
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
57 changes: 44 additions & 13 deletions examples/controlnet/train_controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image
from torch.utils.data import IterableDataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
Expand Down Expand Up @@ -206,7 +207,7 @@ def parse_args():
parser.add_argument(
"--from_pt",
action="store_true",
help="Load the pretrained model from a pytorch checkpoint.",
help="Load the pretrained model from a PyTorch checkpoint.",
)
parser.add_argument(
"--tokenizer_name",
Expand Down Expand Up @@ -332,6 +333,7 @@ def parse_args():
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument("--streaming", action="store_true", help="To stream a large dataset from Hub.")
parser.add_argument(
"--dataset_config_name",
type=str,
Expand Down Expand Up @@ -369,7 +371,7 @@ def parse_args():
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
"value if set. Needed if `streaming` is set to True."
),
)
parser.add_argument(
Expand Down Expand Up @@ -453,10 +455,15 @@ def parse_args():
" or the same number of `--validation_prompt`s and `--validation_image`s"
)

# This idea comes from
# https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370
if args.streaming and args.max_train_samples is None:
raise ValueError("You must specify `max_train_samples` when using dataset streaming.")

return args


def make_train_dataset(args, tokenizer):
def make_train_dataset(args, tokenizer, batch_size=None):
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

Expand All @@ -468,6 +475,7 @@ def make_train_dataset(args, tokenizer):
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
streaming=args.streaming,
)
else:
data_files = {}
Expand All @@ -483,7 +491,10 @@ def make_train_dataset(args, tokenizer):

# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
if isinstance(dataset["train"], IterableDataset):
column_names = next(iter(dataset["train"])).keys()
else:
column_names = dataset["train"].column_names

# 6. Get the column names for input/target.
if args.image_column is None:
Expand Down Expand Up @@ -565,9 +576,20 @@ def preprocess_train(examples):

if jax.process_index() == 0:
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
if args.streaming:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).take(args.max_train_samples)
else:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
if args.streaming:
train_dataset = dataset["train"].map(
preprocess_train,
batched=True,
batch_size=batch_size,
remove_columns=list(dataset["train"].features.keys()),
)
else:
train_dataset = dataset["train"].with_transform(preprocess_train)

return train_dataset

Expand Down Expand Up @@ -661,12 +683,12 @@ def main():
raise NotImplementedError("No tokenizer specified!")

# Get the datasets: you can either provide your own training and evaluation files (see below)
train_dataset = make_train_dataset(args, tokenizer)
total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps
train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size)

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
shuffle=not args.streaming,
collate_fn=collate_fn,
batch_size=total_train_batch_size,
num_workers=args.dataloader_num_workers,
Expand Down Expand Up @@ -897,7 +919,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
vae_params = jax_utils.replicate(vae_params)

# Train!
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.streaming:
dataset_length = args.max_train_samples
else:
dataset_length = len(train_dataloader)
num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps)

# Scheduler and math around the number of training steps.
if args.max_train_steps is None:
Expand All @@ -906,7 +932,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
Expand All @@ -916,7 +942,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
wandb.define_metric("*", step_metric="train/step")
wandb.config.update(
{
"num_train_examples": len(train_dataset),
"num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
"total_train_batch_size": total_train_batch_size,
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
"num_devices": jax.device_count(),
Expand All @@ -935,7 +961,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):

train_metrics = []

steps_per_epoch = len(train_dataset) // total_train_batch_size
steps_per_epoch = (
args.max_train_samples // total_train_batch_size
if args.streaming
else len(train_dataset) // total_train_batch_size
)
train_step_progress_bar = tqdm(
total=steps_per_epoch,
desc="Training...",
Expand Down Expand Up @@ -980,7 +1010,8 @@ def cumul_grad_step(grad_idx, loss_grad_rng):

# Create the pipeline using using the trained modules and save it.
if jax.process_index() == 0:
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
if args.validation_prompt is not None:
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)

controlnet.save_pretrained(
args.output_dir,
Expand Down