Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions diffsynth_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
QwenImagePipelineConfig,
HunyuanPipelineConfig,
ZImagePipelineConfig,
Flux2KleinPipelineConfig,
SDStateDicts,
SDXLStateDicts,
FluxStateDicts,
WanStateDicts,
QwenImageStateDicts,
ZImageStateDicts,
Flux2StateDicts,
AttnImpl,
SpargeAttentionParams,
VideoSparseAttentionParams,
Expand All @@ -26,6 +28,7 @@
SDImagePipeline,
SDXLImagePipeline,
FluxImagePipeline,
Flux2KleinPipeline,
WanVideoPipeline,
WanDMDPipeline,
QwenImagePipeline,
Expand Down Expand Up @@ -59,12 +62,14 @@
"QwenImagePipelineConfig",
"HunyuanPipelineConfig",
"ZImagePipelineConfig",
"Flux2KleinPipelineConfig",
"SDStateDicts",
"SDXLStateDicts",
"FluxStateDicts",
"WanStateDicts",
"QwenImageStateDicts",
"ZImageStateDicts",
"Flux2StateDicts",
"AttnImpl",
"SpargeAttentionParams",
"VideoSparseAttentionParams",
Expand All @@ -78,6 +83,7 @@
"SDXLImagePipeline",
"SDXLControlNetUnion",
"FluxImagePipeline",
"Flux2KleinPipeline",
"FluxControlNet",
"FluxIPAdapter",
"FluxRedux",
Expand Down
68 changes: 68 additions & 0 deletions diffsynth_engine/conf/models/flux2/qwen3_8B_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12288,
"layer_types": [
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention"
],
"max_position_embeddings": 40960,
"max_window_layers": 36,
"model_type": "qwen3",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000,
"sliding_window": null,
"tie_word_embeddings": false,
"transformers_version": "4.56.1",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
}
4 changes: 4 additions & 0 deletions diffsynth_engine/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
QwenImagePipelineConfig,
HunyuanPipelineConfig,
ZImagePipelineConfig,
Flux2KleinPipelineConfig,
BaseStateDicts,
SDStateDicts,
SDXLStateDicts,
Expand All @@ -19,6 +20,7 @@
WanS2VStateDicts,
QwenImageStateDicts,
ZImageStateDicts,
Flux2StateDicts,
AttnImpl,
SpargeAttentionParams,
VideoSparseAttentionParams,
Expand All @@ -44,6 +46,7 @@
"QwenImagePipelineConfig",
"HunyuanPipelineConfig",
"ZImagePipelineConfig",
"Flux2KleinPipelineConfig",
"BaseStateDicts",
"SDStateDicts",
"SDXLStateDicts",
Expand All @@ -52,6 +55,7 @@
"WanS2VStateDicts",
"QwenImageStateDicts",
"ZImageStateDicts",
"Flux2StateDicts",
"AttnImpl",
"SpargeAttentionParams",
"VideoSparseAttentionParams",
Expand Down
51 changes: 50 additions & 1 deletion diffsynth_engine/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional
from typing_extensions import Literal

from diffsynth_engine.configs.controlnet import ControlType

Expand Down Expand Up @@ -339,6 +340,47 @@ def __post_init__(self):
init_parallel_config(self)


@dataclass
class Flux2KleinPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
model_path: str | os.PathLike | List[str | os.PathLike]
model_dtype: torch.dtype = torch.bfloat16
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
vae_dtype: torch.dtype = torch.bfloat16
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
encoder_dtype: torch.dtype = torch.bfloat16
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
image_encoder_dtype: torch.dtype = torch.bfloat16
model_size: Literal["4B", "9B"] = "4B"

@classmethod
def basic_config(
cls,
model_path: str | os.PathLike | List[str | os.PathLike],
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
device: str = "cuda",
parallelism: int = 1,
offload_mode: Optional[str] = None,
offload_to_disk: bool = False,
) -> "Flux2KleinPipelineConfig":
return cls(
model_path=model_path,
device=device,
encoder_path=encoder_path,
vae_path=vae_path,
image_encoder_path=image_encoder_path,
parallelism=parallelism,
use_cfg_parallel=True if parallelism > 1 else False,
use_fsdp=True if parallelism > 1 else False,
offload_mode=offload_mode,
offload_to_disk=offload_to_disk,
)

def __post_init__(self):
init_parallel_config(self)


@dataclass
class BaseStateDicts:
pass
Expand Down Expand Up @@ -398,7 +440,14 @@ class ZImageStateDicts:
image_encoder: Optional[Dict[str, torch.Tensor]] = None


def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig):
@dataclass
class Flux2StateDicts:
model: Dict[str, torch.Tensor]
vae: Dict[str, torch.Tensor]
encoder: Dict[str, torch.Tensor]


def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig | Flux2KleinPipelineConfig):
assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8"
config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg

Expand Down
7 changes: 7 additions & 0 deletions diffsynth_engine/models/flux2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .flux2_dit import Flux2DiT
from .flux2_vae import Flux2VAE

__all__ = [
"Flux2DiT",
"Flux2VAE",
]
Loading