# Introduction
While the mean loss and the noise loss functions should in theory optimize to the same parameters in the diffusion network, in practice, it's important to consider the advantages and drawbacks of these functions like computational cost, accuracy, etc. In this problem, we explore the computational costs of the mean loss and the noise loss. 

# Library imports
Before you begin, make sure you have the following libraries installed.

In [1]:
import jax.numpy as jnp
from jax import random

# Part A: Implementing the Loss Functions
Implement the mean and noise loss functions as specified in the homework.

In [2]:
def mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t):
    error = 0
    # BEGIN SOLUTION
    alpha_bar_t = jnp.prod(alpha_ts)
    alpha_t = alpha_ts[-1]
    beta_t = 1 - alpha_t
    alpha_bar_t_minus_1 = alpha_bar_t / alpha_t
    mu_calc = (((alpha_bar_t_minus_1) ** 0.5) * beta_t / (1 - alpha_bar_t)) * x_0 + (
        (alpha_t**0.5) * (1 - alpha_bar_t_minus_1) / (1 - alpha_bar_t)
    ) * x_t
    norm_diff = jnp.linalg.norm(mu_pred - mu_calc)
    error = (1 / (2 * sigma_t**2)) * norm_diff**2
    # END SOLUTION
    return error


def noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t, eps):
    error = 0
    # BEGIN SOLUTION
    alpha_bar_t = jnp.prod(alpha_ts)
    alpha_t = alpha_ts[-1]
    beta_t = 1 - alpha_t
    coeff = (beta_t**2) / (2 * sigma_t**2 * alpha_t * (1 - alpha_bar_t))
    D = len(eps_pred)
    norm_diff = jnp.linalg.norm(eps - eps_pred)
    error = coeff * norm_diff**2
    # END SOLUTION
    return error


## Sanity Check
Run the cell below to test the correctness of your solution.

In [3]:
from test_cases import run_mean_loss_tests, run_noise_loss_tests
run_mean_loss_tests(mean_loss)
run_noise_loss_tests(noise_loss)

Mean loss: Average difference between expected and actual: 2.8610230629055877e-07
Mean loss: Result vs threshold: 2.8610230629055877e-07 < 1e-05
Mean loss test passed! :)

Noise loss: Average difference between expected and actual: 7.450580707946131e-10
Noise loss: Result vs threshold: 7.450580707946131e-10 < 1e-05
Noise loss test passed! :)



# Part B: Comparing the Speeds of the Loss Functions
After you've implemented the two loss functions, run the cell below.

In [4]:
seed = 0
for D in range(2, 1000, 200):
    key = random.PRNGKey(seed)
    s_key, a_key, mu_key, e0_key, e1_key, x0_key, xt_key = random.split(key, num=7)
    sigma_t = random.uniform(s_key)
    alpha_ts = random.uniform(a_key, (D,))
    mu_pred = random.uniform(mu_key, (D,))
    eps_pred = random.uniform(e0_key, (D,))
    eps = random.uniform(e1_key, (D,))
    x_0 = random.uniform(x0_key, (D,))
    x_t = random.uniform(xt_key, (D,))
    print("Mean Loss, D = " + str(D))
    %time mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t)
    print("Noise Loss, D = " + str(D))
    %time noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t, eps)
    print("\n")
    seed += 1

Mean Loss, D = 2
CPU times: user 60.8 ms, sys: 1.1 ms, total: 61.9 ms
Wall time: 63 ms
Noise Loss, D = 2
CPU times: user 556 µs, sys: 0 ns, total: 556 µs
Wall time: 559 µs


Mean Loss, D = 202
CPU times: user 62.8 ms, sys: 2.17 ms, total: 65 ms
Wall time: 71.8 ms
Noise Loss, D = 202
CPU times: user 643 µs, sys: 8 µs, total: 651 µs
Wall time: 661 µs


Mean Loss, D = 402
CPU times: user 67.7 ms, sys: 2.27 ms, total: 69.9 ms
Wall time: 77.5 ms
Noise Loss, D = 402
CPU times: user 937 µs, sys: 29 µs, total: 966 µs
Wall time: 1.04 ms


Mean Loss, D = 602
CPU times: user 56.3 ms, sys: 1.3 ms, total: 57.6 ms
Wall time: 57.9 ms
Noise Loss, D = 602
CPU times: user 627 µs, sys: 3 µs, total: 630 µs
Wall time: 635 µs


Mean Loss, D = 802
CPU times: user 58.2 ms, sys: 1.64 ms, total: 59.8 ms
Wall time: 61.2 ms
Noise Loss, D = 802
CPU times: user 599 µs, sys: 2 µs, total: 601 µs
Wall time: 603 µs


