## Hands-on Example: Fluxonium Qubit

### Energy spectrum
There are different ways to pass Fluxonium parameters (EC, EJ, EL).

- Directly set parameters when creating a class instance
- Use Haiku to manage model parameters


Haiku's model parameters management

- `hamiltonian.Helper.ls_params` return a dictionary containing keywords and parameters.
- pass parameters through the first argument.

In [1]:
import numpy as np
import jax
import jax.numpy as jnp

import supergrad
from supergrad.quantum_system import Fluxonium


class ExploreFluxonium(supergrad.Helper):

    def _init_quantum_system(self):
        self.fluxonium = Fluxonium(phiext=0, phi_max=5 * np.pi)

    def energy_spectrum(self, phi):
        self.fluxonium.phiext = phi * 2 * jnp.pi  # modify phiext, default 0
        return self.fluxonium.eigenenergies()


explore = ExploreFluxonium()
explore.ls_params()


{'fluxonium': {'ec': Array(1., dtype=float32),
  'ej': Array(1., dtype=float32),
  'el': Array(1., dtype=float32)}}

In [2]:
params = {
    'fluxonium': {
        'ec': jnp.array(1.68),
        'ej': jnp.array(3.5),
        'el': jnp.array(0.5)
    }
}
# each parameters should be float
explore.energy_spectrum(params, 0)


Array([-0.24092618,  5.06891579,  7.16328215,  8.28020319, 10.81742776,
       13.5851526 , 16.36989903, 19.3794768 , 22.42063552, 25.39180377],      dtype=float64)

For a fluxonium, one could vary the external flux bias `phiext` and calculate the 
energy spectrum.

In [3]:
explore.energy_spectrum(params, 0.5)


Array([ 1.3041627 ,  2.09823523,  6.19190726,  9.14847985, 12.57031594,
       15.37689449, 17.40065635, 19.16545642, 21.37778417, 23.89876961],      dtype=float64)

Below we show how we can use Jax to transform the above function

### Auto-vectorization with `vmap()`
JAX has one transformation in its API: `vmap()`, the vectorizing map. It mapping 
a function along array axes(`phiext`), but instead of keeping the loop on the
outside, it pushes the loop down into a function's primitive operations for
better performance.

In [4]:
phi_list = np.linspace(0, 1, 20)
vmap_energy_spectrum = jax.vmap(explore.energy_spectrum, in_axes=(None, 0))
%timeit vmap_energy_spectrum(params, phi_list).block_until_ready()


1.13 s ± 129 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
phi_list = np.linspace(0, 1, 20)
def forloop_energy_spectrum(params, phi_list):
    spectrum_list = []
    for phi in phi_list:
        spectrum_list.append(explore.energy_spectrum(params, phi))
    return jnp.array(spectrum_list)

%timeit forloop_energy_spectrum(params, phi_list)


3.08 s ± 272 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Use `jit()` to speed up functions
JAX runs transparently on the CPU or GPU, however, in the above example, JAX is 
dispatching kernels one operation at a time. If we have a sequence of operators(for 
example, parameters optimization), we can use `jax.jit` to compile multiple operations 
together using `XLA`.
We can speed `vmap_energy_spectrum` up with `jax.jit`, which will jit-compile(Just-In-Time) 
the first time `vmap_energy_spectrum` is called and will be cached thereafter.

In [6]:
jit_energy_spectrum = jax.jit(vmap_energy_spectrum)
spec_out = jit_energy_spectrum(params, phi_list)


In [7]:
%timeit jit_energy_spectrum(params, phi_list).block_until_ready()


973 ms ± 116 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
