Skip to content

Commit af78373

Browse files
committed
Diffusion mostly done
1 parent d5ac83d commit af78373

File tree

2 files changed

+120
-80
lines changed

2 files changed

+120
-80
lines changed

src/content/lessons/diffusion.mdx

Lines changed: 87 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ import Ref from '../../components/numbering/Ref.astro'
2525
'/>
2626

2727
<div class="glossary">
28-
<label><input type="checkbox" checked /> Glossary (keep visible)</label>
28+
<div class="glossary-content">
29+
<label><input type="checkbox" id="cb-glossary" checked /> Glossary (keep visible)</label>
2930
- **~beta_t~** ((choice)) small added noise variance at step ~t~
3031
- **~alpha_t := 1 - beta_t~** signal retention factor at step ~t~
3132
- **~overline(alpha)_ t := product_(s=1)^t alpha_s~** overall signal retention factor from the start to step ~t~
@@ -34,6 +35,12 @@ import Ref from '../../components/numbering/Ref.astro'
3435
- **~q(x_t | x_(t-1))~** ~!:= cal(N)(sqrt(1 - beta_t)x_(t-1), beta_t I)~ forward/noising step distribution
3536
- **~q(x_t | x_0)~** ~!= cal(N)(sqrt(overline(alpha)_t) x_0, (1 - overline(alpha)_t) I)~ big-jump forward/noising step distribution
3637
- **~q(x_(t-1)|x_0)~** ~!= cal(N)(sqrt(overline(alpha)_ (t-1))x_0, (1 - overline(alpha)_(t-1))I)~ big-jump for step ~t-1~
38+
- **~q(x_(t-1) | x_t, x_0) & prop cal(N)(tilde(mu), tilde(beta) I)~** "sandwich view"
39+
- **<T block v='tilde(mu)(x_0, t, x_t)= & (sqrt(overline(alpha)_(t-1)) beta_t)/ (1 - overline(alpha)_t) x_0 \ & + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) x_t'/>** sandwich mean
40+
- **~tilde(beta)(t) & = (1 - overline(alpha)_(t-1)) / (1 - overline(alpha)_t) beta_t~** sandwich variance
41+
42+
<label for="cb-glossary">Hide Glossary</label>
43+
</div>
3744
</div>
3845

3946
This lesson introduces the principles of diffusion models.
@@ -63,9 +70,9 @@ We will just recall some properties of Gaussians that will be useful in the rest
6370

6471
The KL divergence between two normal distributions ~cal(N)(mu_1, sigma_1^2)~ and ~cal(N)(mu_2, sigma_2^2)~ is given by:
6572

66-
~!"KL"(cal(N)(mu_1, sigma_1^2) || cal(N)(mu_2, sigma_2^2)) = (mu_1 - mu_2)^2 / (2 sigma_2^2) + (sigma_1^2 / (2 sigma_2^2)) - 1/2 + log(sigma_2 / sigma_1)~
73+
~!KL(cal(N)(mu_1, sigma_1^2), cal(N)(mu_2, sigma_2^2)) = (mu_1 - mu_2)^2 / (2 sigma_2^2) + (sigma_1^2 / (2 sigma_2^2)) - 1/2 + log(sigma_2 / sigma_1)~
6774

68-
In the case we will use it, only the first term will be useful: it will be the only term depending on the variable of interest.
75+
In this post only the first term will be useful: it is the only term depending on the variable of interest.
6976

7077
### Product of two normal densities
7178

@@ -75,20 +82,26 @@ The product of two normal densities ~cal(N)(mu_1, sigma_1^2)~ and ~cal(N)(mu_2,
7582

7683
or more explicitly,
7784

78-
~! exists K_1, mu, sigma, forall x: cal(N)(mu_1, sigma_1^2)(x) times cal(N)(mu_2, sigma_2^2)(x) = K_1 times cal(N)(mu, sigma)(x)~
85+
<T block v='exists K_1, mu, sigma, forall x: \
86+
cal(N)(mu_1, sigma_1^2)(x) times cal(N)(mu_2, sigma_2^2)(x) = K_1 times cal(N)(mu, sigma)(x)
87+
'/>
7988

80-
with ~!mu = (mu_1 / sigma_1^2 + mu_2 / sigma_2^2) / (1 / sigma_1^2 + 1 / sigma_2^2) = (sigma_2^2 mu_1 + sigma_1^2 mu_2) / (sigma_1^2 + sigma_2^2)~
81-
and ~!sigma^2 = 1 / (1 / sigma_1^2 + 1 / sigma_2^2) = (sigma_1^2 sigma_2^2) / (sigma_1^2 + sigma_2^2)~
89+
and <T block v='
90+
mu & = (mu_1 / sigma_1^2 + mu_2 / sigma_2^2) / (1 / sigma_1^2 + 1 / sigma_2^2) = (sigma_2^2 mu_1 + sigma_1^2 mu_2) / (sigma_1^2 + sigma_2^2) \
91+
sigma^2 & = 1 / (1 / sigma_1^2 + 1 / sigma_2^2) = (sigma_1^2 sigma_2^2) / (sigma_1^2 + sigma_2^2)
92+
'/>
8293

8394
<details>
84-
<summary>Proof</summary>
95+
<summary>Proof/Derivation</summary>
8596
<div>
8697
We work on the log space, focusing/keeping only the terms depending on ~x~ (the rest is the normalization constant of a Gaussian).
87-
More generally, the log of a gaussian density ~cal(N)(mu, sigma^2)~ can be expanded as:
98+
99+
So, first, let's see that the log of a gaussian density ~cal(N)(mu, sigma^2)~ can be expanded as:
88100
<T block v='
89101
ln(cal(N)(mu, sigma^2)(x)) & = -1/2 ((x - mu)^2 / sigma^2) + C_1 \
90102
& = -1/2 (x^2 / sigma^2 - 2 (mu x) / sigma^2 + mu^2 / sigma^2) + C_1 \
91103
& = -1/(2 sigma^2) (x^2 - 2 mu x) + C_2 \
104+
& = -1/2 (add(1/sigma^2) x^2 - 2 add(mu / sigma^2) x) + C_2 \
92105
'/>
93106

94107
This formula is useful for identifying the parameters of a Gaussian density from its expanded log density.
@@ -100,8 +113,8 @@ For the product of two Gaussian densities, we have:
100113
& "(developing)" \
101114
& = -1/2 (x^2 / sigma_1^2 - 2 (mu_1 x) / sigma_1^2 + mu_1^2 / sigma_1^2 + x^2 / sigma_2^2 - 2 (mu_2 x) / sigma_2^2 + mu_2^2 / sigma_2^2) + C_3 \
102115
& "(factorizing and pushing constant terms in the constant)" \
103-
& = -1/2 ((1 / sigma_1^2 + 1 / sigma_2^2) x^2 - 2 (mu_1 / sigma_1^2 + mu_2 / sigma_2^2) x) + C_4 \
104-
& = -1/2 (x^2 / sigma^2 - 2 (mu / sigma^2) x) + C_4 \
116+
& = -1/2 (add((1 / sigma_1^2 + 1 / sigma_2^2)) x^2 - 2 add((mu_1 / sigma_1^2 + mu_2 / sigma_2^2)) x) + C_4 \
117+
& = -1/2 ( add(1/sigma^2) x^2 - 2 add(mu / sigma^2) x) + C_4 \
105118
'/>
106119
We can identify ~sigma^2~ directly, and then deduce ~mu~:
107120
<T block v='
@@ -113,9 +126,9 @@ mu & = sigma^2 (mu_1 / sigma_1^2 + mu_2 / sigma_2^2) = (sigma_2^2 mu_1 + sigma_1
113126

114127
### Identities on normal mean
115128

116-
- ~cal(N)(mu, sigma^2)(x) = cal(N)(x, sigma^2)(mu)~
117-
- ~cal(N)(a mu, sigma^2)(x) = cal(N)(mu, sigma^2 / a^2)(x / a)~
118-
- combining both: ~cal(N)(a x, sigma^2)(mu) = cal(N)(mu / a, sigma^2 / a^2)(x)~
129+
- ~cal(N)(add(mu), sigma^2)(x) = cal(N)(x, sigma^2)(add(mu))~
130+
- ~cal(N)(add(a) mu, sigma^2)(x) = cal(N)(mu, sigma^2 / add(a^2))(x / add(a))~
131+
- combining both: ~cal(N)(add(a x), sigma^2)(mu) = cal(N)(mu / add(a), sigma^2 / add(a^2))(add(x))~
119132

120133

121134

@@ -133,7 +146,7 @@ The forward/noising process progressively adds noise to the data until it become
133146
As shown in Fig. <Ref label="fig:global-noise"/>, the forward process starts from data points
134147
~forall i, x^i_0 tilde.op q(x_0)~ (e.g. images from the training set, also named ~hat(p)_"data"~).
135148
The forward process progressively adds Gaussian noise to these data points, until they are completely shuffled and become close to pure noise.
136-
The total process is run for a finite (but high) number of steps ~T~, and we obtain ~forall i, x^i_T tilde.op cal(N)(0,I)~.
149+
The total process is run for a finite (but high) number of steps ~T~, and we (almost) obtain ~forall i, x^i_T tilde.op cal(N)(0,I)~.
137150

138151
<figure>
139152
<InlineSvg asset="diffusion" hide='#FORWARD, #BACKWARD, #forward, #backward , #more, #bigkl' />
@@ -152,11 +165,12 @@ forall t in [1..T], x_t & tilde.op q(x_t | x_(t-1)) \
152165
"/>
153166

154167
Even more precisely:
168+
- we suppose the original dataset has been normalized,
155169
- we can decide on a variance addition schedule ~beta_1, ..., beta_T~ saying how much noise variance to add at each step,
156170
- as we add noise at each step, the distribution would be more and more spread along time steps (~t~) and would not reach a gaussian noise with identity covariance,
157171
- to avoid this, we also rescale the signal at each step by a factor ~sqrt(1 - beta_t)~.
158172

159-
The goal of the rescaling is to ensure that the variance at step ~t~ is always ~1~, whatever ~t~.
173+
The goal of the rescaling is to ensure that the variance of the dataset at step ~t~ is always ~1~, whatever ~t~.
160174
Overall the forward/noising process is defined as:
161175

162176
<T block v="
@@ -173,13 +187,18 @@ forall t in [1..T], x_t & tilde.op q(x_t | x_(t-1)) = cal(N)(sqrt(1 - beta_t) x_
173187
### Big-jump view
174188

175189
Thanks to the properties of Gaussians, we can express the distribution at step ~t~ as a function of the initial data point ~x_0~.
176-
Indeed, since each step is a Gaussian distribution, the composition of all the steps is also a Gaussian distribution as shown in Fig. <Ref label="fig:multi-step-noise"/>.
190+
Indeed, since each step is adding a Gaussian noise (and rescaling), the composition of all the steps is also a Gaussian distribution as shown in Fig. <Ref label="fig:multi-step-noise"/>.
177191
More precisely, we have:
178192

179193
<T block v="
180194
forall t in [1..T], x_t & tilde.op q(x_t | x_0) = cal(N)(sqrt(overline(alpha)_t) x_0, (1 - overline(alpha)_t) I) \
181195
"/>
182196

197+
with
198+
~overline(alpha)_ t := product_(s=1)^t alpha_s~ the overall signal retained from the start to step ~t~,
199+
in which ~alpha_t := 1 - beta_t~ is the signal retention factor at step ~t~.
200+
201+
183202
<figure>
184203
<InlineSvg asset="diffusion" hide='#BACKWARD, #qtt0, #backward, #more, #bigkl' />
185204
<figcaption>[<Counter label="fig:multi-step-noise"/>] Multi-step noising process.</figcaption>
@@ -197,13 +216,13 @@ This is illustrated in Fig. <Ref label="fig:sandwich-noise"/>:
197216
More precisely, we have:
198217

199218
<T block v="
200-
forall t in [1..T], q(x_(t-1) | x_t, x_0) & prop q(x_(t-1) | x_t) times q(x_(t-1) | x_0) \
219+
forall t in [2..T], q(x_(t-1) | x_t, x_0) & prop q(x_(t-1) | x_t) times q(x_(t-1) | x_0) \
201220
"/>
202221

203222
We will show that we can derive a closed-form expression for this distribution:
204223

205224
<T block v='
206-
forall t in [1..T], q(x_(t-1) | x_t, x_0) & prop cal(N)(tilde(mu), tilde(beta) I) \
225+
forall t in [2..T], q(x_(t-1) | x_t, x_0) & prop cal(N)(tilde(mu), tilde(beta) I) \
207226
"with" \
208227
tilde(mu)(x_0, t, x_t) & = (sqrt(overline(alpha)_(t-1)) beta_t)/ (1 - overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) x_t \
209228
tilde(beta)(t) & = (1 - overline(alpha)_(t-1)) / (1 - overline(alpha)_t) beta_t \
@@ -212,14 +231,14 @@ tilde(beta)(t) & = (1 - overline(alpha)_(t-1)) / (1 - overline(alpha)_t) beta_t
212231
The proof relies on showing that ~q(x_(t-1) | x_t)~ (the anti-forward step) is proportional to a Gaussian density, and then using the product-of-two-Gaussians property that we derived in the preliminaries.
213232

214233
<details>
215-
<summary>Derivation</summary>
234+
<summary>Proof/Derivation</summary>
216235
<div>
217236

218237
Knowing ~x_0~, we can express the anti-forward step ~q(x_(t-1) | x_t, x_0)~ using the Bayes rule.
219-
Remembering that this is a distribution over ~x_(t-1)~, we can focus on the factors that depend on it (and thus drop $q(x_t)$ below):
238+
Remembering that this is a distribution over ~x_(t-1)~, we can focus on the factors that depend on it (and thus drop $q(x_t | x_0)$ below):
220239

221240
<T block v='
222-
forall t in [1..T], q(x_(t-1) | x_t, x_0) & = q(x_t | x_(t-1), x_0) q(x_(t-1) | x_0) / q(x_t | x_0) \
241+
forall t in [2..T], q(x_(t-1) | x_t, x_0) & = q(x_t | x_(t-1), x_0) q(x_(t-1) | x_0) / q(x_t | x_0) \
223242
& prop q(x_t | x_(t-1)) q(x_(t-1) | x_0) \
224243
& prop cal(N)(sqrt(1 - beta_t) x_(t-1), beta_t I)(x_t) times q(x_(t-1) | x_0) \
225244
cal(N)(a x,sigma^2)(mu) => cal(N)(μ/a, σ^2/a^2)(x) "   "
@@ -295,8 +314,9 @@ This part details the key insight of how to transform a global KL loss into a su
295314
We will show that we can simplify the global objective into a sum of local objectives, one for each step, with the proper conditioning.
296315
We will follow a few steps:
297316
- starting by writing the KL divergence between the two joint distributions of the Markov chains, in their natural direction (forward for the noising process, backward for the learned process),
298-
- re-introducing an expectation on ~x_0~ to make it the expression tractable (using the "sandwich view"),
317+
- re-introducing an expectation on ~x_0~ to make it the expression tractable (below, using the "sandwich view"),
299318
- "reversing" the conditional forward noising process,
319+
- using the sandwich view,
300320
- using the closed-form of the KL between Gaussians to get a final closed-form loss.
301321

302322
The final terms involved in the loss are KL divergences between two Gaussian distributions, which can be computed in closed form.
@@ -310,7 +330,7 @@ A big part of these derivations are also detailed at the end of this page, in is
310330
NB: seems ok but need to be chunked better.
311331

312332
<T block v='
313-
cal(L) & := "KL"(q(x_(0:T)) || p_theta (x_(0:T))) \
333+
cal(L) & := KL(q(x_(0:T)), p_theta (x_(0:T))) \
314334
& = EE_(x_(0:T) tilde.op q(x_(0:T))) [ln (q(x_(0:T)) / (p_theta (x_(0:T))))] \
315335
& = EE_(x_(0:T) tilde.op q(x_(0:T))) [ln ((q(x_0) product_(t=1)^T q(x_t | x_(t-1))) / (p_theta (x_T) product_(t=1)^T p_theta (x_(t-1) | x_t)))] \
316336
& = EE_(x_(0:T) tilde.op q(x_(0:T))) [ln q(x_0) - ln p_theta (x_T) + sum_(t=1)^T ln q(x_t | x_(t-1)) - sum_(t=1)^T ln p_theta (x_(t-1) | x_t)] \
@@ -364,12 +384,14 @@ To avoid having handling this special case below, we override ~tilde(mu)(x_0, t=
364384
We thus get a closed-form expression for the loss, involving square norms (coming from the KL between gaussians).
365385

366386
<T block v='
367-
cal(L) & = C + EE_(x_0 tilde.op q(x_0)) sum_(t=1)^T lambda_t ||tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t)||^2 \
368-
& = C + sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) ||tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t)||^2
387+
cal(L) & = C + EE_(x_0 tilde.op q(x_0)) sum_(t=1)^T EE_(x_t) lambda_t norm(tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t))^2 \
388+
& = C + sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(x_t tilde.op q(x_t)) norm(tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t))^2 \
389+
& = C + sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(x_t tilde.op q(x_t | x_0)) norm(tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t))^2
369390
'/>
370391

371-
with ~lambda_t = 1/(2 sigma_t^2)~.
392+
with ~lambda_t = 1/(2 sigma_t^2)~, and where, as above with $x_1$ and $x_0$, $x_t$ can be equivalently sampled either independently or conditioned on $x_0$ or conditioned on it.
372393

394+
Overall, thanks to all derivations, we managed to get a loss that is simple as it is conditioned on the data point ~x_0~ and is local in time (sum over ~t~).
373395

374396
<figure>
375397
<InlineSvg asset="diffusion" hide='#forward, #backward, #more, #bigkl, #qt, #qtt' />
@@ -394,57 +416,69 @@ with ~lambda_t = 1/(2 sigma_t^2)~.
394416

395417
The above derivations reason in terms of fitting the means of the backward steps ~mu_theta (x_t, t)~.
396418

397-
We can reparametrize the sampling of ~x_t~ as a function of ~x_0~ and some noise ~epsilon~:
419+
Based on the "big jump view" of the forward process, and the Gaussian properties,
420+
we can reparametrize the sampling of ~x_t~ (conditioned on $x_0$) as a function of ~x_0~ and some noise ~epsilon~:
398421

399422
<T block v='
400423
x_t & = sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon \
401424
epsilon & tilde.op cal(N)(0,I) \
402425
'/>
403426

404-
We can further use the formula of ~x_t~ to rewrite ~tilde(mu)~ as a function of ~x_0~ and ~epsilon~ only (which will help re-interpret and reparametrize the loss):
427+
Which, once reversed, gives:
428+
~!x_0 = 1 / sqrt(overline(alpha)_t) x_t - sqrt((1 - overline(alpha)_t) / overline(alpha)_t) epsilon~.
405429

406-
TODO redo/check exact simplifications
430+
We can plug this expression on ~tilde(mu)~ to make it depend only on ~x_t~ and ~epsilon~, not on ~x_0~:
407431

408432
<T block v='
409433
tilde(mu)(x_0, t, x_t)
410-
& = tilde(mu)(x_0, t, sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon) \
411-
& = (sqrt(overline(alpha)_(t-1)) beta_t)/ (1 - overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) (sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon) \
412-
& = (sqrt(overline(alpha)_(t-1)) beta_t + sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(overline(alpha)_t)) / (1 - overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon \
413-
& = 1 / sqrt(overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon \
414-
434+
& = (sqrt(overline(alpha)_(t-1)) beta_t)/ (1 - overline(alpha)_t) (1 / sqrt(overline(alpha)_t) x_t - sqrt((1 - overline(alpha)_t) / overline(alpha)_t) epsilon)
435+
+ (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) x_t \
436+
& = (sqrt(overline(alpha)_(t-1)) beta_t) / (sqrt(overline(alpha)_t) (1 - overline(alpha)_t)) x_t
437+
+ (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) x_t
438+
- (sqrt(overline(alpha)_(t-1)) beta_t sqrt((1 - overline(alpha)_t) / overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon \
439+
& = ( (sqrt(overline(alpha)_(t-1)) beta_t) / (sqrt(overline(alpha)_t) (1 - overline(alpha)_t)) + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) ) x_t
440+
- (sqrt(overline(alpha)_(t-1)) beta_t sqrt((1 - overline(alpha)_t) / overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon \
441+
& =: K_(x_t)(t) x_t + K_(epsilon)(t) epsilon \
415442
'/>
416443

444+
The constants can be refined further, but for now, let's look at the implications.
445+
Using the sampling of ~x_t~ reparametrized using ~epsilon~ and substituting the value of ~tilde(mu)~ we just derived, we can rewrite the loss as:
446+
<T block v='
447+
cal(L) & = C + sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) norm(K_(x_t)(t) x_t + K_(epsilon)(t) epsilon - mu_theta (x_t, t))^2 \
448+
& = C + sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I))
449+
norm(K_(epsilon)(t) (epsilon - (mu_theta (x_t, t) - K_(x_t)(t) x_t) / (K_(epsilon)(t))))^2 \
450+
& = C + sum_(t=1)^T (lambda_t K^2_(epsilon)(t)) EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I))
451+
norm(epsilon - (mu_theta (x_t, t) - K_(x_t)(t) x_t) / (K_(epsilon)(t)))^2 \
452+
'/>
417453

418-
419-
420-
This allows to rewrite the loss as:
454+
So, up to the time reweighting (~gamma_t := lambda_t K^2_(epsilon)(t)~ instead of ~lambda_t~), as ~mu_theta~ takes as input ~x_t~ and ~t~, we can equivalently train a network to predict the noise ~epsilon~ that was used to generate ~x_t~ from ~x_0~, instead of predicting ~tilde(mu)~.
455+
We can thus define a network ~epsilon_theta (x_t, t)~ and train it to minimize the loss:
421456

422457
<T block v='
423-
cal(L)
424-
& <= sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [(tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t))^2] + C \
425-
& = sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [(tilde(mu)(x_0, t, sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon) - mu_theta (x_t))^2] + C \
458+
cal(L) & = C + sum_(t=1)^T gamma_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I))
459+
norm(epsilon - epsilon_theta (x_t, t))^2 \
426460
'/>
427461

428-
429-
Where we developed ~x_t~ in ~tilde(mu)~ to be able to simplify the formula (but kept it in ~mu_theta~ as we will keep it untouched).
430-
By substituting the expression of ~tilde(mu)~, we get:
462+
With the two-way mapping between ~mu_theta~ and ~epsilon_theta~ being:
431463

432464
<T block v='
433-
cal(L)
434-
& <= sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [((sqrt(overline(alpha)_(t-1)) beta_t)/ (1 - overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) (sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon) - mu_theta (x_t, t))^2] + C \
435-
& = sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [((sqrt(overline(alpha)_(t-1)) beta_t + sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(overline(alpha)_t)) / (1 - overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon - mu_theta (x_t, t))^2] + C \
436-
& = sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [(1 / sqrt(overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon - mu_theta (x_t, t))^2] + C \
437-
& = sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [epsilon - ( (1 - overline(alpha)_t) / (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) (mu_theta (x_t, t) - 1 / sqrt(overline(alpha)_t) x_0) )^2] + C \
465+
mu_theta (x_t, t) & = K_(x_t)(t) x_t + K_(epsilon)(t) epsilon_theta (x_t, t) \
466+
epsilon_theta (x_t, t) & = (mu_theta (x_t, t) - K_(x_t)(t) x_t) / (K_(epsilon)(t)) \
438467
'/>
439468

469+
## Aside: link with flow matching
440470

471+
Looking at algorithms, we can uncover the link with flow matching.
472+
Conceptually, both sample a time step (although with different semantics), a data point and a unit noise.
473+
However, they differ in the path, i.e., the formula for ~x_t~:
474+
- diffusion aims at preserving the variance across time,
475+
- flow matching (in its typical form) aims at a linear interpolation between data and noise.
476+
477+
We can however instantiate flow matching that will match the diffusion path.
478+
The similarities/differences are then just in the time weighting and what is fit.
479+
The details are left out for now.
441480

442-
Substituting to get rid of ~x_0~ and use only ~x_t~ as a parameter of ~epsilon_theta~, we get:
443481

444-
<T block v='
445-
cal(L)
446-
& <= sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [ ( epsilon - ( (1 - overline(alpha)_t) / (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) (mu_theta (sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon, t) - 1 / sqrt(overline(alpha)_t) x_0) )^2] + C \
447-
'/>
448482

449483
## Aside: Bayes rule over a Markov chain
450484

0 commit comments

Comments
 (0)