-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feature] support stable diffusion inference (#1502)
* [sd] add files and run good. * [sd] misc change. * [sd] remove unused files. * [sd] directly load 7 submodels. * [sd] make pipeline clear * [sd] remove unrelated scheduler * [sd] use our own scheduler * [sd] rm schedulers and outputs data structures * [sd] replace log with mmengine log * [sd] rm utils dir * [sd] remove configure utils and model utils * [sd] move transformer models to clip wrapper. * [sd] load resource from url. * [sd] remove utils and accelerate related * [sd] remove utils * [sd] seperate vae from unet. * [sd] move vae outside. * [sd] move conditional unet to ddpm * [sd] add stable unet to denoisenet. * [sd] use denoising unet in ddpm and run good. * [sd] unet forward with stable type * [sd] delete unused code. * [sd] remove default parameters * [sd] add copy right and format clip_wrapper.py * [sd] format vae.py * [sd] formate stable_diffuser.py * [sd] append to last commit * [sd] format unet_blocks.py * [sd] format files. * [sd] format init.py * [sd] format demo * [sd] format config * [sd] add docsting. * [sd] add transformers dependency. * [sd] rename to stablediffusion. * [sd] add docstr in stable_diffusion.py * [sd] fix linter complain * [sd] res_block.py add docstring. * [sd] add docstring for vae.py * [sd] fix linter. * [sd] add docstrings for unet_blocks.py * [sd] stable diffusin return torch tensor * [sd] run linter. * [sd] add docstr * [sd] add docstring * [sd] put load ckpt together. * [sd] misc change * [sd] add clip wrapper ut. * [sd] add stable_diffusion ut * [sd] sd ut skip windows cuda * [sd] add vae ut. * [sd] fix linter. * [sd] fix vae ut. * [sd] add ddpm ddim ut and remove unused block * [sd] remove ut untested code. * [sd] add resblock ut * [sd] add resblock ut. * [sd] add attention ut. * [sd] add unet clock ut. * [sd] add embeddings ut. * [sd] add unet block ut. * [sd] add denoising unet ut. * [sd] add vae ut. * [sd] ddim ut * [sd] add ddim ddpm ut * [sd] add ut. * [ad] add attention ut. * [sd] remove useless code. * [sd] add sd ut. * [sd] add ddpm ut. * [sd] rename config. * [sd] put function inside timestep class. * [sd] use basemodel for sd. * [sd] add check for silu * [sd] fix tpo. * [sd] add ut. * [sd] add ut for unet_blocks.py * [sd] add ut. * [sd] load pretrained weights as mm way. * [sd] rename device. * [sd] remove main function in test files. * [sd] add readme and remove demo. * [sd] add stable diffusion readme. * [sd] load pretrained ckpt by diffusers. * [sd] update readme. * [sd] fix clip_wrapper ut. * [sd] format sd config. * [sd] update metafile.yml * [sd] try import transformers.
- Loading branch information
Showing
28 changed files
with
4,311 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Stable Diffusion (2022) | ||
|
||
> [Stable Diffusion](https://github.com/CompVis/stable-diffusion) | ||
> **Task**: Text2Image | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
<!-- [ABSTRACT] --> | ||
|
||
Stable Diffusion is a latent diffusion model conditioned on the text embeddings of a CLIP text encoder, which allows you to create images from text inputs. | ||
|
||
<!-- [IMAGE] --> | ||
|
||
<div align=center > | ||
<img src="https://user-images.githubusercontent.com/12782558/209609229-8221c7cc-d5c9-44d5-a1af-c254b5a95fae.png" width="400"/> | ||
</div > | ||
|
||
## Pretrained models | ||
|
||
We use stable diffusion v1.5 weights. This model has several weights including vae, unet and clip. You should download the weights from [stable-diffusion-1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) and change the 'pretrained_model_path' in config to the weights dir. | ||
|
||
| Diffusion Model | Config | Download | | ||
| :-------------------: | :------------------------------------------------: | :------------------------------------------------------------: | | ||
| stable_diffusion_v1.5 | [config](./stable-diffusion_ddim_denoisingunet.py) | [model](https://huggingface.co/runwayml/stable-diffusion-v1-5) | | ||
|
||
## Quick Start | ||
|
||
Running the following codes, you can get a text-generated image. | ||
|
||
```python | ||
from mmengine import MODELS, Config | ||
from torchvision import utils | ||
|
||
from mmedit.utils import register_all_modules | ||
|
||
register_all_modules() | ||
|
||
config = 'configs/stable_diffusion/stable-diffusion_ddim_denoisingunet.py' | ||
StableDiffuser = MODELS.build(Config.fromfile(config).model) | ||
prompt = 'A mecha robot in a favela in expressionist style' | ||
StableDiffuser = StableDiffuser.to('cuda') | ||
|
||
image = StableDiffuser.infer(prompt)['samples'] | ||
utils.save_image(image, 'robot.png') | ||
``` | ||
|
||
## Comments | ||
|
||
Our codebase for the stable diffusion models builds heavily on [diffusers codebase](https://github.com/huggingface/diffusers) and the model weights are from [stable-diffusion-1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5). | ||
|
||
Thanks for the efforts of the community! | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@misc{rombach2021highresolution, | ||
title={High-Resolution Image Synthesis with Latent Diffusion Models}, | ||
author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Bj枚rn Ommer}, | ||
year={2021}, | ||
eprint={2112.10752}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CV} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
Collections: | ||
- Metadata: | ||
Architecture: | ||
- Stable Diffusion | ||
Name: Stable Diffusion | ||
Paper: | ||
- https://github.com/CompVis/stable-diffusion | ||
README: configs/stable_diffusion/README.md | ||
Task: | ||
- text2image | ||
Year: 2022 | ||
Models: | ||
- Config: configs/stable_diffusion/stable-diffusion_ddim_denoisingunet.py | ||
In Collection: Stable Diffusion | ||
Metadata: | ||
Training Data: Others | ||
Name: stable-diffusion_ddim_denoisingunet | ||
Results: | ||
- Dataset: Others | ||
Metrics: {} | ||
Task: Text2Image | ||
Weights: https://huggingface.co/runwayml/stable-diffusion-v1-5 |
58 changes: 58 additions & 0 deletions
58
configs/stable_diffusion/stable-diffusion_ddim_denoisingunet.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
unet = dict( | ||
type='DenoisingUnet', | ||
image_size=512, | ||
base_channels=320, | ||
channels_cfg=[1, 2, 4, 4], | ||
unet_type='stable', | ||
act_cfg=dict(type='silu'), | ||
cross_attention_dim=768, | ||
num_heads=8, | ||
in_channels=4, | ||
layers_per_block=2, | ||
down_block_types=[ | ||
'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', | ||
'DownBlock2D' | ||
], | ||
up_block_types=[ | ||
'UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', | ||
'CrossAttnUpBlock2D' | ||
], | ||
output_cfg=dict(var='fixed')) | ||
|
||
vae = dict( | ||
act_fn='silu', | ||
block_out_channels=[128, 256, 512, 512], | ||
down_block_types=[ | ||
'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', | ||
'DownEncoderBlock2D' | ||
], | ||
in_channels=3, | ||
latent_channels=4, | ||
layers_per_block=2, | ||
norm_num_groups=32, | ||
out_channels=3, | ||
sample_size=512, | ||
up_block_types=[ | ||
'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', | ||
'UpDecoderBlock2D' | ||
]) | ||
|
||
diffusion_scheduler = dict( | ||
type='DDIMScheduler', | ||
variance_type='learned_range', | ||
beta_end=0.012, | ||
beta_schedule='scaled_linear', | ||
beta_start=0.00085, | ||
num_train_timesteps=1000, | ||
set_alpha_to_one=False, | ||
clip_sample=False) | ||
|
||
init_cfg = dict(type='Pretrained', pretrained_model_path='') | ||
|
||
model = dict( | ||
type='StableDiffusion', | ||
diffusion_scheduler=diffusion_scheduler, | ||
unet=unet, | ||
vae=vae, | ||
init_cfg=init_cfg, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.