In [None]:
#| echo: false
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [None]:
# Test np.where with vmap
import jax.numpy as np
from jax import random
from jax.scipy.stats import norm
import matplotlib.pyplot as plt
from jax import vmap


mus = np.array([80, 120])
sigmas = np.array([10, 20])

x = np.linspace(0, 300, 5000)
xs = np.vstack([x, x])
y1, y2 = vmap(norm.pdf)(xs, mus, sigmas)

ys = np.vstack([y1, y2])

# I want to find the x-values at which the heights of the two distributions
# are 13.5% of the maximum height. (This is a magic number often used in chromatography.)


# The most obvious solution is to use the following code:
def values_at_frac_max(xs, ys, fraction: float = 0.135):
    min_x = None
    for x, y in zip(xs, ys):
        if y >= fraction * ys.max():
            min_x = x
            break

    max_x = None
    for x, y in zip(xs[::-1], ys[::-1]):
        if y >= fraction * ys.max():
            max_x = x
            break
    return min_x, max_x


values_at_frac_max(x, y2, fraction=0.135)


In [None]:
def values_at_frac_max(xs, ys, fraction: float = 0.135):
    # The naive way:
    max_height = np.max(ys)
    idxs = np.where(ys >= max_height * fraction)[0]
    return xs[idxs.min()], xs[idxs.max()]


values_at_frac_max(xs, y1, 0.135), values_at_frac_max(xs, y2, 0.135)


# But this doesn't work:
# vmap(values_at_frac_max)(xs, ys)


# We get a ConcretizationTypeError, which is because `np.where` returns something that is of variable shape.
# However, since in this case we need just the minimum and maximum values, we can do something else.
def values_at_frac_max(xs, ys, fraction: float = 0.135):
    # The naive way:
    max_height = np.max(ys)
    idxs_min = np.where(ys >= max_height * fraction, size=1)[0]
    idxs_max = np.where(ys >= max_height * fraction, size=len(xs))[0]
    return np.array([xs[idxs_min.min()], xs[idxs_max.max()]])


values_wanted = vmap(values_at_frac_max)(xs, ys)
values_wanted


In [None]:
values_at_frac_max(x, y1, 0.135), values_at_frac_max(x, y2, 0.135)


In [None]:
import seaborn as sns

plt.plot(x, y1, color="red", label="curve 1")
plt.plot(x, y2, color="blue", label="curve 2")
plt.legend()
sns.despine()


In [None]:
plt.plot(x, y1, color="red", label="curve 1")
plt.plot(x, y2, color="blue", label="curve 2")

plt.hlines(
    y=0.135 * y1.max(),
    xmin=float(values_wanted[0, 0]),
    xmax=float(values_wanted[0, 1]),
    color="red",
    linestyle="--",
)

plt.hlines(
    y=0.135 * y2.max(),
    xmin=float(values_wanted[1, 0]),
    xmax=float(values_wanted[1, 1]),
    color="blue",
    linestyle="--",
)
plt.legend()
sns.despine()


In [None]:
from score_models.losses import score_matching_loss
from score_models.models.feedforward import FeedForwardModel
from score_models.models.gaussian import GaussianModel
from score_models.data import make_gaussian
from score_models.training import fit, default_optimizer, adam_optimizer
import matplotlib.pyplot as plt

from jax import random, numpy as np, vmap, jacfwd
import optax

data = make_gaussian()
model = GaussianModel()
model, history = fit(
    model, data, score_matching_loss, optimizer=adam_optimizer(), steps=600
)
plt.plot(history)


In [None]:
data = make_gaussian()
model = FeedForwardModel()
print(vmap(model)(data).shape, vmap(jacfwd(model))(data).shape)

score_matching_loss(model_func=model, batch=data)


In [None]:
import equinox as eqx
from jax import grad, nn

dloss = eqx.filter_jit(eqx.filter_grad(score_matching_loss))
# optimizer = optax.chain(
#     optax.clip(0.01),
#     optax.sgd(learning_rate=5e-3),
# )
optimizer = optax.adabelief(learning_rate=1e-3)
# model = GaussianModel()
model = eqx.nn.Sequential(
    [
        eqx.nn.Linear(in_features=1, out_features=1024, key=random.PRNGKey(45)),
        nn.relu(),
        eqx.nn.Linear(in_features=1024, out_features=1, key=random.PRNGKey(39)),
    ]
)
opt_state = optimizer.init(model)
# print(score_matching_loss(model_func=model, batch=data))
# model = eqx.nn.MLP(
#             in_size=1,
#             out_size=1,
#             width_size=1024,
#             depth=1,
#             key=random.PRNGKey(45),
#         )

# opt_state = optimizer.init(model)
# grads = dloss(model, data)
# updates, opt_state = optimizer.update(grads, opt_state)
# model = eqx.apply_updates(model, updates)
# print(score_matching_loss(model_func=model, batch=data))


In [None]:
dir(updates.submodule)


In [None]:
updates.mu


In [None]:
# DEBUG
grads.mu


In [None]:
data = make_gaussian()
model = GaussianModel()
print(vmap(model)(data).shape, vmap(jacfwd(model))(data).shape)


In [None]:
import optax
import equinox as eqx
from jax.example_libraries import stax
from jax import random, nn

optimizer = optax.adabelief(learning_rate=1e-3)
# model = GaussianModel()
model = eqx.nn.Sequential(
    [
        eqx.nn.Linear(in_features=1, out_features=1024, key=random.PRNGKey(45)),
        nn.relu,  # no problem when commented out
        eqx.nn.Linear(in_features=1024, out_features=1, key=random.PRNGKey(39)),
    ]
)
opt_state = optimizer.init(model)
