# Introduction
While the mean loss and the noise loss functions should in theory optimize to the same parameters in the diffusion network, in practice, it's important to consider the advantages and drawbacks of these functions like computational cost, accuracy, etc. In this problem, we explore the computational costs of the mean loss and the noise loss. 

# Library imports
Before you begin, make sure you have the following libraries installed.

In [1]:
import jax.numpy as jnp
from jax import random
from jax import jit

# Part A: Implementing the Loss Functions
Implement the mean and noise loss functions as specified in the homework.

In [2]:
def mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t):
    error = 0
    # BEGIN SOLUTION
    alpha_bar_t = jnp.prod(alpha_ts)
    alpha_t = alpha_ts[-1]
    beta_t = 1 - alpha_t
    alpha_bar_t_minus_1 = alpha_bar_t / alpha_t
    mu_calc = (((alpha_bar_t_minus_1) ** 0.5) * beta_t / (1 - alpha_bar_t)) * x_0 + (
        (alpha_t**0.5) * (1 - alpha_bar_t_minus_1) / (1 - alpha_bar_t)
    ) * x_t
    norm_diff = jnp.linalg.norm(mu_pred - mu_calc)
    error = (1 / (2 * sigma_t**2)) * norm_diff**2
    # END SOLUTION
    return error


def noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t, eps):
    error = 0
    # BEGIN SOLUTION
    alpha_bar_t = jnp.prod(alpha_ts)
    alpha_t = alpha_ts[-1]
    beta_t = 1 - alpha_t
    coeff = (beta_t**2) / (2 * sigma_t**2 * alpha_t * (1 - alpha_bar_t))
    D = len(eps_pred)
    norm_diff = jnp.linalg.norm(eps - eps_pred)
    error = coeff * norm_diff**2
    # END SOLUTION
    return error


## Sanity Check
Run the cell below to test the correctness of your solution.

In [3]:
from test_cases import run_mean_loss_tests, run_noise_loss_tests
run_mean_loss_tests(mean_loss)
run_noise_loss_tests(noise_loss)

Mean loss: Average difference between expected and actual: 2.8610230629055877e-07
Mean loss: Result vs threshold: 2.8610230629055877e-07 < 1e-05
Mean loss test passed! :)

Noise loss: Average difference between expected and actual: 7.450580707946131e-10
Noise loss: Result vs threshold: 7.450580707946131e-10 < 1e-05
Noise loss test passed! :)



# Part B: Comparing the Speeds of the Loss Functions
After you've implemented the two loss functions, run the cell below and comment on any differences.

In [4]:
seed = 0
for D in range(2, 1000, 10):
    key = random.PRNGKey(0)
    s_key, a_key, mu_key, e0_key, e1_key, x0_key, xt_key = random.split(key, num=7)
    sigma_t = random.uniform(s_key)
    alpha_ts = random.uniform(a_key, (D,))
    mu_pred = random.uniform(mu_key, (D,))
    eps_pred = random.uniform(e0_key, (D,))
    eps = random.uniform(e1_key, (D,))
    x_0 = random.uniform(x0_key, (D,))
    x_t = random.uniform(xt_key, (D,))
    print("Mean Loss, D = " + str(D))
    %time mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t)
    print("Noise Loss, D = " + str(D))
    %time noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t, eps)
    print("\n")
    seed += 1

Mean Loss, D = 2
CPU times: user 57.4 ms, sys: 1.55 ms, total: 58.9 ms
Wall time: 59.5 ms
Noise Loss, D = 2
CPU times: user 548 µs, sys: 0 ns, total: 548 µs
Wall time: 551 µs


Mean Loss, D = 12
CPU times: user 45.2 ms, sys: 684 µs, total: 45.9 ms
Wall time: 46.5 ms
Noise Loss, D = 12
CPU times: user 544 µs, sys: 1 µs, total: 545 µs
Wall time: 548 µs


Mean Loss, D = 22
CPU times: user 49.2 ms, sys: 1.01 ms, total: 50.2 ms
Wall time: 50 ms
Noise Loss, D = 22
CPU times: user 572 µs, sys: 37 µs, total: 609 µs
Wall time: 586 µs


Mean Loss, D = 32
CPU times: user 44.9 ms, sys: 628 µs, total: 45.6 ms
Wall time: 45.5 ms
Noise Loss, D = 32
CPU times: user 548 µs, sys: 5 µs, total: 553 µs
Wall time: 557 µs


Mean Loss, D = 42
CPU times: user 49 ms, sys: 787 µs, total: 49.8 ms
Wall time: 50.1 ms
Noise Loss, D = 42
CPU times: user 536 µs, sys: 0 ns, total: 536 µs
Wall time: 538 µs


Mean Loss, D = 52
CPU times: user 48.8 ms, sys: 605 µs, total: 49.4 ms
Wall time: 49.3 ms
Noise Loss, D = 52
CPU 