Skip to content

Commit

Permalink
add self conditioning for elucidated ddpm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 10, 2022
1 parent f0d59ac commit beb2f2d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
31 changes: 25 additions & 6 deletions denoising_diffusion_pytorch/elucidated_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from math import sqrt
from random import random
import torch
from torch import nn, einsum
import torch.nn.functional as F
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(
):
super().__init__()
assert net.learned_sinusoidal_cond
assert not net.self_condition, 'not supported yet'
self.self_condition = net.self_condition

self.net = net

Expand Down Expand Up @@ -100,7 +101,7 @@ def c_noise(self, sigma):
# preconditioned network output
# equation (7) in the paper

def preconditioned_network_forward(self, noised_images, sigma, clamp = False):
def preconditioned_network_forward(self, noised_images, sigma, self_cond = None, clamp = False):
batch, device = noised_images.shape[0], noised_images.device

if isinstance(sigma, float):
Expand All @@ -110,7 +111,8 @@ def preconditioned_network_forward(self, noised_images, sigma, clamp = False):

net_out = self.net(
self.c_in(padded_sigma) * noised_images,
self.c_noise(sigma)
self.c_noise(sigma),
self_cond
)

out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out
Expand Down Expand Up @@ -161,6 +163,10 @@ def sample(self, batch_size = 16, num_sample_steps = None, clamp = True):

images = init_sigma * torch.randn(shape, device = self.device)

# for self conditioning

x_start = None

# gradually denoise

for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc = 'sampling time step'):
Expand All @@ -171,19 +177,24 @@ def sample(self, batch_size = 16, num_sample_steps = None, clamp = True):
sigma_hat = sigma + gamma * sigma
images_hat = images + sqrt(sigma_hat ** 2 - sigma ** 2) * eps

model_output = self.preconditioned_network_forward(images_hat, sigma_hat, clamp = clamp)
self_cond = x_start if self.self_condition else None

model_output = self.preconditioned_network_forward(images_hat, sigma_hat, self_cond, clamp = clamp)
denoised_over_sigma = (images_hat - model_output) / sigma_hat

images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma

# second order correction, if not the last timestep

if sigma_next != 0:
model_output_next = self.preconditioned_network_forward(images_next, sigma_next, clamp = clamp)
self_cond = model_output if self.self_condition else None

model_output_next = self.preconditioned_network_forward(images_next, sigma_next, self_cond, clamp = clamp)
denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)

images = images_next
x_start = model_output

images = images.clamp(-1., 1.)
return unnormalize_to_zero_to_one(images)
Expand Down Expand Up @@ -211,7 +222,15 @@ def forward(self, images):

noised_images = images + padded_sigmas * noise # alphas are 1. in the paper

denoised = self.preconditioned_network_forward(noised_images, sigmas)
self_cond = None

if self.self_condition and random() < 0.5:
# from hinton's group's bit diffusion paper
with torch.no_grad():
self_cond = self.preconditioned_network_forward(noised_images, sigmas)
self_cond.detach_()

denoised = self.preconditioned_network_forward(noised_images, sigmas, self_cond)

losses = F.mse_loss(denoised, images, reduction = 'none')
losses = reduce(losses, 'b ... -> b', 'mean')
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.27.1',
version = '0.27.2',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit beb2f2d

Please sign in to comment.