Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/pytorch/sdxl/infer_text_image_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import infer_text_to_image_lora

if __name__ == '__main__':
infer_text_to_image_lora()
5 changes: 5 additions & 0 deletions examples/pytorch/sdxl/infer_text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import infer_text_to_image

if __name__ == '__main__':
infer_text_to_image()
5 changes: 5 additions & 0 deletions examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import infer_text_to_image_lora_sdxl

if __name__ == '__main__':
infer_text_to_image_lora_sdxl()
5 changes: 5 additions & 0 deletions examples/pytorch/sdxl/infer_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import infer_text_to_image_sdxl

if __name__ == '__main__':
infer_text_to_image_sdxl()
8 changes: 8 additions & 0 deletions examples/pytorch/sdxl/scripts/run_infer_text_to_image.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python infer_text_to_image.py \
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-v1-5" \
--unet_model_path "train_text_to_image/checkpoint-15000/unet" \
--prompt "yoda" \
--image_save_path "yoda-pokemon.png" \
--torch_dtype "fp16" \
8 changes: 8 additions & 0 deletions examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python infer_text_to_image_lora.py \
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-v1-5" \
--lora_model_path "train_text_to_image_lora/checkpoint-80000" \
--prompt "A pokemon with green eyes and red legs." \
--image_save_path "lora_pokemon.png" \
--torch_dtype "fp16" \
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python infer_text_to_image_lora_sdxl.py \
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-xl-base-1.0" \
--lora_model_path "train_text_to_image_lora_sdxl/unet" \
--prompt "A pokemon with green eyes and red legs." \
--image_save_path "sdxl_lora_pokemon.png" \
--torch_dtype "fp16" \
8 changes: 8 additions & 0 deletions examples/pytorch/sdxl/scripts/run_infer_text_to_image_sdxl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python infer_text_to_image_sdxl.py \
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-xl-base-1.0" \
--unet_model_path "train_text_to_image_sdxl/checkpoint-10000/unet" \
--prompt "A pokemon with green eyes and red legs." \
--image_save_path "sdxl_pokemon.png" \
--torch_dtype "fp16" \
17 changes: 17 additions & 0 deletions examples/pytorch/sdxl/scripts/run_train_text_to_image.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
PYTHONPATH=../../../ \
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--use_ema \
--resolution=512 \
--center_crop \
--random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--output_dir="train_text_to_image" \
18 changes: 18 additions & 0 deletions examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
PYTHONPATH=../../../ \
accelerate launch train_text_to_image_lora.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--caption_column="text" \
--resolution=512 \
--random_flip \
--train_batch_size=1 \
--num_train_epochs=100 \
--checkpointing_steps=5000 \
--learning_rate=1e-04 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--seed=42 \
--output_dir="train_text_to_image_lora" \
--validation_prompt="cute dragon creature" \
--report_to="tensorboard" \
19 changes: 19 additions & 0 deletions examples/pytorch/sdxl/scripts/run_train_text_to_image_lora_sdxl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
PYTHONPATH=../../../ \
accelerate launch train_text_to_image_lora_sdxl.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \
--pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--caption_column="text" \
--resolution=1024 \
--random_flip \
--train_batch_size=1 \
--num_train_epochs=2 \
--checkpointing_steps=500 \
--learning_rate=1e-04 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--seed=42 \
--output_dir="train_text_to_image_lora_sdxl" \
--validation_prompt="cute dragon creature" \
--report_to="tensorboard" \
24 changes: 24 additions & 0 deletions examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
PYTHONPATH=../../../ \
accelerate launch train_text_to_image_sdxl.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \
--pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--enable_xformers_memory_efficient_attention \
--resolution=512 \
--center_crop \
--random_flip \
--proportion_empty_prompts=0.2 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=10000 \
--use_8bit_adam \
--learning_rate=1e-06 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--report_to="tensorboard" \
--validation_prompt="a cute Sundar Pichai creature" \
--validation_epochs 5 \
--checkpointing_steps=5000 \
--output_dir="train_text_to_image_sdxl" \
6 changes: 6 additions & 0 deletions examples/pytorch/sdxl/train_text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from swift.aigc import train_text_to_image

if __name__ == '__main__':
train_text_to_image()
6 changes: 6 additions & 0 deletions examples/pytorch/sdxl/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from swift.aigc import train_text_to_image_lora

if __name__ == '__main__':
train_text_to_image_lora()
6 changes: 6 additions & 0 deletions examples/pytorch/sdxl/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from swift.aigc import train_text_to_image_lora_sdxl

if __name__ == '__main__':
train_text_to_image_lora_sdxl()
6 changes: 6 additions & 0 deletions examples/pytorch/sdxl/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from swift.aigc import train_text_to_image_sdxl

if __name__ == '__main__':
train_text_to_image_sdxl()
2 changes: 1 addition & 1 deletion requirements/aigc.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
decord
diffusers>=0.18.0
diffusers==0.25.0
einops
torchvision
9 changes: 9 additions & 0 deletions swift/aigc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
# Recommend using `xxx_main`
from .animatediff import animatediff_sft, animatediff_main
from .animatediff_infer import animatediff_infer, animatediff_infer_main
from .diffusers import train_text_to_image, train_text_to_image_lora, train_text_to_image_lora_sdxl, \
train_text_to_image_sdxl, infer_text_to_image, infer_text_to_image_lora, infer_text_to_image_sdxl, \
infer_text_to_image_lora_sdxl
from .utils import AnimateDiffArguments, AnimateDiffInferArguments
else:
_import_structure = {
'animatediff': ['animatediff_sft', 'animatediff_main'],
'animatediff_infer': ['animatediff_infer', 'animatediff_infer_main'],
'diffusers': [
'train_text_to_image', 'train_text_to_image_lora',
'train_text_to_image_lora_sdxl', 'train_text_to_image_sdxl',
'infer_text_to_image', 'infer_text_to_image_lora',
'infer_text_to_image_sdxl', 'infer_text_to_image_lora_sdxl'
],
'utils': ['AnimateDiffArguments', 'AnimateDiffInferArguments'],
}

Expand Down
10 changes: 10 additions & 0 deletions swift/aigc/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .infer_text_to_image import main as infer_text_to_image
from .infer_text_to_image_lora import main as infer_text_to_image_lora
from .infer_text_to_image_lora_sdxl import \
main as infer_text_to_image_lora_sdxl
from .infer_text_to_image_sdxl import main as infer_text_to_image_sdxl
from .train_text_to_image import main as train_text_to_image
from .train_text_to_image_lora import main as train_text_to_image_lora
from .train_text_to_image_lora_sdxl import \
main as train_text_to_image_lora_sdxl
from .train_text_to_image_sdxl import main as train_text_to_image_sdxl
111 changes: 111 additions & 0 deletions swift/aigc/diffusers/infer_text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os

import torch
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
from modelscope import snapshot_download


def parse_args():
parser = argparse.ArgumentParser(
description='Simple example of a text to image inference.')
parser.add_argument(
'--pretrained_model_name_or_path',
type=str,
default='AI-ModelScope/stable-diffusion-v1-5',
required=True,
help=
'Path to pretrained model or model identifier from modelscope.cn/models.',
)
parser.add_argument(
'--revision',
type=str,
default=None,
required=False,
help=
'Revision of pretrained model identifier from modelscope.cn/models.',
)
parser.add_argument(
'--unet_model_path',
type=str,
default=None,
required=False,
help='The path to trained unet model.',
)
parser.add_argument(
'--prompt',
type=str,
default=None,
required=True,
help=
'The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`',
)
parser.add_argument(
'--image_save_path',
type=str,
default=None,
required=True,
help='The path to save generated image',
)
parser.add_argument(
'--torch_dtype',
type=str,
default=None,
choices=['no', 'fp16', 'bf16'],
help=
('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >='
' 1.10.and an Nvidia Ampere GPU. Default to the value of the'
' mixed_precision passed with the `accelerate.launch` command in training script.'
),
)
parser.add_argument(
'--num_inference_steps',
type=int,
default=50,
help=
('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \
expense of slower inference.'),
)
parser.add_argument(
'--guidance_scale',
type=float,
default=7.5,
choices=['no', 'fp16', 'bf16'],
help=
('A higher guidance scale value encourages the model to generate images closely linked to the text \
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.'
),
)

args = parser.parse_args()
return args


def main():
args = parse_args()

if os.path.exists(args.pretrained_model_name_or_path):
model_path = args.pretrained_model_name_or_path
else:
model_path = snapshot_download(
args.pretrained_model_name_or_path, revision=args.revision)

if args.torch_dtype == 'fp16':
torch_dtype = torch.float16
elif args.torch_dtype == 'bf16':
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32

pipe = StableDiffusionPipeline.from_pretrained(
model_path, torch_dtype=torch_dtype)
if args.unet_model_path is not None:
pipe.unet = UNet2DConditionModel.from_pretrained(
args.unet_model_path, torch_dtype=torch_dtype)
pipe.to('cuda')
image = pipe(
prompt=args.prompt,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale).images[0]
image.save(args.image_save_path)
Loading