New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorrect annotation of shapes in Unet in lesson #7? #3
Comments
I think the accidental broadcast is happening here (code is messy b/c I added my own annotations, but it's inside of
|
In general, you'd have a separate random t for each image (so it would be shape [batch_size]). But for the demo and during sampling it's the same t for the whole batch, so it's convenient to also accept a single value. An alternative would be to force the right shape (as you're doing with the assert) and tweak the sampling code to pass in t as a tensor of shape [batch_size] instead of [1]. |
(I could be mistaken on this, will take a look at the code in more depth when I have a bit more time) |
Got it! This make sense. You allow the NN to accept a single value of No need to look through the code, don't want to waste your time! |
Hi! It's me again.
I'm creating an annotated version of the UNet in lesson #7 (diffusion models). I'm adding more comments + assertions for the shapes of all inputs/outputs/weights/intermediate steps.
While doing this, I noticed there might be a mistake in some of the comments?
Here's the code that runs the UNet on dummy data (from the lesson):
Inside the actual UNet this is the forwad pass
It says that the shape of
t
is[batch_size]
. But the shape oft
is 1, which is to be expected if we look at the code that is testing the UNet.Specifically, the assertion:
fails.
I'm not sure exactly what's going on here. My hypothesis is as follows: The UNet is being trained on a batch of images. Each image in the batch should be accompanied by its own time step number. However, it looks like only a single time-step is being passed into the UNet.
Somewhere along the line, this time-step is being accidentally broad-casted by Pytorch to fit the batch dimension and being used as the time-step for all images.
Does that sound correct to you?
The text was updated successfully, but these errors were encountered: