In [3]:
import jax.numpy as jnp
from jax import random
from jax import jit

In [35]:
key  = random.PRNGKey(758493)
def mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t):
    error = 0
    # BEGIN SOLUTION
    alpha_bar_t = jnp.prod(alpha_ts)
    alpha_t = alpha_ts[-1]
    beta_t = 1 - alpha_t
    alpha_bar_t_minus_1 = alpha_bar_t / alpha_t
    mu_calc = (((alpha_bar_t_minus_1)**0.5)*beta_t/(1 - alpha_bar_t))*x_0 + ((alpha_t**0.5)*(1 - alpha_bar_t_minus_1)/(1 - alpha_bar_t))*x_t
    norm_diff = jnp.linalg.norm(mu_pred - mu_calc)
    error = (1 / (2*sigma_t**2))*norm_diff**2
    # END SOLUTION
    return error
    
def noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t):
    error = 0
    # BEGIN SOLUTION
    alpha_bar_t = jnp.prod(alpha_ts)
    alpha_t = alpha_ts[-1]
    beta_t = 1 - alpha_t
    coeff = (beta_t**2) / (2 * sigma_t**2 * alpha_t * (1 - alpha_bar_t))
    D = len(eps_pred)
    eps = random.normal(key, (D,))
    norm_diff = jnp.linalg.norm(eps - eps_pred)
    error = coeff * norm_diff **2
    # END SOLUTION
    return error 

In [32]:
sigma_t = 5
alpha_ts = jnp.array([0.5, 0.5])
mu_pred = jnp.array([0.5, 0.5])
x_0 = jnp.array([0.2, 0.2])
x_t = jnp.array([0.2, 0.2])

In [33]:
print(mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t))

0.00387975


In [34]:
print(noise_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t))

0.008128801


In [None]:
for D in range(2, 1000, 10):
    sigma_t = random.uniform(key)
    # alpha_ts = jnp.array([0.5, 0.5])
    # mu_pred = jnp.array([0.5, 0.5])
    # x_0 = jnp.array([0.2, 0.2])
    # x_t = jnp.array([0.2, 0.2])
    alpha_ts = random.uniform(key, (D,))
    mu_pred = random.uniform(key, (D,))
    x_0 = random.uniform(key, (D,))
    x_t = random.uniform(key, (D,))
    print("Mean Loss, D = " + str(D))
    %time mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t)
    print("Noise Loss, D = " + str(D))
    %time noise_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t)
    print("\n")