From 11a449183eabdd6442555226923f725867e89a60 Mon Sep 17 00:00:00 2001 From: Bhavay Malhotra <56443877+Bhavay-2001@users.noreply.github.com> Date: Fri, 5 Apr 2024 15:35:47 +0530 Subject: [PATCH 1/2] Create diffusers.yml --- diffusers.yml | 166 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 diffusers.yml diff --git a/diffusers.yml b/diffusers.yml new file mode 100644 index 000000000000..fc0f93bbf670 --- /dev/null +++ b/diffusers.yml @@ -0,0 +1,166 @@ +name: diffusers +channels: + - defaults +dependencies: + - bzip2=1.0.8=h80987f9_5 + - ca-certificates=2023.12.12=hca03da5_0 + - expat=2.5.0=h313beb8_0 + - libcxx=14.0.6=h848a8c0_0 + - libffi=3.4.4=hca03da5_0 + - ncurses=6.4=h313beb8_0 + - openssl=3.0.13=h1a28f6b_0 + - pip=23.3.1=py312hca03da5_0 + - python=3.12.2=h99e199e_0 + - readline=8.2=h1a28f6b_0 + - setuptools=68.2.2=py312hca03da5_0 + - sqlite=3.41.2=h80987f9_0 + - tk=8.6.12=hb8d0fd4_0 + - wheel=0.41.2=py312hca03da5_0 + - xz=5.4.6=h80987f9_0 + - zlib=1.2.13=h5a0b063_0 + - pip: + - absl-py==2.1.0 + - accelerate==0.27.2 + - aiohttp==3.9.3 + - aiosignal==1.3.1 + - anyio==4.3.0 + - appdirs==1.4.4 + - attrs==23.2.0 + - audioread==3.0.1 + - backoff==2.2.1 + - certifi==2024.2.2 + - cffi==1.16.0 + - charset-normalizer==3.3.2 + - chex==0.1.85 + - clean-fid==0.1.35 + - click==8.1.7 + - clip-anytorch==2.6.0 + - compel==0.1.8 + - datasets==2.18.0 + - dctorch==0.1.2 + - decorator==5.1.1 + - diffusers==0.27.0.dev0 + - dill==0.3.8 + - docker-pycreds==0.4.0 + - einops==0.7.0 + - etils==1.7.0 + - execnet==2.0.2 + - fastjsonschema==2.19.1 + - filelock==3.13.1 + - flax==0.8.1 + - frozenlist==1.4.1 + - fsspec==2024.2.0 + - ftfy==6.1.3 + - gitdb==4.0.11 + - gitpython==3.1.18 + - gql==3.5.0 + - graphql-core==3.2.3 + - grpcio==1.62.0 + - hf-doc-builder==0.4.0 + - huggingface-hub==0.21.3 + - idna==3.6 + - imageio==2.34.0 + - importlib-metadata==7.0.1 + - importlib-resources==6.1.2 + - iniconfig==2.0.0 + - invisible-watermark==0.2.0 + - isort==5.13.2 + - jax==0.4.25 + - jaxlib==0.4.25 + - jinja2==3.1.3 + - joblib==1.3.2 + - jsonmerge==1.9.2 + - jsonschema==4.21.1 + - jsonschema-specifications==2023.12.1 + - jupyter-core==5.7.1 + - k-diffusion==0.1.1.post1 + - kornia==0.7.1 + - lazy-loader==0.3 + - librosa==0.10.1 + - llvmlite==0.42.0 + - markdown==3.5.2 + - markdown-it-py==3.0.0 + - markupsafe==2.1.5 + - mdurl==0.1.2 + - ml-dtypes==0.3.2 + - mpmath==1.3.0 + - msgpack==1.0.8 + - multidict==6.0.5 + - multiprocess==0.70.16 + - nbformat==5.9.2 + - nest-asyncio==1.6.0 + - networkx==3.2.1 + - numba==0.59.0 + - numpy==1.26.4 + - opencv-python==4.9.0.80 + - opt-einsum==3.3.0 + - optax==0.1.9 + - orbax-checkpoint==0.5.3 + - packaging==23.2 + - pandas==2.2.1 + - parameterized==0.9.0 + - peft==0.9.0 + - pillow==10.2.0 + - platformdirs==4.2.0 + - pluggy==1.4.0 + - pooch==1.8.1 + - protobuf==3.20.3 + - psutil==5.9.8 + - pyarrow==15.0.0 + - pyarrow-hotfix==0.6 + - pycparser==2.21 + - pygments==2.17.2 + - pyparsing==3.1.1 + - pytest==8.0.2 + - pytest-timeout==2.2.0 + - pytest-xdist==3.5.0 + - python-dateutil==2.9.0.post0 + - pytz==2024.1 + - pywavelets==1.5.0 + - pyyaml==6.0.1 + - referencing==0.33.0 + - regex==2023.12.25 + - requests==2.31.0 + - requests-mock==1.10.0 + - requests-toolbelt==1.0.0 + - rich==13.7.1 + - rpds-py==0.18.0 + - ruff==0.1.5 + - safetensors==0.4.2 + - scikit-image==0.22.0 + - scikit-learn==1.4.1.post1 + - scipy==1.12.0 + - sentencepiece==0.2.0 + - sentry-sdk==1.40.6 + - setproctitle==1.3.3 + - six==1.16.0 + - smmap==5.0.1 + - sniffio==1.3.1 + - soundfile==0.12.1 + - soxr==0.3.7 + - sympy==1.12 + - tensorboard==2.16.2 + - tensorboard-data-server==0.7.2 + - tensorstore==0.1.54 + - threadpoolctl==3.3.0 + - tifffile==2024.2.12 + - tokenizers==0.15.2 + - toolz==0.12.1 + - torch==2.2.1 + - torchdiffeq==0.2.3 + - torchsde==0.2.6 + - torchvision==0.17.1 + - tqdm==4.66.2 + - traitlets==5.14.1 + - trampoline==0.1.2 + - transformers==4.38.2 + - typing-extensions==4.10.0 + - tzdata==2024.1 + - urllib3==1.26.18 + - wandb==0.16.3 + - wcwidth==0.2.13 + - werkzeug==3.0.1 + - xxhash==3.4.1 + - yarl==1.9.4 + - zipp==3.17.0 +prefix: /Users/shubhammalhotra/Documents/miniconda3/envs/diffusers From d2ff881e282f56415deee14857ceeeb25d379238 Mon Sep 17 00:00:00 2001 From: Bhavay Malhotra Date: Tue, 11 Jun 2024 23:02:08 +0530 Subject: [PATCH 2/2] num_train_epochs --- diffusers.yml | 166 ------------------- examples/controlnet/train_controlnet_sdxl.py | 25 ++- 2 files changed, 18 insertions(+), 173 deletions(-) delete mode 100644 diffusers.yml diff --git a/diffusers.yml b/diffusers.yml deleted file mode 100644 index fc0f93bbf670..000000000000 --- a/diffusers.yml +++ /dev/null @@ -1,166 +0,0 @@ -name: diffusers -channels: - - defaults -dependencies: - - bzip2=1.0.8=h80987f9_5 - - ca-certificates=2023.12.12=hca03da5_0 - - expat=2.5.0=h313beb8_0 - - libcxx=14.0.6=h848a8c0_0 - - libffi=3.4.4=hca03da5_0 - - ncurses=6.4=h313beb8_0 - - openssl=3.0.13=h1a28f6b_0 - - pip=23.3.1=py312hca03da5_0 - - python=3.12.2=h99e199e_0 - - readline=8.2=h1a28f6b_0 - - setuptools=68.2.2=py312hca03da5_0 - - sqlite=3.41.2=h80987f9_0 - - tk=8.6.12=hb8d0fd4_0 - - wheel=0.41.2=py312hca03da5_0 - - xz=5.4.6=h80987f9_0 - - zlib=1.2.13=h5a0b063_0 - - pip: - - absl-py==2.1.0 - - accelerate==0.27.2 - - aiohttp==3.9.3 - - aiosignal==1.3.1 - - anyio==4.3.0 - - appdirs==1.4.4 - - attrs==23.2.0 - - audioread==3.0.1 - - backoff==2.2.1 - - certifi==2024.2.2 - - cffi==1.16.0 - - charset-normalizer==3.3.2 - - chex==0.1.85 - - clean-fid==0.1.35 - - click==8.1.7 - - clip-anytorch==2.6.0 - - compel==0.1.8 - - datasets==2.18.0 - - dctorch==0.1.2 - - decorator==5.1.1 - - diffusers==0.27.0.dev0 - - dill==0.3.8 - - docker-pycreds==0.4.0 - - einops==0.7.0 - - etils==1.7.0 - - execnet==2.0.2 - - fastjsonschema==2.19.1 - - filelock==3.13.1 - - flax==0.8.1 - - frozenlist==1.4.1 - - fsspec==2024.2.0 - - ftfy==6.1.3 - - gitdb==4.0.11 - - gitpython==3.1.18 - - gql==3.5.0 - - graphql-core==3.2.3 - - grpcio==1.62.0 - - hf-doc-builder==0.4.0 - - huggingface-hub==0.21.3 - - idna==3.6 - - imageio==2.34.0 - - importlib-metadata==7.0.1 - - importlib-resources==6.1.2 - - iniconfig==2.0.0 - - invisible-watermark==0.2.0 - - isort==5.13.2 - - jax==0.4.25 - - jaxlib==0.4.25 - - jinja2==3.1.3 - - joblib==1.3.2 - - jsonmerge==1.9.2 - - jsonschema==4.21.1 - - jsonschema-specifications==2023.12.1 - - jupyter-core==5.7.1 - - k-diffusion==0.1.1.post1 - - kornia==0.7.1 - - lazy-loader==0.3 - - librosa==0.10.1 - - llvmlite==0.42.0 - - markdown==3.5.2 - - markdown-it-py==3.0.0 - - markupsafe==2.1.5 - - mdurl==0.1.2 - - ml-dtypes==0.3.2 - - mpmath==1.3.0 - - msgpack==1.0.8 - - multidict==6.0.5 - - multiprocess==0.70.16 - - nbformat==5.9.2 - - nest-asyncio==1.6.0 - - networkx==3.2.1 - - numba==0.59.0 - - numpy==1.26.4 - - opencv-python==4.9.0.80 - - opt-einsum==3.3.0 - - optax==0.1.9 - - orbax-checkpoint==0.5.3 - - packaging==23.2 - - pandas==2.2.1 - - parameterized==0.9.0 - - peft==0.9.0 - - pillow==10.2.0 - - platformdirs==4.2.0 - - pluggy==1.4.0 - - pooch==1.8.1 - - protobuf==3.20.3 - - psutil==5.9.8 - - pyarrow==15.0.0 - - pyarrow-hotfix==0.6 - - pycparser==2.21 - - pygments==2.17.2 - - pyparsing==3.1.1 - - pytest==8.0.2 - - pytest-timeout==2.2.0 - - pytest-xdist==3.5.0 - - python-dateutil==2.9.0.post0 - - pytz==2024.1 - - pywavelets==1.5.0 - - pyyaml==6.0.1 - - referencing==0.33.0 - - regex==2023.12.25 - - requests==2.31.0 - - requests-mock==1.10.0 - - requests-toolbelt==1.0.0 - - rich==13.7.1 - - rpds-py==0.18.0 - - ruff==0.1.5 - - safetensors==0.4.2 - - scikit-image==0.22.0 - - scikit-learn==1.4.1.post1 - - scipy==1.12.0 - - sentencepiece==0.2.0 - - sentry-sdk==1.40.6 - - setproctitle==1.3.3 - - six==1.16.0 - - smmap==5.0.1 - - sniffio==1.3.1 - - soundfile==0.12.1 - - soxr==0.3.7 - - sympy==1.12 - - tensorboard==2.16.2 - - tensorboard-data-server==0.7.2 - - tensorstore==0.1.54 - - threadpoolctl==3.3.0 - - tifffile==2024.2.12 - - tokenizers==0.15.2 - - toolz==0.12.1 - - torch==2.2.1 - - torchdiffeq==0.2.3 - - torchsde==0.2.6 - - torchvision==0.17.1 - - tqdm==4.66.2 - - traitlets==5.14.1 - - trampoline==0.1.2 - - transformers==4.38.2 - - typing-extensions==4.10.0 - - tzdata==2024.1 - - urllib3==1.26.18 - - wandb==0.16.3 - - wcwidth==0.2.13 - - werkzeug==3.0.1 - - xxhash==3.4.1 - - yarl==1.9.4 - - zipp==3.17.0 -prefix: /Users/shubhammalhotra/Documents/miniconda3/envs/diffusers diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 288a1e3fb612..59a2a65326eb 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1088,17 +1088,22 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer ) # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, num_cycles=args.lr_num_cycles, power=args.lr_power, ) @@ -1110,8 +1115,14 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: + if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)