Skip to content

Cannot build stable diffusion model: "BackendCompilerFailed: backend='_capture' raised AssertionError" #24

@loicmagne

Description

@loicmagne

I tried building the stable diffusion model using the walkthrough.ipynb notebook or the build.py file, but when I run the "Combine every piece together" part :

from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
clip = clip_to_text_embeddings(pipe)
unet = unet_latents_to_noise_pred(pipe, torch_dev_key)
vae = vae_to_image(pipe)
concat_embeddings = concat_embeddings()
image_to_rgba = image_to_rgba()
schedulers = [
    dpm_solver_multistep_scheduler_steps(),
    trace.PNDMScheduler.scheduler_steps()
]

mod: tvm.IRModule = utils.merge_irmodules(
    clip,
    unet,
    vae,
    concat_embeddings,
    image_to_rgba,
    *schedulers,
)

Both results in the same error:

/usr/local/lib/python3.10/dist-packages/torch/__init__.py:1565 in __call__                       │
│                                                                                                  │
│   1562 │   │   │   │   self.dynamic == other.dynamic)                                            │
│   1563 │                                                                                         │
│   1564def __call__(self, model_, inputs_):                                                  │
│ ❱ 1565 │   │   return self.compiler_fn(model_, inputs_, **self.kwargs)                           │
│   1566                                                                                           │
│   1567                                                                                           │
│   1568 def compile(model: Optional[Callable] = None, *,                                          │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/tvm/relax/frontend/torch/dynamo.py:151 in _capture       │
│                                                                                                  │
│   148def _capture(graph_module: fx.GraphModule, example_inputs):                            │
│   149 │   │   assert isinstance(graph_module, torch.fx.GraphModule)                              │
│   150 │   │   input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inp   │
│ ❱ 151 │   │   mod_ = from_fx(                                                                    │
│   152 │   │   │   graph_module,                                                                  │
│   153 │   │   │   input_info,                                                                    │
│   154 │   │   │   keep_params_as_input=keep_params_as_input,                                     │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/tvm/relax/frontend/torch/fx_translator.py:1387 in        │
│ from_fx                                                                                          │
│                                                                                                  │
│   1384to print out the tabular representation of the PyTorch module, and then               │
│   1385check the placeholder rows in the beginning of the tabular.                           │
│   1386 │   """                                                                                   │
│ ❱ 1387 │   return TorchFXImporter().from_fx(                                                     │
│   1388 │   │   model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_retur  │
│   1389 │   )                                                                                     │
│   1390                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/tvm/relax/frontend/torch/fx_translator.py:1282 in        │
│ from_fx                                                                                          │
│                                                                                                  │
│   1279 │   │   │   │   │   │   self.env[node] = self.convert_map[node.target](node)              │
│   1280 │   │   │   │   │   else:                                                                 │
│   1281 │   │   │   │   │   │   raise ValueError(f"Unsupported op {node.op}")                     │
│ ❱ 1282 │   │   │   assert output is not None                                                     │
│   1283 │   │   │   self.block_builder.emit_func_output(output)                                   │
│   1284 │   │                                                                                     │
│   1285 │   │   mod = self.block_builder.get()                                                    │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
BackendCompilerFailed: backend='_capture' raised:
AssertionError: 


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

It seems there is a problem with TorchDynamo

Also a somewhat unrelated error, but I couldn't get to install the CUDA version of the mlc/tvm package:

!python3 -m pip install mlc-ai-nightly-cu116 -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
ERROR: Could not find a version that satisfies the requirement mlc-ai-nightly-cu116 (from versions: none)
ERROR: No matching distribution found for mlc-ai-nightly-cu116

Both errors can be reproduced by running the notebook on google colab

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions