In [None]:
%matplotlib inline

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# NOTE: Janky path addition required because we are now running from the `notebooks` directory
#       rather than the main `two_step_zoo` directory
import os
import sys

mpl.rcParams['font.family'] = 'STIXGeneral'
mpl.rcParams['axes.linewidth'] = 1.0
plt.rcParams["mathtext.fontset"] = "cm"
plt.rcParams['text.usetex'] = False
dark_red ='#DC3220'
light_blue = '#20cadc'
dark_blue='#21b0bf'
blue = '#005AB5'

cwd = os.getcwd()
sys.path.insert(0, cwd.strip(os.path.basename(cwd))[:-1])
from load_run import load_single_module, load_twostep_module
from two_step_zoo.evaluators.metrics import fd
sys.path.pop(0)

In [None]:
run_dir = "../runs/MonthDay_Hour-Minute-Second"  # path to diffusion model on ambient space
load_dict = load_single_module(run_dir)

run_dir_ts = "../runs/MonthDay_Hour-Minute-Second"  # path to diffusion model on latent space
load_dict_ts = load_twostep_module(run_dir_ts)

In [None]:
T = 1.
steps = 1000
eps = 1e-3
n_samples = 100


In [None]:
_, scores = load_dict["module"].sample(n_samples, eps, steps, True)
scores = np.array(scores)
print(scores.shape)

In [None]:
_, scores_ts = load_dict_ts["module"].density_estimator.sample(n_samples, eps, steps, True)
scores_ts = np.array(scores_ts)
print(scores_ts.shape)

In [None]:
linewidth = 6
fs1 = 50
fs2 = 40
lp = 30
tick_size = 20

T_prime = 0.011

x = T - np.linspace(start=0., stop=T-eps, num=steps)
x_prime = x[x < T_prime]
scores_prime = scores[x < T_prime]
scores_ts_prime = scores_ts[x < T_prime]

y_mean = np.mean(scores_prime, axis=1)
y_stddev = np.std(scores_prime, axis=1)
y_mean_ts = np.mean(scores_ts_prime, axis=1)
y_stddev_ts = np.std(scores_ts_prime, axis=1)

plt.figure(figsize=(20, 15))

plt.plot(x_prime, y_mean, c=dark_red, label='Diffusion model', linewidth=linewidth)
plt.fill_between(x_prime, y_mean - y_stddev, y_mean + y_stddev, alpha=0.3, color=dark_red)

plt.plot(x_prime, y_mean_ts, c=blue, label='Latent diffusion model', linewidth=linewidth)
plt.fill_between(x_prime, y_mean_ts - y_stddev_ts, y_mean_ts + y_stddev_ts, alpha=0.3, color=blue)

plt.ylabel(r'$\dfrac{\Vert s_{\theta^*}(\hat{Y}_{t}, T-t) \Vert_2^2}{\mathrm{dim}}$', fontsize=fs1, labelpad=lp)
plt.xlabel(r'$T-t$', fontsize=fs1, labelpad=lp)

plt.tick_params(axis='both', which='major', labelsize=fs2, size=tick_size)

plt.legend(fontsize=fs1)
plt.tight_layout()

# plt.savefig('score_norms.pdf')