Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] Add operator support to run UNet2DConditionModel from diffusers #151

Merged
merged 22 commits into from Apr 7, 2023

Conversation

xinli-git
Copy link
Collaborator

@xinli-git xinli-git commented Mar 28, 2023

Stable diffusion leverages UNet2DConditionModel for the diffusion process.

This is a popular model and because of its size, the torch -> onnx -> hidet workflow is difficult and unnatural to work with.

This PR adds operator support required by torch.compile for UNet2DConditionModel

e.g.

device= 'cuda'
model_dtype = torch.float16

unet = (
  UNet2DConditionModel.from_pretrained(
      'CompVis/stable-diffusion-v1-4',
      subfolder="unet",
      revision="fp16",
  )
  .eval()
  .to(device)
)

hidet_model = torch.compile(unet, backend="hidet")

batch_size = 1
UNET_INPUTS_CHANNEL = 4
height = width = 512

tokinizer_max_len = 64 
embedding_hidden_size = 768

latents = torch.randn(
    (batch_size * 2, UNET_INPUTS_CHANNEL, height // 8, width // 8),
    device=device,
    dtype=model_dtype
)
t = torch.ones(1, dtype=torch.int64, device=device)
text_embedding = torch.randn(
    batch_size * 2,
    tokinizer_max_len,
    embedding_hidden_size,
    dtype=model_dtype,
    device=device,
)
inputs = (latents, t, text_embedding)

hidet_model(*inputs)

@xinli-git xinli-git changed the title [Dynamo] Add operator support to run UNet [Dynamo] Add operator support to run UNet2DConditionModel from diffusers Mar 28, 2023
Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xinli-git!

Could you also add some tests to the added operators (under tests/frontends/torch)?

And we can also consider to add some model-level tests like the tests in tests/unit_tests/test_frontend_onnx.py but for pytorch frontend. And place the tests in a test script like tests/frontends/torch/models/test_unet.py).

def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.bmm(..., out=...)")
return beta * input + alpha * ops.matmul(batch1, batch2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to check whether alpha==1 and beta==1 and do not perform the multiplication as much as possible.

Otherwise, we need to write some graph-level pattern rewrite rules to do this simplification.

@xinli-git
Copy link
Collaborator Author

Thanks Yaoyao! Will add tests shortly and let you know

@xinli-git
Copy link
Collaborator Author

Hi @AndreSlavescu can you also review this PR to see if all the stuff makes sense? I will add the tests shortly following your PR merge

@AndreSlavescu
Copy link
Contributor

AndreSlavescu commented Mar 30, 2023

Hey @xinli-git , looks good. I can also review fully when testcases are added.

  • For the Group Norm test, can you add it in /hidet/tests/operators/test_norm.py
  • For the Interpolation test, can you make a new file for vision functions called test_vision.py
    and add it under the same directory as shown above.

Once the PR is merged, please update #132 with the modules and operators supported

@xinli-git
Copy link
Collaborator Author

Hi @yaoyaoding, this PR is ready for a final review

@yaoyaoding
Copy link
Member

Thanks @xinli-git!

Looks good to me. Merge this PR now.

If you want to track the performance of stable diffusion model, you could add the model to our benchmark cli
and add a line here. The performance will be tracked at this issue.

@yaoyaoding yaoyaoding merged commit 68faaa5 into hidet-org:main Apr 7, 2023
2 checks passed
@xinli-git
Copy link
Collaborator Author

Thanks! will check that out

@xinli-git xinli-git deleted the dynamo_unet branch April 21, 2023 17:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants