In [1]:
from IPython.display import Image

In [2]:
Image(url='https://i.imgur.com/S7KH5hZ.png', width=600)

## Training

$$
\nabla_\theta\|\underbrace{\epsilon}_{\text{target noise}} - \underbrace{\epsilon_\theta(\underbrace{\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon}_{\text{noisy image: }x_t}, t)}_{\text{predict noise}}\|^2
$$

- $\epsilon$: target noise
- $\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon$: noisy image
- $\epsilon_\theta(\cdot, t)$: neural network becomes a **noise predictor**,
    - $\epsilon_\theta(, t)$：neural Network
        - unet
- $\epsilon_\theta(x_t,t)$：预测的是添加进 $x_t$ 上的 noise（added noise）；
    - $t$ 刻画着 noise 的强度；
    - Diffusion forward的过程就是构造noise predictor训练集的过程
        - input：$x_{t}, t$
        - output: $x_{t} - x_{t-1}$
- 这个loss是真实误差与预测误差的差的平方；

## Sampling

$$
x_{t-1}=\underbrace{\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha_t}}}\epsilon_\theta(x_t,t)\right)}_{\mu_\theta(x_t,t)}+\underbrace{\sigma_tz}_{\text{reparameterize}}
$$

- $\epsilon_\theta(x_t, t)$ unet denoising model
    - input: $(x_t, t)$
- $\sigma_tz$：重参数化；    
- 右侧部分（不加重参数化的 $\sigma_tz$），可以视为 $\mu_\theta(x_t,t)$（mean predictor，概率分布均值的估计）

- 抛开这个 noise （$\sigma_tz$）不谈，
    - $\epsilon_\theta(x_t,t)$：预测的是添加进 $x_t$ 上的 noise（added noise）；

$$
x_{t-1}=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha_t}}}\epsilon_\theta(x_t,t)\right)\\
\sqrt{\alpha_t}x_{t-1}+\frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(x_t,t)=x_t
$$

## coding

### training the unet (noise pred) model 

```
for epoch in range(30):
    for step, batch in enumerate(train_dataloader):
        
        # x0
        clean_images = batch["images"].to(device)
        # Sample noise to add to the images
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        bs = clean_images.shape[0]
        
        # t
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
        ).long()
        
        # scheduler(x0s, epsilons, ts) => xts
        # Add noise to the clean images according to the noise magnitude at each timestep
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
        
        # unet(xts, ts) => epsilons
        # Get the model prediction
        noise_pred = unet(noisy_images, timesteps, return_dict=False)[0]

        # Calculate the loss
        loss = F.mse_loss(noise_pred, noise)
        loss.backward(loss)
        losses.append(loss.item())

        # Update the model parameters with the optimizer
        optimizer.step()
        optimizer.zero_grad()
```

### sampling 

#### pipe.__call__()

```
image = randn_tensor(image_shape, generator=generator, device=self.device)

# set step values
self.scheduler.set_timesteps(num_inference_steps)

for t in self.progress_bar(self.scheduler.timesteps):
    # 1. predict noise model_output
    model_output = self.unet(image, t).sample

    # 2. compute previous image: x_t -> x_t-1
    image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
    image = self.numpy_to_pil(image)
```

#### scheduler.step

```
t = timestep

prev_t = self.previous_timestep(t)

if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
    model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
    predicted_variance = None

# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t

# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.prediction_type == "epsilon":
    pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
    pred_original_sample = model_output
elif self.config.prediction_type == "v_prediction":
    pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
else:
    raise ValueError(
        f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
        " `v_prediction`  for the DDPMScheduler."
    )

# 3. Clip or threshold "predicted x_0"
if self.config.thresholding:
    pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
    pred_original_sample = pred_original_sample.clamp(
        -self.config.clip_sample_range, self.config.clip_sample_range
    )

# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t

# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample

# 6. Add noise
variance = 0
if t > 0:
    device = model_output.device
    variance_noise = randn_tensor(
        model_output.shape, generator=generator, device=device, dtype=model_output.dtype
    )
    if self.variance_type == "fixed_small_log":
        variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
    elif self.variance_type == "learned_range":
        variance = self._get_variance(t, predicted_variance=predicted_variance)
        variance = torch.exp(0.5 * variance) * variance_noise
    else:
        variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise

pred_prev_sample = pred_prev_sample + variance

```

$$
\begin{split}
&x_0\approx \hat x_0 = (x_t-\sqrt{1-\bar\alpha_t}\epsilon_\theta(x_t))\sqrt{\bar\alpha_t},\quad (15)\\
&q(x_{t-1}|x_t,x_0)=\mathcal N(x_{t-1};\tilde\mu_t(x_t,x_0),\tilde\beta_tI),\quad (6)\\
&\tilde\mu_t(x_t,x_0):=\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar\alpha_t}x_0+\frac{\sqrt{\alpha_t}{(1-\bar\alpha_{t-1})}}{1-\bar\alpha_t}x_t\\
&\tilde\beta_t:=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t,\quad (7)
\end{split}
$$