From 2c899e9ee55e1241955240c7b9a1336bcd7b5bea Mon Sep 17 00:00:00 2001 From: slin000111 Date: Tue, 9 Jan 2024 20:48:49 +0800 Subject: [PATCH 1/5] add examples aigc sdxl --- .../pytorch/sdxl/infer_text_image_lora.py | 15 + examples/pytorch/sdxl/infer_text_to_image.py | 14 + .../sdxl/infer_text_to_image_lora_sdxl.py | 15 + .../pytorch/sdxl/infer_text_to_image_sdxl.py | 10 + examples/pytorch/sdxl/requirements.txt | 7 + examples/pytorch/sdxl/requirements_sdxl.txt | 7 + .../sdxl/scripts/run_train_text_to_image.sh | 17 + .../scripts/run_train_text_to_image_lora.sh | 17 + .../run_train_text_to_image_lora_sdxl.sh | 19 + .../scripts/run_train_text_to_image_sdxl.sh | 24 + examples/pytorch/sdxl/train_text_to_image.py | 6 + .../pytorch/sdxl/train_text_to_image_lora.py | 6 + .../sdxl/train_text_to_image_lora_sdxl.py | 6 + .../pytorch/sdxl/train_text_to_image_sdxl.py | 6 + swift/aigc/__init__.py | 4 + swift/aigc/diffusers/__init__.py | 4 + swift/aigc/diffusers/train_text_to_image.py | 1227 ++++++++++++++ .../diffusers/train_text_to_image_lora.py | 1099 +++++++++++++ .../train_text_to_image_lora_sdxl.py | 1380 ++++++++++++++++ .../diffusers/train_text_to_image_sdxl.py | 1463 +++++++++++++++++ 20 files changed, 5346 insertions(+) create mode 100644 examples/pytorch/sdxl/infer_text_image_lora.py create mode 100644 examples/pytorch/sdxl/infer_text_to_image.py create mode 100644 examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py create mode 100644 examples/pytorch/sdxl/infer_text_to_image_sdxl.py create mode 100644 examples/pytorch/sdxl/requirements.txt create mode 100644 examples/pytorch/sdxl/requirements_sdxl.txt create mode 100644 examples/pytorch/sdxl/scripts/run_train_text_to_image.sh create mode 100644 examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh create mode 100644 examples/pytorch/sdxl/scripts/run_train_text_to_image_lora_sdxl.sh create mode 100644 examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh create mode 100644 examples/pytorch/sdxl/train_text_to_image.py create mode 100644 examples/pytorch/sdxl/train_text_to_image_lora.py create mode 100644 examples/pytorch/sdxl/train_text_to_image_lora_sdxl.py create mode 100644 examples/pytorch/sdxl/train_text_to_image_sdxl.py create mode 100644 swift/aigc/diffusers/__init__.py create mode 100644 swift/aigc/diffusers/train_text_to_image.py create mode 100644 swift/aigc/diffusers/train_text_to_image_lora.py create mode 100644 swift/aigc/diffusers/train_text_to_image_lora_sdxl.py create mode 100644 swift/aigc/diffusers/train_text_to_image_sdxl.py diff --git a/examples/pytorch/sdxl/infer_text_image_lora.py b/examples/pytorch/sdxl/infer_text_image_lora.py new file mode 100644 index 0000000000..248d47525d --- /dev/null +++ b/examples/pytorch/sdxl/infer_text_image_lora.py @@ -0,0 +1,15 @@ +from diffusers import StableDiffusionPipeline +import torch +from swift import Swift +from modelscope import snapshot_download + + +model_path = snapshot_download("AI-ModelScope/stable-diffusion-v1-5") +lora_model_path = "/mnt/workspace/swift/examples/pytorch/sdxl/train_text_to_image_lora" +pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) +pipe.unet = Swift.from_pretrained(pipe.unet, lora_model_path) +pipe.to("cuda") + +prompt = "A pokemon with green eyes and red legs." +image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] +image.save("sw_sd_lora_pokemon.png") diff --git a/examples/pytorch/sdxl/infer_text_to_image.py b/examples/pytorch/sdxl/infer_text_to_image.py new file mode 100644 index 0000000000..c0d7994615 --- /dev/null +++ b/examples/pytorch/sdxl/infer_text_to_image.py @@ -0,0 +1,14 @@ +import torch +from diffusers import StableDiffusionPipeline, UNet2DConditionModel +from modelscope import snapshot_download + +model_path = snapshot_download("AI-ModelScope/stable-diffusion-v1-5") + +unet_model_path = "/mnt/workspace/swift/examples/pytorch/sdxl/train_text_to_image/unet" +unet = UNet2DConditionModel.from_pretrained(unet_model_path, torch_dtype=torch.float16) + +pipe = StableDiffusionPipeline.from_pretrained(model_path, unet=unet, torch_dtype=torch.float16) +pipe.to("cuda") + +image = pipe(prompt="yoda").images[0] +image.save("sw_yoda-pokemon.png") diff --git a/examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py b/examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py new file mode 100644 index 0000000000..f2a1e15d9f --- /dev/null +++ b/examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py @@ -0,0 +1,15 @@ +from diffusers import DiffusionPipeline, StableDiffusionXLPipeline +import torch +from swift import Swift +import os +from modelscope import snapshot_download + +model_path = snapshot_download("AI-ModelScope/stable-diffusion-v1-5") +lora_model_path = "/mnt/workspace/swift_trans_test/examples/pytorch/sdxl/train_text_to_image_lora_sdxl" + +pipe = StableDiffusionXLPipeline.from_pretrained(model_path, torch_dtype=torch.float16) +pipe = pipe.to("cuda") +pipe.unet = Swift.from_pretrained(pipe.unet, os.path.join(lora_model_path, 'unet')) +prompt = "A pokemon with green eyes and red legs." +image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] +image.save("sw_sdxl_lora_pokemon.png") diff --git a/examples/pytorch/sdxl/infer_text_to_image_sdxl.py b/examples/pytorch/sdxl/infer_text_to_image_sdxl.py new file mode 100644 index 0000000000..1d7cfa8305 --- /dev/null +++ b/examples/pytorch/sdxl/infer_text_to_image_sdxl.py @@ -0,0 +1,10 @@ +from diffusers import DiffusionPipeline +import torch + +model_path = "/mnt/workspace/swift/examples/pytorch/sdxl/sdxl-pokemon-model" +pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) +pipe.to("cuda") + +prompt = "A pokemon with green eyes and red legs." +image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] +image.save("sdxl_pokemon.png") diff --git a/examples/pytorch/sdxl/requirements.txt b/examples/pytorch/sdxl/requirements.txt new file mode 100644 index 0000000000..e92fe73913 --- /dev/null +++ b/examples/pytorch/sdxl/requirements.txt @@ -0,0 +1,7 @@ +accelerate>=0.16.0 +datasets +ftfy +Jinja2 +tensorboard +torchvision +transformers>=4.25.1 diff --git a/examples/pytorch/sdxl/requirements_sdxl.txt b/examples/pytorch/sdxl/requirements_sdxl.txt new file mode 100644 index 0000000000..6e50452b96 --- /dev/null +++ b/examples/pytorch/sdxl/requirements_sdxl.txt @@ -0,0 +1,7 @@ +accelerate>=0.22.0 +datasets +ftfy +Jinja2 +tensorboard +torchvision +transformers>=4.25.1 diff --git a/examples/pytorch/sdxl/scripts/run_train_text_to_image.sh b/examples/pytorch/sdxl/scripts/run_train_text_to_image.sh new file mode 100644 index 0000000000..4d0543f7d9 --- /dev/null +++ b/examples/pytorch/sdxl/scripts/run_train_text_to_image.sh @@ -0,0 +1,17 @@ +PYTHONPATH=../../../ \ +accelerate launch --mixed_precision="fp16" train_text_to_image.py \ + --pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \ + --dataset_name="AI-ModelScope/pokemon-blip-captions" \ + --use_ema \ + --resolution=512 \ + --center_crop \ + --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --output_dir="train_text_to_image" \ diff --git a/examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh b/examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh new file mode 100644 index 0000000000..8e163b1cca --- /dev/null +++ b/examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh @@ -0,0 +1,17 @@ +PYTHONPATH=../../../ \ +accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \ + --pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \ + --dataset_name="AI-ModelScope/pokemon-blip-captions" \ + --caption_column="text" \ + --resolution=512 \ + --random_flip \ + --train_batch_size=1 \ + --num_train_epochs=100 \ + --checkpointing_steps=5000 \ + --learning_rate=1e-04 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --seed=42 \ + --output_dir="train_text_to_image_lora" \ + --validation_prompt="cute dragon creature" \ + --report_to="tensorboard" \ diff --git a/examples/pytorch/sdxl/scripts/run_train_text_to_image_lora_sdxl.sh b/examples/pytorch/sdxl/scripts/run_train_text_to_image_lora_sdxl.sh new file mode 100644 index 0000000000..ac837f77d5 --- /dev/null +++ b/examples/pytorch/sdxl/scripts/run_train_text_to_image_lora_sdxl.sh @@ -0,0 +1,19 @@ +PYTHONPATH=../../../ \ +accelerate launch train_text_to_image_lora_sdxl.py \ + --pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \ + --pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \ + --dataset_name="AI-ModelScope/pokemon-blip-captions" \ + --caption_column="text" \ + --resolution=1024 \ + --random_flip \ + --train_batch_size=1 \ + --num_train_epochs=2 \ + --checkpointing_steps=500 \ + --learning_rate=1e-04 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --mixed_precision="fp16" \ + --seed=42 \ + --output_dir="train_text_to_image_lora_sdxl" \ + --validation_prompt="cute dragon creature" \ + --report_to="tensorboard" \ diff --git a/examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh b/examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh new file mode 100644 index 0000000000..2126961c4c --- /dev/null +++ b/examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh @@ -0,0 +1,24 @@ +PYTHONPATH=../../../ \ +accelerate launch train_text_to_image_sdxl.py \ + --pretrained_model_name_or_path"AI-ModelScope/stable-diffusion-xl-base-1.0" \ + --pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \ + --dataset_name="AI-ModelScope/pokemon-blip-captions" \ + --enable_xformers_memory_efficient_attention \ + --resolution=512 \ + --center_crop \ + --random_flip \ + --proportion_empty_prompts=0.2 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=10000 \ + --use_8bit_adam \ + --learning_rate=1e-06 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --mixed_precision="fp16" \ + --report_to="tensorboard" \ + --validation_prompt="a cute Sundar Pichai creature" \ + --validation_epochs 5 \ + --checkpointing_steps=5000 \ + --output_dir="train_text_to_image_sdxl" \ diff --git a/examples/pytorch/sdxl/train_text_to_image.py b/examples/pytorch/sdxl/train_text_to_image.py new file mode 100644 index 0000000000..b94e9fe421 --- /dev/null +++ b/examples/pytorch/sdxl/train_text_to_image.py @@ -0,0 +1,6 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from swift.aigc import train_text_to_image + +if __name__ == '__main__': + train_text_to_image() diff --git a/examples/pytorch/sdxl/train_text_to_image_lora.py b/examples/pytorch/sdxl/train_text_to_image_lora.py new file mode 100644 index 0000000000..d2ed51a054 --- /dev/null +++ b/examples/pytorch/sdxl/train_text_to_image_lora.py @@ -0,0 +1,6 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from swift.aigc import train_text_to_image_lora + +if __name__ == '__main__': + train_text_to_image_lora() diff --git a/examples/pytorch/sdxl/train_text_to_image_lora_sdxl.py b/examples/pytorch/sdxl/train_text_to_image_lora_sdxl.py new file mode 100644 index 0000000000..90b0d6f2b5 --- /dev/null +++ b/examples/pytorch/sdxl/train_text_to_image_lora_sdxl.py @@ -0,0 +1,6 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from swift.aigc import train_text_to_image_lora_sdxl + +if __name__ == '__main__': + train_text_to_image_lora_sdxl() diff --git a/examples/pytorch/sdxl/train_text_to_image_sdxl.py b/examples/pytorch/sdxl/train_text_to_image_sdxl.py new file mode 100644 index 0000000000..467b3bc7ef --- /dev/null +++ b/examples/pytorch/sdxl/train_text_to_image_sdxl.py @@ -0,0 +1,6 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from swift.aigc import train_text_to_image_sdxl + +if __name__ == '__main__': + train_text_to_image_sdxl() diff --git a/swift/aigc/__init__.py b/swift/aigc/__init__.py index 37c04da323..d5a0439e8b 100644 --- a/swift/aigc/__init__.py +++ b/swift/aigc/__init__.py @@ -7,11 +7,15 @@ # Recommend using `xxx_main` from .animatediff import animatediff_sft, animatediff_main from .animatediff_infer import animatediff_infer, animatediff_infer_main + from .diffusers import train_text_to_image, train_text_to_image_lora, train_text_to_image_lora_sdxl, \ + train_text_to_image_sdxl from .utils import AnimateDiffArguments, AnimateDiffInferArguments else: _import_structure = { 'animatediff': ['animatediff_sft', 'animatediff_main'], 'animatediff_infer': ['animatediff_infer', 'animatediff_infer_main'], + 'diffusers': ['train_text_to_image', 'train_text_to_image_lora', 'train_text_to_image_lora_sdxl', + 'train_text_to_image_sdxl'], 'utils': ['AnimateDiffArguments', 'AnimateDiffInferArguments'], } diff --git a/swift/aigc/diffusers/__init__.py b/swift/aigc/diffusers/__init__.py new file mode 100644 index 0000000000..09de856eab --- /dev/null +++ b/swift/aigc/diffusers/__init__.py @@ -0,0 +1,4 @@ +from .train_text_to_image_sdxl import main as train_text_to_image_sdxl +from .train_text_to_image_lora_sdxl import main as train_text_to_image_lora_sdxl +from .train_text_to_image import main as train_text_to_image +from .train_text_to_image_lora import main as train_text_to_image_lora \ No newline at end of file diff --git a/swift/aigc/diffusers/train_text_to_image.py b/swift/aigc/diffusers/train_text_to_image.py new file mode 100644 index 0000000000..c5cc7a8367 --- /dev/null +++ b/swift/aigc/diffusers/train_text_to_image.py @@ -0,0 +1,1227 @@ +#!/usr/bin/env python +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import datasets +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from diffusers import (AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, + UNet2DConditionModel) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_snr +from diffusers.utils import (check_min_version, deprecate, is_wandb_available, + make_image_grid) +from diffusers.utils.import_utils import is_xformers_available +from modelscope import MsDataset +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers.utils import ContextManagers + +from swift import push_to_hub, snapshot_download + +if is_wandb_available(): + import wandb + +logger = get_logger(__name__, log_level='INFO') + +DATASET_NAME_MAPPING = { + 'AI-ModelScope/pokemon-blip-captions': ('text', 'image:FILE'), +} + + +def save_model_card( + args, + repo_id: str, + images=None, + repo_folder=None, +): + img_str = '' + if len(images) > 0: + image_grid = make_image_grid(images, 1, len(args.validation_prompts)) + image_grid.save(os.path.join(repo_folder, 'val_imgs_grid.png')) + img_str += '![val_imgs_grid](./val_imgs_grid.png)\n' + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {args.pretrained_model_name_or_path} +datasets: +- {args.dataset_name} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +inference: true +--- + """ + model_card = f""" +# Text-to-image finetuning - {repo_id} + +This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below +are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n +{img_str} + +## Pipeline usage + +You can use the pipeline like so: + +```python +from diffusers import DiffusionPipeline +import torch + +pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16) +prompt = "{args.validation_prompts[0]}" +image = pipeline(prompt).images[0] +image.save("my_image.png") +``` + +## Training info + +These are the key hyperparameters used during training: + +* Epochs: {args.num_train_epochs} +* Learning rate: {args.learning_rate} +* Batch size: {args.train_batch_size} +* Gradient accumulation steps: {args.gradient_accumulation_steps} +* Image resolution: {args.resolution} +* Mixed-precision: {args.mixed_precision} + +""" + wandb_info = '' + if is_wandb_available(): + wandb_run_url = None + if wandb.run is not None: + wandb_run_url = wandb.run.url + + if wandb_run_url is not None: + wandb_info = f""" +More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}). +""" + + model_card += wandb_info + + with open(os.path.join(repo_folder, 'README.md'), 'w') as f: + f.write(yaml + model_card) + + +def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, + weight_dtype, epoch): + logger.info('Running validation... ') + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=accelerator.unwrap_model(vae), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet), + safety_checker=None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed( + args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.autocast('cuda'): + image = pipeline( + args.validation_prompts[i], + num_inference_steps=20, + generator=generator).images[0] + + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == 'tensorboard': + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images( + 'validation', np_images, epoch, dataformats='NHWC') + elif tracker.name == 'wandb': + tracker.log({ + 'validation': [ + wandb.Image( + image, caption=f'{i}: {args.validation_prompts[i]}') + for i, image in enumerate(images) + ] + }) + else: + logger.warn(f'image logging not implemented for {tracker.name}') + + del pipeline + torch.cuda.empty_cache() + + return images + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Simple example of a training script.') + parser.add_argument( + '--input_perturbation', + type=float, + default=0, + help='The scale of input perturbation. Recommended 0.1.') + parser.add_argument( + '--pretrained_model_name_or_path', + type=str, + default=None, + required=True, + help= + 'Path to pretrained model or model identifier from huggingface.co/models.', + ) + parser.add_argument( + '--revision', + type=str, + default=None, + required=False, + help= + 'Revision of pretrained model identifier from huggingface.co/models.', + ) + parser.add_argument( + '--variant', + type=str, + default=None, + help= + "Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + '--dataset_name', + type=str, + default=None, + help= + ('The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,' + ' dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,' + ' or to a folder containing files that 🤗 Datasets can understand.'), + ) + parser.add_argument( + '--dataset_config_name', + type=str, + default=None, + help= + "The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + '--train_data_dir', + type=str, + default=None, + help= + ('A folder containing the training data. Folder contents must follow the structure described in' + ' https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file' + ' must exist to provide the captions for the images. Ignored if `dataset_name` is specified.' + ), + ) + parser.add_argument( + '--image_column', + type=str, + default='image:FILE', + help='The column of the dataset containing an image.') + parser.add_argument( + '--caption_column', + type=str, + default='text', + help= + 'The column of the dataset containing a caption or a list of captions.', + ) + parser.add_argument( + '--max_train_samples', + type=int, + default=None, + help= + ('For debugging purposes or quicker training, truncate the number of training examples to this ' + 'value if set.'), + ) + parser.add_argument( + '--validation_prompts', + type=str, + default=None, + nargs='+', + help= + ('A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`.' + ), + ) + parser.add_argument( + '--output_dir', + type=str, + default='sd-model-finetuned', + help= + 'The output directory where the model predictions and checkpoints will be written.', + ) + parser.add_argument( + '--cache_dir', + type=str, + default=None, + help= + 'The directory where the downloaded models and datasets will be stored.', + ) + parser.add_argument( + '--seed', + type=int, + default=None, + help='A seed for reproducible training.') + parser.add_argument( + '--resolution', + type=int, + default=512, + help= + ('The resolution for input images, all the images in the train/validation dataset will be resized to this' + ' resolution'), + ) + parser.add_argument( + '--center_crop', + default=False, + action='store_true', + help= + ('Whether to center crop the input images to the resolution. If not set, the images will be randomly' + ' cropped. The images will be resized to the resolution first before cropping.' + ), + ) + parser.add_argument( + '--random_flip', + action='store_true', + help='whether to randomly flip images horizontally', + ) + parser.add_argument( + '--train_batch_size', + type=int, + default=16, + help='Batch size (per device) for the training dataloader.') + parser.add_argument('--num_train_epochs', type=int, default=100) + parser.add_argument( + '--max_train_steps', + type=int, + default=None, + help= + 'Total number of training steps to perform. If provided, overrides num_train_epochs.', + ) + parser.add_argument( + '--gradient_accumulation_steps', + type=int, + default=1, + help= + 'Number of updates steps to accumulate before performing a backward/update pass.', + ) + parser.add_argument( + '--gradient_checkpointing', + action='store_true', + help= + 'Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.', + ) + parser.add_argument( + '--learning_rate', + type=float, + default=1e-4, + help= + 'Initial learning rate (after the potential warmup period) to use.', + ) + parser.add_argument( + '--scale_lr', + action='store_true', + default=False, + help= + 'Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.', + ) + parser.add_argument( + '--lr_scheduler', + type=str, + default='constant', + help= + ('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument( + '--lr_warmup_steps', + type=int, + default=500, + help='Number of steps for the warmup in the lr scheduler.') + parser.add_argument( + '--snr_gamma', + type=float, + default=None, + help= + 'SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. ' + 'More details here: https://arxiv.org/abs/2303.09556.', + ) + parser.add_argument( + '--use_8bit_adam', + action='store_true', + help='Whether or not to use 8-bit Adam from bitsandbytes.') + parser.add_argument( + '--allow_tf32', + action='store_true', + help= + ('Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see' + ' https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices' + ), + ) + parser.add_argument( + '--use_ema', action='store_true', help='Whether to use EMA model.') + parser.add_argument( + '--non_ema_revision', + type=str, + default=None, + required=False, + help= + ('Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or' + ' remote repository specified with --pretrained_model_name_or_path.'), + ) + parser.add_argument( + '--dataloader_num_workers', + type=int, + default=0, + help= + ('Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.' + ), + ) + parser.add_argument( + '--adam_beta1', + type=float, + default=0.9, + help='The beta1 parameter for the Adam optimizer.') + parser.add_argument( + '--adam_beta2', + type=float, + default=0.999, + help='The beta2 parameter for the Adam optimizer.') + parser.add_argument( + '--adam_weight_decay', + type=float, + default=1e-2, + help='Weight decay to use.') + parser.add_argument( + '--adam_epsilon', + type=float, + default=1e-08, + help='Epsilon value for the Adam optimizer') + parser.add_argument( + '--max_grad_norm', default=1.0, type=float, help='Max gradient norm.') + parser.add_argument( + '--push_to_hub', + action='store_true', + help='Whether or not to push the model to the Hub.') + parser.add_argument( + '--hub_token', + type=str, + default=None, + help='The token to use to push to the Model Hub.') + parser.add_argument( + '--prediction_type', + type=str, + default=None, + help= + "The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave \ + `None`. If left to `None` the default prediction type of the scheduler: \ + `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + '--hub_model_id', + type=str, + default=None, + help= + 'The name of the repository to keep in sync with the local `output_dir`.', + ) + parser.add_argument( + '--logging_dir', + type=str, + default='logs', + help= + ('[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to' + ' *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.'), + ) + parser.add_argument( + '--mixed_precision', + type=str, + default=None, + choices=['no', 'fp16', 'bf16'], + help= + ('Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=' + ' 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the' + ' flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.' + ), + ) + parser.add_argument( + '--report_to', + type=str, + default='tensorboard', + help= + ('The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + '--local_rank', + type=int, + default=-1, + help='For distributed training: local_rank') + parser.add_argument( + '--checkpointing_steps', + type=int, + default=500, + help= + ('Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming' + ' training using `--resume_from_checkpoint`.'), + ) + parser.add_argument( + '--checkpoints_total_limit', + type=int, + default=None, + help=('Max number of checkpoints to store.'), + ) + parser.add_argument( + '--resume_from_checkpoint', + type=str, + default=None, + help= + ('Whether training should be resumed from a previous checkpoint. Use a path saved by' + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + '--enable_xformers_memory_efficient_attention', + action='store_true', + help='Whether or not to use xformers.') + parser.add_argument( + '--noise_offset', + type=float, + default=0, + help='The scale of noise offset.') + parser.add_argument( + '--validation_epochs', + type=int, + default=5, + help='Run validation every X epochs.', + ) + parser.add_argument( + '--tracker_project_name', + type=str, + default='text2image-fine-tune', + help= + ('The `project_name` argument passed to Accelerator.init_trackers for' + ' more information see ' + 'https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator' + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get('LOCAL_RANK', -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError('Need either a dataset name or a training folder.') + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + args.base_model_id = args.pretrained_model_name_or_path + if not os.path.exists(args.pretrained_model_name_or_path): + args.pretrained_model_name_or_path = snapshot_download( + args.pretrained_model_name_or_path, revision=args.revision) + return args + + +def main(): + args = parse_args() + + if args.non_ema_revision is not None: + deprecate( + 'non_ema_revision!=None', + '0.15.0', + message= + ("Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + ' use `--variant=non_ema` instead.'), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder='scheduler') + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='tokenizer', + revision=args.revision) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState( + ).deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='text_encoder', + revision=args.revision, + variant=args.variant) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='vae', + revision=args.revision, + variant=args.variant) + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='unet', + revision=args.non_ema_revision) + + # Freeze vae and text_encoder and set unet to trainable + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.train() + + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='unet', + revision=args.revision, + variant=args.variant) + ema_unet = EMAModel( + ema_unet.parameters(), + model_cls=UNet2DConditionModel, + model_config=ema_unet.config) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse('0.0.16'): + logger.warn( + 'xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training,' + ' please update xFormers to at least 0.0.17. See ' + 'https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.' + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + 'xformers is not available. Make sure it is installed correctly' + ) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse('0.16.0'): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_unet.save_pretrained( + os.path.join(output_dir, 'unet_ema')) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, 'unet')) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained( + os.path.join(input_dir, 'unet_ema'), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained( + input_dir, subfolder='unet') + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps + * args.train_batch_size * accelerator.num_processes) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + 'Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`' + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + def path_to_img(example): + example['image'] = Image.open(example['image:FILE']) + return example + + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = MsDataset.load( + args.dataset_name, + args.dataset_config_name, + data_dir=args.train_data_dir, + ) + if isinstance(dataset, dict): + dataset = { + key: value.to_hf_dataset() + for key, value in dataset.items() + } + else: + dataset = {'train': dataset.to_hf_dataset()} + else: + data_files = {} + if args.train_data_dir is not None: + data_files['train'] = os.path.join(args.train_data_dir, '**') + dataset = load_dataset( + 'imagefolder', + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset['train'].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[ + 1] if dataset_columns is not None else column_names[1] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[ + 0] if dataset_columns is not None else column_names[0] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + if image_column.endswith(':FILE'): + dataset['train'] = dataset['train'].map(path_to_img) + image_column = 'image' + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append( + random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f'Caption column `{caption_column}` should contain either strings or lists of strings.' + ) + inputs = tokenizer( + captions, + max_length=tokenizer.model_max_length, + padding='max_length', + truncation=True, + return_tensors='pt') + return inputs.input_ids + + # Preprocessing the datasets. + train_transforms = transforms.Compose([ + transforms.Resize( + args.resolution, + interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) + if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() + if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def preprocess_train(examples): + images = [image.convert('RGB') for image in examples[image_column]] + examples['pixel_values'] = [ + train_transforms(image) for image in images + ] + examples['input_ids'] = tokenize_captions(examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset['train'] = dataset['train'].shuffle(seed=args.seed).select( + range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset['train'].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack( + [example['pixel_values'] for example in examples]) + pixel_values = pixel_values.to( + memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example['input_ids'] for example in examples]) + return {'pixel_values': pixel_values, 'input_ids': input_ids} + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler) + + if args.use_ema: + ema_unet.to(accelerator.device) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) + # to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == 'fp16': + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == 'bf16': + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Move text_encode and vae to gpu and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps + / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop('validation_prompts') + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info('***** Running training *****') + logger.info(f' Num examples = {len(train_dataset)}') + logger.info(f' Num Epochs = {args.num_train_epochs}') + logger.info( + f' Instantaneous batch size per device = {args.train_batch_size}') + logger.info( + f' Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}' + ) + logger.info( + f' Gradient Accumulation steps = {args.gradient_accumulation_steps}') + logger.info(f' Total optimization steps = {args.max_train_steps}') + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != 'latest': + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith('checkpoint')] + dirs = sorted(dirs, key=lambda x: int(x.split('-')[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f'Resuming from checkpoint {path}') + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split('-')[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc='Steps', + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch['pixel_values'].to( + weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), + device=latents.device) + if args.input_perturbation: + new_noise = noise + args.input_perturbation * torch.randn_like( + noise) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, (bsz, ), + device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.input_perturbation: + noisy_latents = noise_scheduler.add_noise( + latents, new_noise, timesteps) + else: + noisy_latents = noise_scheduler.add_noise( + latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch['input_ids'])[0] + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config( + prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == 'epsilon': + target = noise + elif noise_scheduler.config.prediction_type == 'v_prediction': + target = noise_scheduler.get_velocity( + latents, noise, timesteps) + else: + raise ValueError( + f'Unknown prediction type {noise_scheduler.config.prediction_type}' + ) + + # Predict the noise residual and compute loss + model_pred = unet(noisy_latents, timesteps, + encoder_hidden_states).sample + + if args.snr_gamma is None: + loss = F.mse_loss( + model_pred.float(), target.float(), reduction='mean') + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + if noise_scheduler.config.prediction_type == 'v_prediction': + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], + dim=1).min(dim=1)[0] / snr) + + loss = F.mse_loss( + model_pred.float(), target.float(), reduction='none') + loss = loss.mean( + dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather( + loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item( + ) / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), + args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({'train_loss': train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [ + d for d in checkpoints + if d.startswith('checkpoint') + ] + checkpoints = sorted( + checkpoints, + key=lambda x: int(x.split('-')[1])) + + # before we save the new checkpoint, we need to have at _most_ \ + # `checkpoints_total_limit - 1` checkpoints + if len(checkpoints + ) >= args.checkpoints_total_limit: + num_to_remove = len( + checkpoints + ) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[ + 0:num_to_remove] + + logger.info( + f'{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)}\ + checkpoints') + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, + f'checkpoint-{global_step}') + accelerator.save_state(save_path) + logger.info(f'Saved state to {save_path}') + + logs = { + 'step_loss': loss.detach().item(), + 'lr': lr_scheduler.get_last_lr()[0] + } + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=text_encoder, + vae=vae, + unet=unet, + revision=args.revision, + variant=args.variant, + ) + pipeline.save_pretrained(args.output_dir) + + # Run a final round of inference. + images = [] + if args.validation_prompts is not None: + logger.info('Running inference for collecting generated images...') + pipeline = pipeline.to(accelerator.device) + pipeline.torch_dtype = weight_dtype + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator( + device=accelerator.device).manual_seed(args.seed) + + for i in range(len(args.validation_prompts)): + with torch.autocast('cuda'): + image = pipeline( + args.validation_prompts[i], + num_inference_steps=20, + generator=generator).images[0] + images.append(image) + + if args.push_to_hub: + save_model_card( + args, args.hub_model_id, images, repo_folder=args.output_dir) + push_to_hub(args.hub_model_id, args.output_dir, args.hub_token) + + accelerator.end_training() + + +if __name__ == '__main__': + main() diff --git a/swift/aigc/diffusers/train_text_to_image_lora.py b/swift/aigc/diffusers/train_text_to_image_lora.py new file mode 100644 index 0000000000..bae08b1be5 --- /dev/null +++ b/swift/aigc/diffusers/train_text_to_image_lora.py @@ -0,0 +1,1099 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA.""" + +import argparse +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import datasets +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from diffusers import (AutoencoderKL, DDPMScheduler, DiffusionPipeline, + UNet2DConditionModel) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_snr +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from modelscope import MsDataset +from packaging import version +from peft.utils import get_peft_model_state_dict +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from swift import LoRAConfig, Swift, push_to_hub, snapshot_download + +logger = get_logger(__name__, log_level='INFO') + + +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, + (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f'text_model.encoder.layers.{i}.self_attn' + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f'{name}.q_proj.lora_linear_layer.{k}'] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f'{name}.k_proj.lora_linear_layer.{k}'] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f'{name}.v_proj.lora_linear_layer.{k}'] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f'{name}.out_proj.lora_linear_layer.{k}'] = v + + return state_dict + + +def save_model_card(repo_id: str, + images=None, + base_model=str, + dataset_name=str, + repo_folder=None): + img_str = '' + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f'image_{i}.png')) + img_str += f'![img_{i}](./image_{i}.png)\n' + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- lora +inference: true +--- + """ + model_card = f""" +# LoRA text2image fine-tuning - {repo_id} +These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. +You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, 'README.md'), 'w') as f: + f.write(yaml + model_card) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Simple example of a training script.') + parser.add_argument( + '--pretrained_model_name_or_path', + type=str, + default=None, + required=True, + help= + 'Path to pretrained model or model identifier from huggingface.co/models.', + ) + parser.add_argument( + '--revision', + type=str, + default=None, + required=False, + help= + 'Revision of pretrained model identifier from huggingface.co/models.', + ) + parser.add_argument( + '--variant', + type=str, + default=None, + help= + "Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + '--dataset_name', + type=str, + default=None, + help= + ('The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,' + ' dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,' + ' or to a folder containing files that 🤗 Datasets can understand.'), + ) + parser.add_argument( + '--dataset_config_name', + type=str, + default=None, + help= + "The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + '--train_data_dir', + type=str, + default=None, + help= + ('A folder containing the training data. Folder contents must follow the structure described in' + ' https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file' + ' must exist to provide the captions for the images. Ignored if `dataset_name` is specified.' + ), + ) + parser.add_argument( + '--image_column', + type=str, + default='image:FILE', + help='The column of the dataset containing an image.') + parser.add_argument( + '--caption_column', + type=str, + default='text', + help= + 'The column of the dataset containing a caption or a list of captions.', + ) + parser.add_argument( + '--validation_prompt', + type=str, + default=None, + help='A prompt that is sampled during training for inference.') + parser.add_argument( + '--num_validation_images', + type=int, + default=4, + help= + 'Number of images that should be generated during validation with `validation_prompt`.', + ) + parser.add_argument( + '--validation_epochs', + type=int, + default=1, + help= + ('Run fine-tuning validation every X epochs. The validation process consists of running the prompt' + ' `args.validation_prompt` multiple times: `args.num_validation_images`.' + ), + ) + parser.add_argument( + '--max_train_samples', + type=int, + default=None, + help= + ('For debugging purposes or quicker training, truncate the number of training examples to this ' + 'value if set.'), + ) + parser.add_argument( + '--output_dir', + type=str, + default='sd-model-finetuned-lora', + help= + 'The output directory where the model predictions and checkpoints will be written.', + ) + parser.add_argument( + '--cache_dir', + type=str, + default=None, + help= + 'The directory where the downloaded models and datasets will be stored.', + ) + parser.add_argument( + '--seed', + type=int, + default=None, + help='A seed for reproducible training.') + parser.add_argument( + '--resolution', + type=int, + default=512, + help= + ('The resolution for input images, all the images in the train/validation dataset will be resized to this' + ' resolution'), + ) + parser.add_argument( + '--center_crop', + default=False, + action='store_true', + help= + ('Whether to center crop the input images to the resolution. If not set, the images will be randomly' + ' cropped. The images will be resized to the resolution first before cropping.' + ), + ) + parser.add_argument( + '--random_flip', + action='store_true', + help='whether to randomly flip images horizontally', + ) + parser.add_argument( + '--train_batch_size', + type=int, + default=16, + help='Batch size (per device) for the training dataloader.') + parser.add_argument('--num_train_epochs', type=int, default=100) + parser.add_argument( + '--max_train_steps', + type=int, + default=None, + help= + 'Total number of training steps to perform. If provided, overrides num_train_epochs.', + ) + parser.add_argument( + '--gradient_accumulation_steps', + type=int, + default=1, + help= + 'Number of updates steps to accumulate before performing a backward/update pass.', + ) + parser.add_argument( + '--gradient_checkpointing', + action='store_true', + help= + 'Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.', + ) + parser.add_argument( + '--learning_rate', + type=float, + default=1e-4, + help= + 'Initial learning rate (after the potential warmup period) to use.', + ) + parser.add_argument( + '--scale_lr', + action='store_true', + default=False, + help= + 'Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.', + ) + parser.add_argument( + '--lr_scheduler', + type=str, + default='constant', + help= + ('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument( + '--lr_warmup_steps', + type=int, + default=500, + help='Number of steps for the warmup in the lr scheduler.') + parser.add_argument( + '--snr_gamma', + type=float, + default=None, + help= + 'SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. ' + 'More details here: https://arxiv.org/abs/2303.09556.', + ) + parser.add_argument( + '--use_8bit_adam', + action='store_true', + help='Whether or not to use 8-bit Adam from bitsandbytes.') + parser.add_argument( + '--allow_tf32', + action='store_true', + help= + ('Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see' + ' https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices' + ), + ) + parser.add_argument( + '--dataloader_num_workers', + type=int, + default=0, + help= + ('Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.' + ), + ) + parser.add_argument( + '--adam_beta1', + type=float, + default=0.9, + help='The beta1 parameter for the Adam optimizer.') + parser.add_argument( + '--adam_beta2', + type=float, + default=0.999, + help='The beta2 parameter for the Adam optimizer.') + parser.add_argument( + '--adam_weight_decay', + type=float, + default=1e-2, + help='Weight decay to use.') + parser.add_argument( + '--adam_epsilon', + type=float, + default=1e-08, + help='Epsilon value for the Adam optimizer') + parser.add_argument( + '--max_grad_norm', default=1.0, type=float, help='Max gradient norm.') + parser.add_argument( + '--push_to_hub', + action='store_true', + help='Whether or not to push the model to the Hub.') + parser.add_argument( + '--hub_token', + type=str, + default=None, + help='The token to use to push to the Model Hub.') + parser.add_argument( + '--prediction_type', + type=str, + default=None, + help= + "The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or \ + leave `None`. If left to `None` the default prediction type of the scheduler: \ + `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + '--hub_model_id', + type=str, + default=None, + help= + 'The name of the repository to keep in sync with the local `output_dir`.', + ) + parser.add_argument( + '--logging_dir', + type=str, + default='logs', + help= + ('[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to' + ' *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.'), + ) + parser.add_argument( + '--mixed_precision', + type=str, + default=None, + choices=['no', 'fp16', 'bf16'], + help= + ('Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=' + ' 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the' + ' flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.' + ), + ) + parser.add_argument( + '--report_to', + type=str, + default='tensorboard', + help= + ('The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + '--local_rank', + type=int, + default=-1, + help='For distributed training: local_rank') + parser.add_argument( + '--checkpointing_steps', + type=int, + default=500, + help= + ('Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming' + ' training using `--resume_from_checkpoint`.'), + ) + parser.add_argument( + '--checkpoints_total_limit', + type=int, + default=None, + help=('Max number of checkpoints to store.'), + ) + parser.add_argument( + '--resume_from_checkpoint', + type=str, + default=None, + help= + ('Whether training should be resumed from a previous checkpoint. Use a path saved by' + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + '--enable_xformers_memory_efficient_attention', + action='store_true', + help='Whether or not to use xformers.') + parser.add_argument( + '--noise_offset', + type=float, + default=0, + help='The scale of noise offset.') + parser.add_argument( + '--rank', + type=int, + default=4, + help=('The dimension of the LoRA update matrices.'), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get('LOCAL_RANK', -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError('Need either a dataset name or a training folder.') + + args.base_model_id = args.pretrained_model_name_or_path + if not os.path.exists(args.pretrained_model_name_or_path): + args.pretrained_model_name_or_path = snapshot_download( + args.pretrained_model_name_or_path, revision=args.revision) + return args + + +DATASET_NAME_MAPPING = { + 'AI-ModelScope/pokemon-blip-captions': ('text', 'image:FILE'), +} + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + if args.report_to == 'wandb': + if not is_wandb_available(): + raise ImportError( + 'Make sure to install wandb if you want to use it for logging during training.' + ) + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder='scheduler') + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='tokenizer', + revision=args.revision) + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='text_encoder', + revision=args.revision) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='vae', + revision=args.revision, + variant=args.variant) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='unet', + revision=args.revision, + variant=args.variant) + # freeze parameters of models to save more memory + unet.requires_grad_(False) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to + # half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == 'fp16': + weight_dtype = torch.float16 + elif accelerator.mixed_precision == 'bf16': + weight_dtype = torch.bfloat16 + + # Freeze the unet parameters before adding adapters + for param in unet.parameters(): + param.requires_grad_(False) + + unet_lora_config = LoRAConfig( + r=args.rank, + init_lora_weights='gaussian', + target_modules=['to_k', 'to_q', 'to_v', 'to_out.0']) + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + unet = Swift.prepare_model(unet, unet_lora_config) + if args.mixed_precision == "fp16": + for param in unet.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse('0.0.16'): + logger.warn( + 'xFormers 0.0.16 cannot be used for training in some GPUs. \ + If you observe problems during training, please update xFormers to at least 0.0.17. \ + See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.' + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + 'xformers is not available. Make sure it is installed correctly' + ) + + lora_layers = filter(lambda p: p.requires_grad, unet.parameters()) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps + * args.train_batch_size * accelerator.num_processes) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + 'Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`' + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + lora_layers, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + def path_to_img(example): + example['image'] = Image.open(example['image:FILE']) + return example + + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = MsDataset.load( + args.dataset_name, + args.dataset_config_name, + data_dir=args.train_data_dir, + ) + if isinstance(dataset, dict): + dataset = { + key: value.to_hf_dataset() + for key, value in dataset.items() + } + else: + dataset = {'train': dataset.to_hf_dataset()} + else: + data_files = {} + if args.train_data_dir is not None: + data_files['train'] = os.path.join(args.train_data_dir, '**') + dataset = load_dataset( + 'imagefolder', + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset['train'].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[ + 1] if dataset_columns is not None else column_names[1] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[ + 0] if dataset_columns is not None else column_names[0] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + if image_column.endswith(':FILE'): + dataset['train'] = dataset['train'].map(path_to_img) + image_column = 'image' + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append( + random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f'Caption column `{caption_column}` should contain either strings or lists of strings.' + ) + inputs = tokenizer( + captions, + max_length=tokenizer.model_max_length, + padding='max_length', + truncation=True, + return_tensors='pt') + return inputs.input_ids + + # Preprocessing the datasets. + train_transforms = transforms.Compose([ + transforms.Resize( + args.resolution, + interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) + if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() + if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def preprocess_train(examples): + images = [image.convert('RGB') for image in examples[image_column]] + examples['pixel_values'] = [ + train_transforms(image) for image in images + ] + examples['input_ids'] = tokenize_captions(examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset['train'] = dataset['train'].shuffle(seed=args.seed).select( + range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset['train'].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack( + [example['pixel_values'] for example in examples]) + pixel_values = pixel_values.to( + memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example['input_ids'] for example in examples]) + return {'pixel_values': pixel_values, 'input_ids': input_ids} + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps + / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers('text2image-fine-tune', config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info('***** Running training *****') + logger.info(f' Num examples = {len(train_dataset)}') + logger.info(f' Num Epochs = {args.num_train_epochs}') + logger.info( + f' Instantaneous batch size per device = {args.train_batch_size}') + logger.info( + f' Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}' + ) + logger.info( + f' Gradient Accumulation steps = {args.gradient_accumulation_steps}') + logger.info(f' Total optimization steps = {args.max_train_steps}') + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != 'latest': + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith('checkpoint')] + dirs = sorted(dirs, key=lambda x: int(x.split('-')[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f'Resuming from checkpoint {path}') + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split('-')[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc='Steps', + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch['pixel_values'].to( + dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), + device=latents.device) + + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, (bsz, ), + device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise( + latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch['input_ids'])[0] + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config( + prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == 'epsilon': + target = noise + elif noise_scheduler.config.prediction_type == 'v_prediction': + target = noise_scheduler.get_velocity( + latents, noise, timesteps) + else: + raise ValueError( + f'Unknown prediction type {noise_scheduler.config.prediction_type}' + ) + + # Predict the noise residual and compute loss + model_pred = unet(noisy_latents, timesteps, + encoder_hidden_states).sample + + if args.snr_gamma is None: + loss = F.mse_loss( + model_pred.float(), target.float(), reduction='mean') + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + if noise_scheduler.config.prediction_type == 'v_prediction': + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], + dim=1).min(dim=1)[0] / snr) + + loss = F.mse_loss( + model_pred.float(), target.float(), reduction='none') + loss = loss.mean( + dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather( + loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item( + ) / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = lora_layers + accelerator.clip_grad_norm_(params_to_clip, + args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({'train_loss': train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [ + d for d in checkpoints + if d.startswith('checkpoint') + ] + checkpoints = sorted( + checkpoints, + key=lambda x: int(x.split('-')[1])) + + # before we save the new checkpoint, we need to have at _most_ \ + # `checkpoints_total_limit - 1` checkpoints + if len(checkpoints + ) >= args.checkpoints_total_limit: + num_to_remove = len( + checkpoints + ) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[ + 0:num_to_remove] + + logger.info( + f'{len(checkpoints)} checkpoints already exist, ' + f'removing {len(removing_checkpoints)} checkpoints' + ) + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, + f'checkpoint-{global_step}') + accelerator.save_state(save_path) + + unet.save_pretrained(save_path) + logger.info(f'Saved state to {save_path}') + + logs = { + 'step_loss': loss.detach().item(), + 'lr': lr_scheduler.get_last_lr()[0] + } + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f'Running validation... \n Generating {args.num_validation_images} images with prompt:' + f' {args.validation_prompt}.') + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet.base_model), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append( + pipeline( + args.validation_prompt, + num_inference_steps=30, + generator=generator).images[0]) + + for tracker in accelerator.trackers: + if tracker.name == 'tensorboard': + np_images = np.stack( + [np.asarray(img) for img in images]) + tracker.writer.add_images( + 'validation', np_images, epoch, dataformats='NHWC') + if tracker.name == 'wandb': + tracker.log({ + 'validation': [ + wandb.Image( + image, + caption=f'{i}: {args.validation_prompt}') + for i, image in enumerate(images) + ] + }) + + del pipeline + torch.cuda.empty_cache() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unet.to(torch.float32) + + unet.save_pretrained(args.output_dir) + + if args.push_to_hub: + save_model_card( + args.hub_model_id, + images=images, + base_model=args.base_model_id, + dataset_name=args.dataset_name, + repo_folder=args.output_dir, + ) + push_to_hub(args.hub_model_id, args.output_dir, args.hub_token) + + # Final inference + # Load previous pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype) + pipeline = pipeline.to(accelerator.device) + + # load attention processors + pipeline.unet = Swift.from_pretrained(pipeline.unet, args.output_dir) + + # run inference + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append( + pipeline( + args.validation_prompt, + num_inference_steps=30, + generator=generator).images[0]) + + if accelerator.is_main_process: + for tracker in accelerator.trackers: + if len(images) != 0: + if tracker.name == 'tensorboard': + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images( + 'test', np_images, epoch, dataformats='NHWC') + if tracker.name == 'wandb': + tracker.log({ + 'test': [ + wandb.Image( + image, + caption=f'{i}: {args.validation_prompt}') + for i, image in enumerate(images) + ] + }) + + accelerator.end_training() diff --git a/swift/aigc/diffusers/train_text_to_image_lora_sdxl.py b/swift/aigc/diffusers/train_text_to_image_lora_sdxl.py new file mode 100644 index 0000000000..735c680697 --- /dev/null +++ b/swift/aigc/diffusers/train_text_to_image_lora_sdxl.py @@ -0,0 +1,1380 @@ +#!/usr/bin/env python +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA.""" + +import argparse +import logging +import math +import os +import random +import shutil +from pathlib import Path +from typing import Dict + +import datasets +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import (DistributedDataParallelKwargs, + ProjectConfiguration, set_seed) +from datasets import load_dataset +from diffusers import (AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, + UNet2DConditionModel) +from diffusers.loaders import LoraLoaderMixin +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_snr +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from modelscope import AutoTokenizer, MsDataset +from packaging import version +from PIL import Image +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import PretrainedConfig + +from swift import (LoRAConfig, Swift, get_peft_model_state_dict, push_to_hub, + snapshot_download) + +logger = get_logger(__name__) + + +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, + (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f'text_model.encoder.layers.{i}.self_attn' + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f'{name}.q_proj.lora_linear_layer.{k}'] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f'{name}.k_proj.lora_linear_layer.{k}'] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f'{name}.v_proj.lora_linear_layer.{k}'] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f'{name}.out_proj.lora_linear_layer.{k}'] = v + + return state_dict + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + dataset_name=str, + train_text_encoder=False, + repo_folder=None, + vae_path=None, +): + img_str = '' + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f'image_{i}.png')) + img_str += f'![img_{i}](./image_{i}.png)\n' + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +dataset: {dataset_name} +tags: +- stable-diffusion-xl +- stable-diffusion-xl-diffusers +- text-to-image +- diffusers +- lora +inference: true +--- + """ + model_card = f""" +# LoRA text2image fine-tuning - {repo_id} + +These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. +You can find some example images in the following. \n +{img_str} + +LoRA for the text encoder was enabled: {train_text_encoder}. + +Special VAE used for training: {vae_path}. +""" + with open(os.path.join(repo_folder, 'README.md'), 'w') as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, + revision: str, + subfolder: str = 'text_encoder'): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision) + model_class = text_encoder_config.architectures[0] + + if model_class == 'CLIPTextModel': + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == 'CLIPTextModelWithProjection': + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f'{model_class} is not supported.') + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser( + description='Simple example of a training script.') + parser.add_argument( + '--pretrained_model_name_or_path', + type=str, + default=None, + required=True, + help= + 'Path to pretrained model or model identifier from huggingface.co/models.', + ) + parser.add_argument( + '--pretrained_vae_model_name_or_path', + type=str, + default=None, + help='Path to pretrained VAE model with better numerical stability. \ + More details: https://github.com/huggingface/diffusers/pull/4038.', + ) + parser.add_argument( + '--revision', + type=str, + default=None, + required=False, + help= + 'Revision of pretrained model identifier from huggingface.co/models.', + ) + parser.add_argument( + '--variant', + type=str, + default=None, + help= + "Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + '--dataset_name', + type=str, + default=None, + help= + ('The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,' + ' dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,' + ' or to a folder containing files that 🤗 Datasets can understand.'), + ) + parser.add_argument( + '--dataset_config_name', + type=str, + default=None, + help= + "The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + '--train_data_dir', + type=str, + default=None, + help= + ('A folder containing the training data. Folder contents must follow the structure described in' + ' https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file' + ' must exist to provide the captions for the images. Ignored if `dataset_name` is specified.' + ), + ) + parser.add_argument( + '--image_column', + type=str, + default='image:FILE', + help='The column of the dataset containing an image.') + parser.add_argument( + '--caption_column', + type=str, + default='text', + help= + 'The column of the dataset containing a caption or a list of captions.', + ) + parser.add_argument( + '--validation_prompt', + type=str, + default=None, + help= + 'A prompt that is used during validation to verify that the model is learning.', + ) + parser.add_argument( + '--num_validation_images', + type=int, + default=4, + help= + 'Number of images that should be generated during validation with `validation_prompt`.', + ) + parser.add_argument( + '--validation_epochs', + type=int, + default=1, + help= + ('Run fine-tuning validation every X epochs. The validation process consists of running the prompt' + ' `args.validation_prompt` multiple times: `args.num_validation_images`.' + ), + ) + parser.add_argument( + '--max_train_samples', + type=int, + default=None, + help= + ('For debugging purposes or quicker training, truncate the number of training examples to this ' + 'value if set.'), + ) + parser.add_argument( + '--output_dir', + type=str, + default='sd-model-finetuned-lora', + help= + 'The output directory where the model predictions and checkpoints will be written.', + ) + parser.add_argument( + '--cache_dir', + type=str, + default=None, + help= + 'The directory where the downloaded models and datasets will be stored.', + ) + parser.add_argument( + '--seed', + type=int, + default=None, + help='A seed for reproducible training.') + parser.add_argument( + '--resolution', + type=int, + default=1024, + help= + ('The resolution for input images, all the images in the train/validation dataset will be resized to this' + ' resolution'), + ) + parser.add_argument( + '--center_crop', + default=False, + action='store_true', + help= + ('Whether to center crop the input images to the resolution. If not set, the images will be randomly' + ' cropped. The images will be resized to the resolution first before cropping.' + ), + ) + parser.add_argument( + '--random_flip', + action='store_true', + help='whether to randomly flip images horizontally', + ) + parser.add_argument( + '--train_text_encoder', + action='store_true', + help= + 'Whether to train the text encoder. If set, the text encoder should be float32 precision.', + ) + parser.add_argument( + '--train_batch_size', + type=int, + default=16, + help='Batch size (per device) for the training dataloader.') + parser.add_argument('--num_train_epochs', type=int, default=100) + parser.add_argument( + '--max_train_steps', + type=int, + default=None, + help= + 'Total number of training steps to perform. If provided, overrides num_train_epochs.', + ) + parser.add_argument( + '--checkpointing_steps', + type=int, + default=500, + help= + ('Save a checkpoint of the training state every X updates. These checkpoints can be used both as final' + ' checkpoints in case they are better than the last checkpoint, and are also suitable for resuming' + ' training using `--resume_from_checkpoint`.'), + ) + parser.add_argument( + '--checkpoints_total_limit', + type=int, + default=None, + help=('Max number of checkpoints to store.'), + ) + parser.add_argument( + '--resume_from_checkpoint', + type=str, + default=None, + help= + ('Whether training should be resumed from a previous checkpoint. Use a path saved by' + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + '--gradient_accumulation_steps', + type=int, + default=1, + help= + 'Number of updates steps to accumulate before performing a backward/update pass.', + ) + parser.add_argument( + '--gradient_checkpointing', + action='store_true', + help= + 'Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.', + ) + parser.add_argument( + '--learning_rate', + type=float, + default=1e-4, + help= + 'Initial learning rate (after the potential warmup period) to use.', + ) + parser.add_argument( + '--scale_lr', + action='store_true', + default=False, + help= + 'Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.', + ) + parser.add_argument( + '--lr_scheduler', + type=str, + default='constant', + help= + ('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument( + '--lr_warmup_steps', + type=int, + default=500, + help='Number of steps for the warmup in the lr scheduler.') + parser.add_argument( + '--snr_gamma', + type=float, + default=None, + help= + 'SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. ' + 'More details here: https://arxiv.org/abs/2303.09556.', + ) + parser.add_argument( + '--allow_tf32', + action='store_true', + help= + ('Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see' + ' https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices' + ), + ) + parser.add_argument( + '--dataloader_num_workers', + type=int, + default=0, + help= + ('Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.' + ), + ) + parser.add_argument( + '--use_8bit_adam', + action='store_true', + help='Whether or not to use 8-bit Adam from bitsandbytes.') + parser.add_argument( + '--adam_beta1', + type=float, + default=0.9, + help='The beta1 parameter for the Adam optimizer.') + parser.add_argument( + '--adam_beta2', + type=float, + default=0.999, + help='The beta2 parameter for the Adam optimizer.') + parser.add_argument( + '--adam_weight_decay', + type=float, + default=1e-2, + help='Weight decay to use.') + parser.add_argument( + '--adam_epsilon', + type=float, + default=1e-08, + help='Epsilon value for the Adam optimizer') + parser.add_argument( + '--max_grad_norm', default=1.0, type=float, help='Max gradient norm.') + parser.add_argument( + '--push_to_hub', + action='store_true', + help='Whether or not to push the model to the Hub.') + parser.add_argument( + '--hub_token', + type=str, + default=None, + help='The token to use to push to the Model Hub.') + parser.add_argument( + '--prediction_type', + type=str, + default=None, + help= + "The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or \ + leave `None`. If left to `None` the default prediction type of the scheduler: \ + `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + '--hub_model_id', + type=str, + default=None, + help= + 'The name of the repository to keep in sync with the local `output_dir`.', + ) + parser.add_argument( + '--logging_dir', + type=str, + default='logs', + help= + ('[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to' + ' *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.'), + ) + parser.add_argument( + '--report_to', + type=str, + default='tensorboard', + help= + ('The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + '--mixed_precision', + type=str, + default=None, + choices=['no', 'fp16', 'bf16'], + help= + ('Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=' + ' 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the' + ' flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.' + ), + ) + parser.add_argument( + '--local_rank', + type=int, + default=-1, + help='For distributed training: local_rank') + parser.add_argument( + '--enable_xformers_memory_efficient_attention', + action='store_true', + help='Whether or not to use xformers.') + parser.add_argument( + '--noise_offset', + type=float, + default=0, + help='The scale of noise offset.') + parser.add_argument( + '--rank', + type=int, + default=4, + help=('The dimension of the LoRA update matrices.'), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get('LOCAL_RANK', -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError('Need either a dataset name or a training folder.') + + args.base_model_id = args.pretrained_model_name_or_path + if not os.path.exists(args.pretrained_model_name_or_path): + args.pretrained_model_name_or_path = snapshot_download( + args.pretrained_model_name_or_path, revision=args.revision) + + args.vae_base_model_id = args.pretrained_vae_model_name_or_path + if args.pretrained_vae_model_name_or_path and not os.path.exists( + args.pretrained_vae_model_name_or_path): + args.pretrained_vae_model_name_or_path = snapshot_download( + args.pretrained_vae_model_name_or_path) + return args + + +DATASET_NAME_MAPPING = { + 'AI-ModelScope/pokemon-blip-captions': ('text', 'image:FILE'), +} + + +def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]: + """ + Returns: + a state dict containing just the attention processor parameters. + """ + attn_processors = unet.attn_processors + + attn_processors_state_dict = {} + + for attn_processor_key, attn_processor in attn_processors.items(): + for parameter_key, parameter in attn_processor.state_dict().items(): + attn_processors_state_dict[ + f'{attn_processor_key}.{parameter_key}'] = parameter + + return attn_processors_state_dict + + +def tokenize_prompt(tokenizer, prompt): + text_inputs = tokenizer( + prompt, + padding='max_length', + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors='pt', + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): + prompt_embeds_list = [] + + for i, text_encoder in enumerate(text_encoders): + if tokenizers is not None: + tokenizer = tokenizers[i] + text_input_ids = tokenize_prompt(tokenizer, prompt) + else: + assert text_input_ids_list is not None + text_input_ids = text_input_ids_list[i] + + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + if args.report_to == 'wandb': + if not is_wandb_available(): + raise ImportError( + 'Make sure to install wandb if you want to use it for logging during training.' + ) + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='tokenizer', + revision=args.revision, + use_fast=False, + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='tokenizer_2', + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, + args.revision, + subfolder='text_encoder_2') + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder='scheduler') + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='text_encoder', + revision=args.revision, + variant=args.variant) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='text_encoder_2', + revision=args.revision, + variant=args.variant) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None else + args.pretrained_vae_model_name_or_path) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder='vae' + if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='unet', + revision=args.revision, + variant=args.variant) + + # We only train the additional adapter LoRA layers + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to + # half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == 'fp16': + weight_dtype = torch.float16 + elif accelerator.mixed_precision == 'bf16': + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + unet.to(accelerator.device, dtype=weight_dtype) + if args.pretrained_vae_model_name_or_path is None: + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse('0.0.16'): + logger.warn( + 'xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, \ + please update xFormers to at least 0.0.17. \ + See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.' + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + 'xformers is not available. Make sure it is installed correctly' + ) + + # now we will add new LoRA weights to the attention layers + # Set correct lora layers + unet_lora_config = LoRAConfig( + r=args.rank, + init_lora_weights='gaussian', + target_modules=['to_k', 'to_q', 'to_v', 'to_out.0']) + + unet = Swift.prepare_model(unet, unet_lora_config) + if args.mixed_precision == "fp16": + for param in unet.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + + # The text encoder comes from 🤗 transformers, we will also attach adapters to it. + if args.train_text_encoder: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + text_lora_config = LoRAConfig( + r=args.rank, + init_lora_weights='gaussian', + target_modules=['q_proj', 'k_proj', 'v_proj', 'out_proj']) + text_encoder_one = Swift.prepare_model(text_encoder_one, + text_lora_config) + text_encoder_two = Swift.prepare_model(text_encoder_two, + text_lora_config) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps + * args.train_batch_size * accelerator.num_processes) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + 'To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.' + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = list( + filter(lambda p: p.requires_grad, unet.parameters())) + if args.train_text_encoder: + params_to_optimize = ( + params_to_optimize + list( + filter(lambda p: p.requires_grad, + text_encoder_one.parameters())) + list( + filter(lambda p: p.requires_grad, + text_encoder_two.parameters()))) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + def path_to_img(example): + example['image'] = Image.open(example['image:FILE']) + return example + + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = MsDataset.load( + args.dataset_name, + args.dataset_config_name, + data_dir=args.train_data_dir, + ) + if isinstance(dataset, dict): + dataset = { + key: value.to_hf_dataset() + for key, value in dataset.items() + } + else: + dataset = {'train': dataset.to_hf_dataset()} + else: + data_files = {} + if args.train_data_dir is not None: + data_files['train'] = os.path.join(args.train_data_dir, '**') + dataset = load_dataset( + 'imagefolder', + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset['train'].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[ + 1] if dataset_columns is not None else column_names[1] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[ + 0] if dataset_columns is not None else column_names[0] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + if image_column.endswith(':FILE'): + dataset['train'] = dataset['train'].map(path_to_img) + image_column = 'image' + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append( + random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f'Caption column `{caption_column}` should contain either strings or lists of strings.' + ) + tokens_one = tokenize_prompt(tokenizer_one, captions) + tokens_two = tokenize_prompt(tokenizer_two, captions) + return tokens_one, tokens_two + + # Preprocessing the datasets. + train_resize = transforms.Resize( + args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop( + args.resolution) if args.center_crop else transforms.RandomCrop( + args.resolution) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def preprocess_train(examples): + images = [image.convert('RGB') for image in examples[image_column]] + # image aug + original_sizes = [] + all_images = [] + crop_top_lefts = [] + for image in images: + original_sizes.append((image.height, image.width)) + image = train_resize(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params( + image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + if args.random_flip and random.random() < 0.5: + # flip + x1 = image.width - x1 + image = train_flip(image) + crop_top_left = (y1, x1) + crop_top_lefts.append(crop_top_left) + image = train_transforms(image) + all_images.append(image) + + examples['original_sizes'] = original_sizes + examples['crop_top_lefts'] = crop_top_lefts + examples['pixel_values'] = all_images + tokens_one, tokens_two = tokenize_captions(examples) + examples['input_ids_one'] = tokens_one + examples['input_ids_two'] = tokens_two + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset['train'] = dataset['train'].shuffle(seed=args.seed).select( + range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset['train'].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack( + [example['pixel_values'] for example in examples]) + pixel_values = pixel_values.to( + memory_format=torch.contiguous_format).float() + original_sizes = [example['original_sizes'] for example in examples] + crop_top_lefts = [example['crop_top_lefts'] for example in examples] + input_ids_one = torch.stack( + [example['input_ids_one'] for example in examples]) + input_ids_two = torch.stack( + [example['input_ids_two'] for example in examples]) + return { + 'pixel_values': pixel_values, + 'input_ids_one': input_ids_one, + 'input_ids_two': input_ids_two, + 'original_sizes': original_sizes, + 'crop_top_lefts': crop_top_lefts, + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps + * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps + * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder_one, text_encoder_two, optimizer, + train_dataloader, lr_scheduler) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps + / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers('text2image-fine-tune', config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info('***** Running training *****') + logger.info(f' Num examples = {len(train_dataset)}') + logger.info(f' Num Epochs = {args.num_train_epochs}') + logger.info( + f' Instantaneous batch size per device = {args.train_batch_size}') + logger.info( + f' Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}' + ) + logger.info( + f' Gradient Accumulation steps = {args.gradient_accumulation_steps}') + logger.info(f' Total optimization steps = {args.max_train_steps}') + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != 'latest': + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith('checkpoint')] + dirs = sorted(dirs, key=lambda x: int(x.split('-')[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f'Resuming from checkpoint {path}') + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split('-')[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc='Steps', + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder_one.train() + text_encoder_two.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = batch['pixel_values'].to(dtype=weight_dtype) + else: + pixel_values = batch['pixel_values'] + + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (model_input.shape[0], model_input.shape[1], 1, 1), + device=model_input.device) + + bsz = model_input.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, (bsz, ), + device=model_input.device) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise( + model_input, noise, timesteps) + + # time ids + def compute_time_ids(original_size, crops_coords_top_left): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + target_size = (args.resolution, args.resolution) + add_time_ids = list(original_size + crops_coords_top_left + + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to( + accelerator.device, dtype=weight_dtype) + return add_time_ids + + add_time_ids = torch.cat([ + compute_time_ids(s, c) for s, c in zip( + batch['original_sizes'], batch['crop_top_lefts']) + ]) + + # Predict the noise residual + unet_added_conditions = {'time_ids': add_time_ids} + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=None, + prompt=None, + text_input_ids_list=[ + batch['input_ids_one'], batch['input_ids_two'] + ], + ) + unet_added_conditions.update( + {'text_embeds': pooled_prompt_embeds}) + model_pred = unet( + noisy_model_input, + timesteps, + prompt_embeds, + added_cond_kwargs=unet_added_conditions).sample + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config( + prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == 'epsilon': + target = noise + elif noise_scheduler.config.prediction_type == 'v_prediction': + target = noise_scheduler.get_velocity( + model_input, noise, timesteps) + else: + raise ValueError( + f'Unknown prediction type {noise_scheduler.config.prediction_type}' + ) + + if args.snr_gamma is None: + loss = F.mse_loss( + model_pred.float(), target.float(), reduction='mean') + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + if noise_scheduler.config.prediction_type == 'v_prediction': + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], + dim=1).min(dim=1)[0] / snr) + + loss = F.mse_loss( + model_pred.float(), target.float(), reduction='none') + loss = loss.mean( + dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather( + loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item( + ) / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(params_to_optimize, + args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({'train_loss': train_loss}, step=global_step) + train_loss = 0.0 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [ + d for d in checkpoints + if d.startswith('checkpoint') + ] + checkpoints = sorted( + checkpoints, + key=lambda x: int(x.split('-')[1])) + + # before we save the new checkpoint, we need to have at _most_ \ + # `checkpoints_total_limit - 1` checkpoints + if len(checkpoints + ) >= args.checkpoints_total_limit: + num_to_remove = len( + checkpoints + ) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[ + 0:num_to_remove] + + logger.info( + f'{len(checkpoints)} checkpoints already exist, ' + f'removing {len(removing_checkpoints)} checkpoints' + ) + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, + f'checkpoint-{global_step}') + accelerator.save_state(save_path) + logger.info(f'Saved state to {save_path}') + + logs = { + 'step_loss': loss.detach().item(), + 'lr': lr_scheduler.get_last_lr()[0] + } + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f'Running validation... \n Generating {args.num_validation_images} images with prompt:' + f' {args.validation_prompt}.') + # create pipeline + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder_one), + text_encoder_2=accelerator.unwrap_model(text_encoder_two), + unet=accelerator.unwrap_model(unet.base_model), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator( + device=accelerator.device).manual_seed( + args.seed) if args.seed else None + pipeline_args = {'prompt': args.validation_prompt} + + with torch.cuda.amp.autocast(): + images = [ + pipeline(**pipeline_args, + generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == 'tensorboard': + np_images = np.stack( + [np.asarray(img) for img in images]) + tracker.writer.add_images( + 'validation', np_images, epoch, dataformats='NHWC') + if tracker.name == 'wandb': + tracker.log({ + 'validation': [ + wandb.Image( + image, + caption=f'{i}: {args.validation_prompt}') + for i, image in enumerate(images) + ] + }) + + del pipeline + torch.cuda.empty_cache() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet.save_pretrained(os.path.join(args.output_dir, 'unet')) + + if args.train_text_encoder: + text_encoder_one = accelerator.unwrap_model(text_encoder_one) + text_encoder_one.save_pretrained( + os.path.join(args.output_dir, 'text_encoder1')) + text_encoder_two = accelerator.unwrap_model(text_encoder_two) + text_encoder_two.save_pretrained( + os.path.join(args.output_dir, 'text_encoder2')) + + del unet + del text_encoder_one + del text_encoder_two + torch.cuda.empty_cache() + + # Final inference + # Load previous pipeline + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + + # load attention processors + pipeline.unet = Swift.from_pretrained( + pipeline.unet, os.path.join(args.output_dir, 'unet')) + if args.train_text_encoder: + pipeline.text_encoder_one = Swift.from_pretrained( + pipeline.text_encoder_one, + os.path.join(args.output_dir, 'text_encoder1')) + pipeline.text_encoder_two = Swift.from_pretrained( + pipeline.text_encoder_two, + os.path.join(args.output_dir, 'text_encoder2')) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + generator = torch.Generator(device=accelerator.device).manual_seed( + args.seed) if args.seed else None + images = [ + pipeline( + args.validation_prompt, + num_inference_steps=25, + generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == 'tensorboard': + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images( + 'test', np_images, epoch, dataformats='NHWC') + if tracker.name == 'wandb': + tracker.log({ + 'test': [ + wandb.Image( + image, + caption=f'{i}: {args.validation_prompt}') + for i, image in enumerate(images) + ] + }) + + if args.push_to_hub: + save_model_card( + args.hub_model_id, + images=images, + base_model=args.base_model_id, + dataset_name=args.dataset_name, + train_text_encoder=args.train_text_encoder, + repo_folder=args.output_dir, + vae_path=args.vae_base_model_id, + ) + push_to_hub( + args.hub_model_id, + args.output_dir, + args.hub_token, + ) + + accelerator.end_training() diff --git a/swift/aigc/diffusers/train_text_to_image_sdxl.py b/swift/aigc/diffusers/train_text_to_image_sdxl.py new file mode 100644 index 0000000000..efef370846 --- /dev/null +++ b/swift/aigc/diffusers/train_text_to_image_sdxl.py @@ -0,0 +1,1463 @@ +#!/usr/bin/env python +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-tuning script for Stable Diffusion XL for text2image.""" + +import argparse +import functools +import gc +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import datasets +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from diffusers import (AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, + UNet2DConditionModel) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_snr +from diffusers.utils import is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from modelscope import AutoTokenizer, MsDataset +from packaging import version +from PIL import Image +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import PretrainedConfig + +from swift import push_to_hub, snapshot_download + +logger = get_logger(__name__) + +DATASET_NAME_MAPPING = { + 'AI-ModelScope/pokemon-blip-captions': ('text', 'image:FILE'), +} + + +def save_model_card( + repo_id: str, + images=None, + validation_prompt=None, + base_model=str, + dataset_name=str, + repo_folder=None, + vae_path=None, +): + img_str = '' + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f'image_{i}.png')) + img_str += f'![img_{i}](./image_{i}.png)\n' + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +dataset: {dataset_name} +tags: +- stable-diffusion-xl +- stable-diffusion-xl-diffusers +- text-to-image +- diffusers +inference: true +--- + """ + model_card = f""" +# Text-to-image finetuning - {repo_id} + +This pipeline was finetuned from **{base_model}** on the **{args.dataset_name}** dataset. Below are some example images +generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n +{img_str} + +Special VAE used for training: {vae_path}. +""" + with open(os.path.join(repo_folder, 'README.md'), 'w') as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, + revision: str, + subfolder: str = 'text_encoder'): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision) + model_class = text_encoder_config.architectures[0] + + if model_class == 'CLIPTextModel': + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == 'CLIPTextModelWithProjection': + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f'{model_class} is not supported.') + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser( + description='Simple example of a training script.') + parser.add_argument( + '--pretrained_model_name_or_path', + type=str, + default=None, + required=True, + help= + 'Path to pretrained model or model identifier from huggingface.co/models.', + ) + parser.add_argument( + '--pretrained_vae_model_name_or_path', + type=str, + default=None, + help='Path to pretrained VAE model with better numerical stability. \ + More details: https://github.com/huggingface/diffusers/pull/4038.', + ) + parser.add_argument( + '--revision', + type=str, + default=None, + required=False, + help= + 'Revision of pretrained model identifier from huggingface.co/models.', + ) + parser.add_argument( + '--variant', + type=str, + default=None, + help= + "Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + '--dataset_name', + type=str, + default=None, + help= + ('The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,' + ' dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,' + ' or to a folder containing files that 🤗 Datasets can understand.'), + ) + parser.add_argument( + '--dataset_config_name', + type=str, + default=None, + help= + "The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + '--train_data_dir', + type=str, + default=None, + help= + ('A folder containing the training data. Folder contents must follow the structure described in' + ' https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file' + ' must exist to provide the captions for the images. Ignored if `dataset_name` is specified.' + ), + ) + parser.add_argument( + '--image_column', + type=str, + default='image:FILE', + help='The column of the dataset containing an image.') + parser.add_argument( + '--caption_column', + type=str, + default='text', + help= + 'The column of the dataset containing a caption or a list of captions.', + ) + parser.add_argument( + '--validation_prompt', + type=str, + default=None, + help= + 'A prompt that is used during validation to verify that the model is learning.', + ) + parser.add_argument( + '--num_validation_images', + type=int, + default=4, + help= + 'Number of images that should be generated during validation with `validation_prompt`.', + ) + parser.add_argument( + '--validation_epochs', + type=int, + default=1, + help= + ('Run fine-tuning validation every X epochs. The validation process consists of running the prompt' + ' `args.validation_prompt` multiple times: `args.num_validation_images`.' + ), + ) + parser.add_argument( + '--max_train_samples', + type=int, + default=None, + help= + ('For debugging purposes or quicker training, truncate the number of training examples to this ' + 'value if set.'), + ) + parser.add_argument( + '--proportion_empty_prompts', + type=float, + default=0, + help= + 'Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).', + ) + parser.add_argument( + '--output_dir', + type=str, + default='sdxl-model-finetuned', + help= + 'The output directory where the model predictions and checkpoints will be written.', + ) + parser.add_argument( + '--cache_dir', + type=str, + default=None, + help= + 'The directory where the downloaded models and datasets will be stored.', + ) + parser.add_argument( + '--seed', + type=int, + default=None, + help='A seed for reproducible training.') + parser.add_argument( + '--resolution', + type=int, + default=1024, + help= + ('The resolution for input images, all the images in the train/validation dataset will be resized to this' + ' resolution'), + ) + parser.add_argument( + '--center_crop', + default=False, + action='store_true', + help= + ('Whether to center crop the input images to the resolution. If not set, the images will be randomly' + ' cropped. The images will be resized to the resolution first before cropping.' + ), + ) + parser.add_argument( + '--random_flip', + action='store_true', + help='whether to randomly flip images horizontally', + ) + parser.add_argument( + '--train_batch_size', + type=int, + default=16, + help='Batch size (per device) for the training dataloader.') + parser.add_argument('--num_train_epochs', type=int, default=100) + parser.add_argument( + '--max_train_steps', + type=int, + default=None, + help= + 'Total number of training steps to perform. If provided, overrides num_train_epochs.', + ) + parser.add_argument( + '--checkpointing_steps', + type=int, + default=500, + help= + ('Save a checkpoint of the training state every X updates. These checkpoints can be used both as final' + ' checkpoints in case they are better than the last checkpoint, and are also suitable for resuming' + ' training using `--resume_from_checkpoint`.'), + ) + parser.add_argument( + '--checkpoints_total_limit', + type=int, + default=None, + help=('Max number of checkpoints to store.'), + ) + parser.add_argument( + '--resume_from_checkpoint', + type=str, + default=None, + help= + ('Whether training should be resumed from a previous checkpoint. Use a path saved by' + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + '--gradient_accumulation_steps', + type=int, + default=1, + help= + 'Number of updates steps to accumulate before performing a backward/update pass.', + ) + parser.add_argument( + '--gradient_checkpointing', + action='store_true', + help= + 'Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.', + ) + parser.add_argument( + '--learning_rate', + type=float, + default=1e-4, + help= + 'Initial learning rate (after the potential warmup period) to use.', + ) + parser.add_argument( + '--scale_lr', + action='store_true', + default=False, + help= + 'Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.', + ) + parser.add_argument( + '--lr_scheduler', + type=str, + default='constant', + help= + ('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument( + '--lr_warmup_steps', + type=int, + default=500, + help='Number of steps for the warmup in the lr scheduler.') + parser.add_argument( + '--timestep_bias_strategy', + type=str, + default='none', + choices=['earlier', 'later', 'range', 'none'], + help= + ('The timestep bias strategy, which may help direct the model toward learning low or high frequency details.' + " Choices: ['earlier', 'later', 'range', 'none']." + " The default is 'none', which means no bias is applied, and training proceeds normally." + " The value of 'later' will increase the frequency of the model's final training timesteps." + ), + ) + parser.add_argument( + '--timestep_bias_multiplier', + type=float, + default=1.0, + help= + ('The multiplier for the bias. Defaults to 1.0, which means no bias is applied.' + ' A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it.' + ), + ) + parser.add_argument( + '--timestep_bias_begin', + type=int, + default=0, + help= + ('When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias.' + ' Defaults to zero, which equates to having no specific bias.'), + ) + parser.add_argument( + '--timestep_bias_end', + type=int, + default=1000, + help= + ('When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias.' + ' Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on.' + ), + ) + parser.add_argument( + '--timestep_bias_portion', + type=float, + default=0.25, + help= + ('The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased.' + ' A value of 0.5 will bias one half of the timesteps. ' + 'The value provided for `--timestep_bias_strategy` determines' + ' whether the biased portions are in the earlier or later timesteps.' + ), + ) + parser.add_argument( + '--snr_gamma', + type=float, + default=None, + help= + 'SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. ' + 'More details here: https://arxiv.org/abs/2303.09556.', + ) + parser.add_argument( + '--use_ema', action='store_true', help='Whether to use EMA model.') + parser.add_argument( + '--allow_tf32', + action='store_true', + help= + ('Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see' + ' https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices' + ), + ) + parser.add_argument( + '--dataloader_num_workers', + type=int, + default=0, + help= + ('Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.' + ), + ) + parser.add_argument( + '--use_8bit_adam', + action='store_true', + help='Whether or not to use 8-bit Adam from bitsandbytes.') + parser.add_argument( + '--adam_beta1', + type=float, + default=0.9, + help='The beta1 parameter for the Adam optimizer.') + parser.add_argument( + '--adam_beta2', + type=float, + default=0.999, + help='The beta2 parameter for the Adam optimizer.') + parser.add_argument( + '--adam_weight_decay', + type=float, + default=1e-2, + help='Weight decay to use.') + parser.add_argument( + '--adam_epsilon', + type=float, + default=1e-08, + help='Epsilon value for the Adam optimizer') + parser.add_argument( + '--max_grad_norm', default=1.0, type=float, help='Max gradient norm.') + parser.add_argument( + '--push_to_hub', + action='store_true', + help='Whether or not to push the model to the Hub.') + parser.add_argument( + '--hub_token', + type=str, + default=None, + help='The token to use to push to the Model Hub.') + parser.add_argument( + '--prediction_type', + type=str, + default=None, + help= + "The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or \ + leave `None`. If left to `None` the default prediction type of the scheduler: \ + `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + '--hub_model_id', + type=str, + default=None, + help= + 'The name of the repository to keep in sync with the local `output_dir`.', + ) + parser.add_argument( + '--logging_dir', + type=str, + default='logs', + help= + ('[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to' + ' *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.'), + ) + parser.add_argument( + '--report_to', + type=str, + default='tensorboard', + help= + ('The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + '--mixed_precision', + type=str, + default=None, + choices=['no', 'fp16', 'bf16'], + help= + ('Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=' + ' 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the' + ' flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.' + ), + ) + parser.add_argument( + '--local_rank', + type=int, + default=-1, + help='For distributed training: local_rank') + parser.add_argument( + '--enable_xformers_memory_efficient_attention', + action='store_true', + help='Whether or not to use xformers.') + parser.add_argument( + '--noise_offset', + type=float, + default=0, + help='The scale of noise offset.') + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get('LOCAL_RANK', -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError('Need either a dataset name or a training folder.') + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError( + '`--proportion_empty_prompts` must be in the range [0, 1].') + + args.base_model_id = args.pretrained_model_name_or_path + if not os.path.exists(args.pretrained_model_name_or_path): + args.pretrained_model_name_or_path = snapshot_download( + args.pretrained_model_name_or_path, revision=args.revision) + + args.vae_base_model_id = args.pretrained_vae_model_name_or_path + if args.pretrained_vae_model_name_or_path and not os.path.exists( + args.pretrained_vae_model_name_or_path): + args.pretrained_vae_model_name_or_path = snapshot_download( + args.pretrained_vae_model_name_or_path) + return args + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(batch, + text_encoders, + tokenizers, + proportion_empty_prompts, + caption_column, + is_train=True): + prompt_embeds_list = [] + prompt_batch = batch[caption_column] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append('') + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding='max_length', + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors='pt', + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return { + 'prompt_embeds': prompt_embeds.cpu(), + 'pooled_prompt_embeds': pooled_prompt_embeds.cpu() + } + + +def compute_vae_encodings(batch, vae): + images = batch.pop('pixel_values') + pixel_values = torch.stack(list(images)) + pixel_values = pixel_values.to( + memory_format=torch.contiguous_format).float() + pixel_values = pixel_values.to(vae.device, dtype=vae.dtype) + + with torch.no_grad(): + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + return {'model_input': model_input.cpu()} + + +def generate_timestep_weights(args, num_timesteps): + weights = torch.ones(num_timesteps) + + # Determine the indices to bias + num_to_bias = int(args.timestep_bias_portion * num_timesteps) + + if args.timestep_bias_strategy == 'later': + bias_indices = slice(-num_to_bias, None) + elif args.timestep_bias_strategy == 'earlier': + bias_indices = slice(0, num_to_bias) + elif args.timestep_bias_strategy == 'range': + # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500. + range_begin = args.timestep_bias_begin + range_end = args.timestep_bias_end + if range_begin < 0: + raise ValueError( + 'When using the range strategy for timestep bias, you must provide a beginning timestep greater \ + or equal to zero.') + if range_end > num_timesteps: + raise ValueError( + 'When using the range strategy for timestep bias, you must provide an ending timestep smaller than \ + the number of timesteps.') + bias_indices = slice(range_begin, range_end) + else: # 'none' or any other string + return weights + if args.timestep_bias_multiplier <= 0: + return ValueError( + 'The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific ' + 'timesteps.' + ' If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead.' + ' A timestep bias multiplier less than or equal to 0 is not allowed.' + ) + + # Apply the bias + weights[bias_indices] *= args.timestep_bias_multiplier + + # Normalize + weights /= weights.sum() + + return weights + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if args.report_to == 'wandb': + if not is_wandb_available(): + raise ImportError( + 'Make sure to install wandb if you want to use it for logging during training.' + ) + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='tokenizer', + revision=args.revision, + use_fast=False, + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='tokenizer_2', + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, + args.revision, + subfolder='text_encoder_2') + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder='scheduler') + # Check for terminal SNR in combination with SNR Gamma + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='text_encoder', + revision=args.revision, + variant=args.variant) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='text_encoder_2', + revision=args.revision, + variant=args.variant) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None else + args.pretrained_vae_model_name_or_path) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder='vae' + if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='unet', + revision=args.revision, + variant=args.variant) + + # Freeze vae and text encoders. + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + # Set unet as trainable. + unet.train() + + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == 'fp16': + weight_dtype = torch.float16 + elif accelerator.mixed_precision == 'bf16': + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae.to(accelerator.device, dtype=torch.float32) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder='unet', + revision=args.revision, + variant=args.variant) + ema_unet = EMAModel( + ema_unet.parameters(), + model_cls=UNet2DConditionModel, + model_config=ema_unet.config) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse('0.0.16'): + logger.warn( + 'xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training,' + ' please update xFormers to at least 0.0.17. ' + 'See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.' + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + 'xformers is not available. Make sure it is installed correctly' + ) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse('0.16.0'): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_unet.save_pretrained( + os.path.join(output_dir, 'unet_ema')) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, 'unet')) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained( + os.path.join(input_dir, 'unet_ema'), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained( + input_dir, subfolder='unet') + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps + * args.train_batch_size * accelerator.num_processes) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + 'To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.' + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = unet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + def path_to_img(example): + example['image'] = Image.open(example['image:FILE']) + return example + + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = MsDataset.load( + args.dataset_name, + args.dataset_config_name, + data_dir=args.train_data_dir, + ) + if isinstance(dataset, dict): + dataset = { + key: value.to_hf_dataset() + for key, value in dataset.items() + } + else: + dataset = {'train': dataset.to_hf_dataset()} + else: + data_files = {} + if args.train_data_dir is not None: + data_files['train'] = os.path.join(args.train_data_dir, '**') + dataset = load_dataset( + 'imagefolder', + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset['train'].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[ + 1] if dataset_columns is not None else column_names[1] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[ + 0] if dataset_columns is not None else column_names[0] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + if image_column.endswith(':FILE'): + dataset['train'] = dataset['train'].map(path_to_img) + image_column = 'image' + + # Preprocessing the datasets. + train_resize = transforms.Resize( + args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop( + args.resolution) if args.center_crop else transforms.RandomCrop( + args.resolution) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize([0.5], [0.5])]) + + def preprocess_train(examples): + images = [image.convert('RGB') for image in examples[image_column]] + # image aug + original_sizes = [] + all_images = [] + crop_top_lefts = [] + for image in images: + original_sizes.append((image.height, image.width)) + image = train_resize(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params( + image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + if args.random_flip and random.random() < 0.5: + # flip + x1 = image.width - x1 + image = train_flip(image) + crop_top_left = (y1, x1) + crop_top_lefts.append(crop_top_left) + image = train_transforms(image) + all_images.append(image) + + examples['original_sizes'] = original_sizes + examples['crop_top_lefts'] = crop_top_lefts + examples['pixel_values'] = all_images + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset['train'] = dataset['train'].shuffle(seed=args.seed).select( + range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset['train'].with_transform(preprocess_train) + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. We will pre-compute the VAE encodings too. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + compute_embeddings_fn = functools.partial( + encode_prompt, + text_encoders=text_encoders, + tokenizers=tokenizers, + proportion_empty_prompts=args.proportion_empty_prompts, + caption_column=args.caption_column, + ) + compute_vae_encodings_fn = functools.partial( + compute_vae_encodings, vae=vae) + with accelerator.main_process_first(): + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + new_fingerprint_for_vae = Hasher.hash('vae') + train_dataset = train_dataset.map( + compute_embeddings_fn, + batched=True, + new_fingerprint=new_fingerprint) + train_dataset = train_dataset.map( + compute_vae_encodings_fn, + batched=True, + batch_size=args.train_batch_size * accelerator.num_processes + * args.gradient_accumulation_steps, + new_fingerprint=new_fingerprint_for_vae, + ) + + del text_encoders, tokenizers, vae + gc.collect() + torch.cuda.empty_cache() + + def collate_fn(examples): + model_input = torch.stack( + [torch.tensor(example['model_input']) for example in examples]) + original_sizes = [example['original_sizes'] for example in examples] + crop_top_lefts = [example['crop_top_lefts'] for example in examples] + prompt_embeds = torch.stack( + [torch.tensor(example['prompt_embeds']) for example in examples]) + pooled_prompt_embeds = torch.stack([ + torch.tensor(example['pooled_prompt_embeds']) + for example in examples + ]) + + return { + 'model_input': model_input, + 'prompt_embeds': prompt_embeds, + 'pooled_prompt_embeds': pooled_prompt_embeds, + 'original_sizes': original_sizes, + 'crop_top_lefts': crop_top_lefts, + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps + * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps + * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps + / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers( + 'text2image-fine-tune-sdxl', config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info('***** Running training *****') + logger.info(f' Num examples = {len(train_dataset)}') + logger.info(f' Num Epochs = {args.num_train_epochs}') + logger.info( + f' Instantaneous batch size per device = {args.train_batch_size}') + logger.info( + f' Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}' + ) + logger.info( + f' Gradient Accumulation steps = {args.gradient_accumulation_steps}') + logger.info(f' Total optimization steps = {args.max_train_steps}') + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != 'latest': + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith('checkpoint')] + dirs = sorted(dirs, key=lambda x: int(x.split('-')[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f'Resuming from checkpoint {path}') + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split('-')[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc='Steps', + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Sample noise that we'll add to the latents + model_input = batch['model_input'].to(accelerator.device) + noise = torch.randn_like(model_input) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (model_input.shape[0], model_input.shape[1], 1, 1), + device=model_input.device) + + bsz = model_input.shape[0] + if args.timestep_bias_strategy == 'none': + # Sample a random timestep for each image without bias. + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, (bsz, ), + device=model_input.device) + else: + # Sample a random timestep for each image, potentially biased by the timestep weights. + # Biasing the timestep weights allows us to spend less time training irrelevant timesteps. + weights = generate_timestep_weights( + args, noise_scheduler.config.num_train_timesteps).to( + model_input.device) + timesteps = torch.multinomial( + weights, bsz, replacement=True).long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise( + model_input, noise, timesteps) + + # time ids + def compute_time_ids(original_size, crops_coords_top_left): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + target_size = (args.resolution, args.resolution) + add_time_ids = list(original_size + crops_coords_top_left + + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to( + accelerator.device, dtype=weight_dtype) + return add_time_ids + + add_time_ids = torch.cat([ + compute_time_ids(s, c) for s, c in zip( + batch['original_sizes'], batch['crop_top_lefts']) + ]) + + # Predict the noise residual + unet_added_conditions = {'time_ids': add_time_ids} + prompt_embeds = batch['prompt_embeds'].to(accelerator.device) + pooled_prompt_embeds = batch['pooled_prompt_embeds'].to( + accelerator.device) + unet_added_conditions.update( + {'text_embeds': pooled_prompt_embeds}) + model_pred = unet( + noisy_model_input, + timesteps, + prompt_embeds, + added_cond_kwargs=unet_added_conditions).sample + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config( + prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == 'epsilon': + target = noise + elif noise_scheduler.config.prediction_type == 'v_prediction': + target = noise_scheduler.get_velocity( + model_input, noise, timesteps) + elif noise_scheduler.config.prediction_type == 'sample': + # We set the target to latents here, but the model_pred will return the noise sample prediction. + target = model_input + # We will have to subtract the noise residual from the prediction to get the target sample. + model_pred = model_pred - noise + else: + raise ValueError( + f'Unknown prediction type {noise_scheduler.config.prediction_type}' + ) + + if args.snr_gamma is None: + loss = F.mse_loss( + model_pred.float(), target.float(), reduction='mean') + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + if noise_scheduler.config.prediction_type == 'v_prediction': + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], + dim=1).min(dim=1)[0] / snr) + + loss = F.mse_loss( + model_pred.float(), target.float(), reduction='none') + loss = loss.mean( + dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather( + loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item( + ) / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = unet.parameters() + accelerator.clip_grad_norm_(params_to_clip, + args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({'train_loss': train_loss}, step=global_step) + train_loss = 0.0 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [ + d for d in checkpoints + if d.startswith('checkpoint') + ] + checkpoints = sorted( + checkpoints, + key=lambda x: int(x.split('-')[1])) + + # before we save the new checkpoint, we need to have at _most_ \ + # `checkpoints_total_limit - 1` checkpoints + if len(checkpoints + ) >= args.checkpoints_total_limit: + num_to_remove = len( + checkpoints + ) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[ + 0:num_to_remove] + + logger.info( + f'{len(checkpoints)} checkpoints already exist, ' + f'removing {len(removing_checkpoints)} checkpoints' + ) + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir,f'checkpoint-{global_step}') + accelerator.save_state(save_path) + logger.info(f'Saved state to {save_path}') + + logs = { + 'step_loss': loss.detach().item(), + 'lr': lr_scheduler.get_last_lr()[0] + } + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f'Running validation... \n Generating {args.num_validation_images} images with prompt:' + f' {args.validation_prompt}.') + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + # create pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder='vae' if + args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=accelerator.unwrap_model(unet), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + if args.prediction_type is not None: + scheduler_args = {'prediction_type': args.prediction_type} + pipeline.scheduler = pipeline.scheduler.from_config( + pipeline.scheduler.config, **scheduler_args) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator( + device=accelerator.device).manual_seed( + args.seed) if args.seed else None + pipeline_args = {'prompt': args.validation_prompt} + + with torch.cuda.amp.autocast(): + images = [ + pipeline( + **pipeline_args, + generator=generator, + num_inference_steps=25).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == 'tensorboard': + np_images = np.stack( + [np.asarray(img) for img in images]) + tracker.writer.add_images( + 'validation', np_images, epoch, dataformats='NHWC') + if tracker.name == 'wandb': + tracker.log({ + 'validation': [ + wandb.Image( + image, + caption=f'{i}: {args.validation_prompt}') + for i, image in enumerate(images) + ] + }) + + del pipeline + torch.cuda.empty_cache() + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + + # Serialize pipeline. + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder='vae' + if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=unet, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + if args.prediction_type is not None: + scheduler_args = {'prediction_type': args.prediction_type} + pipeline.scheduler = pipeline.scheduler.from_config( + pipeline.scheduler.config, **scheduler_args) + pipeline.save_pretrained(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline = pipeline.to(accelerator.device) + generator = torch.Generator(device=accelerator.device).manual_seed( + args.seed) if args.seed else None + with torch.cuda.amp.autocast(): + images = [ + pipeline( + args.validation_prompt, + num_inference_steps=25, + generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == 'tensorboard': + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images( + 'test', np_images, epoch, dataformats='NHWC') + if tracker.name == 'wandb': + tracker.log({ + 'test': [ + wandb.Image( + image, + caption=f'{i}: {args.validation_prompt}') + for i, image in enumerate(images) + ] + }) + + if args.push_to_hub: + save_model_card( + repo_id=args.hub_model_id, + images=images, + validation_prompt=args.validation_prompt, + base_model=args.base_model_id, + dataset_name=args.dataset_name, + repo_folder=args.output_dir, + vae_path=args.vae_base_model_id, + ) + push_to_hub( + args.hub_model_id, + args.output_dir, + args.hub_token, + ) + + accelerator.end_training() From 672c7324cb679453ce77394e573dabf500c92944 Mon Sep 17 00:00:00 2001 From: slin000111 Date: Tue, 9 Jan 2024 20:51:56 +0800 Subject: [PATCH 2/5] add examples aigc sdxl --- .../pytorch/sdxl/infer_text_image_lora.py | 17 ++++++------- examples/pytorch/sdxl/infer_text_to_image.py | 16 +++++++------ .../sdxl/infer_text_to_image_lora_sdxl.py | 24 +++++++++++-------- .../pytorch/sdxl/infer_text_to_image_sdxl.py | 10 ++++---- swift/aigc/__init__.py | 6 +++-- swift/aigc/diffusers/__init__.py | 7 +++--- .../diffusers/train_text_to_image_lora.py | 2 +- .../train_text_to_image_lora_sdxl.py | 2 +- .../diffusers/train_text_to_image_sdxl.py | 3 ++- 9 files changed, 49 insertions(+), 38 deletions(-) diff --git a/examples/pytorch/sdxl/infer_text_image_lora.py b/examples/pytorch/sdxl/infer_text_image_lora.py index 248d47525d..f068b2c15d 100644 --- a/examples/pytorch/sdxl/infer_text_image_lora.py +++ b/examples/pytorch/sdxl/infer_text_image_lora.py @@ -1,15 +1,16 @@ -from diffusers import StableDiffusionPipeline import torch -from swift import Swift +from diffusers import StableDiffusionPipeline from modelscope import snapshot_download +from swift import Swift -model_path = snapshot_download("AI-ModelScope/stable-diffusion-v1-5") -lora_model_path = "/mnt/workspace/swift/examples/pytorch/sdxl/train_text_to_image_lora" -pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) +model_path = snapshot_download('AI-ModelScope/stable-diffusion-v1-5') +lora_model_path = '/mnt/workspace/swift/examples/pytorch/sdxl/train_text_to_image_lora' +pipe = StableDiffusionPipeline.from_pretrained( + model_path, torch_dtype=torch.float16) pipe.unet = Swift.from_pretrained(pipe.unet, lora_model_path) -pipe.to("cuda") +pipe.to('cuda') -prompt = "A pokemon with green eyes and red legs." +prompt = 'A pokemon with green eyes and red legs.' image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] -image.save("sw_sd_lora_pokemon.png") +image.save('sw_sd_lora_pokemon.png') diff --git a/examples/pytorch/sdxl/infer_text_to_image.py b/examples/pytorch/sdxl/infer_text_to_image.py index c0d7994615..00ef4cc143 100644 --- a/examples/pytorch/sdxl/infer_text_to_image.py +++ b/examples/pytorch/sdxl/infer_text_to_image.py @@ -2,13 +2,15 @@ from diffusers import StableDiffusionPipeline, UNet2DConditionModel from modelscope import snapshot_download -model_path = snapshot_download("AI-ModelScope/stable-diffusion-v1-5") +model_path = snapshot_download('AI-ModelScope/stable-diffusion-v1-5') -unet_model_path = "/mnt/workspace/swift/examples/pytorch/sdxl/train_text_to_image/unet" -unet = UNet2DConditionModel.from_pretrained(unet_model_path, torch_dtype=torch.float16) +unet_model_path = '/mnt/workspace/swift/examples/pytorch/sdxl/train_text_to_image/unet' +unet = UNet2DConditionModel.from_pretrained( + unet_model_path, torch_dtype=torch.float16) -pipe = StableDiffusionPipeline.from_pretrained(model_path, unet=unet, torch_dtype=torch.float16) -pipe.to("cuda") +pipe = StableDiffusionPipeline.from_pretrained( + model_path, unet=unet, torch_dtype=torch.float16) +pipe.to('cuda') -image = pipe(prompt="yoda").images[0] -image.save("sw_yoda-pokemon.png") +image = pipe(prompt='yoda').images[0] +image.save('sw_yoda-pokemon.png') diff --git a/examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py b/examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py index f2a1e15d9f..f6e4d53bc8 100644 --- a/examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py +++ b/examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py @@ -1,15 +1,19 @@ -from diffusers import DiffusionPipeline, StableDiffusionXLPipeline -import torch -from swift import Swift import os + +import torch +from diffusers import DiffusionPipeline, StableDiffusionXLPipeline from modelscope import snapshot_download -model_path = snapshot_download("AI-ModelScope/stable-diffusion-v1-5") -lora_model_path = "/mnt/workspace/swift_trans_test/examples/pytorch/sdxl/train_text_to_image_lora_sdxl" +from swift import Swift + +model_path = snapshot_download('AI-ModelScope/stable-diffusion-v1-5') +lora_model_path = '/mnt/workspace/swift_trans_test/examples/pytorch/sdxl/train_text_to_image_lora_sdxl' -pipe = StableDiffusionXLPipeline.from_pretrained(model_path, torch_dtype=torch.float16) -pipe = pipe.to("cuda") -pipe.unet = Swift.from_pretrained(pipe.unet, os.path.join(lora_model_path, 'unet')) -prompt = "A pokemon with green eyes and red legs." +pipe = StableDiffusionXLPipeline.from_pretrained( + model_path, torch_dtype=torch.float16) +pipe = pipe.to('cuda') +pipe.unet = Swift.from_pretrained(pipe.unet, + os.path.join(lora_model_path, 'unet')) +prompt = 'A pokemon with green eyes and red legs.' image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] -image.save("sw_sdxl_lora_pokemon.png") +image.save('sw_sdxl_lora_pokemon.png') diff --git a/examples/pytorch/sdxl/infer_text_to_image_sdxl.py b/examples/pytorch/sdxl/infer_text_to_image_sdxl.py index 1d7cfa8305..03721742e5 100644 --- a/examples/pytorch/sdxl/infer_text_to_image_sdxl.py +++ b/examples/pytorch/sdxl/infer_text_to_image_sdxl.py @@ -1,10 +1,10 @@ -from diffusers import DiffusionPipeline import torch +from diffusers import DiffusionPipeline -model_path = "/mnt/workspace/swift/examples/pytorch/sdxl/sdxl-pokemon-model" +model_path = '/mnt/workspace/swift/examples/pytorch/sdxl/sdxl-pokemon-model' pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) -pipe.to("cuda") +pipe.to('cuda') -prompt = "A pokemon with green eyes and red legs." +prompt = 'A pokemon with green eyes and red legs.' image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] -image.save("sdxl_pokemon.png") +image.save('sdxl_pokemon.png') diff --git a/swift/aigc/__init__.py b/swift/aigc/__init__.py index d5a0439e8b..fa54269279 100644 --- a/swift/aigc/__init__.py +++ b/swift/aigc/__init__.py @@ -14,8 +14,10 @@ _import_structure = { 'animatediff': ['animatediff_sft', 'animatediff_main'], 'animatediff_infer': ['animatediff_infer', 'animatediff_infer_main'], - 'diffusers': ['train_text_to_image', 'train_text_to_image_lora', 'train_text_to_image_lora_sdxl', - 'train_text_to_image_sdxl'], + 'diffusers': [ + 'train_text_to_image', 'train_text_to_image_lora', + 'train_text_to_image_lora_sdxl', 'train_text_to_image_sdxl' + ], 'utils': ['AnimateDiffArguments', 'AnimateDiffInferArguments'], } diff --git a/swift/aigc/diffusers/__init__.py b/swift/aigc/diffusers/__init__.py index 09de856eab..592852c99c 100644 --- a/swift/aigc/diffusers/__init__.py +++ b/swift/aigc/diffusers/__init__.py @@ -1,4 +1,5 @@ -from .train_text_to_image_sdxl import main as train_text_to_image_sdxl -from .train_text_to_image_lora_sdxl import main as train_text_to_image_lora_sdxl from .train_text_to_image import main as train_text_to_image -from .train_text_to_image_lora import main as train_text_to_image_lora \ No newline at end of file +from .train_text_to_image_lora import main as train_text_to_image_lora +from .train_text_to_image_lora_sdxl import \ + main as train_text_to_image_lora_sdxl +from .train_text_to_image_sdxl import main as train_text_to_image_sdxl diff --git a/swift/aigc/diffusers/train_text_to_image_lora.py b/swift/aigc/diffusers/train_text_to_image_lora.py index bae08b1be5..51ec736071 100644 --- a/swift/aigc/diffusers/train_text_to_image_lora.py +++ b/swift/aigc/diffusers/train_text_to_image_lora.py @@ -568,7 +568,7 @@ def main(): text_encoder.to(accelerator.device, dtype=weight_dtype) unet = Swift.prepare_model(unet, unet_lora_config) - if args.mixed_precision == "fp16": + if args.mixed_precision == 'fp16': for param in unet.parameters(): # only upcast trainable parameters (LoRA) into fp32 if param.requires_grad: diff --git a/swift/aigc/diffusers/train_text_to_image_lora_sdxl.py b/swift/aigc/diffusers/train_text_to_image_lora_sdxl.py index 735c680697..db8e176516 100644 --- a/swift/aigc/diffusers/train_text_to_image_lora_sdxl.py +++ b/swift/aigc/diffusers/train_text_to_image_lora_sdxl.py @@ -741,7 +741,7 @@ def main(): target_modules=['to_k', 'to_q', 'to_v', 'to_out.0']) unet = Swift.prepare_model(unet, unet_lora_config) - if args.mixed_precision == "fp16": + if args.mixed_precision == 'fp16': for param in unet.parameters(): # only upcast trainable parameters (LoRA) into fp32 if param.requires_grad: diff --git a/swift/aigc/diffusers/train_text_to_image_sdxl.py b/swift/aigc/diffusers/train_text_to_image_sdxl.py index efef370846..1b3f4a4d8d 100644 --- a/swift/aigc/diffusers/train_text_to_image_sdxl.py +++ b/swift/aigc/diffusers/train_text_to_image_sdxl.py @@ -1304,7 +1304,8 @@ def compute_time_ids(original_size, crops_coords_top_left): args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) - save_path = os.path.join(args.output_dir,f'checkpoint-{global_step}') + save_path = os.path.join(args.output_dir, + f'checkpoint-{global_step}') accelerator.save_state(save_path) logger.info(f'Saved state to {save_path}') From 1550ea465d8ba58111ca3d26973ae2484079993d Mon Sep 17 00:00:00 2001 From: slin000111 Date: Tue, 9 Jan 2024 21:30:50 +0800 Subject: [PATCH 3/5] add diffusers version --- examples/pytorch/sdxl/requirements.txt | 1 + examples/pytorch/sdxl/requirements_sdxl.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/pytorch/sdxl/requirements.txt b/examples/pytorch/sdxl/requirements.txt index e92fe73913..f6535bb514 100644 --- a/examples/pytorch/sdxl/requirements.txt +++ b/examples/pytorch/sdxl/requirements.txt @@ -1,5 +1,6 @@ accelerate>=0.16.0 datasets +diffusers==0.25.0 ftfy Jinja2 tensorboard diff --git a/examples/pytorch/sdxl/requirements_sdxl.txt b/examples/pytorch/sdxl/requirements_sdxl.txt index 6e50452b96..16ef8946b5 100644 --- a/examples/pytorch/sdxl/requirements_sdxl.txt +++ b/examples/pytorch/sdxl/requirements_sdxl.txt @@ -1,5 +1,6 @@ accelerate>=0.22.0 datasets +diffusers==0.25.0 ftfy Jinja2 tensorboard From 2babe53d892dfadcbc1523d3aebfeb5e83363ba4 Mon Sep 17 00:00:00 2001 From: slin000111 Date: Tue, 9 Jan 2024 23:12:34 +0800 Subject: [PATCH 4/5] fix sdxl sh file --- examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh | 3 ++- examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh b/examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh index 8e163b1cca..5a4397da24 100644 --- a/examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh +++ b/examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh @@ -1,5 +1,5 @@ PYTHONPATH=../../../ \ -accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \ +accelerate launch train_text_to_image_lora.py \ --pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \ --dataset_name="AI-ModelScope/pokemon-blip-captions" \ --caption_column="text" \ @@ -11,6 +11,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \ --learning_rate=1e-04 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ + --mixed_precision="fp16" \ --seed=42 \ --output_dir="train_text_to_image_lora" \ --validation_prompt="cute dragon creature" \ diff --git a/examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh b/examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh index 2126961c4c..6094141542 100644 --- a/examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh +++ b/examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh @@ -1,6 +1,6 @@ PYTHONPATH=../../../ \ accelerate launch train_text_to_image_sdxl.py \ - --pretrained_model_name_or_path"AI-ModelScope/stable-diffusion-xl-base-1.0" \ + --pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \ --pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \ --dataset_name="AI-ModelScope/pokemon-blip-captions" \ --enable_xformers_memory_efficient_attention \ From ea808aabaaf48ff7bd7ccdcb70407554070cebbd Mon Sep 17 00:00:00 2001 From: slin000111 Date: Thu, 11 Jan 2024 23:05:14 +0800 Subject: [PATCH 5/5] modify inference scripts --- .../pytorch/sdxl/infer_text_image_lora.py | 19 +-- examples/pytorch/sdxl/infer_text_to_image.py | 19 +-- .../sdxl/infer_text_to_image_lora_sdxl.py | 22 +--- .../pytorch/sdxl/infer_text_to_image_sdxl.py | 13 +- examples/pytorch/sdxl/requirements.txt | 8 -- examples/pytorch/sdxl/requirements_sdxl.txt | 8 -- .../sdxl/scripts/run_infer_text_to_image.sh | 8 ++ .../scripts/run_infer_text_to_image_lora.sh | 8 ++ .../run_infer_text_to_image_lora_sdxl.sh | 8 ++ .../scripts/run_infer_text_to_image_sdxl.sh | 8 ++ requirements/aigc.txt | 2 +- swift/aigc/__init__.py | 7 +- swift/aigc/diffusers/__init__.py | 5 + swift/aigc/diffusers/infer_text_to_image.py | 111 +++++++++++++++++ .../diffusers/infer_text_to_image_lora.py | 112 ++++++++++++++++++ .../infer_text_to_image_lora_sdxl.py | 112 ++++++++++++++++++ .../diffusers/infer_text_to_image_sdxl.py | 111 +++++++++++++++++ swift/aigc/diffusers/train_text_to_image.py | 4 +- .../diffusers/train_text_to_image_lora.py | 4 +- .../train_text_to_image_lora_sdxl.py | 4 +- .../diffusers/train_text_to_image_sdxl.py | 4 +- 21 files changed, 513 insertions(+), 84 deletions(-) delete mode 100644 examples/pytorch/sdxl/requirements.txt delete mode 100644 examples/pytorch/sdxl/requirements_sdxl.txt create mode 100644 examples/pytorch/sdxl/scripts/run_infer_text_to_image.sh create mode 100644 examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora.sh create mode 100644 examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora_sdxl.sh create mode 100644 examples/pytorch/sdxl/scripts/run_infer_text_to_image_sdxl.sh create mode 100644 swift/aigc/diffusers/infer_text_to_image.py create mode 100644 swift/aigc/diffusers/infer_text_to_image_lora.py create mode 100644 swift/aigc/diffusers/infer_text_to_image_lora_sdxl.py create mode 100644 swift/aigc/diffusers/infer_text_to_image_sdxl.py diff --git a/examples/pytorch/sdxl/infer_text_image_lora.py b/examples/pytorch/sdxl/infer_text_image_lora.py index f068b2c15d..8956a3f726 100644 --- a/examples/pytorch/sdxl/infer_text_image_lora.py +++ b/examples/pytorch/sdxl/infer_text_image_lora.py @@ -1,16 +1,5 @@ -import torch -from diffusers import StableDiffusionPipeline -from modelscope import snapshot_download +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.aigc import infer_text_to_image_lora -from swift import Swift - -model_path = snapshot_download('AI-ModelScope/stable-diffusion-v1-5') -lora_model_path = '/mnt/workspace/swift/examples/pytorch/sdxl/train_text_to_image_lora' -pipe = StableDiffusionPipeline.from_pretrained( - model_path, torch_dtype=torch.float16) -pipe.unet = Swift.from_pretrained(pipe.unet, lora_model_path) -pipe.to('cuda') - -prompt = 'A pokemon with green eyes and red legs.' -image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] -image.save('sw_sd_lora_pokemon.png') +if __name__ == '__main__': + infer_text_to_image_lora() diff --git a/examples/pytorch/sdxl/infer_text_to_image.py b/examples/pytorch/sdxl/infer_text_to_image.py index 00ef4cc143..d9b3563617 100644 --- a/examples/pytorch/sdxl/infer_text_to_image.py +++ b/examples/pytorch/sdxl/infer_text_to_image.py @@ -1,16 +1,5 @@ -import torch -from diffusers import StableDiffusionPipeline, UNet2DConditionModel -from modelscope import snapshot_download +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.aigc import infer_text_to_image -model_path = snapshot_download('AI-ModelScope/stable-diffusion-v1-5') - -unet_model_path = '/mnt/workspace/swift/examples/pytorch/sdxl/train_text_to_image/unet' -unet = UNet2DConditionModel.from_pretrained( - unet_model_path, torch_dtype=torch.float16) - -pipe = StableDiffusionPipeline.from_pretrained( - model_path, unet=unet, torch_dtype=torch.float16) -pipe.to('cuda') - -image = pipe(prompt='yoda').images[0] -image.save('sw_yoda-pokemon.png') +if __name__ == '__main__': + infer_text_to_image() diff --git a/examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py b/examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py index f6e4d53bc8..c67b859a65 100644 --- a/examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py +++ b/examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py @@ -1,19 +1,5 @@ -import os +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.aigc import infer_text_to_image_lora_sdxl -import torch -from diffusers import DiffusionPipeline, StableDiffusionXLPipeline -from modelscope import snapshot_download - -from swift import Swift - -model_path = snapshot_download('AI-ModelScope/stable-diffusion-v1-5') -lora_model_path = '/mnt/workspace/swift_trans_test/examples/pytorch/sdxl/train_text_to_image_lora_sdxl' - -pipe = StableDiffusionXLPipeline.from_pretrained( - model_path, torch_dtype=torch.float16) -pipe = pipe.to('cuda') -pipe.unet = Swift.from_pretrained(pipe.unet, - os.path.join(lora_model_path, 'unet')) -prompt = 'A pokemon with green eyes and red legs.' -image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] -image.save('sw_sdxl_lora_pokemon.png') +if __name__ == '__main__': + infer_text_to_image_lora_sdxl() diff --git a/examples/pytorch/sdxl/infer_text_to_image_sdxl.py b/examples/pytorch/sdxl/infer_text_to_image_sdxl.py index 03721742e5..452e2ba1f1 100644 --- a/examples/pytorch/sdxl/infer_text_to_image_sdxl.py +++ b/examples/pytorch/sdxl/infer_text_to_image_sdxl.py @@ -1,10 +1,5 @@ -import torch -from diffusers import DiffusionPipeline +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.aigc import infer_text_to_image_sdxl -model_path = '/mnt/workspace/swift/examples/pytorch/sdxl/sdxl-pokemon-model' -pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) -pipe.to('cuda') - -prompt = 'A pokemon with green eyes and red legs.' -image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] -image.save('sdxl_pokemon.png') +if __name__ == '__main__': + infer_text_to_image_sdxl() diff --git a/examples/pytorch/sdxl/requirements.txt b/examples/pytorch/sdxl/requirements.txt deleted file mode 100644 index f6535bb514..0000000000 --- a/examples/pytorch/sdxl/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -accelerate>=0.16.0 -datasets -diffusers==0.25.0 -ftfy -Jinja2 -tensorboard -torchvision -transformers>=4.25.1 diff --git a/examples/pytorch/sdxl/requirements_sdxl.txt b/examples/pytorch/sdxl/requirements_sdxl.txt deleted file mode 100644 index 16ef8946b5..0000000000 --- a/examples/pytorch/sdxl/requirements_sdxl.txt +++ /dev/null @@ -1,8 +0,0 @@ -accelerate>=0.22.0 -datasets -diffusers==0.25.0 -ftfy -Jinja2 -tensorboard -torchvision -transformers>=4.25.1 diff --git a/examples/pytorch/sdxl/scripts/run_infer_text_to_image.sh b/examples/pytorch/sdxl/scripts/run_infer_text_to_image.sh new file mode 100644 index 0000000000..a8ac4d8e79 --- /dev/null +++ b/examples/pytorch/sdxl/scripts/run_infer_text_to_image.sh @@ -0,0 +1,8 @@ +PYTHONPATH=../../.. \ +CUDA_VISIBLE_DEVICES=0 \ +python infer_text_to_image.py \ + --pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-v1-5" \ + --unet_model_path "train_text_to_image/checkpoint-15000/unet" \ + --prompt "yoda" \ + --image_save_path "yoda-pokemon.png" \ + --torch_dtype "fp16" \ diff --git a/examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora.sh b/examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora.sh new file mode 100644 index 0000000000..bf73047c01 --- /dev/null +++ b/examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora.sh @@ -0,0 +1,8 @@ +PYTHONPATH=../../.. \ +CUDA_VISIBLE_DEVICES=0 \ +python infer_text_to_image_lora.py \ + --pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-v1-5" \ + --lora_model_path "train_text_to_image_lora/checkpoint-80000" \ + --prompt "A pokemon with green eyes and red legs." \ + --image_save_path "lora_pokemon.png" \ + --torch_dtype "fp16" \ diff --git a/examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora_sdxl.sh b/examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora_sdxl.sh new file mode 100644 index 0000000000..4e95035cb7 --- /dev/null +++ b/examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora_sdxl.sh @@ -0,0 +1,8 @@ +PYTHONPATH=../../.. \ +CUDA_VISIBLE_DEVICES=0 \ +python infer_text_to_image_lora_sdxl.py \ + --pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-xl-base-1.0" \ + --lora_model_path "train_text_to_image_lora_sdxl/unet" \ + --prompt "A pokemon with green eyes and red legs." \ + --image_save_path "sdxl_lora_pokemon.png" \ + --torch_dtype "fp16" \ diff --git a/examples/pytorch/sdxl/scripts/run_infer_text_to_image_sdxl.sh b/examples/pytorch/sdxl/scripts/run_infer_text_to_image_sdxl.sh new file mode 100644 index 0000000000..f87a3d44f8 --- /dev/null +++ b/examples/pytorch/sdxl/scripts/run_infer_text_to_image_sdxl.sh @@ -0,0 +1,8 @@ +PYTHONPATH=../../.. \ +CUDA_VISIBLE_DEVICES=0 \ +python infer_text_to_image_sdxl.py \ + --pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-xl-base-1.0" \ + --unet_model_path "train_text_to_image_sdxl/checkpoint-10000/unet" \ + --prompt "A pokemon with green eyes and red legs." \ + --image_save_path "sdxl_pokemon.png" \ + --torch_dtype "fp16" \ diff --git a/requirements/aigc.txt b/requirements/aigc.txt index 513d3c4094..fe11a24b33 100644 --- a/requirements/aigc.txt +++ b/requirements/aigc.txt @@ -1,4 +1,4 @@ decord -diffusers>=0.18.0 +diffusers==0.25.0 einops torchvision diff --git a/swift/aigc/__init__.py b/swift/aigc/__init__.py index fa54269279..11218b26cb 100644 --- a/swift/aigc/__init__.py +++ b/swift/aigc/__init__.py @@ -8,7 +8,8 @@ from .animatediff import animatediff_sft, animatediff_main from .animatediff_infer import animatediff_infer, animatediff_infer_main from .diffusers import train_text_to_image, train_text_to_image_lora, train_text_to_image_lora_sdxl, \ - train_text_to_image_sdxl + train_text_to_image_sdxl, infer_text_to_image, infer_text_to_image_lora, infer_text_to_image_sdxl, \ + infer_text_to_image_lora_sdxl from .utils import AnimateDiffArguments, AnimateDiffInferArguments else: _import_structure = { @@ -16,7 +17,9 @@ 'animatediff_infer': ['animatediff_infer', 'animatediff_infer_main'], 'diffusers': [ 'train_text_to_image', 'train_text_to_image_lora', - 'train_text_to_image_lora_sdxl', 'train_text_to_image_sdxl' + 'train_text_to_image_lora_sdxl', 'train_text_to_image_sdxl', + 'infer_text_to_image', 'infer_text_to_image_lora', + 'infer_text_to_image_sdxl', 'infer_text_to_image_lora_sdxl' ], 'utils': ['AnimateDiffArguments', 'AnimateDiffInferArguments'], } diff --git a/swift/aigc/diffusers/__init__.py b/swift/aigc/diffusers/__init__.py index 592852c99c..930064e5e7 100644 --- a/swift/aigc/diffusers/__init__.py +++ b/swift/aigc/diffusers/__init__.py @@ -1,3 +1,8 @@ +from .infer_text_to_image import main as infer_text_to_image +from .infer_text_to_image_lora import main as infer_text_to_image_lora +from .infer_text_to_image_lora_sdxl import \ + main as infer_text_to_image_lora_sdxl +from .infer_text_to_image_sdxl import main as infer_text_to_image_sdxl from .train_text_to_image import main as train_text_to_image from .train_text_to_image_lora import main as train_text_to_image_lora from .train_text_to_image_lora_sdxl import \ diff --git a/swift/aigc/diffusers/infer_text_to_image.py b/swift/aigc/diffusers/infer_text_to_image.py new file mode 100644 index 0000000000..8ca1054b60 --- /dev/null +++ b/swift/aigc/diffusers/infer_text_to_image.py @@ -0,0 +1,111 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import os + +import torch +from diffusers import StableDiffusionPipeline, UNet2DConditionModel +from modelscope import snapshot_download + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Simple example of a text to image inference.') + parser.add_argument( + '--pretrained_model_name_or_path', + type=str, + default='AI-ModelScope/stable-diffusion-v1-5', + required=True, + help= + 'Path to pretrained model or model identifier from modelscope.cn/models.', + ) + parser.add_argument( + '--revision', + type=str, + default=None, + required=False, + help= + 'Revision of pretrained model identifier from modelscope.cn/models.', + ) + parser.add_argument( + '--unet_model_path', + type=str, + default=None, + required=False, + help='The path to trained unet model.', + ) + parser.add_argument( + '--prompt', + type=str, + default=None, + required=True, + help= + 'The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`', + ) + parser.add_argument( + '--image_save_path', + type=str, + default=None, + required=True, + help='The path to save generated image', + ) + parser.add_argument( + '--torch_dtype', + type=str, + default=None, + choices=['no', 'fp16', 'bf16'], + help= + ('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=' + ' 1.10.and an Nvidia Ampere GPU. Default to the value of the' + ' mixed_precision passed with the `accelerate.launch` command in training script.' + ), + ) + parser.add_argument( + '--num_inference_steps', + type=int, + default=50, + help= + ('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \ + expense of slower inference.'), + ) + parser.add_argument( + '--guidance_scale', + type=float, + default=7.5, + choices=['no', 'fp16', 'bf16'], + help= + ('A higher guidance scale value encourages the model to generate images closely linked to the text \ + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.' + ), + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if os.path.exists(args.pretrained_model_name_or_path): + model_path = args.pretrained_model_name_or_path + else: + model_path = snapshot_download( + args.pretrained_model_name_or_path, revision=args.revision) + + if args.torch_dtype == 'fp16': + torch_dtype = torch.float16 + elif args.torch_dtype == 'bf16': + torch_dtype = torch.bfloat16 + else: + torch_dtype = torch.float32 + + pipe = StableDiffusionPipeline.from_pretrained( + model_path, torch_dtype=torch_dtype) + if args.unet_model_path is not None: + pipe.unet = UNet2DConditionModel.from_pretrained( + args.unet_model_path, torch_dtype=torch_dtype) + pipe.to('cuda') + image = pipe( + prompt=args.prompt, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale).images[0] + image.save(args.image_save_path) diff --git a/swift/aigc/diffusers/infer_text_to_image_lora.py b/swift/aigc/diffusers/infer_text_to_image_lora.py new file mode 100644 index 0000000000..6c2aa39fcf --- /dev/null +++ b/swift/aigc/diffusers/infer_text_to_image_lora.py @@ -0,0 +1,112 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import os + +import torch +from diffusers import StableDiffusionPipeline +from modelscope import snapshot_download + +from swift import Swift + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Simple example of a text to image inference.') + parser.add_argument( + '--pretrained_model_name_or_path', + type=str, + default='AI-ModelScope/stable-diffusion-v1-5', + required=True, + help= + 'Path to pretrained model or model identifier from modelscope.cn/models.', + ) + parser.add_argument( + '--revision', + type=str, + default=None, + required=False, + help= + 'Revision of pretrained model identifier from modelscope.cn/models.', + ) + parser.add_argument( + '--lora_model_path', + type=str, + default=None, + required=False, + help='The path to trained lora model.', + ) + parser.add_argument( + '--prompt', + type=str, + default=None, + required=True, + help= + 'The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`', + ) + parser.add_argument( + '--image_save_path', + type=str, + default=None, + required=True, + help='The path to save generated image', + ) + parser.add_argument( + '--torch_dtype', + type=str, + default=None, + choices=['no', 'fp16', 'bf16'], + help= + ('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=' + ' 1.10.and an Nvidia Ampere GPU. Default to the value of the' + ' mixed_precision passed with the `accelerate.launch` command in training script.' + ), + ) + parser.add_argument( + '--num_inference_steps', + type=int, + default=30, + help= + ('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \ + expense of slower inference.'), + ) + parser.add_argument( + '--guidance_scale', + type=float, + default=7.5, + choices=['no', 'fp16', 'bf16'], + help= + ('A higher guidance scale value encourages the model to generate images closely linked to the text \ + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.' + ), + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if os.path.exists(args.pretrained_model_name_or_path): + model_path = args.pretrained_model_name_or_path + else: + model_path = snapshot_download( + args.pretrained_model_name_or_path, revision=args.revision) + + if args.torch_dtype == 'fp16': + torch_dtype = torch.float16 + elif args.torch_dtype == 'bf16': + torch_dtype = torch.bfloat16 + else: + torch_dtype = torch.float32 + + pipe = StableDiffusionPipeline.from_pretrained( + model_path, torch_dtype=torch_dtype) + if args.lora_model_path is not None: + pipe.unet = Swift.from_pretrained(pipe.unet, args.lora_model_path) + pipe.to('cuda') + image = pipe( + prompt=args.prompt, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale).images[0] + image.save(args.image_save_path) diff --git a/swift/aigc/diffusers/infer_text_to_image_lora_sdxl.py b/swift/aigc/diffusers/infer_text_to_image_lora_sdxl.py new file mode 100644 index 0000000000..ca6647b6ba --- /dev/null +++ b/swift/aigc/diffusers/infer_text_to_image_lora_sdxl.py @@ -0,0 +1,112 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import os + +import torch +from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel +from modelscope import snapshot_download + +from swift import Swift + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Simple example of a text to image lora sdxl inference.') + parser.add_argument( + '--pretrained_model_name_or_path', + type=str, + default='AI-ModelScope/stable-diffusion-v1-5', + required=True, + help= + 'Path to pretrained model or model identifier from modelscope.cn/models.', + ) + parser.add_argument( + '--revision', + type=str, + default=None, + required=False, + help= + 'Revision of pretrained model identifier from modelscope.cn/models.', + ) + parser.add_argument( + '--lora_model_path', + type=str, + default=None, + required=False, + help='The path to trained lora model.', + ) + parser.add_argument( + '--prompt', + type=str, + default=None, + required=True, + help= + 'The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`', + ) + parser.add_argument( + '--image_save_path', + type=str, + default=None, + required=True, + help='The path to save generated image', + ) + parser.add_argument( + '--torch_dtype', + type=str, + default=None, + choices=['no', 'fp16', 'bf16'], + help= + ('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=' + ' 1.10.and an Nvidia Ampere GPU. Default to the value of the' + ' mixed_precision passed with the `accelerate.launch` command in training script.' + ), + ) + parser.add_argument( + '--num_inference_steps', + type=int, + default=30, + help= + ('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \ + expense of slower inference.'), + ) + parser.add_argument( + '--guidance_scale', + type=float, + default=7.5, + choices=['no', 'fp16', 'bf16'], + help= + ('A higher guidance scale value encourages the model to generate images closely linked to the text \ + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.' + ), + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if os.path.exists(args.pretrained_model_name_or_path): + model_path = args.pretrained_model_name_or_path + else: + model_path = snapshot_download( + args.pretrained_model_name_or_path, revision=args.revision) + + if args.torch_dtype == 'fp16': + torch_dtype = torch.float16 + elif args.torch_dtype == 'bf16': + torch_dtype = torch.bfloat16 + else: + torch_dtype = torch.float32 + + pipe = StableDiffusionXLPipeline.from_pretrained( + model_path, torch_dtype=torch_dtype) + if args.lora_model_path is not None: + pipe.unet = Swift.from_pretrained(pipe.unet, args.lora_model_path) + pipe.to('cuda') + image = pipe( + prompt=args.prompt, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale).images[0] + image.save(args.image_save_path) diff --git a/swift/aigc/diffusers/infer_text_to_image_sdxl.py b/swift/aigc/diffusers/infer_text_to_image_sdxl.py new file mode 100644 index 0000000000..0cb7b2dbe6 --- /dev/null +++ b/swift/aigc/diffusers/infer_text_to_image_sdxl.py @@ -0,0 +1,111 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import os + +import torch +from diffusers import DiffusionPipeline, UNet2DConditionModel +from modelscope import snapshot_download + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Simple example of a text to image inference.') + parser.add_argument( + '--pretrained_model_name_or_path', + type=str, + default='AI-ModelScope/stable-diffusion-v1-5', + required=True, + help= + 'Path to pretrained model or model identifier from modelscope.cn/models.', + ) + parser.add_argument( + '--revision', + type=str, + default=None, + required=False, + help= + 'Revision of pretrained model identifier from modelscope.cn/models.', + ) + parser.add_argument( + '--unet_model_path', + type=str, + default=None, + required=False, + help='The path to trained unet model.', + ) + parser.add_argument( + '--prompt', + type=str, + default=None, + required=True, + help= + 'The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`', + ) + parser.add_argument( + '--image_save_path', + type=str, + default=None, + required=True, + help='The path to save generated image', + ) + parser.add_argument( + '--torch_dtype', + type=str, + default=None, + choices=['no', 'fp16', 'bf16'], + help= + ('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=' + ' 1.10.and an Nvidia Ampere GPU. Default to the value of the' + ' mixed_precision passed with the `accelerate.launch` command in training script.' + ), + ) + parser.add_argument( + '--num_inference_steps', + type=int, + default=30, + help= + ('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \ + expense of slower inference.'), + ) + parser.add_argument( + '--guidance_scale', + type=float, + default=7.5, + choices=['no', 'fp16', 'bf16'], + help= + ('A higher guidance scale value encourages the model to generate images closely linked to the text \ + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.' + ), + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if os.path.exists(args.pretrained_model_name_or_path): + model_path = args.pretrained_model_name_or_path + else: + model_path = snapshot_download( + args.pretrained_model_name_or_path, revision=args.revision) + + if args.torch_dtype == 'fp16': + torch_dtype = torch.float16 + elif args.torch_dtype == 'bf16': + torch_dtype = torch.bfloat16 + else: + torch_dtype = torch.float32 + + pipe = DiffusionPipeline.from_pretrained( + model_path, torch_dtype=torch_dtype) + if args.unet_model_path is not None: + pipe.unet = UNet2DConditionModel.from_pretrained( + args.unet_model_path, torch_dtype=torch_dtype) + pipe.to('cuda') + image = pipe( + prompt=args.prompt, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale).images[0] + image.save(args.image_save_path) diff --git a/swift/aigc/diffusers/train_text_to_image.py b/swift/aigc/diffusers/train_text_to_image.py index c5cc7a8367..f4a218f3e5 100644 --- a/swift/aigc/diffusers/train_text_to_image.py +++ b/swift/aigc/diffusers/train_text_to_image.py @@ -209,7 +209,7 @@ def parse_args(): default=None, required=True, help= - 'Path to pretrained model or model identifier from huggingface.co/models.', + 'Path to pretrained model or model identifier from huggingface.co/models or modelscope.cn/models.', ) parser.add_argument( '--revision', @@ -217,7 +217,7 @@ def parse_args(): default=None, required=False, help= - 'Revision of pretrained model identifier from huggingface.co/models.', + 'Revision of pretrained model identifier from huggingface.co/models or modelscope.cn/models.', ) parser.add_argument( '--variant', diff --git a/swift/aigc/diffusers/train_text_to_image_lora.py b/swift/aigc/diffusers/train_text_to_image_lora.py index 51ec736071..0c5d74c323 100644 --- a/swift/aigc/diffusers/train_text_to_image_lora.py +++ b/swift/aigc/diffusers/train_text_to_image_lora.py @@ -127,7 +127,7 @@ def parse_args(): default=None, required=True, help= - 'Path to pretrained model or model identifier from huggingface.co/models.', + 'Path to pretrained model or model identifier from huggingface.co/models or modelscope.cn/models.', ) parser.add_argument( '--revision', @@ -135,7 +135,7 @@ def parse_args(): default=None, required=False, help= - 'Revision of pretrained model identifier from huggingface.co/models.', + 'Revision of pretrained model identifier from huggingface.co/models or modelscope.cn/models.', ) parser.add_argument( '--variant', diff --git a/swift/aigc/diffusers/train_text_to_image_lora_sdxl.py b/swift/aigc/diffusers/train_text_to_image_lora_sdxl.py index db8e176516..47ca0ba756 100644 --- a/swift/aigc/diffusers/train_text_to_image_lora_sdxl.py +++ b/swift/aigc/diffusers/train_text_to_image_lora_sdxl.py @@ -162,7 +162,7 @@ def parse_args(input_args=None): default=None, required=True, help= - 'Path to pretrained model or model identifier from huggingface.co/models.', + 'Path to pretrained model or model identifier from huggingface.co/models or modelscope.cn/models.', ) parser.add_argument( '--pretrained_vae_model_name_or_path', @@ -177,7 +177,7 @@ def parse_args(input_args=None): default=None, required=False, help= - 'Revision of pretrained model identifier from huggingface.co/models.', + 'Revision of pretrained model identifier from huggingface.co/models or modelscope.cn/models.', ) parser.add_argument( '--variant', diff --git a/swift/aigc/diffusers/train_text_to_image_sdxl.py b/swift/aigc/diffusers/train_text_to_image_sdxl.py index 1b3f4a4d8d..675d2f439f 100644 --- a/swift/aigc/diffusers/train_text_to_image_sdxl.py +++ b/swift/aigc/diffusers/train_text_to_image_sdxl.py @@ -128,7 +128,7 @@ def parse_args(input_args=None): default=None, required=True, help= - 'Path to pretrained model or model identifier from huggingface.co/models.', + 'Path to pretrained model or model identifier from huggingface.co/models or modelscope.cn/models.', ) parser.add_argument( '--pretrained_vae_model_name_or_path', @@ -143,7 +143,7 @@ def parse_args(input_args=None): default=None, required=False, help= - 'Revision of pretrained model identifier from huggingface.co/models.', + 'Revision of pretrained model identifier from huggingface.co/models or modelscope.cn/models.', ) parser.add_argument( '--variant',