# Precomputing Bragg pulses

Here we precompute the Bragg pulses needed to simulate an SCI.

First define functions to call `gbragg` in a parallel loop with the targeted parameters. Also define a test function.

In [None]:
from mwave.integrate import make_kvec, make_phi, gbragg, pops_vs_time
from mwave.precompute import write_bragg_precompute, read_bragg_precompute
import numpy as np
from scipy.interpolate import RegularGridInterpolator as RGI
from matplotlib import pyplot as plt
from tqdm import tqdm
from joblib import Parallel, delayed
import h5py

# Define output file name
fname = 'sig0.247.h5'

# Define simulation parameters
n_init = 0
n_bragg = 5
N_bloch = 10
sigma = 0.247

# Define function to perform precomputation
def precompute_gbragg(n0, nf, n_bragg, N_bloch=None):

    # Make kvec
    kvec, n0_idx, nf_idx = make_kvec(n0,nf)
    
    # Compute over grid
    omegas = np.linspace(0, 43, 800)
    deltas = np.linspace(-2, 2, 400) + 4*n_bragg
    
    # Create array to store output
    phi = np.full((len(omegas), len(deltas), len(kvec)), np.nan, dtype=np.complex128)
    
    # Define the grid shape
    grid_shape = [len(omegas), len(deltas)]
    
    # Define function to compute single Bragg pulse
    def do_gbragg(i):
        idx1, idx2 = np.unravel_index(i, grid_shape)
        if N_bloch is not None:
            sol = gbragg(kvec, make_phi(kvec, n0), 6*sigma, deltas[idx2], omegas[idx1], sigma, omega_mod=8*(N_bloch+n_bragg))
        else:
            sol = gbragg(kvec, make_phi(kvec, n0), 6*sigma, deltas[idx2], omegas[idx1], sigma)
        return sol.y[:,-1]
    
    # Compute function in parallel
    out = Parallel(n_jobs=-1)(delayed(do_gbragg)(i) for i in tqdm(range(np.prod(grid_shape))))
    
    # Put all of the output into the phi array
    for i in range(np.prod(grid_shape)):
        idx1, idx2 = np.unravel_index(i, grid_shape)
        phi[idx1, idx2, :] = out[i]

    # Write output to HDF5 file
    write_bragg_precompute(fname, phi, kvec, ((omegas, 'omegas'), (deltas, 'deltas')), n0, nf, n_bragg, N_bloch=N_bloch)

def test_precomp_grid(n0, nf, multifreq=False):
    # Load grid
    phi, kvec, grid = read_bragg_precompute(fname, n0, nf, n_bragg, N_bloch if multifreq else None)
    omegas, deltas = grid[0][0], grid[1][0]
    
    # Interpolate
    rgi = RGI([omegas, deltas], phi, method='cubic')
    
    # Create vectorized gbragg function
    def vgbragg(omegas, deltas):
        if len(omegas) != len(deltas):
            raise ValueError('omegas and deltas must have the same length')
        phi = np.full((len(omegas), len(kvec)), np.nan, dtype=np.complex128)
        for i in tqdm(range(len(omegas))):
            if multifreq:
                sol = gbragg(kvec, make_phi(kvec, n0), 6*sigma, deltas[i], omegas[i], sigma, omega_mod=8*(N_bloch+n_bragg))
            else:
                sol = gbragg(kvec, make_phi(kvec, n0), 6*sigma, deltas[i], omegas[i], sigma)
            phi[i, :] = sol.y[:,-1]
        
        return phi
    
    # Drop random points on grid
    npoints = 4000
    rng = np.random.default_rng()
    omega_pnts = rng.uniform(omegas[0], omegas[-1], npoints)
    delta_pnts = rng.uniform(deltas[0], deltas[-1], npoints)
    
    # Compute real value on points
    phi_actual = vgbragg(omega_pnts, delta_pnts)
    
    # Compute interpolated value on points
    phi_interp = rgi((omega_pnts, delta_pnts))
    
    # Take the difference
    phi_diff = phi_actual - phi_interp
    
    # Sum the error along each wavefunction
    err = np.sum(phi_diff, axis=-1)
    
    # Plot the real and imaginary parts of the error
    plt.figure()
    plt.scatter(omega_pnts, delta_pnts, c=np.real(err))
    cbar = plt.colorbar()
    cbar.ax.set_ylabel('re abs err')
    plt.show()
    
    plt.figure()
    plt.scatter(omega_pnts, delta_pnts, c=np.imag(err))
    cbar = plt.colorbar()
    cbar.ax.set_ylabel('im abs err')
    plt.show()

    # Plot histograms of the error
    plt.figure()
    plt.hist(np.real(err), range=(-2e-6,2e-6), bins=200)
    plt.xlabel('re abs err')
    plt.show()
    
    plt.figure()
    plt.hist(np.imag(err), range=(-2e-6,2e-6), bins=200)
    plt.xlabel('im abs err')
    plt.show()

    # Print the error mean and std
    print(f'Re[error] mean={np.mean(np.real(err))}, error std={np.std(np.real(err))}')
    print(f'Im[error] mean={np.mean(np.imag(err))}, error std={np.std(np.imag(err))}')

Precompute the single frequency Bragg pulses

In [None]:
n0, nf = n_init, n_init + n_bragg
# Uncomment the below code when you want to perform the computation
# precompute_gbragg(n0, nf, n_bragg)
# test_precomp_grid(n0,nf)

In [None]:
n0, nf = n_init + n_bragg, n_init
# Uncomment the below code when you want to perform the computation
# precompute_gbragg(n0, nf, n_bragg)
# test_precomp_grid(n0,nf)

Precompute the multifrequency Bragg pulses

In [None]:
n0, nf = -N_bloch, -N_bloch-n_bragg
# Uncomment the below code when you want to perform the computation
# precompute_gbragg(n0, nf, n_bragg, N_bloch=N_bloch)
# test_precomp_grid(n0,nf,multifreq=True)

In [None]:
n0, nf = n_bragg+N_bloch, 2*n_bragg+N_bloch
# Uncomment the below code when you want to perform the computation
# precompute_gbragg(n0, nf, n_bragg, N_bloch=N_bloch)
# test_precomp_grid(n0,nf,multifreq=True)

In [None]:
n0, nf = 2*n_bragg+N_bloch, n_bragg+N_bloch
# Uncomment the below code when you want to perform the computation
# precompute_gbragg(n0, nf, n_bragg, N_bloch=N_bloch)
# test_precomp_grid(n0,nf,multifreq=True)

In [None]:
n0, nf = -n_bragg-N_bloch, -N_bloch
# Uncomment the below code when you want to perform the computation
# precompute_gbragg(n0, nf, n_bragg, N_bloch=N_bloch)
# test_precomp_grid(n0,nf,multifreq=True)