In [1]:
import jax
import equinox as eqx
import jax.numpy as jnp
import jax.random as jrandom
import optimistix as optx
import optax

jax.config.update('jax_enable_x64', True)

In [2]:
# !pip install --upgrade optax

In [3]:
key = jrandom.PRNGKey(0)
IN = 10
OUT = 50

mlp = eqx.nn.MLP(IN, OUT, depth=3, width_size=OUT, key=key)



In [4]:
?optx.fixed_point

In [5]:
import pandas as pd
from datetime import datetime 
from tqdm import tqdm

@eqx.filter_jit
def reconstruction_loss(z, mlp_mask_x):  # contractive map
    mlp, mask, x = mlp_mask_x
    x_hat = mlp(z)
    loss = jnp.sqrt(jnp.nanmean((x - x_hat)**2, where=mask))
    return loss, loss

solver = {
    'cg': optx.BestSoFarMinimiser(optx.NonlinearCG(rtol=1e-8, atol=1e-8)),
    # 'optax': optx.BestSoFarMinimiser(optx.OptaxMinimiser(optim=optax.adam(1e-3), rtol=1e-8, atol=1e-8)),
    'BFGS': optx.BestSoFarMinimiser(optx.BFGS(rtol=1e-8, atol=1e-8))
}

@eqx.filter_jit
def solve(solver, z_init, args):
    return optx.minimise(reconstruction_loss, solver=solver, y0=z_init, args=args, throw=True,
                             has_aux=True,
                             max_steps=None)

df = []
for i in tqdm(range(100)):
    key, _ = jrandom.split(key, 2)
    z_init = jrandom.normal(key, shape=(IN, ))
    mask = jrandom.bernoulli(key, shape=(OUT, ))
    x = jrandom.normal(key, shape=(OUT, ))
    init_loss = reconstruction_loss(z_init, (mlp, mask, x))[0]    
    for solver_k, solver_v in solver.items():
        
        timenow = datetime.now()
        solution = solve(solver_v, z_init, (mlp, mask, x))
        eval_time = (datetime.now() - timenow).total_seconds()
        loss_reduction = init_loss - solution.aux

        df.append((i, solver_k, eval_time, loss_reduction.item()))
    
    # print(solution.value)
    # print(solution.stats)


In [6]:
df = pd.DataFrame(df, columns=['iteration', 'solver', 'duration', 'reduction'])

In [16]:
df = df.astype({'reduction': float})

In [18]:
import seaborn as sns
sns.displot(data=df, x="duration", y="reduction", 
            hue="solver", kind="kde")


In [9]:
!pip install seaborn