# Consistency Models Training Example

## Introduction
Consistency model are a new family of generative models that achieve high sample quality without adversarial training. They support fast one-step generation by design, while still allowing for few-step sampling to trade compute for sample quality. They also zero-shot data editing, like image inpainting, colorization, and super-resolution, without requiring  explicit training on these tasks.
### Key Idea
Learn a model that maps any arbitrary point in the latent space to the initial data point, i.e: if points lie on the same probability flow trajectory they are mapped to the same initial data point.
### Contributions
* Single step sampling
* Zero-shot data editing: inpainting, outpainting e.t.c
### Difinition
Given a diffusion trajectory $x_{t\in[t_{min},t_{max}]}$, we define a consistency function $f:(x_t,t)\rightarrow x_{t_{min}}$.
We can then train a consistency model $f_\theta(\cdot,\cdot)$ to approximate the consistency function. A property of the consistency function is that $f:(x_{t_{min}},t_{min})\rightarrow x_{t_{min}}$. To achieve this, we parameterize the consistency model using skip connections:
$$
f_\theta(x_t,t) = c_{skip}(t)x_t+c_{out}F_\theta(x_t,t),
$$
where $c_{skip}(t_{min})=1$ and $c_{out}(t_{min})=0$ and $F_\theta(\cdot, \cdot)$ is the neural network.

## Algorithms
### Training
To train the model we follow the following algorithm:

In [2]:
for itr in range(iterations):
    data = data_distribution()
    
    # consider improved techniques for training consistency models
    if improved_CT:
        N = improved_timesteps_schedule(itr, iterations, initial_timesteps=10, final_timesteps=1280)
    else:
        N = timesteps_schedule(itr, iterations, initial_timesteps=2, final_timesteps=150)
    
    if adaptive_ema:
        start_scales = 2.0
        c = np.log(ema_decay) * start_scales
        target_ema = np.exp(c / N)
    
    boundaries = kerras_boundaries(7, 0.002, N, 80).to(device)
    sigma = boundaries
    z = torch.randn_like(data)
    
    # consider improved techniques for training consistency models
    if improved_CT:
        # t = lognormal_timesteps_distribution(x.shape[0], boundaries, mean=-1.1, std=2.0)
        mean = -1.1
        std = 2.0
        pdf = torch.erf((torch.log(sigma[1:]) - mean) / (std * math.sqrt(2))) - torch.erf((torch.log(sigma[:-1]) - mean) / (std * math.sqrt(2)))
        pdf = pdf / pdf.sum()
        t = torch.multinomial(pdf, num_samples, replacement=True)
        t = t.view(-1,1).to(device)
    else:
        t = torch.randint(0, N - 1, (data.shape[0], 1), device=device)
    
    t_1 = sigma[t]
    t_2 = sigma[t + 1]
    
    # consider improved techniques for training consistency models
    if improved_CT:
        teacher_model = None
    else:
        teacher_model = ema_actor
    
    loss = actor.loss(data, z, t_1, t_2, teacher_model)
    mean_loss = loss.mean()
    
    if loss_ema is None:
        loss_ema = mean_loss.item()
    else:
        loss_ema = 0.9 * loss_ema + 0.1 * mean_loss.item()
    
    actor_optimizer.zero_grad()
    mean_loss.backward()
    if grad_norm > 0:
        actor_grad_norms = nn.utils.clip_grad_norm_(actor.parameters(), max_norm=grad_norm, norm_type=2)
    actor_optimizer.step()
    
    # Step target network
    for p, ema_p in zip(actor.parameters(), ema_actor.parameters()):
        ema_p.mul_(target_ema).add_(p, alpha=1 - target_ema)


NameError: name 'iterations' is not defined