# Diffusion Models

Diffusion models are generative models which learn to reverse diffusion processes.

## Stochastic Processes And Diffusion

To understand diffusion models, we first need to understand stochastic processes. A stochastic process is a family of random variables $\{X_t\}$, where $t \in T$ is called the *index*. Very often, $t$ has the interpretation of *time*, however this isn't a requirement. A Markov process is a stochastic process where the $X_t$ depends only on $X_{t-1}$. A diffusion process is one type of Markov process, which for the sake of simplicity, we will not define rigorously here. A typical example is the Weiner process, which can be used to model Brownian motion. At each time step, $X_t - X_{t - 1}$ has a Gaussian distribution with zero mean and some known variance.    

$$q(x_0,\dots,x_t) = \prod_{i = 1}^T q(x_i|x_{i-1})$$

In [25]:
import numpy as np


In [112]:
def sample_from_data_distribution():
    
    return np.random.randn()

In [149]:
def beta_schedule(t, T):
    return t / T # simple linear beta schedule.

In [151]:
def corrupt(xt, t, T):
    
    beta = beta_schedule(t, T)
    mean = xt * np.sqrt(1 - beta) # Gradually reduce the influence of x_t.
    variance = beta
    
    return mean + variance * np.random.randn(*xt.shape)

In [152]:
def forward_process(xt, T):
    
    # Markov chain forward process...
    
    t = 0

    for _ in range(T):
        
        t += 1
        xt = corrupt(xt, t, T) # xt ~ q(x_t|x_(t-1))
        
        yield xt

In [153]:
list(forward_process(a, 1000))[-1].var()

1.0687883047633855

In [154]:
def reverse_process(xt, T):
    
    t = T
    
    for _ in range(T):
        t -= 1
        xt = uncorrupt(xt, t, T)
        
        yield xt

In [156]:
def uncorrupt(xt, t, T):
    pass # To be approximated with a neural network

    # nn predicts the mean (and variance) of the posterior on x_(t-1) given
    # xt and t. This can be assumed Gaussian due to theoretical reasons.
    
    # then we sample from this distribution to find a likely x_(t-1).