# Visualizing Diffusion

## Introduction

One of the key ideas of diffusion models is the forward diffusion process involved.
By adding Gaussian noise to an input over a series of timesteps,
we're able to move the initial image to a latent space that approaches a Gaussian distribution.

In this notebook, we'll be seeing the effects of different choices of variance schedules.
The variance schedule will determine the variances of the noise we're adding to the image over time.
What happens when we choose the variance to be higher? What about lower?

We'll be exploring the effects of these choices by visualizing the state over the diffusion process,
and hopefully get a sense for how they relate to the rate at which this movement from
the original image / distribution to the latent Gaussian distribution occurs.

## Library imports

Before you begin, you'll need to set up your environment.

### If you're running locally:

To set up a new environment that contains the necessary libraries,
you can run the startup script by running `bash startup.sh`.

Double check that you are now in the `env-proj` conda environment that was created.
If not, run `conda activate env-proj` in your terminal.

### If you're running on Colab

Please change the following cell from `%matplotlib widget` to `%matplotlib inline`.
since Colab does not currently seem to support `ipympl` or `widget` out of the box.
(It may be able to support the `widget` mode but more setup for that may be required.)

### Double check

Make sure that the following libraries have been installed.

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
from typing import List

In [None]:
import jax.numpy as jnp
from jax import jit, random, Array
from jax.config import config
config.update("jax_debug_nans", True)
config.update("jax_array", True)

## Part A: Mean and Covariance of Diffusion

Implement the forward diffusion process according to the formulas specified in problem (1) of the homework, specifically the mean and covariance of the forward process posterior Normal distribution. 
Run the cells immediately following to test your implementation.

**Note:**
If you run into issues or errors that mention `jit`,
feel free to comment out the `@jit` decorator at the top of the functions.
The `@jit` decorator is not needed for correctness, but may help speed up the code.

In [None]:
def compute_mean(x: Array, var: Array) -> Array:
    mean = None # REPLACE WITH YOUR CODE HERE
    return mean

def compute_cov(x: Array, var: Array) -> Array:
    n = x.shape[0]
    cov = None # REPLACE WITH YOUR CODE HERE
    return cov

### Test your implementation

The following cells will test your implementation of the mean and covariance functions.
While there may be slight differences due to floating point accuracy,
your implementation should fall within the allowed error range.

In [None]:
from test_cases import run_diffusion_mean_tests, run_diffusion_cov_tests
run_diffusion_mean_tests(compute_mean)
run_diffusion_cov_tests(compute_cov)

## Visualizing states over time

Now that we've verified the correctness of the calculated mean and covariance,
let's visualize how the states get diffused over time.

First, we'll want to set up the diffusion process by sampling from the multivariate
Gaussian normal that is characterized by our mean and covariance functions,
where we call `diffuse` to diffuse the state over a single timestep.
We will then want to repeat this over the length of the variance schedule,
as seen in `diffuse_over_time`.

In [None]:
@jit
def diffuse(key, x: Array, var: Array) -> Array:
    """
    Given (flattened) x, sample x diffused with Gaussian noise
    according to the variance schedule.
    """
    mean = compute_mean(x, var)
    cov = compute_cov(x, var)
    return random.multivariate_normal(key, mean, cov)

@jit
def diffuse_over_time(key, x: Array, var_schedule: Array) -> List[Array]:
    states = [x]
    shape = x.shape
    x = x.flatten()
    for t in jnp.arange(var_schedule.shape[0]):
        key, subkey = random.split(key)
        x = diffuse(subkey, x, var_schedule[t])
        states.append(x.reshape(shape))
    return states


With these functions defined, we can now visualize how the identity matrix gets diffused over time,
according to each variance schedule.

In [None]:
from utils import show_plots

In [None]:
x = jnp.eye(3)

var_schedules = [
    jnp.array([0.01, 0.01, 0.01]),
    jnp.array([0.1, 0.2, 0.5]),
    jnp.array([0.5, 0.2, 0.1]),
    jnp.array([0.99, 0.99, 0.99]),
]

key = random.PRNGKey(0)
for var_schedule in var_schedules:
    states = diffuse_over_time(key, x, var_schedule)
    show_plots(states)

Now, can you come up with an interesting input `x` and one or more interesting variance schedules?
**Define your input and add variance schedules below.**

You can then run the cell to visualize your input being diffused
according to the variance schedules you proposed.

What kind of observations do you have?
What happens when you make the variances very high?
How about very low?
What if you alternate?
Can you still see traces of the original input's patterns in the final diffused state?
What about in the intermediate diffused states?

In [None]:
x = None # REPLACE WITH YOUR SOLUTION HERE

var_schedules = [
    # YOUR CODE HERE
]

key = random.PRNGKey(1)
for var_schedule in var_schedules:
    states = diffuse_over_time(key, x, var_schedule)
    show_plots(states)

## (Optional) Comparing the JIT and non-JIT versions of the diffusion process

You may have noticed the `@jit` decorator above the `diffuse` function earlier in this notebook.
`jax` provides `jit` as a way to speed up user-defined functions.
Here, we'll be comparing the performance of `diffuse` and its non-JIT version `diffuse_nojit`.

Run the following cell to see the difference in speed between `diffuse`
(that has been sped up by `jit`) and `diffuse_nojit` which has not had that speedup.

In [None]:
# An interesting comparison in timing between
# non-JIT and JIT versions of the diffuse function

def diffuse_nojit(key, x, var):
    """
    Given (flattened) x, sample x diffused with Gaussian noise
    according to the variance schedule.
    """
    mean = compute_mean(x, var)
    cov = compute_cov(x, var)
    return random.multivariate_normal(key, mean, cov)

key = random.PRNGKey(0)
x = jnp.eye(3).flatten()

print("Timing for JIT diffuse:")
%timeit diffuse(key, x, 0.01)

print("Timing for non-JIT diffuse:")
%timeit diffuse_nojit(key, x, 0.01)

When diffusion is used in training larger models, the computation time can either
help or hurt over the course of long training loops.
So even a small speedup for a inner function like `diffusion` can help in the
long run, since the forward diffusion process is applied on every input during
training.

However, since our homework deals with only small, toy examples,
the difference is not really necessary but simply interesting to note.
The difference may also be more or less significant based upon whether you're
running this locally or in Colab (where you and `jax` can take advantage of
accelerators such as GPUs for even faster computation).

Later on this homework, we'll be training a model for the *denoising* process,
and speedups will help us even there, where our inputs are still small.