Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect annotation of shapes in Unet in lesson #7? #3

Closed
vedantroy opened this issue Jul 17, 2022 · 4 comments
Closed

Incorrect annotation of shapes in Unet in lesson #7? #3

vedantroy opened this issue Jul 17, 2022 · 4 comments

Comments

@vedantroy
Copy link

vedantroy commented Jul 17, 2022

Hi! It's me again.

I'm creating an annotated version of the UNet in lesson #7 (diffusion models). I'm adding more comments + assertions for the shapes of all inputs/outputs/weights/intermediate steps.

While doing this, I noticed there might be a mistake in some of the comments?

Here's the code that runs the UNet on dummy data (from the lesson):

# A dummy batch of 10 3-channel 32px images
x = torch.randn(10, 3, 32, 32)

# 't' - what timestep are we on
t = torch.tensor([50], dtype=torch.long)

# Define the unet model
unet = UNet()

# The foreward pass (takes both x and t)
model_output = unet(x, t)

Inside the actual UNet this is the forwad pass

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        * `x` has shape `[batch_size, in_channels, height, width]`
        * `t` has shape `[batch_size]`
        """

        # Get time-step embeddings
        t = self.time_emb(t)

It says that the shape of t is [batch_size]. But the shape of t is 1, which is to be expected if we look at the code that is testing the UNet.

Specifically, the assertion:

        batch_size = x.shape[0]
        print(t.shape)
        assert t.shape[0] == batch_size

fails.

I'm not sure exactly what's going on here. My hypothesis is as follows: The UNet is being trained on a batch of images. Each image in the batch should be accompanied by its own time step number. However, it looks like only a single time-step is being passed into the UNet.

Somewhere along the line, this time-step is being accidentally broad-casted by Pytorch to fit the batch dimension and being used as the time-step for all images.

Does that sound correct to you?

@vedantroy
Copy link
Author

I think the accidental broadcast is happening here (code is messy b/c I added my own annotations, but it's inside of ResidualBlock):

        batch_size = 1

        # First convolution layer
        h = self.conv1(self.act1(self.norm1(x)))

        time_emb = self.time_emb(t)
        assert t.shape == (batch_size, self.time_channels)
        assert time_emb.shape == (batch_size, self.out_channels)
        time_emb = time_emb[:, :, None, None]
        assert time_emb.shape == (batch_size, self.out_channels, 1, 1)
        # This looks like:
        # [ [[a]], [[b]], [[c]], [[d]], [[e]], [[f]] ]
        # when self.out_channels = 6

@johnowhitaker
Copy link
Owner

In general, you'd have a separate random t for each image (so it would be shape [batch_size]). But for the demo and during sampling it's the same t for the whole batch, so it's convenient to also accept a single value. An alternative would be to force the right shape (as you're doing with the assert) and tweak the sampling code to pass in t as a tensor of shape [batch_size] instead of [1].
During training, t has the shape described:
t = torch.randint(0, n_steps, (batch_size,), dtype=torch.long).cuda()

@johnowhitaker
Copy link
Owner

(I could be mistaken on this, will take a look at the code in more depth when I have a bit more time)

@vedantroy
Copy link
Author

vedantroy commented Jul 18, 2022

Got it! This make sense. You allow the NN to accept a single value of t, which makes the assumption that all images are at the same timestep. This makes it more convenient to use. Sounds good!

No need to look through the code, don't want to waste your time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants