Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running Speed is Slower for SDXL Model #16

Closed
alecyan1993 opened this issue Nov 6, 2023 · 11 comments
Closed

Running Speed is Slower for SDXL Model #16

alecyan1993 opened this issue Nov 6, 2023 · 11 comments
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@alecyan1993
Copy link

alecyan1993 commented Nov 6, 2023

Hi, another issue that I found is it's not accelerating SDXL. I'm running the demo with A100, the speed of SDXL for the compiled model is 5.3 it/s but with normal diffusers it's 8.8 it/s. The compiled one with stable-fast is slower.

@chengzeyi
Copy link
Owner

I have observed same performance regression when tested on my PC.
Initially I thought it was caused by insufficient GPU VRAM.
But as A100 has a relatively large VARM, it should be caused by other restrictions or bugs.

@chengzeyi chengzeyi added the bug Something isn't working label Nov 6, 2023
@chengzeyi
Copy link
Owner

Hi, another issue that I found is it's not accelerating SDXL. I'm running the demo with A100, the speed of SDXL for the compiled model is 5.3 it/s but with normal diffusers it's 8.8 it/s. The compiled one with stable-fast is slower.

I still think it's because of insufficient vram. Could you please share more info about your system and inference configuration? I want to know the peak vram utilization during inference and your image resolution.

@chengzeyi
Copy link
Owner

Do you happen to run SDXL on WSL or on Windows or other operating systems that support shared VRAM?

I think I have found the reason: On systems that have shared VRAM functionality support, NVIDIA drivers choose to dispatch memory allocation requests to shared VRAM instead of throwing an OOM error when GPU VRAM is insufficient because the model is too large or the resolution is too high, or some leaks caused by PyTorch make the previous allocated memory cannot be released.

And shared VRAM is just indeed the system memory that your computer has. It is thousands of times slower than the dedicated VRAM on board, resulting in slow performance of the inference, even if only a few layers and intermediate buffers are put into shared VRAM.

@alecyan1993
Copy link
Author

Hi,

Thanks so much for your reply. I'm running on Ubuntu 20.04 within the docker container with --gpu-all when running the docker image. The GPU that I used is A100 40G so should be enough VRAM for running SDXL model.

@chengzeyi
Copy link
Owner

Hi,

Thanks so much for your reply. I'm running on Ubuntu 20.04 within the docker container with --gpu-all when running the docker image. The GPU that I used is A100 40G so should be enough VRAM for running SDXL model.

That's really weird. On my system I can make 90% sure that this problem should be caused by the VRAM offloading mechanism of the NVIDIA driver on Windows. But I don't have GPUs with large VRAM like A100 to test on, so it is hard to debug.

@chengzeyi chengzeyi reopened this Nov 9, 2023
@chengzeyi chengzeyi added the help wanted Extra attention is needed label Nov 9, 2023
@alecyan1993
Copy link
Author

Do you have any debugging script so I can have some tests on my instance?

@chengzeyi
Copy link
Owner

chengzeyi commented Nov 9, 2023

Do you have any debugging script so I can have some tests on my instance?

The following script should work.
Detailed performance analysis result could be exported by nsight-system.

import torch
from diffusers import (StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler)
from sfast.compilers.stable_diffusion_pipeline_compiler import (compile,
                                                                CompilationConfig
                                                                )

def load_model():
    # NOTE:
    # You could change to StableDiffusionXLPipeline to load SDXL model.
    # If the resolution is high (1024x1024),
    # ensure you VRAM is sufficient (or RAM? I'm not sure, maybe I should upgrade my PC).
    # Or the performance might regress.
    model = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)

    model.scheduler = EulerAncestralDiscreteScheduler.from_config(
        model.scheduler.config)
    model.safety_checker = None
    model.to(torch.device('cuda'))
    return model

model = load_model()

config = CompilationConfig.Default()

# xformers and Triton are suggested for achieving best performance.
# It might be slow for Triton to generate, compile and fine-tune kernels.
try:
    import xformers
    config.enable_xformers = True
except ImportError:
    print('xformers not installed, skip')
# NOTE:
# When GPU VRAM is insufficient or the architecture is too old, Triton might be slow.
# Disable Triton if you encounter this problem.
try:
    import triton
    config.enable_triton = True
except ImportError:
    print('Triton not installed, skip')
# NOTE:
# CUDA Graph is suggested for small batch sizes and small resolutions to reduce CPU overhead.
# My implementation can handle dynamic shape with increased need for GPU memory.
# But when your GPU VRAM is insufficient or the image resolution is high,
# CUDA Graph could cause less efficient VRAM utilization and slow down the inference.
# If you meet problems related to it, you should disable it.
config.enable_cuda_graph = True

compiled_model = compile(model, config)

kwarg_inputs = dict(
    prompt=
    '(masterpiece:1,2), best quality, masterpiece, best detail face, lineart, monochrome, a beautiful girl',
    # NOTE: If you use SDXL, you should use a higher resolution to improve the generation quality.
    height=1024,
    width=1024,
    num_inference_steps=30,
    num_images_per_prompt=1,
)

# NOTE: Warm it up.
# The first call will trigger compilation and might be very slow.
# After the first call, it should be very fast.
output_image = compiled_model(**kwarg_inputs).images[0]

# Let's see the second call!
output_image = compiled_model(**kwarg_inputs).images[0]

@chengzeyi
Copy link
Owner

Do you have any debugging script so I can have some tests on my instance?

Information of current PyTorch environment can be collected as below.

python -m torch.utils.collect_env

@SuperSecureHuman
Copy link

In my case, it was close to 10 to 12 it/s (30 steps)

A100 is already fast enough to see any much compilation improvements (?)

Stock SDXL - 3.87 Sec
Manual Torch Compile - 3.43 Sec
Compiled Fast SDXL - 3.3 Sec

I am not putting iterations per second numbers, because they vary too much. It starts off at really really high, and ends up low.

But I notice that, the initial set of iterations go faster than stock.

@chengzeyi
Copy link
Owner

In my case, it was close to 10 to 12 it/s (30 steps)

A100 is already fast enough to see any much compilation improvements (?)

Stock SDXL - 3.87 Sec Manual Torch Compile - 3.43 Sec Compiled Fast SDXL - 3.3 Sec

I am not putting iterations per second numbers, because they vary too much. It starts off at really really high, and ends up low.

But I notice that, the initial set of iterations go faster than stock.

The printed table result could be incorrect due to some mysterious bug of Python's cProfile (maybe?) and could cause a relative high CPU overhead. I don't know how to solve it. Maybe a plain time.time() should be a good replacement?

@chengzeyi
Copy link
Owner

chengzeyi commented Nov 22, 2023

A100 80GB could have very impressive speed. About six months ago, I could achieve a generation speed of 61.8 it/s on A100. However, to achieve this I need to use a modified version of scheduler to reduce CPU overhead and this is conflict with my wish that users could use any scheduler that they want. So I have to sacrifice and we also disable some further optimizations now to let users be able to switch LoRA dynamically.

Triton autotune is another important technique to make the kernel run faster. But as so many people want the compilation to be fast it is also replaced by heuristics.😂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants