Skip to content

Commit

Permalink
[Alpha-VLLM Team] Add Lumina-T2X to diffusers (#8652)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: zhuole1025 <zhuole1025@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
  • Loading branch information
3 people committed Jul 8, 2024
1 parent 9e9ed35 commit 9838867
Show file tree
Hide file tree
Showing 22 changed files with 2,478 additions and 17 deletions.
6 changes: 6 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@
title: DiTTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
- local: api/models/sd3_transformer2d
Expand Down Expand Up @@ -324,6 +326,8 @@
title: Latent Diffusion
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
title: Marigold
- local: api/pipelines/panorama
Expand Down Expand Up @@ -435,6 +439,8 @@
title: EulerDiscreteScheduler
- local: api/schedulers/flow_match_euler_discrete
title: FlowMatchEulerDiscreteScheduler
- local: api/schedulers/flow_match_heun_discrete
title: FlowMatchHeunDiscreteScheduler
- local: api/schedulers/heun
title: HeunDiscreteScheduler
- local: api/schedulers/ipndm
Expand Down
20 changes: 20 additions & 0 deletions docs/source/en/api/models/lumina_nextdit2d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# LuminaNextDiT2DModel

A Next Version of Diffusion Transformer model for 2D data from [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X).

## LuminaNextDiT2DModel

[[autodoc]] LuminaNextDiT2DModel

88 changes: 88 additions & 0 deletions docs/source/en/api/pipelines/lumina.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Lumina-T2X
![concepts](https://github.com/Alpha-VLLM/Lumina-T2X/assets/54879512/9f52eabb-07dc-4881-8257-6d8a5f2a0a5a)

[Lumina-Next : Making Lumina-T2X Stronger and Faster with Next-DiT](https://github.com/Alpha-VLLM/Lumina-T2X/blob/main/assets/lumina-next.pdf) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.

The abstract from the paper is:

*Lumina-T2X is a nascent family of Flow-based Large Diffusion Transformers (Flag-DiT) that establishes a unified framework for transforming noise into various modalities, such as images and videos, conditioned on text instructions. Despite its promising capabilities, Lumina-T2X still encounters challenges including training instability, slow inference, and extrapolation artifacts. In this paper, we present Lumina-Next, an improved version of Lumina-T2X, showcasing stronger generation performance with increased training and inference efficiency. We begin with a comprehensive analysis of the Flag-DiT architecture and identify several suboptimal components, which we address by introducing the Next-DiT architecture with 3D RoPE and sandwich normalizations. To enable better resolution extrapolation, we thoroughly compare different context extrapolation methods applied to text-to-image generation with 3D RoPE, and propose Frequency- and Time-Aware Scaled RoPE tailored for diffusion transformers. Additionally, we introduce a sigmoid time discretization schedule to reduce sampling steps in solving the Flow ODE and the Context Drop method to merge redundant visual tokens for faster network evaluation, effectively boosting the overall sampling speed. Thanks to these improvements, Lumina-Next not only improves the quality and efficiency of basic text-to-image generation but also demonstrates superior resolution extrapolation capabilities and multilingual generation using decoder-based LLMs as the text encoder, all in a zero-shot manner. To further validate Lumina-Next as a versatile generative framework, we instantiate it on diverse tasks including visual recognition, multi-view, audio, music, and point cloud generation, showcasing strong performance across these domains. By releasing all codes and model weights at https://github.com/Alpha-VLLM/Lumina-T2X, we aim to advance the development of next-generation generative AI capable of universal modeling.*

**Highlights**: Lumina-Next is a next-generation Diffusion Transformer that significantly enhances text-to-image generation, multilingual generation, and multitask performance by introducing the Next-DiT architecture, 3D RoPE, and frequency- and time-aware RoPE, among other improvements.

Lumina-Next has the following components:
* It improves sampling efficiency with fewer and faster Steps.
* It uses a Next-DiT as a transformer backbone with Sandwichnorm 3D RoPE, and Grouped-Query Attention.
* It uses a Frequency- and Time-Aware Scaled RoPE.

---

[Lumina-T2X: Transforming Text into Any Modality, Resolution, and Duration via Flow-based Large Diffusion Transformers](https://arxiv.org/abs/2405.05945) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.

The abstract from the paper is:

*Sora unveils the potential of scaling Diffusion Transformer for generating photorealistic images and videos at arbitrary resolutions, aspect ratios, and durations, yet it still lacks sufficient implementation details. In this technical report, we introduce the Lumina-T2X family - a series of Flow-based Large Diffusion Transformers (Flag-DiT) equipped with zero-initialized attention, as a unified framework designed to transform noise into images, videos, multi-view 3D objects, and audio clips conditioned on text instructions. By tokenizing the latent spatial-temporal space and incorporating learnable placeholders such as [nextline] and [nextframe] tokens, Lumina-T2X seamlessly unifies the representations of different modalities across various spatial-temporal resolutions. This unified approach enables training within a single framework for different modalities and allows for flexible generation of multimodal data at any resolution, aspect ratio, and length during inference. Advanced techniques like RoPE, RMSNorm, and flow matching enhance the stability, flexibility, and scalability of Flag-DiT, enabling models of Lumina-T2X to scale up to 7 billion parameters and extend the context window to 128K tokens. This is particularly beneficial for creating ultra-high-definition images with our Lumina-T2I model and long 720p videos with our Lumina-T2V model. Remarkably, Lumina-T2I, powered by a 5-billion-parameter Flag-DiT, requires only 35% of the training computational costs of a 600-million-parameter naive DiT. Our further comprehensive analysis underscores Lumina-T2X's preliminary capability in resolution extrapolation, high-resolution editing, generating consistent 3D views, and synthesizing videos with seamless transitions. We expect that the open-sourcing of Lumina-T2X will further foster creativity, transparency, and diversity in the generative AI community.*


You can find the original codebase at [Alpha-VLLM](https://github.com/Alpha-VLLM/Lumina-T2X) and all the available checkpoints at [Alpha-VLLM Lumina Family](https://huggingface.co/collections/Alpha-VLLM/lumina-family-66423205bedb81171fd0644b).

**Highlights**: Lumina-T2X supports Any Modality, Resolution, and Duration.

Lumina-T2X has the following components:
* It uses a Flow-based Large Diffusion Transformer as the backbone
* It supports different any modalities with one backbone and corresponding encoder, decoder.

<Tip>

Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.

</Tip>

### Inference (Text-to-Image)

Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.

First, load the pipeline:

```python
from diffusers import LuminaText2ImgPipeline
import torch

pipeline = LuminaText2ImgPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
).to("cuda")
```

Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:

```python
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
```

Finally, compile the components and run inference:

```python
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)

image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures").images[0]
```

## LuminaText2ImgPipeline

[[autodoc]] LuminaText2ImgPipeline
- all
- __call__

18 changes: 18 additions & 0 deletions docs/source/en/api/schedulers/flow_match_heun_discrete.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# FlowMatchHeunDiscreteScheduler

`FlowMatchHeunDiscreteScheduler` is based on the flow-matching sampling introduced in [EDM](https://arxiv.org/abs/2403.03206).

## FlowMatchHeunDiscreteScheduler
[[autodoc]] FlowMatchHeunDiscreteScheduler
142 changes: 142 additions & 0 deletions scripts/convert_lumina_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import argparse
import os

import torch
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer

from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline


def main(args):
# checkpoint from https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT or https://huggingface.co/Alpha-VLLM/Lumina-Next-T2I
all_sd = load_file(args.origin_ckpt_path, device="cpu")
converted_state_dict = {}
# pad token
converted_state_dict["pad_token"] = all_sd["pad_token"]

# patch embed
converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"]
converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"]

# time and caption embed
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"]
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"]
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"]
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"]
converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"]
converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"]
converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"]
converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"]

for i in range(24):
# adaln
converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"]
converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"]
converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"]

# qkv
converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"]
converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"]

# cap
converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"]
converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"]

# output
converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"]

# attention
# qk norm
converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]

converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"]
converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"]

converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]

converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"]
converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"]

# attention norm
converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"]
converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"]
converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"]

# feed forward
converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"]
converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"]
converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"]

# feed forward norm
converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"]
converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"]

# final layer
converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"]
converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"]

converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"]
converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"]

# Lumina-Next-SFT 2B
transformer = LuminaNextDiT2DModel(
sample_size=128,
patch_size=2,
in_channels=4,
hidden_size=2304,
num_layers=24,
num_attention_heads=32,
num_kv_heads=8,
multiple_of=256,
ffn_dim_multiplier=None,
norm_eps=1e-5,
learn_sigma=True,
qk_norm=True,
cross_attention_dim=2048,
scaling_factor=1.0,
)
transformer.load_state_dict(converted_state_dict, strict=True)

num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")

if args.only_transformer:
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
else:
scheduler = FlowMatchEulerDiscreteScheduler()

vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
text_encoder = AutoModel.from_pretrained("google/gemma-2b")

pipeline = LuminaText2ImgPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
)
pipeline.save_pretrained(args.dump_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--image_size",
default=1024,
type=int,
choices=[256, 512, 1024],
required=False,
help="Image size of pretrained model, either 512 or 1024.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--only_transformer", default=True, type=bool, required=True)

args = parser.parse_args()
main(args)
6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
"HunyuanDiT2DMultiControlNetModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"LuminaNextDiT2DModel",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
Expand Down Expand Up @@ -162,6 +163,7 @@
"EulerAncestralDiscreteScheduler",
"EulerDiscreteScheduler",
"FlowMatchEulerDiscreteScheduler",
"FlowMatchHeunDiscreteScheduler",
"HeunDiscreteScheduler",
"IPNDMScheduler",
"KarrasVeScheduler",
Expand Down Expand Up @@ -270,6 +272,7 @@
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldNormalsPipeline",
"MusicLDMPipeline",
Expand Down Expand Up @@ -509,6 +512,7 @@
HunyuanDiT2DMultiControlNetModel,
I2VGenXLUNet,
Kandinsky3UNet,
LuminaNextDiT2DModel,
ModelMixin,
MotionAdapter,
MultiAdapter,
Expand Down Expand Up @@ -580,6 +584,7 @@
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
FlowMatchEulerDiscreteScheduler,
FlowMatchHeunDiscreteScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
Expand Down Expand Up @@ -669,6 +674,7 @@
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
MusicLDMPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
Expand Down Expand Up @@ -85,6 +86,7 @@
DiTTransformer2DModel,
DualTransformer2DModel,
HunyuanDiT2DModel,
LuminaNextDiT2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3Transformer2DModel,
Expand Down
Loading

0 comments on commit 9838867

Please sign in to comment.