Skip to content

Commit

Permalink
1. Update README.md;
Browse files Browse the repository at this point in the history
2. code update;
  • Loading branch information
lawrence-cj committed Apr 11, 2024
1 parent ac59cfb commit 318c5a5
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 51 deletions.
69 changes: 59 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ we will try to keep this repo as simple as possible so that everyone in the PixA

---
## Breaking News 🔥🔥!!
- (🔥 New) Apr. 11, 2024. 💥 [PixArt-Σ Demo](#3-pixart-demo) & [PixArt-Σ Pipeline](#2-integration-in-diffusers)! PixArt-Σ supports `🧨 diffusers` using [patches](scripts/diffusers_patches.py) for fast experience!
- (🔥 New) Apr. 10, 2024. 💥 PixArt-α-DMD one step sampler [demo code](app/app_pixart_dmd.py) & [PixArt-α-DMD checkpoint](https://huggingface.co/PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512) 512px are released!
- (🔥 New) Apr. 9, 2024. 💥 [PixArt-Σ checkpoint](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-1024-MS.pth) 1024px is released!
- (🔥 New) Apr. 6, 2024. 💥 [PixArt-Σ checkpoint](https://huggingface.co/PixArt-alpha/PixArt-Sigma/tree/main) 256px & 512px are released!
Expand Down Expand Up @@ -158,25 +159,72 @@ python scripts/interface.py --model_path output/pretrained_models/PixArt-Sigma-X
```

## 2. Integration in diffusers
(Coming soon)
**First**
```bash
pip install git+https://github.com/huggingface/diffusers
```
**Then**
```python
import torch
from diffusers import Transformer2DModel
from scripts.diffusers_patches import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline

setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transformer = Transformer2DModel.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
subfolder='transformer',
use_safetensors=True,
)
pipe = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
transformer=transformer,
use_safetensors=True,
)
pipe.to(device)

# Enable memory optimizations.
# pipe.enable_model_cpu_offload()

prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]
image.save("./catcus.png")
```

## 3. PixArt-DMD Demo
## 3. PixArt Demo
```bash
pip install git+https://github.com/huggingface/diffusers

# PixArt-Sigma
DEMO_PORT=12345 python app/app_pixart_sigma.py

# PixArt-Sigma One step Sampler(DMD)
DEMO_PORT=12345 python app/app_pixart_dmd.py
```
Let's have a look at a simple example using the `http://your-server-ip:12345`.


## 4. Convert .pth checkpoint into diffusers version
Directly download from [Hugging Face](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS)

or run with:
```bash
pip install git+https://github.com/huggingface/diffusers

python tools/convert_pixart_to_diffusers.py --orig_ckpt_path output/pretrained_models/PixArt-Sigma-XL-2-1024-MS.pth --dump_path output/pretrained_models/PixArt-Sigma-XL-2-1024-MS --only_transformer=True --image_size=1024 --version sigma
```

# ⏬ Available Models
All models will be automatically downloaded [here](#12-download-pretrained-checkpoint). You can also choose to download manually from this [url](https://huggingface.co/PixArt-alpha/PixArt-Sigma).

| Model | #Params | Checkpoint path | Download in OpenXLab |
|:-----------------|:--------|:--------------------------------------------------------------------------------------------------------------------------|:---------------------|
| T5 & SDXL-VAE | 4.5B | [pixart_sigma_sdxlvae_T5_diffusers](https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers) | [coming soon]( ) |
| PixArt-Σ-256 | 0.6B | [PixArt-Sigma-XL-2-256x256.pth](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-256x256.pth) | [coming soon]( ) |
| PixArt-Σ-512 | 0.6B | [PixArt-Sigma-XL-2-512-MS.pth](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-512-MS.pth) | [coming soon]( ) |
| PixArt-α-512-DMD | 0.6B | Diffusers: [PixArt-Alpha-DMD-XL-2-512x512](https://huggingface.co/PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512) | [coming soon]( ) |
| PixArt-Σ-1024 | 0.6B | [PixArt-Sigma-XL-2-1024-MS.pth](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-1024-MS.pth) | [coming soon]( ) |
| Model | #Params | Checkpoint path | Download in OpenXLab |
|:-----------------|:--------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:---------------------|
| T5 & SDXL-VAE | 4.5B | Diffusers: [pixart_sigma_sdxlvae_T5_diffusers](https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers) | [coming soon]( ) |
| PixArt-Σ-256 | 0.6B | pth: [PixArt-Sigma-XL-2-256x256.pth](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-256x256.pth) <br/> Diffusers: [PixArt-Sigma-XL-2-256x256](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-256x256) | [coming soon]( ) |
| PixArt-Σ-512 | 0.6B | pth: [PixArt-Sigma-XL-2-512-MS.pth](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-512-MS.pth) <br/> Diffusers: [PixArt-Sigma-XL-2-512-MS](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-512-MS) | [coming soon]( ) |
| PixArt-α-512-DMD | 0.6B | Diffusers: [PixArt-Alpha-DMD-XL-2-512x512](https://huggingface.co/PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512) | [coming soon]( ) |
| PixArt-Σ-1024 | 0.6B | pth: [PixArt-Sigma-XL-2-1024-MS.pth](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-1024-MS.pth) <br/> Diffusers: [PixArt-Sigma-XL-2-1024-MS](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-1024-MS) | [coming soon]( ) |


## 💪To-Do List
Expand All @@ -186,6 +234,7 @@ We will try our best to release
- [x] Inference code
- [x] Inference code of One Step Sampling with [DMD](https://arxiv.org/abs/2311.18828)
- [x] Model zoo (256/512/1024)
- [ ] Diffusers
- [x] Diffusers (for fast experience)
- [ ] Diffusers (stable official version)
- [ ] Training code of One Step Sampling with [DMD](https://arxiv.org/abs/2311.18828)
- [ ] Model zoo (2K)
23 changes: 11 additions & 12 deletions app/app_pixart_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,23 @@
import gradio as gr
import numpy as np
import uuid
from diffusers import ConsistencyDecoderVAE, PixArtAlphaPipeline, DPMSolverMultistepScheduler, Transformer2DModel, AutoencoderKL
from diffusers import ConsistencyDecoderVAE, DPMSolverMultistepScheduler, Transformer2DModel, AutoencoderKL
import torch
from typing import Tuple
from datetime import datetime
from diffusion.sa_solver_diffusers import SASolverScheduler
from peft import PeftModel
from scripts.diffusers_patches import pixart_sigma_init_patched_inputs
from scripts.diffusers_patches import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline


DESCRIPTION = """![Logo](https://raw.githubusercontent.com/PixArt-alpha/PixArt-alpha.github.io/master/static/images/logo.png)
# PixArt-Alpha 1024px
#### [PixArt-Alpha 1024px](https://github.com/PixArt-alpha/PixArt-alpha) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-alpha/PixArt-XL-2-1024-MS](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) checkpoint.
DESCRIPTION = """![Logo](https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma-project/master/static/images/logo-sigma.png)
# PixArt-Sigma 1024px
#### [PixArt-Sigma 1024px](https://github.com/PixArt-alpha/PixArt-sigma) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-alpha/PixArt-XL-2-1024-MS](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) checkpoint.
#### English prompts ONLY; 提示词仅限英文
Don't want to queue? Try [OpenXLab](https://openxlab.org.cn/apps/detail/PixArt-alpha/PixArt-alpha) or [Google Colab Demo](https://colab.research.google.com/drive/1jZ5UZXk7tcpTfVwnX33dDuefNMcnW9ME?usp=sharing).
### <span style='color: red;'>You may change the DPM-Solver inference steps from 14 to 20, if you didn't get satisfied results.
### <span style='color: red;'>You may change the DPM-Solver inference steps from 14 to 20, or DPM-Solver Guidance scale from 4.5 to 3.5 if you didn't get satisfied results.
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU �� This demo does not work on CPU.</p>"
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
Expand Down Expand Up @@ -142,7 +141,7 @@ def get_args():
subfolder='transformer',
torch_dtype=weight_dtype,
)
pipe = PixArtAlphaPipeline.from_pretrained(
pipe = PixArtSigmaPipeline.from_pretrained(
args.pipeline_load_from,
transformer=transformer,
torch_dtype=weight_dtype,
Expand All @@ -152,7 +151,7 @@ def get_args():
assert args.lora_repo_id is not None
transformer = Transformer2DModel.from_pretrained(args.repo_id, subfolder="transformer", torch_dtype=torch.float16)
transformer = PeftModel.from_pretrained(transformer, args.lora_repo_id)
pipe = PixArtAlphaPipeline.from_pretrained(
pipe = PixArtSigmaPipeline.from_pretrained(
args.repo_id,
transformer=transformer,
torch_dtype=torch.float16,
Expand Down Expand Up @@ -207,7 +206,7 @@ def generate(
width: int = 1024,
height: int = 1024,
schedule: str = 'DPM-Solver',
dpms_guidance_scale: float = 3.5,
dpms_guidance_scale: float = 4.5,
sas_guidance_scale: float = 3,
dpms_inference_steps: int = 20,
sas_inference_steps: int = 25,
Expand Down Expand Up @@ -350,7 +349,7 @@ def generate(
minimum=1,
maximum=10,
step=0.1,
value=3.5,
value=4.5,
)
dpms_inference_steps = gr.Slider(
label="DPM-Solver inference steps",
Expand Down
82 changes: 54 additions & 28 deletions scripts/diffusers_patches.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import torch
from diffusers import ImagePipelineOutput
from diffusers import ImagePipelineOutput, PixArtAlphaPipeline, AutoencoderKL, Transformer2DModel, \
DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.embeddings import PixArtAlphaTextProjection, PatchEmbed
from diffusers.models.normalization import AdaLayerNormSingle
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import retrieve_timesteps
from diffusers.utils import deprecate
from torch import nn
from typing import Callable, List, Optional, Tuple, Union

from diffusers.utils import deprecate
from torch import nn
from transformers import T5Tokenizer, T5EncoderModel

ASPECT_RATIO_2048_BIN = {
"0.25": [1024.0, 4096.0],
Expand Down Expand Up @@ -162,30 +165,30 @@


def pipeline_pixart_alpha_call(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
timesteps: List[int] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
use_resolution_binning: bool = True,
max_sequence_length: int = 120,
**kwargs,
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
timesteps: List[int] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
use_resolution_binning: bool = True,
max_sequence_length: int = 120,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -441,6 +444,29 @@ def pipeline_pixart_alpha_call(
return ImagePipelineOutput(images=image)


class PixArtSigmaPipeline(PixArtAlphaPipeline):
r"""
tmp Pipeline for text-to-image generation using PixArt-Sigma.
"""

def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKL,
transformer: Transformer2DModel,
scheduler: DPMSolverMultistepScheduler,
):
super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)

self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)

self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)


def pixart_sigma_init_patched_inputs(self, norm_type):
assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"

Expand Down Expand Up @@ -493,7 +519,7 @@ def pixart_sigma_init_patched_inputs(self, norm_type):
)
elif self.config.norm_type == "ada_norm_single":
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim ** 0.5)
self.proj_out = nn.Linear(
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
)
Expand Down
2 changes: 1 addition & 1 deletion tools/convert_pixart_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def main(args):
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
f"blocks.{depth}.scale_shift_table"
)
# Attention is all you need ��
# Attention is all you need 🤘

# Self attention.
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
Expand Down

0 comments on commit 318c5a5

Please sign in to comment.