Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support SDXL training #2040

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,4 @@ batchscript-*
*.zip
work_dir
work_dir/
!tests/data/sd/*
23 changes: 23 additions & 0 deletions configs/_base_/datasets/pokemon_blip_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
pipeline = [
dict(
type='LoadImageFromHuggingFaceDataset', key='img',
channel_order='rgb'),
dict(type='ResizeEdge', scale=1024),
dict(type='RandomCropXL', size=1024),
dict(type='FlipXL', keys=['img'], flip_ratio=0.5, direction='horizontal'),
dict(type='ComputeTimeIds'),
dict(type='PackInputs', keys=['merged', 'img', 'time_ids']),
]
dataset = dict(
type='HuggingFaceDataset',
dataset='lambdalabs/pokemon-blip-captions',
pipeline=pipeline)
train_dataloader = dict(
batch_size=1,
num_workers=2,
dataset=dataset,
sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = val_evaluator = None
test_dataloader = test_evaluator = None
36 changes: 36 additions & 0 deletions configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Use DiffuserWrapper!
stable_diffusion_xl_url = 'stabilityai/stable-diffusion-xl-base-1.0'
vae_url = 'madebyollin/sdxl-vae-fp16-fix'
unet = dict(
type='UNet2DConditionModel',
subfolder='unet',
from_pretrained=stable_diffusion_xl_url)
vae = dict(type='AutoencoderKL', from_pretrained=vae_url)

diffusion_scheduler = dict(
type='DDPMScheduler',
from_pretrained=stable_diffusion_xl_url,
subfolder='scheduler')

model = dict(
type='StableDiffusionXL',
dtype='fp16',
with_cp=True,
unet=unet,
vae=vae,
enable_xformers=False,
text_encoder_one=dict(
type='ClipWrapper',
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_xl_url,
subfolder='text_encoder'),
tokenizer_one=stable_diffusion_xl_url,
text_encoder_two=dict(
type='ClipWrapper',
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_xl_url,
subfolder='text_encoder_2'),
tokenizer_two=stable_diffusion_xl_url,
scheduler=diffusion_scheduler,
test_scheduler=diffusion_scheduler,
data_preprocessor=dict(type='DataPreprocessor', data_keys=None))
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Use DiffuserWrapper!
stable_diffusion_xl_url = 'stabilityai/stable-diffusion-xl-base-1.0'
vae_url = 'madebyollin/sdxl-vae-fp16-fix'
unet = dict(
type='UNet2DConditionModel',
subfolder='unet',
from_pretrained=stable_diffusion_xl_url)
vae = dict(type='AutoencoderKL', from_pretrained=vae_url)

diffusion_scheduler = dict(
type='DDPMScheduler',
from_pretrained=stable_diffusion_xl_url,
subfolder='scheduler')

lora_config = dict(rank=8, target_modules=['to_q', 'to_k', 'to_v'])

model = dict(
type='StableDiffusionXL',
dtype='fp16',
with_cp=True,
unet=unet,
vae=vae,
enable_xformers=False,
text_encoder_one=dict(
type='ClipWrapper',
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_xl_url,
subfolder='text_encoder'),
tokenizer_one=stable_diffusion_xl_url,
text_encoder_two=dict(
type='ClipWrapper',
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_xl_url,
subfolder='text_encoder_2'),
tokenizer_two=stable_diffusion_xl_url,
scheduler=diffusion_scheduler,
test_scheduler=diffusion_scheduler,
data_preprocessor=dict(type='DataPreprocessor', data_keys=None),
lora_config=lora_config)
11 changes: 11 additions & 0 deletions configs/_base_/schedules/sd_10e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
optim_wrapper = dict(
type='AmpOptimWrapper',
dtype='float16',
optimizer=dict(type='AdamW', lr=1e-5, weight_decay=1e-2),
clip_grad=dict(max_norm=1.0),
accumulative_counts=1)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=10)
val_cfg = None
test_cfg = None
16 changes: 16 additions & 0 deletions configs/_base_/schedules/sdxl_10e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
optim_wrapper = dict(
type='AmpOptimWrapper',
dtype='float16',
optimizer=dict(
type='Adafactor',
lr=1e-5,
weight_decay=1e-2,
scale_parameter=False,
relative_step=False),
clip_grad=dict(max_norm=1.0),
accumulative_counts=1)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=10)
val_cfg = None
test_cfg = None
46 changes: 46 additions & 0 deletions configs/_base_/sd_default_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
default_scope = 'mmagic'

# configure for default hooks
default_hooks = dict(
# record time of every iteration.
timer=dict(type='IterTimerHook'),
# print log every 100 iterations.
logger=dict(type='LoggerHook', interval=100),
# save checkpoint per 10000 iterations
checkpoint=dict(
type='CheckpointHook',
interval=1,
by_epoch=True,
max_keep_ckpts=3,
save_optimizer=True))

# config for environment
env_cfg = dict(
# whether to enable cudnn benchmark.
cudnn_benchmark=True,
# set multi process parameters.
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters.
dist_cfg=dict(backend='nccl'))

# set log level
log_level = 'INFO'
log_processor = dict(type='LogProcessor', by_epoch=True)

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = None

# config for model wrapper
model_wrapper_cfg = dict(
type='MMSeparateDistributedDataParallel',
broadcast_buffers=False,
find_unused_parameters=False)

# set visualizer
vis_backends = [dict(type='VisBackend')]
visualizer = dict(type='Visualizer', vis_backends=vis_backends)

randomness = dict(seed=None, deterministic=False)
8 changes: 5 additions & 3 deletions configs/stable_diffusion_xl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ We present SDXL, a latent diffusion model for text-to-image synthesis. Compared

## Pretrained models

| Model | Task | Dataset | Download |
| :----------------------------------------------------------------: | :--------: | :-----: | :------: |
| [stable_diffusion_xl](./stable-diffusion_xl_ddim_denoisingunet.py) | Text2Image | - | - |
| Model | Task | Dataset | Download |
| :---------------------------------------------------------------------------------: | :--------: | :--------------------------------------------------------------------------------------: | :---------: |
| [stable_diffusion_xl](./stable-diffusion_xl_ddim_denoisingunet.py) | Text2Image | - | - |
| [stable_diffusion_xl_pokemon_blip](./stable-diffusion_xl_pokemon_blip.py) | Text2Image | [pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) | [model](<>) |
| [stable-diffusion_xl_lora_pokemon_blip](./stable-diffusion_xl_lora_pokemon_blip.py) | Text2Image | [pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) | [model](<>) |

We use stable diffusion xl weights. This model has several weights including vae, unet and clip.

Expand Down
16 changes: 16 additions & 0 deletions configs/stable_diffusion_xl/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,19 @@ Models:
- Dataset: '-'
Metrics: {}
Task: Text2Image
- Config: configs/stable_diffusion_xl/stable-diffusion_xl_pokemon_blip.py
In Collection: Stable Diffusion XL
Name: stable-diffusion_xl_pokemon_blip
Results:
- Dataset: '[pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions)'
Metrics: {}
Task: Text2Image
Weights: <>
- Config: configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py
In Collection: Stable Diffusion XL
Name: stable-diffusion_xl_lora_pokemon_blip
Results:
- Dataset: '[pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions)'
Metrics: {}
Task: Text2Image
Weights: <>
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
_base_ = [
'../_base_/models/stable_diffusion_xl/stable_diffusion_xl_lora.py',
'../_base_/datasets/pokemon_blip_xl.py', '../_base_/schedules/sd_10e.py',
'../_base_/sd_default_runtime.py'
]

val_prompts = ['yoda pokemon'] * 4

model = dict(val_prompts=val_prompts)

train_dataloader = dict(batch_size=4, num_workers=4)

# hooks
custom_hooks = [
dict(
type='VisualizationHook',
by_epoch=True,
interval=1,
fixed_input=True,
# visualize train dataset
vis_kwargs_list=dict(type='Data', name='fake_img'),
n_samples=1),
dict(type='LoRACheckpointToSaveHook')
]
21 changes: 21 additions & 0 deletions configs/stable_diffusion_xl/stable-diffusion_xl_pokemon_blip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
_base_ = [
'../_base_/models/stable_diffusion_xl/stable_diffusion_xl.py',
'../_base_/datasets/pokemon_blip_xl.py', '../_base_/schedules/sdxl_10e.py',
'../_base_/sd_default_runtime.py'
]

val_prompts = ['yoda pokemon'] * 4

model = dict(val_prompts=val_prompts)

# hooks
custom_hooks = [
dict(
type='VisualizationHook',
by_epoch=True,
interval=1,
fixed_input=True,
# visualize train dataset
vis_kwargs_list=dict(type='Data', name='fake_img'),
n_samples=1)
]
4 changes: 3 additions & 1 deletion mmagic/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .controlnet_dataset import ControlNetDataset
from .dreambooth_dataset import DreamBoothDataset
from .grow_scale_image_dataset import GrowScaleImgDataset
from .hf_dataset import HuggingFaceDataset
from .imagenet_dataset import ImageNet
from .mscoco_dataset import MSCoCoDataset
from .paired_image_dataset import PairedImageDataset
Expand All @@ -19,5 +20,6 @@
'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset',
'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'SinGANDataset',
'MSCoCoDataset', 'ControlNetDataset', 'DreamBoothDataset', 'ViCoDataset',
'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset'
'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset',
'HuggingFaceDataset'
]
84 changes: 84 additions & 0 deletions mmagic/datasets/hf_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import random
from pathlib import Path
from typing import Callable, List, Optional, Union

import numpy as np
from mmengine.dataset import BaseDataset

from mmagic.registry import DATASETS


@DATASETS.register_module()
class HuggingFaceDataset(BaseDataset):
"""Huggingface Dataset for DreamBooth.

Args:
dataset (str): Dataset name for Huggingface datasets.
image_column (str): Image column name. Defaults to 'image'.
caption_column (str): Caption column name. Defaults to 'text'.
csv (str): Caption csv file name when loading local folder.
Defaults to 'metadata.csv'.
cache_dir (str, optional): The directory where the downloaded datasets
will be stored.Defaults to None.
pipeline (list[dict | callable]): A sequence of data transforms.
"""

def __init__(self,
dataset: str,
image_column: str = 'image',
caption_column: str = 'text',
csv: str = 'metadata.csv',
cache_dir: Optional[str] = None,
pipeline: List[Union[dict, Callable]] = []):

self.dataset = dataset
self.image_column = image_column
self.caption_column = caption_column
self.csv = csv
self.cache_dir = cache_dir

super().__init__(pipeline=pipeline)

def load_data_list(self) -> list:
"""Load data list from concept_dir and class_dir."""
try:
from datasets import load_dataset
except BaseException:
raise ImportError(
'HuggingFaceDreamBoothDataset requires datasets, please '
'install it by `pip install datasets`.')

data_list = []

if Path(self.dataset).exists():
# load local folder
data_file = os.path.join(self.dataset, self.csv)
dataset = load_dataset(
'csv', data_files=data_file, cache_dir=self.cache_dir)['train']
else:
# load huggingface online
dataset = load_dataset(
self.dataset, cache_dir=self.cache_dir)['train']

for i in range(len(dataset)):
caption = dataset[i][self.caption_column]
if isinstance(caption, str):
pass
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
caption = random.choice(caption)
else:
raise ValueError(
f'Caption column `{self.caption_column}` should contain'
' either strings or lists of strings.')

img = dataset[i][self.image_column]
if type(img) == str:
img = os.path.join(self.dataset, img)

data_info = dict(img=img, prompt=caption)
data_list.append(data_info)

return data_list
8 changes: 6 additions & 2 deletions mmagic/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
GenerateFrameIndiceswithPadding,
GenerateSegmentIndices)
from .get_masked_image import GetMaskedImage
from .loading import (GetSpatialDiscountMask, LoadImageFromFile, LoadMask,
from .loading import (GetSpatialDiscountMask, LoadImageFromFile,
LoadImageFromHuggingFaceDataset, LoadMask,
LoadPairedImageFromFile)
from .matlab_like_resize import MATLABLikeResize
from .normalization import Normalize, RescaleToZeroOne
from .random_degradations import (DegradationsWithShuffle, RandomBlur,
RandomJPEGCompression, RandomNoise,
RandomResize, RandomVideoCompression)
from .random_down_sampling import RandomDownSampling
from .sdxl import ComputeTimeIds, FlipXL, RandomCropXL, ResizeEdge
from .trimap import (FormatTrimap, GenerateTrimap,
GenerateTrimapWithDistTransform, TransformTrimap)
from .values import CopyValues, SetValues
Expand All @@ -49,5 +51,7 @@
'GenerateTrimapWithDistTransform', 'CompositeFg', 'RandomLoadResizeBg',
'MergeFgAndBg', 'PerturbBg', 'RandomJitter', 'LoadPairedImageFromFile',
'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad', 'InstanceCrop',
'Albumentations', 'AlbuCorruptFunction', 'PairedAlbuTransForms'
'Albumentations', 'AlbuCorruptFunction', 'PairedAlbuTransForms',
'LoadImageFromHuggingFaceDataset', 'RandomCropXL', 'FlipXL',
'ComputeTimeIds', 'ResizeEdge'
]
Loading
Loading