In [192]:
import jax
from cft import *
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tqdm import tqdm
import optax


In [193]:
def boots(beta,c,deltas, opt_fn, opt_state, steps=100):
    """ finds deltas

        Params:
        beta - random point for beta
        c - central charge
        deltas - intitalised deltas
        opt_fn - optimization function e.g. Adam
        opt_state - 

        Returns:
        losses, deltas, state

    """
    @jit
    def loss_function(deltas):
        identy = vmap(reduced_partition_function_spinless, in_axes=(0,0,None), out_axes=0)(deltas,beta,c)
        transformed = vmap(reduced_partition_function_spinless, in_axes=(0,0,None), out_axes=0)(deltas,1/beta,c)
        return jnp.mean((identy-transformed)**2) # mean squared loss

    losses = []
    for _ in tqdm(range(steps)):
        # loss and gradients calulated with 
        loss, grads = jax.value_and_grad(loss_function)(deltas) 
        updates, opt_state = opt_fn(grads, opt_state)
        deltas += updates
        losses.append(loss) 

    return jnp.stack(losses), deltas, opt_state


In [203]:
batch_size = 128
rng = random.PRNGKey(0)
keys = random.split(rng, 2)
beta = random.normal(keys[0], (batch_size,))*0.25 + 1 # takes points normally distributed around beta = 1
deltas = random.uniform(keys[1], (batch_size, 1))*5
c = 1
step=1000

# test for different learning rates
for lr in [1e-5,1e-4,5e-3,1e-2,5e-2, 1e-1, 0.2]:
    adam = optax.adam(learning_rate=lr)
    losses, learned_deltas, _ = boots(beta, c, deltas, opt_fn=adam.update, opt_state=adam.init(deltas),steps=step)
    steps = np.linspace(1,step,step)
    plt.plot(steps,losses,label=f"{lr}")
    plt.yscale("log")
    plt.legend()

    # print(learned_deltas)

100%|██████████| 1000/1000 [00:02<00:00, 382.24it/s]
100%|██████████| 1000/1000 [00:02<00:00, 361.78it/s]