Skip to content

Conversation

@devanshi00
Copy link

@devanshi00 devanshi00 commented Jan 19, 2026

What does this PR do?

This PR adds FlashPack support to Diffusers, enabling significantly faster model loading and improved inference throughput through an optimized on-disk weight format.

What is FlashPack?

FlashPack is a weight-packing format that serializes model parameters into a contiguous, GPU-friendly layout.
It reduces:

  • Disk I/O overhead during model loading
  • CPU overhead from fragmented tensor deserialization
  • Runtime memory indirection during inference

As a result, FlashPack provides:

  • Faster pipeline initialization
  • Faster end-to-end inference
  • Zero changes to sampling logic or model architecture

FlashPack is particularly beneficial for large transformer-based diffusion models, including video and multi-modal pipelines.

Architecture & Design

FlashPack is implemented as a storage-level optimization, not a runtime execution hook.

Design principles

  • Non-intrusive: No changes to model forward passes
  • Opt-in: Enabled explicitly via use_flashpack=True
  • One-time conversion: Packed weights are generated once and reused
  • Graceful fallback: Automatically falls back to standard weights if FlashPack artifacts are missing

Workflow

  1. Load a pipeline using standard Diffusers APIs
  2. Convert weights using save_pretrained(..., use_flashpack=True)
  3. Reload the pipeline with use_flashpack=True
  4. FlashPack weights are used automatically if present

Note : To push a FlashPack model to the Hugging Face Hub, users must first save with use_flashpack=True and then manually upload the resulting directory, as push_to_hub=True currently only supports standard PyTorch/safetensors weights.

Benchmark Results

Model Load Time Comparison

Model Size (GB) Standard Load (s) FlashPack Load (s) Speedup
Wan2.1 1.3B DiT 2.64 1.181 0.109 10.81×
FLUX.1 [dev] 12B DiT 22.17 3.738 0.092 40.64×
Wan2.1 14B DiT 26.61 4.105 0.125 32.84×
Qwen-Image-Edit 20B DiT 38.05 5.988 0.166 36.09×

Effective Weight Loading Throughput & FlashPack Conversion Cost

Model Standard Load (GB/s) FlashPack Load (GB/s) FlashPack Conversion Time (s)
Wan2.1 1.3B DiT 2.43 27.99 8.21
FLUX.1 [dev] 12B DiT 5.94 241.07 77.44
Wan2.1 14B DiT 6.52 213.27 81.85
Qwen-Image-Edit 20B DiT 6.36 229.78 129.59
benchmark_comparison benchmark_comparison_final_2

Note: All timings use direct GPU loading (device_map="cuda" or equivalent) to ensure a fair comparison between standard loading and FlashPack.

Benchmark Setup

  • Hardware: NVIDIA A100
  • Precision: bfloat16
  • Framework: Diffusers with FlashPack enabled
  • Measurement: End-to-end model load time (disk → GPU)
  • Reproducibility: Same environment and configuration across runs

Usage

One-time FlashPack conversion

import torch
from diffusers import WanPipeline

pipe = WanPipeline.from_pretrained(
    "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
    torch_dtype=torch.bfloat16,
)
pipe.to("cuda")

# One-time conversion to FlashPack format
pipe.save_pretrained(
    "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
    use_flashpack=True,
)

This produces the FlashPack-packed weights alongside the regular weights in the repo.

Loading with FlashPack

pipe = WanPipeline.from_pretrained(
    "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
    torch_dtype=torch.bfloat16,
    use_flashpack=True,  # Enable FlashPack loading
)
pipe.to("cuda")

video = pipe(
    prompt="A beautiful sunset over a calm ocean",
    width=832,
    height=480,
    num_inference_steps=16,
).videos[0]

FlashPack load/inference falls back gracefully to standard weights if FlashPack output does not exist.

Inference-Time Results (FlashPack)

Model Standard Inference (s) FlashPack Inference (s)
Wan2.1-T2V-1.3B-Diffusers 65.527 64.326

FlashPack optimizes model loading, not transformer compute. Inference speedups are expected to be minimal.

FIxes #(issue)
#12550

Before submitting

Who can review?

@sayakpaul @yiyixuxu @DN6

@devanshi00 devanshi00 mentioned this pull request Jan 19, 2026
@devanshi00 devanshi00 changed the title added fal-flashpack support [feat] added fal-flashpack support Jan 19, 2026
@sayakpaul
Copy link
Member

Thanks for the comprehensive PR. To benchmark model loading, are we reporting just the pipeline loading time or just the denoiser?

Could we benchmark just the denoiser loading? Also, during the standard load, let's specify device_map="cuda" that should provide faster loading.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting this. I left some questions.

@sayakpaul sayakpaul requested a review from DN6 January 20, 2026 07:20
@devanshi00
Copy link
Author

devanshi00 commented Jan 20, 2026

Thanks for the comprehensive PR. To benchmark model loading, are we reporting just the pipeline loading time or just the denoiser?

Could we benchmark just the denoiser loading? Also, during the standard load, let's specify device_map="cuda" that should provide faster loading.

For benchmarking model loading, I used the official benchmarking script from the orignal flashpack repository, which measures only the diffusion transformer (DiT / denoiser) loading time, not the full pipeline. The results reported in the table above therefore correspond exclusively to denoiser loading for the respective Diffusers models.

Here is the script

import csv
import gc
import os
import shutil
import tempfile
import time
import torch

from huggingface_hub import snapshot_download
from diffusers.models import AutoModel as DiffusersAutoModel
def test_model(
    repo_id: str,
    subfolder: str | None = None,
    accelerate_device: str | torch.device = "cuda",
    flashpack_device: str | torch.device = "cuda",
    dtype: torch.dtype | None = None,
    use_transformers: bool = False,
) -> tuple[float, float, int]:
    """
    Test a model from a repository.
    """
    repo_dir = snapshot_download(
        repo_id, allow_patterns=None if subfolder is None else [f"{subfolder}/*"]
    )
    model_dir = repo_dir if subfolder is None else os.path.join(repo_dir, subfolder)
    saved_flashpack_path = os.path.join(model_dir, "model.flashpack")
    saved_flashpack_config_path = os.path.join(model_dir, "flashpack_config.json")
    
    with tempfile.TemporaryDirectory() as tmpdir:
        # Make a new model directory with the model in it so it isn't cached
        temp_model_dir = os.path.join(tmpdir, "model")
        flashpack_dir = os.path.join(tmpdir, "flashpack")
        os.makedirs(flashpack_dir, exist_ok=True)
        print("Copying model to temporary directory")
        shutil.copytree(model_dir, temp_model_dir)
        # Load from the temporary model directory
        print("Loading model from temporary directory using from_pretrained")
        start_time = time.time()
        model = DiffusersAutoModel.from_pretrained(
            temp_model_dir,
            torch_dtype=dtype,
            device_map={"": accelerate_device},
        )

        end_time = time.time()
        accelerate_time = end_time - start_time
        print(f"Time taken with from_pretrained: {accelerate_time} seconds")

        if os.path.exists(saved_flashpack_path) and os.path.exists(
            saved_flashpack_config_path
        ):
            print("Copying flashpack to temporary directory")
            shutil.copy(
                saved_flashpack_path, os.path.join(flashpack_dir, "model.flashpack")
            )
            shutil.copy(
                saved_flashpack_config_path, os.path.join(flashpack_dir, "config.json")
            )
        else:
            print("Packing model to flashpack")
            pack_start_time = time.time()
            model.save_pretrained(
                flashpack_dir, target_dtype=dtype, use_flashpack=True
            )
            pack_end_time = time.time()
            print(
                f"Time taken for flashpack packing: {pack_end_time - pack_start_time} seconds"
            )
            # Copy back to the original model directory
            shutil.copy(
                os.path.join(flashpack_dir, "model.flashpack"), saved_flashpack_path
            )
            shutil.copy(
                os.path.join(flashpack_dir, "config.json"), saved_flashpack_config_path
            )

        del model
        sync_and_flush()

        print("Loading model from flashpack directory using from_pretrained_flashpack")
        flashpack_start_time = time.time()

        flashpack_model = DiffusersAutoModel.from_pretrained(
                flashpack_dir, device=flashpack_device, target_dtype=dtype,use_flashpack=True
            )

        flashpack_end_time = time.time()
        flashpack_time = flashpack_end_time - flashpack_start_time
        print(f"Time taken with flashpack loading: {flashpack_time} seconds")

        total_numel = 0
        for param in flashpack_model.parameters():
            total_numel += param.numel()

        total_bytes = total_numel * dtype.itemsize

        del flashpack_model
        sync_and_flush()

        return accelerate_time, flashpack_time, total_bytes


def test_wan_small_transformer() -> tuple[float, float, int]:
    return test_model(
        repo_id="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
        subfolder="transformer",
        accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
        flashpack_device="cuda:1" if torch.cuda.is_available() else "cpu",
        dtype=torch.bfloat16,
    )


def test_wan_large_transformer() -> tuple[float, float, int]:
    return test_model(
        repo_id="Wan-AI/Wan2.1-T2V-14B-Diffusers",
        subfolder="transformer",
        accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
        flashpack_device="cuda:1" if torch.cuda.is_available() else "cpu",
        dtype=torch.bfloat16,
    )


def test_flux_transformer() -> tuple[float, float, int]:
    return test_model(
        repo_id="black-forest-labs/FLUX.1-dev",
        subfolder="transformer",
        accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
        flashpack_device="cuda:1" if torch.cuda.is_available() else "cpu",
        dtype=torch.bfloat16,
    )


def test_qwen_transformer() -> tuple[float, float, int]:
    return test_model(
        repo_id="Qwen/Qwen-Image-Edit",
        subfolder="transformer",
        accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
        flashpack_device="cuda:1" if torch.cuda.is_available() else "cpu",
        dtype=torch.bfloat16,
        use_transformers=True,
    )


def print_test_result(
    model_name: str,
    accelerate_time: float,
    flashpack_time: float,
    total_bytes: int,
) -> None:
    print(f"{model_name}: Accelerate time: {accelerate_time} seconds")
    print(f"{model_name}: Flashpack time: {flashpack_time} seconds")
    accelerate_gbps = (total_bytes / 1000**3) / accelerate_time
    flashpack_gbps = (total_bytes / 1000**3) / flashpack_time
    print(f"{model_name}: Accelerate GB/s: {accelerate_gbps} GB/s")
    print(f"{model_name}: Flashpack GB/s: {flashpack_gbps} GB/s")


def sync_and_flush() -> None:
    torch.cuda.empty_cache()
    gc.collect()
    os.system("sync")
    if os.geteuid() == 0:
        os.system("echo 3 | tee /proc/sys/vm/drop_caches")


if __name__ == "__main__":
    with open("benchmark_results_finall.csv", "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(
            ["model", "accelerate_time", "flashpack_time", "total_overhead_time", "total_bytes"]
        )

        for i in range(10):
            for test_model_name, test_func in [
                ("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", test_wan_small_transformer),
                ("Wan-AI/Wan2.1-T2V-14B-Diffusers", test_wan_large_transformer),
                ("black-forest-labs/FLUX.1-dev", test_flux_transformer),
                ("Qwen/Qwen-Image-Edit", test_qwen_transformer),
            ]:
                accelerate_time, flashpack_time, total_bytes = test_func()
                # ---- main benchmark CSV (unchanged, written every run) ----
                writer.writerow(
                    [
                        test_model_name,
                        accelerate_time,
                        flashpack_time,
                        total_bytes,
                    ]
                )

                print_test_result(
                    test_model_name,
                    accelerate_time,
                    flashpack_time,
                    total_bytes,
                )

Regarding FlashPack overhead, the total overhead time is the sum of:
FlashPack conversion time (one-time cost), and
FlashPack load time (per load)
These values are approximately as follows, based on the measurements reported earlier in the PR:

Model Total FlashPack Overhead Time (s)
Wan2.1 1.3B DiT 8.319
FLUX.1 [dev] 12B DiT 77.532
Wan2.1 14B DiT 81.975
Qwen-Image-Edit 20B DiT 129.756

Note: All timings use direct GPU loading (device_map="cuda" or equivalent) to ensure a fair comparison between standard loading and FlashPack.

@sayakpaul
Copy link
Member

Thanks, so it should be `device_map="cuda".

These values are approximately as follows, based on the measurements reported earlier in the PR:

The table doesn't make it clear the time with from_pretrained() when device_map="cuda" is used.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some further comments.

except EnvironmentError:
resolved_model_file = None
with no_init_weights():
model = cls.from_config(config, **unused_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to initialize the model here? It is already done here:

model = cls.from_config(config, **unused_kwargs)

IMO, the flashpack weight name resolving related code can be around the following block:

if resolved_model_file is None and not is_sharded:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)

And then the rest of the code can follow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve changed the load_pretrained() function and removed the redundant initialization. Happy to clarify or discuss further if you have any questions.

@sayakpaul
Copy link
Member

Also, WDYT of running a conversion on the go from bin / safetensors -> flashpack when a user requests to load in flashpack for a non-flashpack checkpoint?

@devanshi00
Copy link
Author

Thanks, so it should be `device_map="cuda".

These values are approximately as follows, based on the measurements reported earlier in the PR:

The table doesn't make it clear the time with from_pretrained() when device_map="cuda" is used.

I hope it is clear now.

@devanshi00
Copy link
Author

Also, WDYT of running a conversion on the go from bin / safetensors -> flashpack when a user requests to load in flashpack for a non-flashpack checkpoint?

I don’t think automatic conversion should be enabled by default. FlashPack conversion can take a long time for large models, so triggering it automatically when use_flashpack=True would result in unexpected delays in first run. It may also silently create files on disk or modify the local cache, which can be confusing for users. Keeping conversion explicit avoids these surprises and makes behavior easier to understand.
The current behavior should remain the default, with no automatic conversion. If needed, this can be offered as an explicit opt-in (for example, flashpack_auto_convert=True) or via a separate helper utility to generate FlashPack files ahead of time. This keeps loading predictable while still giving advanced users more control.

@sayakpaul
Copy link
Member

I don’t think automatic conversion should be enabled by default. FlashPack conversion can take a long time for large models, so triggering it automatically when use_flashpack=True would result in unexpected delays in first run. It may also silently create files on disk or modify the local cache, which can be confusing for users. Keeping conversion explicit avoids these surprises and makes behavior easier to understand.
The current behavior should remain the default, with no automatic conversion. If needed, this can be offered as an explicit opt-in (for example, flashpack_auto_convert=True) or via a separate helper utility to generate FlashPack files ahead of time. This keeps loading predictable while still giving advanced users more control.

This makes sense. We can make all of this explicitly documented. Thanks for the context!

I hope it is clear now.

It is not. I would have expected to see a table with the following columns:

  • Model checkpoint id
  • Flashpack loading time
  • from_pretrained(device_map="cuda") timing

LMK if this is still unclear.

@devanshi00
Copy link
Author

devanshi00 commented Jan 21, 2026

It is not. I would have expected to see a table with the following columns:

  • Model checkpoint id
  • Flashpack loading time
  • from_pretrained(device_map="cuda") timing

LMK if this is still unclear.

So according to you, the benchmarking script is okay ? I use

 model = DiffusersAutoModel.from_pretrained(
            temp_model_dir,
            torch_dtype=dtype,
            device_map={"": accelerate_device},
        )

For the standard loading time . I just need to change the csv header right?

This makes sense. We can make all of this explicitly documented. Thanks for the context!

So do we need an explicit opt-in (for example, flashpack_auto_convert=True) or not?

@sayakpaul
Copy link
Member

sayakpaul commented Jan 21, 2026

So do we need an explicit opt-in (for example, flashpack_auto_convert=True) or not?

I think we can skip that for now and advise that users run the conversion with use_flashpack=True during save_pretrained().

For the standard loading time . I just need to change the csv header right?

I only see two columns in the table provided in #12999 (comment). It has the flashpack timing but not the from_pretrained() with device_map timing unless I am missing out on something obvious.

@devanshi00
Copy link
Author

I only see two columns in the table provided in #12999 (comment). It has the flashpack timing but not the from_pretrained() with device_map timing unless I am missing out on something obvious.

Sorry for the confusion. But I m refering to the tables in the first comment where Standard load is equivalent to from_pretrained() with device_map cuda. I had uploaded those results with the same benchmarking script I mentioned afterwards.
#12999 (comment)

@sayakpaul
Copy link
Member

I ran the benchmark script myself on an H100 machine (with this PR branch). Here are my findings and I am awestruck!

model avg_accelerate_time std_accelerate_time avg_flashpack_time std_flashpack_time total_bytes
Qwen/Qwen-Image-Edit 20.3753 0.7633 2.7582 0.8035 40860802176
Wan-AI/Wan2.1-T2V-1.3B-Diffusers 7.0213 10.5124 0.263 0.0612 2837993600
Wan-AI/Wan2.1-T2V-14B-Diffusers 26.9982 0.8495 1.9878 0.5714 28576983168
black-forest-labs/FLUX.1-dev 12.3575 0.6743 1.6569 0.4736 23802816640

Comment on lines +1333 to +1334
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
model.eval()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should these have to be conditioned under if flashpack_file is not None?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current early-return behavior, the FlashPack code path is intentionally isolated from the standard loading flow. After instantiating the model and loading weights via FlashPack, we finalize the model (registering _name_or_path and switching to eval mode) and return immediately.
As a result, the model deliberately does not go through the usual checkpoint-loading logic (load_state_dict, _load_pretrained_model), device-map dispatch, sharding handling, quantizer post-processing, or the common end-of-function finalization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me look into it and see if I can help with further simplifying.

Comment on lines +1336 to +1344
if output_loading_info:
return model, {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}

return model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block exists to preserve the from_pretrained(..., output_loading_info=True) API contract for the FlashPack loading path. Since FlashPack bypasses _load_pretrained_model, none of the usual key-matching or error-collection logic is executed, so there is no meaningful loading diagnostics to report. Returning an empty but well-formed loading_info dict keeps the return type consistent with the standard loading path without implying any missing or mismatched keys were checked.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants