Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: Fix p_sample input args and resample loop of Repaint method #320

Merged
62 changes: 43 additions & 19 deletions denoising_diffusion_pytorch/repaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,17 @@ def p_sample(self, x, t: int, x_self_cond = None, gt=None, mask=None):
return pred_img, x_start

@torch.inference_mode()
def p_sample_loop(self, shape, return_all_timesteps = False, gt=None, mask=None,resample = True,resample_iter = 10,resample_jump = 3):
def p_sample_loop(
self,
shape,
return_all_timesteps=False,
gt=None,
mask=None,
resample=True,
resample_iter=10,
resample_jump=3,
resample_every=50,
):
batch, device = shape[0], self.device

img = torch.randn(shape, device = device)
Expand All @@ -693,25 +703,21 @@ def p_sample_loop(self, shape, return_all_timesteps = False, gt=None, mask=None,

for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, t, self_cond, gt, mask)
imgs.append(img)

if resample is True and t == 0:
#Jump back for resample_jump timesteps and resample_iter times

for iter in tqdm(range(resample_iter), desc = 'resample loop', total = resample_iter):
t = resample_jump

beta = self.betas[t]
img = torch.sqrt(1 - beta) * img + torch.sqrt(beta) * torch.randn_like(img)
for j in reversed(range(0, resample_jump)):
img, x_start = self.p_sample(img, t,gt,mask)
img, x_start = self.p_sample(x=img, t=t, x_self_cond=self_cond, gt=gt, mask=mask)
imgs.append(img)


# Resampling loop: line 9 of Algorithm 1 in https://arxiv.org/pdf/2201.09865
if resample is True and (t > 0) and (t % resample_every == 0):
# Jump back for resample_jump timesteps and resample_iter times
for iter in tqdm(range(resample_iter), desc = 'resample loop', total = resample_iter):
t = resample_jump
beta = self.betas[t]
img = torch.sqrt(1 - beta) * img + torch.sqrt(beta) * torch.randn_like(img)
for j in reversed(range(0, resample_jump)):
img, x_start = self.p_sample(x=img, t=t, gt=gt, mask=mask)
imgs.append(img)

ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)

ret = self.unnormalize(ret)
return ret

Expand Down Expand Up @@ -758,11 +764,29 @@ def ddim_sample(self, shape, return_all_timesteps = False):
return ret

@torch.inference_mode()
def sample(self, batch_size = 16, return_all_timesteps = False, gt=None, mask=None,resample = True,resample_iter = 10,resample_jump = 10):
def sample(
self,
batch_size=16,
return_all_timesteps=False,
gt=None,
mask=None,
resample=True,
resample_iter=10,
resample_jump=10,
resample_every=50,
):
(h, w), channels = self.image_size, self.channels
batch_size = mask.shape[0] if mask is not None else batch_size
sample_fn = self.p_sample_loop
return sample_fn((batch_size, channels, h, w), return_all_timesteps = return_all_timesteps, gt=gt, mask=mask,resample=resample,resample_iter=resample_iter,resample_jump=resample_jump)
return self.p_sample_loop(
shape=(batch_size, channels, h, w),
return_all_timesteps=return_all_timesteps,
gt=gt,
mask=mask,
resample=resample,
resample_iter=resample_iter,
resample_jump=resample_jump,
resample_every=resample_every,
)

@torch.inference_mode()
def interpolate(self, x1, x2, t = None, lam = 0.5):
Expand Down