# Introduction
While the mean loss and the noise loss functions should in theory optimize to the same parameters in the diffusion network, in practice, it's important to consider the advantages and drawbacks of these functions like computational cost, accuracy, etc. In this problem, we explore the computational costs of the mean loss and the noise loss. 

# Library imports
Before you begin, make sure you have the following libraries installed.

In [15]:
import jax.numpy as jnp
from jax import random
from jax import jit

# Part A: Implementing the Loss Functions
Implement the mean and noise loss functions as specified in the homework.

In [17]:
def mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t):
    error = 0
    # BEGIN SOLUTION
    alpha_bar_t = jnp.prod(alpha_ts)
    alpha_t = alpha_ts[-1]
    beta_t = 1 - alpha_t
    alpha_bar_t_minus_1 = alpha_bar_t / alpha_t
    mu_calc = (((alpha_bar_t_minus_1)**0.5)*beta_t/(1 - alpha_bar_t))*x_0 + ((alpha_t**0.5)*(1 - alpha_bar_t_minus_1)/(1 - alpha_bar_t))*x_t
    norm_diff = jnp.linalg.norm(mu_pred - mu_calc)
    error = (1 / (2*sigma_t**2))*norm_diff**2
    # END SOLUTION
    return error 
    
def noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t, eps):
    error = 0
    # BEGIN SOLUTION
    alpha_bar_t = jnp.prod(alpha_ts)
    alpha_t = alpha_ts[-1]
    beta_t = 1 - alpha_t
    coeff = (beta_t**2) / (2 * sigma_t**2 * alpha_t * (1 - alpha_bar_t))
    D = len(eps_pred)
    norm_diff = jnp.linalg.norm(eps - eps_pred)
    error = coeff * norm_diff **2
    # END SOLUTION
    return error

## Sanity Check
Run the cell below to test the correctness of your solution.

In [18]:
# Test correctness of your solutions

import test_cases
sigma_ts = jnp.array(test_cases.sigma_ts)
alpha_tses = test_cases.alpha_tses
mu_preds = test_cases.mu_preds
eps_preds = test_cases.eps_preds
epses = test_cases.epses
x_0s = test_cases.x_0s
x_ts = test_cases.x_ts
mean_vals = test_cases.mean_vals
noise_vals = test_cases.noise_vals

for i in range(10):
    sigma_t = sigma_ts[i]
    alpha_ts = jnp.array(alpha_tses[i])
    mu_pred = jnp.array(mu_preds[i])
    eps_pred = jnp.array(eps_preds[i])
    eps = jnp.array(epses[i])
    x_0 = jnp.array(x_0s[i])
    x_t = jnp.array(x_ts[i])
    mean_val_student = mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t)
    noise_val_student = noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t, eps)
    mean_error = abs(mean_val_student - mean_vals[i])
    noise_error = abs(noise_val_student - noise_vals[i])
    print("Test " + str(i) + ": ")
    print("Mean Loss error: " + str(mean_error))
    print("Noise Loss error: " + str(noise_error))
    print("\n")

Test 0: 
Mean Loss error: 0.0
Noise Loss error: 0.0


Test 1: 
Mean Loss error: 0.0
Noise Loss error: 0.0


Test 2: 
Mean Loss error: 0.0
Noise Loss error: 0.0


Test 3: 
Mean Loss error: 0.0
Noise Loss error: 0.0


Test 4: 
Mean Loss error: 0.0
Noise Loss error: 0.0


Test 5: 
Mean Loss error: 0.0
Noise Loss error: 0.0


Test 6: 
Mean Loss error: 0.0
Noise Loss error: 0.0


Test 7: 
Mean Loss error: 0.0
Noise Loss error: 0.0


Test 8: 
Mean Loss error: 0.0
Noise Loss error: 0.0


Test 9: 
Mean Loss error: 0.0
Noise Loss error: 0.0




# Part B: Comparing the Speeds of the Loss Functions
After you've implemented the two loss functions, run the cell below and comment on any differences.

In [19]:
seed = 0
for D in range(2, 1000, 10):
    key = random.PRNGKey(0)
    s_key, a_key, mu_key, e0_key, e1_key, x0_key, xt_key = random.split(key, num=7)
    sigma_t = random.uniform(s_key)
    alpha_ts = random.uniform(a_key, (D,))
    mu_pred = random.uniform(mu_key, (D,))
    eps_pred = random.uniform(e0_key, (D,))
    eps = random.uniform(e1_key, (D,))
    x_0 = random.uniform(x0_key, (D,))
    x_t = random.uniform(xt_key, (D,))
    print("Mean Loss, D = " + str(D))
    %time mean_loss(sigma_t, alpha_ts, mu_pred, x_0, x_t)
    print("Noise Loss, D = " + str(D))
    %time noise_loss(sigma_t, alpha_ts, eps_pred, x_0, x_t, eps)
    print("\n")
    seed += 1

Mean Loss, D = 2
CPU times: user 71.3 ms, sys: 1.39 ms, total: 72.7 ms
Wall time: 73.6 ms
Noise Loss, D = 2
CPU times: user 1.24 ms, sys: 13 µs, total: 1.25 ms
Wall time: 1.26 ms


Mean Loss, D = 12
CPU times: user 92.7 ms, sys: 2.34 ms, total: 95 ms
Wall time: 96.7 ms
Noise Loss, D = 12
CPU times: user 1.36 ms, sys: 90 µs, total: 1.45 ms
Wall time: 1.37 ms


Mean Loss, D = 22
CPU times: user 78.4 ms, sys: 1.14 ms, total: 79.5 ms
Wall time: 79.5 ms
Noise Loss, D = 22
CPU times: user 1.18 ms, sys: 1e+03 ns, total: 1.19 ms
Wall time: 1.19 ms


Mean Loss, D = 32
CPU times: user 75.6 ms, sys: 1.8 ms, total: 77.4 ms
Wall time: 78.5 ms
Noise Loss, D = 32
CPU times: user 1.17 ms, sys: 1 µs, total: 1.17 ms
Wall time: 1.18 ms


Mean Loss, D = 42
CPU times: user 102 ms, sys: 3.64 ms, total: 106 ms
Wall time: 115 ms
Noise Loss, D = 42
CPU times: user 3.26 ms, sys: 282 µs, total: 3.54 ms
Wall time: 3.35 ms


Mean Loss, D = 52
CPU times: user 117 ms, sys: 3.67 ms, total: 120 ms
Wall time: 129 ms
No

CPU times: user 120 ms, sys: 3.62 ms, total: 124 ms
Wall time: 130 ms
Noise Loss, D = 452
CPU times: user 2.34 ms, sys: 156 µs, total: 2.49 ms
Wall time: 2.53 ms


Mean Loss, D = 462
CPU times: user 98.3 ms, sys: 2.26 ms, total: 101 ms
Wall time: 101 ms
Noise Loss, D = 462
CPU times: user 1.44 ms, sys: 115 µs, total: 1.55 ms
Wall time: 1.5 ms


Mean Loss, D = 472
CPU times: user 88.7 ms, sys: 1.35 ms, total: 90 ms
Wall time: 90.4 ms
Noise Loss, D = 472
CPU times: user 1.16 ms, sys: 1 µs, total: 1.16 ms
Wall time: 1.17 ms


Mean Loss, D = 482
CPU times: user 92.9 ms, sys: 1.43 ms, total: 94.3 ms
Wall time: 94.2 ms
Noise Loss, D = 482
CPU times: user 1.2 ms, sys: 3 µs, total: 1.2 ms
Wall time: 1.21 ms


Mean Loss, D = 492
CPU times: user 94.3 ms, sys: 1.63 ms, total: 96 ms
Wall time: 95.7 ms
Noise Loss, D = 492
CPU times: user 1.26 ms, sys: 116 µs, total: 1.38 ms
Wall time: 1.28 ms


Mean Loss, D = 502
CPU times: user 96.3 ms, sys: 1.89 ms, total: 98.2 ms
Wall time: 98.3 ms
Noise Loss, D

CPU times: user 97.6 ms, sys: 2.27 ms, total: 99.8 ms
Wall time: 101 ms
Noise Loss, D = 902
CPU times: user 1.38 ms, sys: 49 µs, total: 1.43 ms
Wall time: 1.44 ms


Mean Loss, D = 912
CPU times: user 89.6 ms, sys: 1.84 ms, total: 91.5 ms
Wall time: 91.5 ms
Noise Loss, D = 912
CPU times: user 1.39 ms, sys: 116 µs, total: 1.51 ms
Wall time: 1.45 ms


Mean Loss, D = 922
CPU times: user 94.2 ms, sys: 1.5 ms, total: 95.7 ms
Wall time: 95.1 ms
Noise Loss, D = 922
CPU times: user 1.32 ms, sys: 91 µs, total: 1.41 ms
Wall time: 1.33 ms


Mean Loss, D = 932
CPU times: user 94.7 ms, sys: 1.7 ms, total: 96.4 ms
Wall time: 96.2 ms
Noise Loss, D = 932
CPU times: user 1.93 ms, sys: 74 µs, total: 2 ms
Wall time: 2.39 ms


Mean Loss, D = 942
CPU times: user 91.1 ms, sys: 1.67 ms, total: 92.8 ms
Wall time: 92.8 ms
Noise Loss, D = 942
CPU times: user 1.45 ms, sys: 116 µs, total: 1.57 ms
Wall time: 1.51 ms


Mean Loss, D = 952
CPU times: user 93.9 ms, sys: 2.09 ms, total: 96 ms
Wall time: 96 ms
Noise Loss