In [None]:
"""Z-Image PyTorch Native Inference."""

import os
import time
import warnings
import random
import torch

warnings.filterwarnings("ignore")

from utils import AttentionBackend, ensure_model_weights, load_from_local_dir, set_attention_backend
from zimage import generate
from peft import PeftModel  # 新增导入



model_path = ensure_model_weights("/home/xxx/z-image", verify=False)  # True to verify with md5
dtype = torch.bfloat16
compile = False  # default False for compatibility
out_path = "/home/xxx/Z-Image/pic_test"
os.makedirs(out_path, exist_ok=True)
output_path = out_path+"/test1.png"
height = 1024
width = 1024
num_inference_steps = 8
guidance_scale = 0.0

attn_backend = os.environ.get("ZIMAGE_ATTENTION", "_native_flash")
# attn_backend = os.environ.get("ZIMAGE_ATTENTION", "flash")
# Available attention backends list: ['flash', 'flash_varlen', '_flash_3', '_flash_varlen_3', 'native', '_native_flash', '_native_math']
#这段代码中的 set_attention_backend 是一个配置函数
#用于为你的 AI 模型选择不同的注意力计算后端。简单理解就是：它决定了模型在计算注意力时使用的"算法引擎"。
# | 后端名称              | 含义                 | 适用场景                     |
# | ----------------- | ------------------ | ------------------------ |
# | `"_native_flash"` | 原生 Flash Attention | 通用，性能较好                  |
# | `"flash"`         | Flash Attention 2  | 需要手动安装 flash-attn 库      |
# | `"_flash_3"`      | Flash Attention 3  | 仅支持 Hopper 架构 GPU (H100) |
# | `"sdpa"`          | PyTorch SDPA 默认    | 兼容性最好，无需额外安装             |
# | `"xformers"`      | xFormers 实现        | 需要安装 xformers 库          |

# prompt = (
#     "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. "
#     "Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. "
#     "Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, "
#     "silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
# )
device = "cuda:0"
print("Chosen device:"+device)

# Load models
components = load_from_local_dir(model_path, device=device, dtype=dtype, compile=compile)
#================== 新增：LoRA 加载代码 ==================
lora_path = ""  # 修改为你的 LoRA 路径

print(f"正在加载 LoRA 权重: {lora_path}")

# 方法A：不合并权重（推荐，支持多 LoRA 切换和权重调节）
components["transformer"] = PeftModel.from_pretrained(
    components["transformer"],
    lora_path,
    torch_dtype=dtype,  # 保持与基础模型一致 (bfloat16)
    device_map=None,    # 因为我们已经手动移动到 device
)

# 可选：如果有多个 LoRA 想切换，可以设置活动适配器
# components["transformer"].set_adapter("default")  # "default" 是默认名称

# 方法B：合并权重（推理速度更快，但无法继续训练或调节权重）
# components["transformer"] = PeftModel.from_pretrained(
#     components["transformer"], 
#     lora_path,
#     torch_dtype=dtype,
# )
# components["transformer"] = components["transformer"].merge_and_unload()  # 合并后不再是 PeftModel
# print("✅ LoRA 权重已合并到基础模型")

# 确保模型在评估模式（重要！）
components["transformer"].eval()

# 验证可训练参数（应显示 trainable params: 0，因为推理时冻结）
if hasattr(components["transformer"], "print_trainable_parameters"):
    components["transformer"].print_trainable_parameters()
# =====================================================


AttentionBackend.print_available_backends()
set_attention_backend(attn_backend)
print(f"Chosen attention backend: {attn_backend}")



PyTorch version is >= 2.5.0, check pass.


2026-01-29 20:33:18.649969: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-29 20:33:18.706437: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI AVX_VNNI_INT8 AVX_NE_CONVERT FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-29 20:33:19.680450: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
[32m2026-01-29 20:33:20.316[0m | [1mINFO    [0m | [36muti

Chosen device:cuda:0


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

正在加载 LoRA 权重: /home/liaoge/Z-Image/src/zimage_lora_output_1/checkpoint-1200
trainable params: 0 || all params: 6,236,009,536 || trainable%: 0.0000
Available attention backends list: ['flash', 'flash_varlen', '_flash_3', '_flash_varlen_3', 'native', '_native_flash', '_native_math']
Chosen attention backend: _native_flash


In [None]:
for i in range(10):
    prompt = (
    ""
)
   
    output_path = out_path+f"/qc{i}.png"
    seed = random.randint(43, 10000)
        
    # Gen an image
    start_time = time.time()
    images = generate(
        prompt=prompt,
        **components,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=torch.Generator(device).manual_seed(seed),
    )
    end_time = time.time()
    print(f"Time taken: {end_time - start_time:.2f} seconds")
    images[0].save(output_path)

### !! For best speed performance, recommend to use `_flash_3` backend and set `compile=True`
### This would give you sub-second generation speed on Hopper GPU (H100/H200/H800) after warm-up



[32m2026-01-29 20:33:28.409[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m106[0m - [1mGenerating image: 1024x1024, steps=8, cfg=0.0[0m


[32m2026-01-29 20:33:28.937[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m284[0m - [1mSampling loop start: 8 steps[0m
Denoising: 100%|██████████| 8/8 [00:04<00:00,  1.85it/s]


Pre-decode latent range: [-12.52, 10.09]
Scaling factor: 0.3611
Shift factor: 0.1159
Time taken: 5.64 seconds


[32m2026-01-29 20:33:34.656[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m106[0m - [1mGenerating image: 1024x1024, steps=8, cfg=0.0[0m
[32m2026-01-29 20:33:34.700[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m284[0m - [1mSampling loop start: 8 steps[0m
Denoising: 100%|██████████| 8/8 [00:04<00:00,  1.88it/s]


Pre-decode latent range: [-13.08, 9.40]
Scaling factor: 0.3611
Shift factor: 0.1159
Time taken: 4.81 seconds


[32m2026-01-29 20:33:40.064[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m106[0m - [1mGenerating image: 1024x1024, steps=8, cfg=0.0[0m
[32m2026-01-29 20:33:40.107[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m284[0m - [1mSampling loop start: 8 steps[0m
Denoising: 100%|██████████| 8/8 [00:04<00:00,  1.88it/s]


Pre-decode latent range: [-13.70, 9.22]
Scaling factor: 0.3611
Shift factor: 0.1159
Time taken: 4.81 seconds


[32m2026-01-29 20:33:45.485[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m106[0m - [1mGenerating image: 1024x1024, steps=8, cfg=0.0[0m
[32m2026-01-29 20:33:45.528[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m284[0m - [1mSampling loop start: 8 steps[0m
Denoising: 100%|██████████| 8/8 [00:04<00:00,  1.88it/s]


Pre-decode latent range: [-12.49, 9.41]
Scaling factor: 0.3611
Shift factor: 0.1159
Time taken: 4.81 seconds


[32m2026-01-29 20:33:50.899[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m106[0m - [1mGenerating image: 1024x1024, steps=8, cfg=0.0[0m
[32m2026-01-29 20:33:50.942[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m284[0m - [1mSampling loop start: 8 steps[0m
Denoising: 100%|██████████| 8/8 [00:04<00:00,  1.88it/s]


Pre-decode latent range: [-12.91, 9.64]
Scaling factor: 0.3611
Shift factor: 0.1159
Time taken: 4.81 seconds


[32m2026-01-29 20:33:56.330[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m106[0m - [1mGenerating image: 1024x1024, steps=8, cfg=0.0[0m
[32m2026-01-29 20:33:56.373[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m284[0m - [1mSampling loop start: 8 steps[0m
Denoising: 100%|██████████| 8/8 [00:04<00:00,  1.88it/s]


Pre-decode latent range: [-12.77, 9.26]
Scaling factor: 0.3611
Shift factor: 0.1159
Time taken: 4.81 seconds


[32m2026-01-29 20:34:01.738[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m106[0m - [1mGenerating image: 1024x1024, steps=8, cfg=0.0[0m
[32m2026-01-29 20:34:01.780[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m284[0m - [1mSampling loop start: 8 steps[0m
Denoising: 100%|██████████| 8/8 [00:04<00:00,  1.88it/s]


Pre-decode latent range: [-12.41, 10.06]
Scaling factor: 0.3611
Shift factor: 0.1159
Time taken: 4.81 seconds


[32m2026-01-29 20:34:07.151[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m106[0m - [1mGenerating image: 1024x1024, steps=8, cfg=0.0[0m
[32m2026-01-29 20:34:07.194[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m284[0m - [1mSampling loop start: 8 steps[0m
Denoising: 100%|██████████| 8/8 [00:04<00:00,  1.88it/s]


Pre-decode latent range: [-12.73, 9.72]
Scaling factor: 0.3611
Shift factor: 0.1159
Time taken: 4.81 seconds


[32m2026-01-29 20:34:12.567[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m106[0m - [1mGenerating image: 1024x1024, steps=8, cfg=0.0[0m
[32m2026-01-29 20:34:12.610[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m284[0m - [1mSampling loop start: 8 steps[0m
Denoising: 100%|██████████| 8/8 [00:04<00:00,  1.88it/s]


Pre-decode latent range: [-12.51, 8.85]
Scaling factor: 0.3611
Shift factor: 0.1159
Time taken: 4.81 seconds


[32m2026-01-29 20:34:17.984[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m106[0m - [1mGenerating image: 1024x1024, steps=8, cfg=0.0[0m
[32m2026-01-29 20:34:18.026[0m | [1mINFO    [0m | [36mzimage.pipeline[0m:[36mgenerate[0m:[36m284[0m - [1mSampling loop start: 8 steps[0m
Denoising: 100%|██████████| 8/8 [00:04<00:00,  1.88it/s]


Pre-decode latent range: [-12.11, 9.08]
Scaling factor: 0.3611
Shift factor: 0.1159
Time taken: 4.81 seconds
