@@ -25,7 +25,8 @@ import Ref from '../../components/numbering/Ref.astro'
2525' />
2626
2727<div class = " glossary" >
28- <label ><input type = " checkbox" checked /> Glossary (keep visible)</label >
28+ <div class = " glossary-content" >
29+ <label ><input type = " checkbox" id = " cb-glossary" checked /> Glossary (keep visible)</label >
2930- ** ~ beta_t~ ** ((choice)) small added noise variance at step ~ t~
3031- ** ~ alpha_t := 1 - beta_t~ ** signal retention factor at step ~ t~
3132- ** ~ overline(alpha)_ t := product_ (s=1)^t alpha_s~ ** overall signal retention factor from the start to step ~ t~
@@ -34,6 +35,12 @@ import Ref from '../../components/numbering/Ref.astro'
3435- ** ~ q(x_t | x_ (t-1))~ ** ~ !:= cal(N)(sqrt(1 - beta_t)x_ (t-1), beta_t I)~ forward/noising step distribution
3536- ** ~ q(x_t | x_0)~ ** ~ != cal(N)(sqrt(overline(alpha)_ t) x_0, (1 - overline(alpha)_ t) I)~ big-jump forward/noising step distribution
3637- ** ~ q(x_ (t-1)|x_0)~ ** ~ != cal(N)(sqrt(overline(alpha)_ (t-1))x_0, (1 - overline(alpha)_ (t-1))I)~ big-jump for step ~ t-1~
38+ - ** ~ q(x_ (t-1) | x_t, x_0) & prop cal(N)(tilde(mu), tilde(beta) I)~ ** "sandwich view"
39+ - ** <T block v = ' tilde(mu)(x_0, t, x_t)= & (sqrt(overline(alpha)_(t-1)) beta_t)/ (1 - overline(alpha)_t) x_0 \ & + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) x_t' />** sandwich mean
40+ - ** ~ tilde(beta)(t) & = (1 - overline(alpha)_ (t-1)) / (1 - overline(alpha)_ t) beta_t~ ** sandwich variance
41+
42+ <label for = " cb-glossary" >Hide Glossary</label >
43+ </div >
3744</div >
3845
3946This lesson introduces the principles of diffusion models.
@@ -63,9 +70,9 @@ We will just recall some properties of Gaussians that will be useful in the rest
6370
6471The KL divergence between two normal distributions ~ cal(N)(mu_1, sigma_1^2)~ and ~ cal(N)(mu_2, sigma_2^2)~ is given by:
6572
66- ~ !"KL" (cal(N)(mu_1, sigma_1^2) || cal(N)(mu_2, sigma_2^2)) = (mu_1 - mu_2)^2 / (2 sigma_2^2) + (sigma_1^2 / (2 sigma_2^2)) - 1/2 + log(sigma_2 / sigma_1)~
73+ ~ !KL (cal(N)(mu_1, sigma_1^2), cal(N)(mu_2, sigma_2^2)) = (mu_1 - mu_2)^2 / (2 sigma_2^2) + (sigma_1^2 / (2 sigma_2^2)) - 1/2 + log(sigma_2 / sigma_1)~
6774
68- In the case we will use it, only the first term will be useful: it will be the only term depending on the variable of interest.
75+ In this post only the first term will be useful: it is the only term depending on the variable of interest.
6976
7077### Product of two normal densities
7178
@@ -75,20 +82,26 @@ The product of two normal densities ~cal(N)(mu_1, sigma_1^2)~ and ~cal(N)(mu_2,
7582
7683or more explicitly,
7784
78- ~ ! exists K_1, mu, sigma, forall x: cal(N)(mu_1, sigma_1^2)(x) times cal(N)(mu_2, sigma_2^2)(x) = K_1 times cal(N)(mu, sigma)(x)~
85+ <T block v = ' exists K_1, mu, sigma, forall x: \
86+ cal(N)(mu_1, sigma_1^2)(x) times cal(N)(mu_2, sigma_2^2)(x) = K_1 times cal(N)(mu, sigma)(x)
87+ ' />
7988
80- with ~ !mu = (mu_1 / sigma_1^2 + mu_2 / sigma_2^2) / (1 / sigma_1^2 + 1 / sigma_2^2) = (sigma_2^2 mu_1 + sigma_1^2 mu_2) / (sigma_1^2 + sigma_2^2)~
81- and ~ !sigma^2 = 1 / (1 / sigma_1^2 + 1 / sigma_2^2) = (sigma_1^2 sigma_2^2) / (sigma_1^2 + sigma_2^2)~
89+ and <T block v = '
90+ mu & = (mu_1 / sigma_1^2 + mu_2 / sigma_2^2) / (1 / sigma_1^2 + 1 / sigma_2^2) = (sigma_2^2 mu_1 + sigma_1^2 mu_2) / (sigma_1^2 + sigma_2^2) \
91+ sigma^2 & = 1 / (1 / sigma_1^2 + 1 / sigma_2^2) = (sigma_1^2 sigma_2^2) / (sigma_1^2 + sigma_2^2)
92+ ' />
8293
8394<details >
84- <summary >Proof</summary >
95+ <summary >Proof/Derivation </summary >
8596<div >
8697We work on the log space, focusing/keeping only the terms depending on ~ x~ (the rest is the normalization constant of a Gaussian).
87- More generally, the log of a gaussian density ~ cal(N)(mu, sigma^2)~ can be expanded as:
98+
99+ So, first, let's see that the log of a gaussian density ~ cal(N)(mu, sigma^2)~ can be expanded as:
88100<T block v = '
89101ln(cal(N)(mu, sigma^2)(x)) & = -1/2 ((x - mu)^2 / sigma^2) + C_1 \
90102& = -1/2 (x^2 / sigma^2 - 2 (mu x) / sigma^2 + mu^2 / sigma^2) + C_1 \
91103& = -1/(2 sigma^2) (x^2 - 2 mu x) + C_2 \
104+ & = -1/2 (add(1/sigma^2) x^2 - 2 add(mu / sigma^2) x) + C_2 \
92105' />
93106
94107This formula is useful for identifying the parameters of a Gaussian density from its expanded log density.
@@ -100,8 +113,8 @@ For the product of two Gaussian densities, we have:
100113& "(developing)" \
101114& = -1/2 (x^2 / sigma_1^2 - 2 (mu_1 x) / sigma_1^2 + mu_1^2 / sigma_1^2 + x^2 / sigma_2^2 - 2 (mu_2 x) / sigma_2^2 + mu_2^2 / sigma_2^2) + C_3 \
102115& "(factorizing and pushing constant terms in the constant)" \
103- & = -1/2 (( 1 / sigma_1^2 + 1 / sigma_2^2) x^2 - 2 ( mu_1 / sigma_1^2 + mu_2 / sigma_2^2) x) + C_4 \
104- & = -1/2 (x^2 / sigma^2 - 2 (mu / sigma^2) x) + C_4 \
116+ & = -1/2 (add(( 1 / sigma_1^2 + 1 / sigma_2^2)) x^2 - 2 add(( mu_1 / sigma_1^2 + mu_2 / sigma_2^2) ) x) + C_4 \
117+ & = -1/2 ( add(1/ sigma^2) x^2 - 2 add (mu / sigma^2) x) + C_4 \
105118' />
106119We can identify ~ sigma^2~ directly, and then deduce ~ mu~ :
107120<T block v = '
@@ -113,9 +126,9 @@ mu & = sigma^2 (mu_1 / sigma_1^2 + mu_2 / sigma_2^2) = (sigma_2^2 mu_1 + sigma_1
113126
114127### Identities on normal mean
115128
116- - ~ cal(N)(mu , sigma^2)(x) = cal(N)(x, sigma^2)(mu )~
117- - ~ cal(N)(a mu, sigma^2)(x) = cal(N)(mu, sigma^2 / a^2)(x / a )~
118- - combining both: ~ cal(N)(a x, sigma^2)(mu) = cal(N)(mu / a , sigma^2 / a^2)(x )~
129+ - ~ cal(N)(add(mu) , sigma^2)(x) = cal(N)(x, sigma^2)(add(mu) )~
130+ - ~ cal(N)(add(a) mu, sigma^2)(x) = cal(N)(mu, sigma^2 / add( a^2)) (x / add(a) )~
131+ - combining both: ~ cal(N)(add( a x) , sigma^2)(mu) = cal(N)(mu / add(a) , sigma^2 / add( a^2))(add(x) )~
119132
120133
121134
@@ -133,7 +146,7 @@ The forward/noising process progressively adds noise to the data until it become
133146As shown in Fig. <Ref label = " fig:global-noise" />, the forward process starts from data points
134147~ forall i, x^i_0 tilde.op q(x_0)~ (e.g. images from the training set, also named ~ hat(p)_ "data"~ ).
135148The forward process progressively adds Gaussian noise to these data points, until they are completely shuffled and become close to pure noise.
136- The total process is run for a finite (but high) number of steps ~ T~ , and we obtain ~ forall i, x^i_T tilde.op cal(N)(0,I)~ .
149+ The total process is run for a finite (but high) number of steps ~ T~ , and we (almost) obtain ~ forall i, x^i_T tilde.op cal(N)(0,I)~ .
137150
138151<figure >
139152 <InlineSvg asset = " diffusion" hide = ' #FORWARD, #BACKWARD, #forward, #backward , #more, #bigkl' />
@@ -152,11 +165,12 @@ forall t in [1..T], x_t & tilde.op q(x_t | x_(t-1)) \
152165" />
153166
154167Even more precisely:
168+ - we suppose the original dataset has been normalized,
155169- we can decide on a variance addition schedule ~ beta_1, ..., beta_T~ saying how much noise variance to add at each step,
156170- as we add noise at each step, the distribution would be more and more spread along time steps (~ t~ ) and would not reach a gaussian noise with identity covariance,
157171- to avoid this, we also rescale the signal at each step by a factor ~ sqrt(1 - beta_t)~ .
158172
159- The goal of the rescaling is to ensure that the variance at step ~ t~ is always ~ 1~ , whatever ~ t~ .
173+ The goal of the rescaling is to ensure that the variance of the dataset at step ~ t~ is always ~ 1~ , whatever ~ t~ .
160174Overall the forward/noising process is defined as:
161175
162176<T block v = "
@@ -173,13 +187,18 @@ forall t in [1..T], x_t & tilde.op q(x_t | x_(t-1)) = cal(N)(sqrt(1 - beta_t) x_
173187### Big-jump view
174188
175189Thanks to the properties of Gaussians, we can express the distribution at step ~ t~ as a function of the initial data point ~ x_0~ .
176- Indeed, since each step is a Gaussian distribution , the composition of all the steps is also a Gaussian distribution as shown in Fig. <Ref label = " fig:multi-step-noise" />.
190+ Indeed, since each step is adding a Gaussian noise (and rescaling) , the composition of all the steps is also a Gaussian distribution as shown in Fig. <Ref label = " fig:multi-step-noise" />.
177191More precisely, we have:
178192
179193<T block v = "
180194forall t in [1..T], x_t & tilde.op q(x_t | x_0) = cal(N)(sqrt(overline(alpha)_t) x_0, (1 - overline(alpha)_t) I) \
181195" />
182196
197+ with
198+ ~ overline(alpha)_ t := product_ (s=1)^t alpha_s~ the overall signal retained from the start to step ~ t~ ,
199+ in which ~ alpha_t := 1 - beta_t~ is the signal retention factor at step ~ t~ .
200+
201+
183202<figure >
184203 <InlineSvg asset = " diffusion" hide = ' #BACKWARD, #qtt0, #backward, #more, #bigkl' />
185204 <figcaption >[ <Counter label = " fig:multi-step-noise" />] Multi-step noising process.</figcaption >
@@ -197,13 +216,13 @@ This is illustrated in Fig. <Ref label="fig:sandwich-noise"/>:
197216More precisely, we have:
198217
199218<T block v = "
200- forall t in [1 ..T], q(x_(t-1) | x_t, x_0) & prop q(x_(t-1) | x_t) times q(x_(t-1) | x_0) \
219+ forall t in [2 ..T], q(x_(t-1) | x_t, x_0) & prop q(x_(t-1) | x_t) times q(x_(t-1) | x_0) \
201220" />
202221
203222We will show that we can derive a closed-form expression for this distribution:
204223
205224<T block v = '
206- forall t in [1 ..T], q(x_(t-1) | x_t, x_0) & prop cal(N)(tilde(mu), tilde(beta) I) \
225+ forall t in [2 ..T], q(x_(t-1) | x_t, x_0) & prop cal(N)(tilde(mu), tilde(beta) I) \
207226"with" \
208227tilde(mu)(x_0, t, x_t) & = (sqrt(overline(alpha)_(t-1)) beta_t)/ (1 - overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) x_t \
209228tilde(beta)(t) & = (1 - overline(alpha)_(t-1)) / (1 - overline(alpha)_t) beta_t \
@@ -212,14 +231,14 @@ tilde(beta)(t) & = (1 - overline(alpha)_(t-1)) / (1 - overline(alpha)_t) beta_t
212231The proof relies on showing that ~ q(x_ (t-1) | x_t)~ (the anti-forward step) is proportional to a Gaussian density, and then using the product-of-two-Gaussians property that we derived in the preliminaries.
213232
214233<details >
215- <summary >Derivation</summary >
234+ <summary >Proof/ Derivation</summary >
216235<div >
217236
218237Knowing ~ x_0~ , we can express the anti-forward step ~ q(x_ (t-1) | x_t, x_0)~ using the Bayes rule.
219- Remembering that this is a distribution over ~ x_ (t-1)~ , we can focus on the factors that depend on it (and thus drop $q(x_t)$ below):
238+ Remembering that this is a distribution over ~ x_ (t-1)~ , we can focus on the factors that depend on it (and thus drop $q(x_t | x_0 )$ below):
220239
221240<T block v = '
222- forall t in [1 ..T], q(x_(t-1) | x_t, x_0) & = q(x_t | x_(t-1), x_0) q(x_(t-1) | x_0) / q(x_t | x_0) \
241+ forall t in [2 ..T], q(x_(t-1) | x_t, x_0) & = q(x_t | x_(t-1), x_0) q(x_(t-1) | x_0) / q(x_t | x_0) \
223242& prop q(x_t | x_(t-1)) q(x_(t-1) | x_0) \
224243& prop cal(N)(sqrt(1 - beta_t) x_(t-1), beta_t I)(x_t) times q(x_(t-1) | x_0) \
225244cal(N)(a x,sigma^2)(mu) => cal(N)(μ/a, σ^2/a^2)(x) " "
@@ -295,8 +314,9 @@ This part details the key insight of how to transform a global KL loss into a su
295314We will show that we can simplify the global objective into a sum of local objectives, one for each step, with the proper conditioning.
296315We will follow a few steps:
297316- starting by writing the KL divergence between the two joint distributions of the Markov chains, in their natural direction (forward for the noising process, backward for the learned process),
298- - re-introducing an expectation on ~ x_0~ to make it the expression tractable (using the "sandwich view"),
317+ - re-introducing an expectation on ~ x_0~ to make it the expression tractable (below, using the "sandwich view"),
299318- "reversing" the conditional forward noising process,
319+ - using the sandwich view,
300320- using the closed-form of the KL between Gaussians to get a final closed-form loss.
301321
302322The final terms involved in the loss are KL divergences between two Gaussian distributions, which can be computed in closed form.
@@ -310,7 +330,7 @@ A big part of these derivations are also detailed at the end of this page, in is
310330NB: seems ok but need to be chunked better.
311331
312332<T block v = '
313- cal(L) & := "KL" (q(x_(0:T)) || p_theta (x_(0:T))) \
333+ cal(L) & := KL (q(x_(0:T)), p_theta (x_(0:T))) \
314334& = EE_(x_(0:T) tilde.op q(x_(0:T))) [ln (q(x_(0:T)) / (p_theta (x_(0:T))))] \
315335& = EE_(x_(0:T) tilde.op q(x_(0:T))) [ln ((q(x_0) product_(t=1)^T q(x_t | x_(t-1))) / (p_theta (x_T) product_(t=1)^T p_theta (x_(t-1) | x_t)))] \
316336& = EE_(x_(0:T) tilde.op q(x_(0:T))) [ln q(x_0) - ln p_theta (x_T) + sum_(t=1)^T ln q(x_t | x_(t-1)) - sum_(t=1)^T ln p_theta (x_(t-1) | x_t)] \
@@ -364,12 +384,14 @@ To avoid having handling this special case below, we override ~tilde(mu)(x_0, t=
364384We thus get a closed-form expression for the loss, involving square norms (coming from the KL between gaussians).
365385
366386<T block v = '
367- cal(L) & = C + EE_(x_0 tilde.op q(x_0)) sum_(t=1)^T lambda_t ||tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t)||^2 \
368- & = C + sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) ||tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t)||^2
387+ cal(L) & = C + EE_(x_0 tilde.op q(x_0)) sum_(t=1)^T EE_(x_t) lambda_t norm(tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t))^2 \
388+ & = C + sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(x_t tilde.op q(x_t)) norm(tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t))^2 \
389+ & = C + sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(x_t tilde.op q(x_t | x_0)) norm(tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t))^2
369390' />
370391
371- with ~ lambda_t = 1/(2 sigma_t^2)~ .
392+ with ~ lambda_t = 1/(2 sigma_t^2)~ , and where, as above with $x_1$ and $x_0$, $x_t$ can be equivalently sampled either independently or conditioned on $x_0$ or conditioned on it .
372393
394+ Overall, thanks to all derivations, we managed to get a loss that is simple as it is conditioned on the data point ~ x_0~ and is local in time (sum over ~ t~ ).
373395
374396<figure >
375397 <InlineSvg asset = " diffusion" hide = ' #forward, #backward, #more, #bigkl, #qt, #qtt' />
@@ -394,57 +416,69 @@ with ~lambda_t = 1/(2 sigma_t^2)~.
394416
395417The above derivations reason in terms of fitting the means of the backward steps ~ mu_theta (x_t, t)~ .
396418
397- We can reparametrize the sampling of ~ x_t~ as a function of ~ x_0~ and some noise ~ epsilon~ :
419+ Based on the "big jump view" of the forward process, and the Gaussian properties,
420+ we can reparametrize the sampling of ~ x_t~ (conditioned on $x_0$) as a function of ~ x_0~ and some noise ~ epsilon~ :
398421
399422<T block v = '
400423x_t & = sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon \
401424epsilon & tilde.op cal(N)(0,I) \
402425' />
403426
404- We can further use the formula of ~ x_t~ to rewrite ~ tilde(mu)~ as a function of ~ x_0~ and ~ epsilon~ only (which will help re-interpret and reparametrize the loss):
427+ Which, once reversed, gives:
428+ ~ !x_0 = 1 / sqrt(overline(alpha)_ t) x_t - sqrt((1 - overline(alpha)_ t) / overline(alpha)_ t) epsilon~ .
405429
406- TODO redo/check exact simplifications
430+ We can plug this expression on ~ tilde(mu) ~ to make it depend only on ~ x_t ~ and ~ epsilon ~ , not on ~ x_0 ~ :
407431
408432<T block v = '
409433tilde(mu)(x_0, t, x_t)
410- & = tilde(mu)(x_0, t, sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon) \
411- & = (sqrt(overline(alpha)_(t-1)) beta_t)/ (1 - overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) (sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon) \
412- & = (sqrt(overline(alpha)_(t-1)) beta_t + sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(overline(alpha)_t)) / (1 - overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon \
413- & = 1 / sqrt(overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon \
414-
434+ & = (sqrt(overline(alpha)_(t-1)) beta_t)/ (1 - overline(alpha)_t) (1 / sqrt(overline(alpha)_t) x_t - sqrt((1 - overline(alpha)_t) / overline(alpha)_t) epsilon)
435+ + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) x_t \
436+ & = (sqrt(overline(alpha)_(t-1)) beta_t) / (sqrt(overline(alpha)_t) (1 - overline(alpha)_t)) x_t
437+ + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) x_t
438+ - (sqrt(overline(alpha)_(t-1)) beta_t sqrt((1 - overline(alpha)_t) / overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon \
439+ & = ( (sqrt(overline(alpha)_(t-1)) beta_t) / (sqrt(overline(alpha)_t) (1 - overline(alpha)_t)) + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) ) x_t
440+ - (sqrt(overline(alpha)_(t-1)) beta_t sqrt((1 - overline(alpha)_t) / overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon \
441+ & =: K_(x_t)(t) x_t + K_(epsilon)(t) epsilon \
415442' />
416443
444+ The constants can be refined further, but for now, let's look at the implications.
445+ Using the sampling of ~ x_t~ reparametrized using ~ epsilon~ and substituting the value of ~ tilde(mu)~ we just derived, we can rewrite the loss as:
446+ <T block v = '
447+ cal(L) & = C + sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) norm(K_(x_t)(t) x_t + K_(epsilon)(t) epsilon - mu_theta (x_t, t))^2 \
448+ & = C + sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I))
449+ norm(K_(epsilon)(t) (epsilon - (mu_theta (x_t, t) - K_(x_t)(t) x_t) / (K_(epsilon)(t))))^2 \
450+ & = C + sum_(t=1)^T (lambda_t K^2_(epsilon)(t)) EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I))
451+ norm(epsilon - (mu_theta (x_t, t) - K_(x_t)(t) x_t) / (K_(epsilon)(t)))^2 \
452+ ' />
417453
418-
419-
420- This allows to rewrite the loss as:
454+ So, up to the time reweighting (~ gamma_t := lambda_t K^2_ (epsilon)(t)~ instead of ~ lambda_t~ ), as ~ mu_theta~ takes as input ~ x_t~ and ~ t~ , we can equivalently train a network to predict the noise ~ epsilon~ that was used to generate ~ x_t~ from ~ x_0~ , instead of predicting ~ tilde(mu)~ .
455+ We can thus define a network ~ epsilon_theta (x_t, t)~ and train it to minimize the loss:
421456
422457<T block v = '
423- cal(L)
424- & <= sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [(tilde(mu)(x_0, t, x_t) - mu_theta (x_t, t))^2] + C \
425- & = sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [(tilde(mu)(x_0, t, sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon) - mu_theta (x_t))^2] + C \
458+ cal(L) & = C + sum_(t=1)^T gamma_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I))
459+ norm(epsilon - epsilon_theta (x_t, t))^2 \
426460' />
427461
428-
429- Where we developed ~ x_t~ in ~ tilde(mu)~ to be able to simplify the formula (but kept it in ~ mu_theta~ as we will keep it untouched).
430- By substituting the expression of ~ tilde(mu)~ , we get:
462+ With the two-way mapping between ~ mu_theta~ and ~ epsilon_theta~ being:
431463
432464<T block v = '
433- cal(L)
434- & <= sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [((sqrt(overline(alpha)_(t-1)) beta_t)/ (1 - overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1))) / (1 - overline(alpha)_t) (sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon) - mu_theta (x_t, t))^2] + C \
435- & = sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [((sqrt(overline(alpha)_(t-1)) beta_t + sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(overline(alpha)_t)) / (1 - overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon - mu_theta (x_t, t))^2] + C \
436- & = sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [(1 / sqrt(overline(alpha)_t) x_0 + (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) / (1 - overline(alpha)_t) epsilon - mu_theta (x_t, t))^2] + C \
437- & = sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [epsilon - ( (1 - overline(alpha)_t) / (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) (mu_theta (x_t, t) - 1 / sqrt(overline(alpha)_t) x_0) )^2] + C \
465+ mu_theta (x_t, t) & = K_(x_t)(t) x_t + K_(epsilon)(t) epsilon_theta (x_t, t) \
466+ epsilon_theta (x_t, t) & = (mu_theta (x_t, t) - K_(x_t)(t) x_t) / (K_(epsilon)(t)) \
438467' />
439468
469+ ## Aside: link with flow matching
440470
471+ Looking at algorithms, we can uncover the link with flow matching.
472+ Conceptually, both sample a time step (although with different semantics), a data point and a unit noise.
473+ However, they differ in the path, i.e., the formula for ~ x_t~ :
474+ - diffusion aims at preserving the variance across time,
475+ - flow matching (in its typical form) aims at a linear interpolation between data and noise.
476+
477+ We can however instantiate flow matching that will match the diffusion path.
478+ The similarities/differences are then just in the time weighting and what is fit.
479+ The details are left out for now.
441480
442- Substituting to get rid of ~ x_0~ and use only ~ x_t~ as a parameter of ~ epsilon_theta~ , we get:
443481
444- <T block v = '
445- cal(L)
446- & <= sum_(t=1)^T lambda_t EE_(x_0 tilde.op q(x_0)) EE_(epsilon tilde.op cal(N)(0,I)) [ ( epsilon - ( (1 - overline(alpha)_t) / (sqrt(alpha_t) (1- overline(alpha)_(t-1)) sqrt(1 - overline(alpha)_t)) (mu_theta (sqrt(overline(alpha)_t) x_0 + sqrt(1 - overline(alpha)_t) epsilon, t) - 1 / sqrt(overline(alpha)_t) x_0) )^2] + C \
447- ' />
448482
449483## Aside: Bayes rule over a Markov chain
450484
0 commit comments