In [1]:
import pandas as pd

locs_df = pd.read_csv("data/simulated_data_spcap.csv")

In [None]:
import jax
import jax.numpy as jnp

locs = jnp.asarray(locs_df.values[:, 0:3], dtype=jnp.float32)
stddev = jnp.asarray(locs_df.values[:, 3:], dtype=jnp.float32)
half_tau  = (stddev ** -2) / 2.0
log_const = -0.5 * jnp.log(jnp.prod(stddev**2, axis=1))

In [8]:
log_const.shape

(44, 3)

In [3]:
import locmofitpy2

key = jax.random.PRNGKey(1)
sc = locmofitpy2.SphericalCap.init(key)
# trainable0, static0 = locmofitpy2.partition_with_freeze(sc, freeze=("c",))
trainable0, static0 = locmofitpy2.partition_with_freeze(sc, freeze=())

In [4]:
import optax

loss = locmofitpy2.loss(static0, locs, half_tau, log_const)

optimizer = optax.lbfgs()
opt_state = optimizer.init(trainable0)

value_and_grad_fun = optax.value_and_grad_from_state(loss)


@jax.jit
def solve_lbfgs(trainable, opt_state, max_iter=50, tol=1e-6):
    def step(carry):
        trainable, state = carry
        value, grad = value_and_grad_fun(trainable, state=state)
        updates, state = optimizer.update(
            grad, state, trainable, value=value, grad=grad, value_fn=loss
        )
        trainable = optax.apply_updates(trainable, updates)
        return trainable, state

    def cond(carry):
        _, state = carry
        k = optax.tree.get(state, "count")
        g = optax.tree.get(state, "grad")
        return (k == 0) | ((k < max_iter) & (optax.tree.norm(g) >= tol))

    return jax.lax.while_loop(cond, step, (trainable, opt_state))


In [7]:
# %%timeit

trainable_opt, opt_state = solve_lbfgs(trainable0, opt_state, max_iter=800, tol=1e-10)

model_opt = locmofitpy2.combine(trainable_opt, static0)
final_loss = loss(trainable_opt)
final_loss.block_until_ready()

TypeError: sub got incompatible shapes for broadcasting: (44, 3), (100, 44).

In [None]:
print(
    float(model_opt.x),
    float(model_opt.y),
    float(model_opt.z),
    float(model_opt.c),
    float(model_opt.vartheta),
    float(model_opt.phi0),
    float(model_opt.theta),
    float(model_opt.phi),
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

ground_truth = pd.read_csv("data/ground_truth_spcap.csv")

positions = model_opt()

# If running on GPU/TPU, ensure compute finished before transferring for plotting
positions.block_until_ready()

positions_np = np.array(positions)  # converts JAX array -> NumPy array on host

fig = plt.figure()
ax = fig.add_subplot(projection="3d")
ax.set_box_aspect((1, 1, 1))
ax.scatter(locs_df["x"], locs_df["y"], locs_df["z"], c="gray", s=8)  # type: ignore[arg-type]
ax.scatter(ground_truth["x"], ground_truth["y"], ground_truth["z"])  # type: ignore[arg-type]
ax.scatter(positions_np[:, 0], positions_np[:, 1], positions_np[:, 2])  # type: ignore[arg-type]

plt.show()
