In [1]:
import jax.numpy as jnp
from jax import random
from jax import jit
import loss_func

In [18]:
def mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t):
    error = 0
    # BEGIN SOLUTION
    alpha_bar_t = jnp.prod(alpha_ts)
    alpha_t = alpha_ts[-1]
    beta_t = 1 - alpha_t
    alpha_bar_t_minus_1 = alpha_bar_t / alpha_t
    mu_calc = (((alpha_bar_t_minus_1)**0.5)*beta_t/(1 - alpha_bar_t))*x_0 + ((alpha_t**0.5)*(1 - alpha_bar_t_minus_1)/(1 - alpha_bar_t))*x_t
    norm_diff = jnp.linalg.norm(mu_pred - mu_calc)
    error = (1 / (2*sigma_t**2))*norm_diff**2
    # END SOLUTION
    return error
    
def noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t, eps):
    error = 0
    # BEGIN SOLUTION
    alpha_bar_t = jnp.prod(alpha_ts)
    alpha_t = alpha_ts[-1]
    beta_t = 1 - alpha_t
    coeff = (beta_t**2) / (2 * sigma_t**2 * alpha_t * (1 - alpha_bar_t))
    D = len(eps_pred)
    norm_diff = jnp.linalg.norm(eps - eps_pred)
    error = coeff * norm_diff **2
    # END SOLUTION
    return error

In [24]:
# Test correctness of your solutions
D = 5
key = random.PRNGKey(0)
s_key, a_key, mu_key, e0_key, e1_key, x0_key, xt_key = random.split(key, num=7)
sigma_t = random.uniform(s_key)
alpha_ts = random.uniform(a_key, (D,))
mu_pred = random.uniform(mu_key, (D,))
eps_pred = random.uniform(e0_key, (D,))
eps = random.uniform(e1_key, (D,))
x_0 = random.uniform(x0_key, (D,))
x_t = random.uniform(xt_key, (D,))

mean_error = mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t) - loss_func.mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t)
print("Difference between your mean loss and staff mean loss: " + str(mean_error))
noise_error = noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t, eps) - loss_func.noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t, eps)
print("Difference between your noise loss and staff noise loss: " + str(noise_error))

Difference between your mean loss and staff mean loss: 0.0
Difference between your noise loss and staff noise loss: 0.0


In [25]:
seed = 0
for D in range(2, 1000, 10):
    key = random.PRNGKey(0)
    s_key, a_key, mu_key, e0_key, e1_key, x0_key, xt_key = random.split(key, num=7)
    sigma_t = random.uniform(s_key)
    alpha_ts = random.uniform(a_key, (D,))
    mu_pred = random.uniform(mu_key, (D,))
    eps_pred = random.uniform(e0_key, (D,))
    eps = random.uniform(e1_key, (D,))
    x_0 = random.uniform(x0_key, (D,))
    x_t = random.uniform(xt_key, (D,))
    print("Mean Loss, D = " + str(D))
    %time mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t)
    print("Noise Loss, D = " + str(D))
    %time noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t, eps)
    print("\n")
    seed += 1

Mean Loss, D = 2
CPU times: user 72.7 ms, sys: 1.89 ms, total: 74.5 ms
Wall time: 75.4 ms
Noise Loss, D = 2
CPU times: user 1.26 ms, sys: 20 µs, total: 1.28 ms
Wall time: 1.31 ms


Mean Loss, D = 12
CPU times: user 89.9 ms, sys: 2.23 ms, total: 92.1 ms
Wall time: 93.9 ms
Noise Loss, D = 12
CPU times: user 1.44 ms, sys: 114 µs, total: 1.56 ms
Wall time: 1.5 ms


Mean Loss, D = 22
CPU times: user 78.6 ms, sys: 1.99 ms, total: 80.6 ms
Wall time: 82.2 ms
Noise Loss, D = 22
CPU times: user 1.2 ms, sys: 1 µs, total: 1.2 ms
Wall time: 1.21 ms


Mean Loss, D = 32
CPU times: user 75.5 ms, sys: 1.92 ms, total: 77.5 ms
Wall time: 78.8 ms
Noise Loss, D = 32
CPU times: user 1.19 ms, sys: 1 µs, total: 1.19 ms
Wall time: 1.2 ms


Mean Loss, D = 42
CPU times: user 96.9 ms, sys: 3.56 ms, total: 100 ms
Wall time: 107 ms
Noise Loss, D = 42
CPU times: user 1.45 ms, sys: 117 µs, total: 1.57 ms
Wall time: 1.51 ms


Mean Loss, D = 52
CPU times: user 88.4 ms, sys: 2.27 ms, total: 90.7 ms
Wall time: 92.5 ms
No

Mean Loss, D = 452
CPU times: user 98.2 ms, sys: 1.78 ms, total: 100 ms
Wall time: 101 ms
Noise Loss, D = 452
CPU times: user 1.2 ms, sys: 1e+03 ns, total: 1.2 ms
Wall time: 1.2 ms


Mean Loss, D = 462
CPU times: user 98 ms, sys: 1.96 ms, total: 99.9 ms
Wall time: 101 ms
Noise Loss, D = 462
CPU times: user 1.24 ms, sys: 1 µs, total: 1.24 ms
Wall time: 1.24 ms


Mean Loss, D = 472
CPU times: user 97.9 ms, sys: 1.68 ms, total: 99.6 ms
Wall time: 100 ms
Noise Loss, D = 472
CPU times: user 1.21 ms, sys: 1 µs, total: 1.21 ms
Wall time: 1.22 ms


Mean Loss, D = 482
CPU times: user 94.7 ms, sys: 1.79 ms, total: 96.5 ms
Wall time: 96.9 ms
Noise Loss, D = 482
CPU times: user 1.2 ms, sys: 1e+03 ns, total: 1.2 ms
Wall time: 1.2 ms


Mean Loss, D = 492
CPU times: user 95.6 ms, sys: 1.86 ms, total: 97.5 ms
Wall time: 98.1 ms
Noise Loss, D = 492
CPU times: user 1.19 ms, sys: 1 µs, total: 1.19 ms
Wall time: 1.2 ms


Mean Loss, D = 502
CPU times: user 95.9 ms, sys: 1.82 ms, total: 97.7 ms
Wall time: 9

Mean Loss, D = 902
CPU times: user 99.2 ms, sys: 2.52 ms, total: 102 ms
Wall time: 103 ms
Noise Loss, D = 902
CPU times: user 1.2 ms, sys: 1e+03 ns, total: 1.2 ms
Wall time: 1.21 ms


Mean Loss, D = 912
CPU times: user 98 ms, sys: 2.18 ms, total: 100 ms
Wall time: 101 ms
Noise Loss, D = 912
CPU times: user 1.21 ms, sys: 2 µs, total: 1.21 ms
Wall time: 1.22 ms


Mean Loss, D = 922
CPU times: user 93.6 ms, sys: 1.5 ms, total: 95.1 ms
Wall time: 95 ms
Noise Loss, D = 922
CPU times: user 1.66 ms, sys: 20 µs, total: 1.68 ms
Wall time: 1.71 ms


Mean Loss, D = 932
CPU times: user 98.8 ms, sys: 1.91 ms, total: 101 ms
Wall time: 104 ms
Noise Loss, D = 932
CPU times: user 1.31 ms, sys: 1 µs, total: 1.31 ms
Wall time: 1.32 ms


Mean Loss, D = 942
CPU times: user 96.2 ms, sys: 1.68 ms, total: 97.9 ms
Wall time: 98.5 ms
Noise Loss, D = 942
CPU times: user 1.21 ms, sys: 1e+03 ns, total: 1.21 ms
Wall time: 1.21 ms


Mean Loss, D = 952
CPU times: user 94.5 ms, sys: 1.5 ms, total: 96 ms
Wall time: 96.