In [14]:
import jax.numpy as jnp
import jax
import src.models
import src.data_generate_sde.sde_ornstein_uhlenbeck as ou
import orbax.checkpoint
from src.training import utils
import flax.linen as nn

Setup to load in checkpoint

In [15]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()


sde = {"x0": jnp.ones(shape=(8,)), "N": 100, "dim": 8, "T": 1.0, "y": jnp.ones(shape=(8,))}
y = sde["y"]
dim = sde["dim"]
T = sde["T"]
checkpoint_path = f"/Users/libbybaker/Documents/Python/doobs-score-project/doobs_score_matching/checkpoints/ou/fixed_y_{y}_d_{dim}_T_{T}"


drift, diffusion = ou.vector_fields()
score_fn = utils.get_score(drift=drift, diffusion=diffusion)
train_step = utils.create_train_step_reverse(score_fn)
data_fn = ou.data_reverse(sde["y"], sde["T"], sde["N"])

network = {
    "output_dim": sde["dim"],
    "time_embedding_dim": 16,
    "init_embedding_dim": 16,
    "activation": nn.leaky_relu,
    "encoder_layer_dims": [16],
    "decoder_layer_dims": [128, 128],
}

training = {
    "batch_size": 1000,
    "epochs_per_load": 1,
    "lr": 0.01,
    "num_reloads": 1000,
    "load_size": 1000,
}

num_samples = training["batch_size"] * sde["N"]
x_shape = jnp.empty(shape=(num_samples, sde["dim"]))
t_shape = jnp.empty(shape=(num_samples, 1))
model = src.models.ScoreMLP(**network)

empty_train_state = utils.create_train_state(model, jax.random.PRNGKey(0), training["lr"], x_shape, t_shape)
target = {"state":empty_train_state, "sde":sde, "network":network, "training":training}

restored = orbax_checkpointer.restore(checkpoint_path, target)
train_state = restored["state"]
trained_score = utils.trained_score(train_state)


Find errors for different d and T

In [16]:
def error(ts, true_score, trained_score):
    true = jax.vmap(true_score, in_axes=(0, None, None, None))(ts, sde["x0"], sde["T"], sde["y"])
    trained = jax.vmap(trained_score, in_axes=(0, None))(ts, jnp.asarray(sde["x0"]))
    return jnp.mean(jnp.abs(true - trained))

In [22]:
from src.data_generate_sde import sde_ornstein_uhlenbeck, time

true_score = sde_ornstein_uhlenbeck.score
ts = time.grid(0, sde["T"], sde["N"])
error_d_8_T_1 = error(ts[:-1], true_score, trained_score)

In [23]:
print(error_d_8_T_1)

1.1553881


In [None]:
from tueplots import bundles
import matplotlib.pyplot as plt

bundle = bundles.neurips2023()
plt.rcParams.update(bundle)

true_score = sde_ornstein_uhlenbeck.score

In [None]:
fig, axs = plot_score_variable_y(true_score, trained_score, -1, 1, -1, 1)
plt.savefig('ou_score_varied_y_-1.0_to_1.0.pdf')
plt.show()

In [None]:
plot_score_error_variable_y(true_score, trained_score, -3, 3, -1, 1)