Skip to content

Commit

Permalink
Add CompVisVDenoiser wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed Nov 23, 2022
1 parent 686dbad commit 4314f91
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions k_diffusion/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,42 @@ def __init__(self, model, quantize=False, device='cpu'):

def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)


class DiscreteVDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output v."""

def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.

def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in

def get_v(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)

def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)

def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip


class CompVisVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for CompVis diffusion models that output v."""

def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)

def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond)

0 comments on commit 4314f91

Please sign in to comment.