In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
import optimistix as optx
import optax

In [332]:
jax.grad(loss)(params_0, 5)

Array([-1.7215826e-06,  2.4857254e+01], dtype=float32)

In [756]:
def potential(phi, m2, beta):
    return -m2/4*phi**2 + beta/4*jnp.log(phi)*phi**4

def loss(params, args):
    eps = 1e-8
    m2, beta = params
    def potential_mask(phi):
        return potential(phi, m2, beta)
    def potential_grad(phi, args):
        return (jax.grad(potential_mask)(phi))**2/phi**2 #phi**2 is here to remove the minimum at phi=0 
    eps = 1e-4
    solver=optx.OptaxMinimiser(optax.adabelief(learning_rate=1e-3), rtol=1e-8, atol=1e-8)
    sol1 = optx.minimise(potential_grad, y0=jnp.array(0.1), solver=solver, max_steps=100000, throw=False)
    sol2 = optx.minimise(potential_grad, y0=1., solver=solver, max_steps=100000, throw=False)
    return jnp.abs(sol1.value - sol2.value - eps)

def critical_params(params_0):
    meta_solver = optx.OptaxMinimiser(optax.adam(learning_rate=1e-6), rtol=1e-8, atol=1e-8)
    param_sol = optx.minimise(loss, y0=params_0, solver=meta_solver, max_steps = 1000000, throw=False)
    return param_sol.value

beta = jnp.linspace(-0.01, -1, 3)
m2 = -beta/100.
params_0 = jnp.stack([m2, beta], axis=1)
results = jax.vmap(critical_params)(params_0)

In [757]:
results

Array([[ 0.00175084, -0.00760542],
       [ 0.0924923 , -0.41427147],
       [ 0.18393622, -0.8241027 ]], dtype=float32)

In [758]:
jnp.log(-results[:,0]/results[:,1])

Array([-1.4687648, -1.4993961, -1.4997061], dtype=float32)

In [723]:
loss(critical_params(jnp.array([0.01, -0.5])), 5)

Array(0.65179014, dtype=float32)

In [697]:
jnp.log(-param_sol.value[1]/param_sol.value[0]) #log(beta/m^2) should be 3/2 at criticality

Array(1.4964715, dtype=float32)

In [698]:
param_sol.value[0]

Array(0.00180499, dtype=float32)