In [None]:
import torch
from ddm_dynamical.scheduler import BinarizedScheduler, EDMSamplingScheduler

import matplotlib.pyplot as plt

In [None]:
plt.style.use("paper")
plt.style.use("wiley")

In [None]:
ckpt = torch.load("../data/models/diffusion/diff_l_exp/best.ckpt", map_location="cpu")

In [None]:
train_scheduler = BinarizedScheduler(gamma_min=-20, gamma_max=20)
test_scheduler = EDMSamplingScheduler(gamma_min=-10, gamma_max=15)

In [None]:
train_scheduler._bin_times = ckpt["state_dict"]["scheduler._bin_times"]
train_scheduler.bin_values = ckpt["state_dict"]["scheduler.bin_values"]
train_scheduler.bin_limits = ckpt["state_dict"]["scheduler.bin_limits"]

In [None]:
time_range = torch.linspace(0, 1, 1001)
sample_range = torch.linspace(0, 1, 21)
with torch.no_grad():
    train_gamma = train_scheduler(time_range)
    test_gamma = test_scheduler(time_range)
    sample_gamma = test_scheduler(sample_range)

In [None]:
fig, ax = plt.subplots(figsize=(3, 2))
ax.grid(ls="dotted", lw=0.5, alpha=0.5)
ax.plot(time_range, train_gamma, label="Adaptive train scheduler", c="#89CAFF", ls="--", zorder=99)
ax.plot(time_range, test_gamma, label="EDM sampling scheduler", c="#FF5A54", zorder=99)
ax.scatter(sample_range, sample_gamma, label="Used during sampling", c="#FF5A54", marker="x", zorder=99)

ax.set_xlim(-0.01, 1.01)
ax.set_xlabel(r"Pseudo time $\tau$")

ax.set_ylim(-15.5, 21)
ax.set_ylabel(r"Log signal-to-noise $(\lambda_{\tau})$")
ax.legend(loc=1, bbox_to_anchor=(1.05, 1.05))

fig.savefig("figures/fig_app_a2_noise_scheduling.png", dpi=300)