Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion examples/research_projects/consistency_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,32 @@ accelerate launch examples/research_projects/consistency_training/train_cm_ct_un
--validation_steps=100 --eval_batch_size=4 \
--checkpointing_steps=100 --checkpoints_total_limit=10 \
--class_conditional --num_classes=10 \
```
```

## Hyperparameters

A short description of the consistency training-specific hyperparameters is as follows. The default hyperparameter values follow those in the improved consistency training column in Table 1 of [Improved Techniques for Training Consistency Models](https://arxiv.org/abs/2310.14189).

- Time Discretization
- `sigma_min`/`sigma_max`: These define the lower and upper boundaries of the noise level $\sigma \in [\sigma_{min}, \sigma_{max}]$ By default, these are set to $\sigma_{min} = 0.002$ and $\sigma_{max} = 80.0$, following both the [original consistency models paper](https://arxiv.org/abs/2303.01469) and the [improved consistency training paper](https://arxiv.org/abs/2310.14189).
- `rho`: in practice, the time interval $[\sigma_{min}, \sigma_{max}]$ is discretized into a sequence of noise levels $\sigma_{min} = \sigma_1 < \ldots < \sigma_{N} = \sigma_{max}$ following the Karras sigmas with parameter $\rho$:
$$\sigma_i = (\sigma_{min}^{1 / \rho} - \frac{i + 1}{N - 1}(\sigma_{max}^{1 / \rho} - \sigma_{min}^{1 / \rho}))^\rho$$
By default, $\rho = 7$, which is the value originally suggested in the [EDM paper](https://arxiv.org/abs/2206.00364) and used in the consistency model papers.
- `discretization_s_0`/`discretization_s_1`: During training, we vary the number of discretization steps $N$ following a discretization curriculum $N(k)$ based on the current training step $k$ out of $K$ (`max_train_steps`) total:
$$N(k) = \min{(s_02^{\lfloor k / K' \rfloor}, s_1)} + 1, K' = \lfloor\frac{K}{\log_{2}{\lfloor s_1 / s_0 \rfloor} + 1}\rfloor$$
In this exponential curriculum, we start with $s_0 + 1$ discretization steps at the beginning of training, with the number of discretization steps $N$ doubling after a set number of training iterations until the maximum number of discretization steps $s_1 + 1$ is reached. By default, $s_0 = 10$ and $s_1 = 1280$, which are the values used in the [improved consistency training paper](https://arxiv.org/abs/2310.14189).
- `constant_discretization_steps`: If set, disables the above discretization curriculum and uses a constant curriculum $N(k) = s_0 + 1$. This is useful for debugging.
- Input and Output Preconditioning
- `input_precond_type`: this specifies how the $c_{in}(\sigma)$ input preconditioning parameter is calculated. By default, this is set to `'cm'`, which uses the input preconditioning from the original CM paper (which is also the original EDM input preconditioning) $c_{in}(\sigma) = 1 / \sqrt{\sigma^2 + \sigma_{data}^2}$. If `'none'` is specified, no input preconditioning will be used.
- `noise_precond_type`: this specifies the function $c_{noise}(\sigma)$ which transforms discrete timesteps $\sigma_i$ for input into the consistency model U-Net. By default, this is set to `'cm'`, which uses the function $c_{noise}(\sigma) = 1000 \cdot \frac{1}{4}\log{(\sigma + 10^{-44})}$ from [the original consistency models repo](https://github.com/openai/consistency_models/blob/e32b69ee436d518377db86fb2127a3972d0d8716/cm/karras_diffusion.py#L346). The original EDM noise preconditioning function $c_{noise}(\sigma) = \frac{1}{4}\log{\sigma}$ can be used by setting this argument to `'edm'`. If `'none'` is specified, no noise preconditioning will be used.
- Noise Schedule
- `p_mean`/`p_std`: the probability of sampling noise level $\sigma$ for training is distributed according to a lognormal distribution where $\log{\sigma} \sim \mathcal{N}(P_{mean}, P_{std}^2)$. Since we discretize the noise levels $\{\sigma_i\}$, we use a discretized version of the distribution where $i \sim p(i)$ and
$$p(i) \propto \textrm{erf}{(\frac{\log{\sigma_{i + 1}} - P_{mean}}{\sqrt{2}P_{std}})} - \textrm{erf}{(\frac{\log{\sigma_{i}} - P_{mean}}{\sqrt{2}P_{std}})}$$
By default, $P_{mean} = -1.1$ and $P_{std} = 2.0$, which are the default values used in the [improved consistency training paper](https://arxiv.org/abs/2310.14189).
- Loss
- `huber_c`: this corresponds to the $c$ parameter in the Pseudo-Huber metric
$$d(x, y) = \sqrt{\mid\mid x - y \mid\mid_2^2 + c^2} - c$$
If not set, this will default to the heuristic value of $c = 0.00054\sqrt{d}$ where $d$ is dimensionality of the input image data suggested in the [improved consistency training paper](https://arxiv.org/abs/2310.14189).
- Exponential Moving Average (EMA)
- `use_ema`: set this to use EMA for the student model (the model updated via gradient descent). Note that EMA is not used to update the teacher model (the model not updated via gradient descent with lower noise value); rather, the teacher parameters $\theta^-$ are set to the student parameters $\theta$ after each training step (equivalent to a EMA decay rate of 0).
- `ema_min_decay`/`ema_max_decay`: specifies the minimum and maximum EMA decay. The [improved consistency training paper](https://arxiv.org/abs/2310.14189) uses a fixed EMA decay rate of `0.99993` for CIFAR10, which is achieved by the default setting of `ema_max_decay == 0.99993` and not setting `ema_min_decay` (when not set, `ema_min_decay` defaults to `ema_max_decay` so that the EMA decay is fixed throughout training).
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ def get_input_preconditioning(sigmas, sigma_data=0.5, input_precond_type: str =
)


def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=1.0):
def scalings_for_boundary_conditions(timestep, sigma_min, sigma_data=0.5, timestep_scaling=1.0):
scaled_timestep = timestep_scaling * timestep
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
c_skip = sigma_data**2 / ((scaled_timestep - sigma_min) ** 2 + sigma_data**2)
c_out = (scaled_timestep - sigma_min) * sigma_data / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out


Expand Down Expand Up @@ -1255,8 +1255,8 @@ def unwrap_model(model):
c_in_teacher = get_input_preconditioning(teacher_timesteps, input_precond_type=args.input_precond_type)
c_in_student = get_input_preconditioning(student_timesteps, input_precond_type=args.input_precond_type)

c_skip_teacher, c_out_teacher = scalings_for_boundary_conditions(teacher_timesteps)
c_skip_student, c_out_student = scalings_for_boundary_conditions(student_timesteps)
c_skip_teacher, c_out_teacher = scalings_for_boundary_conditions(teacher_timesteps, args.sigma_min)
c_skip_student, c_out_student = scalings_for_boundary_conditions(student_timesteps, args.sigma_min)
Comment on lines +1258 to +1259
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we defaulting to reasonable values for these args?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sigma_min defaults to 0.002, which is the value of $\sigma_{min}$/$\epsilon$ used in the original Consistency Models paper and the improved Consistency Training paper (see Table 1 in the latter paper), which I believe is a reasonable default value.

(Note that sigma_min should be a small positive value rather than 0 to avoid numerical issues when using ODE solvers.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Do you think it makes sense to add this info to the README as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added documentation on the consistency training-specific hyperparameters to the README.


c_skip_teacher, c_out_teacher, c_in_teacher = [
append_dims(x, clean_images.ndim) for x in [c_skip_teacher, c_out_teacher, c_in_teacher]
Expand Down