Skip to content

The bug occurs when using torch.compile on StableVideoDiffusionPipeline, and it happens when passing different images for the second time. #10317

@ZHJ19970917

Description

@ZHJ19970917

Describe the bug

I created a page using Gradio to generate videos with the StableVideoDiffusionPipeline, and I used torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True) for acceleration. I noticed that after inference with StableVideoDiffusionPipeline, the GPU memory usage increases from 4.8GB to 16GB, sometimes even up to 35GB. The first time I upload an image, it runs normally, but when I upload a different image for the second time, an error occurs. The error is as follows:

Traceback (most recent call last):
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/gradio/queueing.py", line 622, in process_events
    response = await route_utils.call_process_api(
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/gradio/route_utils.py", line 323, in call_process_api
    output = await app.get_blocks().process_api(
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/gradio/blocks.py", line 2014, in process_api
    result = await self.call_function(
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/gradio/blocks.py", line 1567, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/anyio/to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
    return await future
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 807, in run
    result = context.run(func, *args)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/gradio/utils.py", line 846, in wrapper
    response = f(*args, **kwargs)
  File "/u01/SdProject/app/gradio_ui/gen_video.py", line 38, in generate_video_from_image
    frames = stable_video(image, pipeline, generator)
  File "/u01/SdProject/app/service/sd_inference.py", line 226, in stable_video
    frames = pipeline(image, decode_chunk_size=8, generator=generator).frames[0]
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py", line 576, in __call__
    noise_pred = self.unet(
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/diffusers/models/unets/unet_spatio_temporal_condition.py", line 357, in forward
    def forward(
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1100, in forward
    return compiled_fn(full_args)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 321, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 667, in inner_fn
    outs = compiled_fn(args)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 488, in wrapper
    return compiled_fn(runtime_args)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1478, in __call__
    return self.current_callable(inputs)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1008, in run
    return compiled_fn(new_inputs)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 398, in deferred_cudagraphify
    fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 420, in cudagraphify
    manager = get_container(device_index).get_tree_manager()
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 341, in get_container
    container_dict = get_obj(local, "tree_manager_containers")
  File "/home/self/anaconda3/envs/Diffusers/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 336, in get_obj
    assert torch._C._is_key_in_tls(attr_name)
AssertionError

Reproduction

def generate_video_path() -> str:
    timestamp = int(time.time())

    video_dir = settings.VIDEO_DIR
    if not os.path.exists(video_dir):
        os.makedirs(video_dir)

    video_path = os.path.join(video_dir, f"generated_video_{timestamp}.mp4")

    return video_path

 def generate_random_seed():
        """Generate a random seed."""
        return random.randint(0, 2 ** 32 - 1)

def stable_video(image, pipeline, generator):

    image = image.resize((1024, 576))
    frames = pipeline(image, decode_chunk_size=8, generator=generator).frames[0]
    return frames


def generate_video_from_image(image, fps, seed):
    # pipeline = model_loader.pipeline
    model_path = settings.STABLE_VIDEO_DIFFUSION_DIR
    pipeline = StableVideoDiffusionPipeline.from_pretrained(
        model_path, torch_dtype=torch.float16, variant="fp16"
    ).to('cuda')

    pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
    generator = torch.Generator(device=pipeline.device)
    generator.manual_seed(seed)
    frames = stable_video(image, pipeline, generator)
    video_path = generate_video_path()
    video_path = export_to_video(frames, video_path, fps)
    return video_path


def make_video():
    with gr.Block():
        with gr.Row():
            with gr.Accordion("stable-video", open=False):
                with gr.Row():
                    with gr.Column():
                        with gr.Row():
                            seed_input = gr.Number(value=generate_random_seed(), show_label=False)
                        generate_seed_button = gr.Button("🎲 Generate Random Seed")
                        generate_seed_button.click(fn=generate_random_seed, outputs=seed_input)
                        with gr.Row():
                            fps = gr.Slider(label="Frames per second", minimum=7, maximum=25, step=1, value=10)
                        steps_slider = gr.Slider(label="Inference Steps", minimum=0, maximum=100, value=28,
                                                 info="The number of steps the denoiser removes image noise")

                    with gr.Column():
                        video_output = gr.Video(label='Video Output', width='600px', height='200px')  # Video output
                        img_input = gr.Image(type="pil", width='600px', height='200px')  # Image input

                generate_button = gr.Button("Generate Video")
                generate_button.click(generate_video_from_image,
                                      inputs=[img_input, fps, seed_input],
                                      outputs=video_output)

# Launch the interface
make_video().launch()

Logs

No response

System Info

GPU A5880 48g
python 3.10
diffusers 0.32
torch 2.5.1

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions