Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditioning on y_cond #29

Closed
PouriaRouzrokh opened this issue Jul 21, 2022 · 5 comments
Closed

Conditioning on y_cond #29

PouriaRouzrokh opened this issue Jul 21, 2022 · 5 comments

Comments

@PouriaRouzrokh
Copy link

PouriaRouzrokh commented Jul 21, 2022

Hi, thanks for the awesome codes :)

One question for the inpainting task:

Looking at the following snippet from your code in networks.py, I cannot understand why you are conditioning your model on y_cond if you are already modifying your y_noisy based on the y_0 image using the expression "y_noisy*mask+(1.-mask)*y_0"?

Shouldn't concatenating with y_cond be redundant in this case? Your model is already seeing the ground truth parts of the image in the modified version of the y_noisy.

    def forward(self, y_0, y_cond=None, mask=None, noise=None):
        # sampling from p(gammas)
        b, *_ = y_0.shape
        t = torch.randint(1, self.num_timesteps, (b,), device=y_0.device).long()
        gamma_t1 = extract(self.gammas, t-1, x_shape=(1, 1))
        sqrt_gamma_t2 = extract(self.gammas, t, x_shape=(1, 1))
        sample_gammas = (sqrt_gamma_t2-gamma_t1) * torch.rand((b, 1), device=y_0.device) + gamma_t1
        sample_gammas = sample_gammas.view(b, -1)

        noise = default(noise, lambda: torch.randn_like(y_0))
        y_noisy = self.q_sample(
            y_0=y_0, sample_gammas=sample_gammas.view(-1, 1, 1, 1), noise=noise)

        if mask is not None:
            noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy*mask+(1.-mask)*y_0], dim=1), sample_gammas)
            loss = self.loss_fn(mask*noise, mask*noise_hat)
        else:
            noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy], dim=1), sample_gammas)
            loss = self.loss_fn(noise, noise_hat)
        return loss
@Janspiry
Copy link
Owner

Hi, thanks for this great question, and I think there are two potential considerations for this:

  1. Keep the consistency between training and inference over all tasks. The model samples from random noise and y_cond in the inference stage.
  2. y_cond can distinguish between the mask and unmasked areas since y_t may not be straightforward enough when t is small.

@PouriaRouzrokh
Copy link
Author

Hi, thanks for this great question, and I think there are two potential considerations for this:

  1. Keep the consistency between training and inference over all tasks. The model samples from random noise and y_cond in the inference stage.
  2. y_cond can distinguish between the mask and unmasked areas since y_t may not be straightforward enough when t is small.

Thanks for the kind reply. This makes sense, though it is worth trying the second reason. I will post here if I realized something different.

@Janspiry
Copy link
Owner

Feel free to reopen the issue if there is any question.

@vinodrajendran001
Copy link

@PouriaRouzrokh I have opened a separate issue on it. But I am in urgent need of a solution, so I just wanted to check with you .

In my inpainting case, during the inference only the y_cond and mask images are given. In that case, may I know how to do a inference?

In the network.py script, for the inpainting task the below line will be executed as part of the restoration function. As y_0 is None for me, I am not sure how to deal with this line. If I skip the below line then the results are very bad (just only some whitish kind of image is generated). Also, in the Process.png image I can notice that for each step the noise level is increasing rather than decreasing.

if mask is not None:
    y_t = y_0*(1.-mask) + mask*y_t

Any idea on how to proceed?

@yc-cui
Copy link

yc-cui commented Jan 26, 2023

@Janspiry Why not just set the mask as y_cond? for consistency among all tasks?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants