In [None]:
from math import *
import numpy as np
import random as rand
import matplotlib.pyplot as plt
import time as time
import copy
import scipy.optimize as optimize
import scipy.special as kate
from scipy import integrate
from scipy import linalg as linalg
from scipy.interpolate import interp1d
import scipy.stats as stats
import scipy.fft as fft
import math as math
from matplotlib.colors import BoundaryNorm
from matplotlib.ticker import MaxNLocator
from matplotlib import animation, rc
from IPython.display import HTML, display, Math, Latex
import pickle
import mpmath
import scipy.ndimage.filters as filters
import scipy.ndimage as ndimage
import datetime as dtime

from mayavi import mlab
from mayavi.api import Engine
from PIL import Image
import glob
import os

In [None]:
tu = 10**(-6) #convert all times to us
du = 10**(-6) #convert all distances to um
mu = 10**(-34) #convert all masses to 10**(-31) kg (~ the electron's mass)

#Fundamental constants
hbar = 1.054571817*10**(-34)*tu/du**(2)/mu #J*s, reduced Planck constant
kB = 1.380649*10**(-23)*(tu/du)**(2)/mu # J/K, Boltzmann constant
c = 2.99792458*10**(8)*tu/du #m/s, speed of light
mu0 = 1.25663706212*10**(-6)*tu**(2)/mu/du #H/m, vacuum permeability
aB = 5.29e-11/du #m, Bohr radius

#Less fundamental constants
m = 85.4678*10**(-3)/(6.022*10**(23))/mu #kg, atomic mass of rubidium

In [None]:
#Scattering lengths for |1,-1>, |1,0>, |1,1>, |2,0>, |2,1>

const = 4*np.pi*hbar**(2)/m

a1_0 = 101.78*aB
a1_2 = 100.4*aB
c1_0 = const*(a1_0 + 2*a1_2)/3
c1_1 = const*(a1_2 - a1_0)/3

a2_0 = 87.93*aB
a2_2 = 91.28*aB
a2_4 = 98.98*aB
c2_0 = const*(4*a2_2 + 3*a2_4)/7
c2_1 = const*(a2_4 - a2_2)/7
c2_2 = const*(7*a2_0 - 10*a2_2 + 3*a2_4)/7

a12_1 = 99.29*aB
a12_2 = 96.15*aB
a12_3 = 98.64*aB
c12_0 = const*(2*a12_2 + a12_3)/3
c12_1 = const*(a12_3 - a12_2)/3
c12_2 = const*(3*a12_1 - 5*a12_2 + 2*a12_3)/3

#Interaction strength constants
g1_m1m1 = c1_0 + c1_1
g1_00 = c1_0
g1_11 = c1_0 + c1_1
g1_m10 = c1_0 + c1_1
g1_m11 = c1_0 - c1_1
g1_01 = c1_0 + c1_1

g2_00 = c2_0 + c2_2/5
g2_11 = c2_0 + c2_1
g2_01 = c2_0 + 3*c2_1

g12_m10 = c12_0 + c12_2/10
g12_m11 = c12_0 - c12_1 + 3*c12_2/10
g12_00 = c12_0 + 2*c12_2/5
g12_01 = c12_0 + 3*c12_2/10
g12_10 = c12_0 + c12_2/10
g12_11 = c12_0 + c12_1

In [None]:
#For different temperatures, calculate the fugacity for a dilute BEC in an isotropic 3D harmonic
#trap, and use that to find the population of each level

Nt = 1e6     #number of atoms
w_HO = 2*np.pi*100*tu #Hz, harmonic trap frequency
N_levels = 2 #number of HO levels to calculate the spectrum of
frac = 0.7 #the temperature at which this fraction of atoms will be condensed is returned

maxiter = 10000
opts = {"maxiter":maxiter, "disp":False} #options for the minimizer

Tc = (hbar*w_HO/kB)*(Nt/kate.zeta(3))**(1/3)*(1 - kate.zeta(2)*kate.zeta(3)**(-2/3)*Nt**(-1/3)/2)
T = Tc*np.linspace(0.01, 1.2, 100) #K, temperatures

j = np.arange(0, N_levels) #3D HO states

#We'll minimize an expression for Nt that is a function only of the fugacity, z, to find
#the optimal value of z
def func(z, frac):
    if z[0] >= 1:
        z[0] = 1-1e-10
    return np.abs(z/(1-z) + mpmath.polylog(3, z[0])*frac**3 - Nt)

#Minimizer
def Bose_distribution(T):
    z = 0.99 #initial guess for the fugacity
    frac = kB*T/(hbar*w_HO)
    
    out = optimize.minimize(func, z, args=(frac), method='Nelder-Mead', options=opts)
    
    z_exp = out.x[0]*np.exp(-j/frac)
    
    #Return the relative population of each single-particle level
    return z_exp/(1 - z_exp)/Nt

#For each temperature, find the population spectrum and chemical potential
f0 = np.zeros((len(T), N_levels))
for ind in range(len(T)):
    f0[ind] = Bose_distribution(T[ind])

ind_Tc = np.argmin(np.abs(f0[:,0] - frac))
T0 = T[ind_Tc]
print('Critical temperature:', Tc*1e9, 'nK')
print(str(frac*100) + '% condensed fraction temperature:', T0*1e9, 'nK')

plt.figure()
plt.plot(T/Tc, f0[:,0], 'b.')
plt.xlabel('T/Tc')
plt.ylabel('N0/N')
plt.show()

plt.figure()
plt.plot(T/Tc, f0[:,1], 'r.')
plt.xlabel('T/Tc')
plt.ylabel('N1/N')
plt.show()

In [None]:
#Band structure calculation that uses the Fourier transform of any infinite and periodic potential

lam_eff = 1064e-9/du #m, effective wavelength of the lattice
Vb = 2*np.sin(np.pi/2 + np.pi/150)*hbar*800.176251680392e3*tu #J, lattice depth
N = int(frac*Nt) #number of condensed atoms

L =  1.0*lam_eff/4 #m, box radius along x: wavefunction should go to zero by the time it reaches +/- L
pow2 = 10 #each axis will be broken down into 2**pow2 spatial points
x = np.linspace(-L, L, 2**pow2+1)
delta = x[1]-x[0]

pot = m*w_HO**(2)*x**(2)/2 + Vb*np.sin(2*np.pi*x/lam_eff)**2

k = 2*np.pi/lam_eff #1/m, lattice quasimomentum will range from +hbar*k to -hbar*k (don't change)
l = np.arange(-20, 20+1) #number of terms to include in the Fourier series expansion of the wavefunction and potential
q = np.linspace(-hbar*k, hbar*k, 101) #lattice quasimomentum

indm = int(len(l)/2)
pot_ft = np.sum(pot[:,np.newaxis] * np.exp(-1j*k*l[np.newaxis,:]*x[:,np.newaxis]), axis=0)/len(pot)
pot_ft = pot_ft[indm:]

#First generate the Hamiltonian in momentum space (for the Schrodinger equation after using Bloch's theorem and writing 
#the wavefunction and potential as Fourier series). We'll create an array of Hamiltonians, each corresponding to a different q

#Set the diagonal elements of the Hamiltonian
H = np.fromfunction(lambda i0, i1, i2: (i1 == i2)*((2*hbar*k*l[i1.astype(int)] + q[i0.astype(int)])**(2)/(2*m) + pot_ft[0]), 
                    (len(q), len(l), len(l)))/2 #Divide by 2 so that the diagonals are right after the matrix is made symmetric

for ind in range(2, len(pot_ft), 2):
    ind0 = int(ind/2)
    H[:, np.arange(ind0, len(l)), np.arange(0, len(l)-ind0)] = pot_ft[ind]*np.ones(len(l)-ind0)[np.newaxis,:]

H = H + np.transpose(H, axes=[0,2,1]) #Make each Hamiltonian matrix symmetric
eigvals, eigvectors = np.linalg.eigh(H)

def wannier_func(cnl, norm, x, xi):
    return (1/np.sqrt(norm))*np.sum(cnl[:,:,np.newaxis]*np.exp(1j*q[:,np.newaxis,np.newaxis]*(x[np.newaxis,np.newaxis,:] - xi)/hbar + 2j*l[np.newaxis,:,np.newaxis]*k*x[np.newaxis,np.newaxis,:]), axis=(0,1))

font = {'family':'Times New Roman', 'weight':'bold', 'size':15}
plt.rc('font', **font)
fig, ax = plt.subplots(nrows=1,ncols=3,figsize=(16,4))

ax[0].plot(q/(hbar*k), eigvals[:,0:6]*1e-3/tu/hbar, linewidth=3)
ax[0].set_xlabel(r'$q$ $(\hbar k)$')
ax[0].set_ylabel(r'$E_{n}$ $(kHz)$')
ax[0].set_title('Band structure')

ind = int(len(q)/2)
ind2 = -1
ax[1].plot(x, np.real(np.exp(1j*q[ind]*x/hbar)*np.sum(eigvectors[ind][:,0][:,np.newaxis]*np.exp(2j*l[:,np.newaxis]*k*x[np.newaxis,:]), axis=0)), 'b-', label='q = '+str(np.round(q[ind]/(hbar*k), 1)), linewidth=3)
ax[1].plot(x, np.real(np.exp(1j*q[ind2]*x/hbar)*np.sum(eigvectors[ind2][:,0][:,np.newaxis]*np.exp(2j*l[:,np.newaxis]*k*x[np.newaxis,:]), axis=0)), 'r--', label='q = '+str(np.round(q[ind2]/(hbar*k), 1)), linewidth=3)
ax[1].set_xlabel(r'$x$ $(\mu m)$')
#ax[1].set_ylabel(r'$Re(u_{nq})$')
ax[1].set_title('Real part of Bloch wavefunction')
ax[1].legend()

norm = integrate.romb(np.abs(wannier_func(eigvectors[:,:,0], 1, x, 0))**2, dx=delta)
wan0 = wannier_func(eigvectors[:,:,0], norm, x, 0)
norm = integrate.romb(np.abs(wannier_func(eigvectors[:,:,1], 1, x, 0))**2, dx=delta)
wan1 = wannier_func(eigvectors[:,:,1], norm, x, 0)

norm = integrate.romb(np.abs(wannier_func(eigvectors[:,:,2], 1, x, 0))**2, dx=delta)
wan2 = wannier_func(eigvectors[:,:,2], norm, x, 0)
norm = integrate.romb(np.abs(wannier_func(eigvectors[:,:,3], 1, x, 0))**2, dx=delta)
wan3 = wannier_func(eigvectors[:,:,3], norm, x, 0)
norm = integrate.romb(np.abs(wannier_func(eigvectors[:,:,4], 1, x, 0))**2, dx=delta)
wan4 = wannier_func(eigvectors[:,:,4], norm, x, 0)

pot = (pot - np.min(pot))*np.max(np.abs(wan0)**2)/np.max(pot - np.min(pot))
ax[2].plot(x, pot, 'k-', label='Potential', linewidth=3) 
ax[2].plot(x, np.abs(wan0)**2, 'b--', linewidth=3)
ax[2].set_xlabel(r'$x$ $(\mu m)$')
ax[2].set_title('Localized Wannier probabilities')
plt.show()

#Single particle harmonic oscillator eigenfunctions
def psi_HO(x, w, n):
    x0 = np.sqrt(hbar/m/w)
    return (1/np.sqrt(x0*kate.factorial(n)*2**n)/np.pi**(1/4))*np.exp(-0.5*(x/x0)**2)*kate.eval_hermite(n, x/x0)

#TF approximation
asc = g1_00*m/(4*np.pi*hbar**2)
mew = 0.5*(15*N*asc/np.sqrt(hbar/m/w_HO))**(2/5)*hbar*w_HO
R = np.sqrt(2*mew/m/w_HO**2)
rTF = np.linspace(0, R, 2**pow2 + 1)

def psi_TF(r):
    prob_TF = (1/g1_00/N)*(mew - 0.5*m*w_HO**(2)*r**2)
    prob_TF[prob_TF < 0] = 0
    return np.sqrt(prob_TF)

potp = 0.5*m*w_HO**(2)*rTF**(2)*1e-3/tu/hbar
plt.figure()
plt.plot(rTF, potp, 'k-', label=r'$V$')
plt.plot(rTF, (np.max(potp)/np.max(psi_HO(rTF, w_HO, 0)))*psi_HO(rTF, w_HO, 0), 'b--', label=r'$\psi_{HO}$')
plt.plot(rTF, (np.max(potp)/np.max(psi_TF(rTF)))*psi_TF(rTF), 'r-', label=r'$\psi_{TF}$')
plt.xlabel(r'$\rho$ $(\mu m)$')
plt.ylabel(r'$V(\rho)$ $(kHz)$')
plt.legend()
plt.show()

U = g1_11*integrate.romb(np.abs(wan0)**4, dx=delta)**3
U_HO = g1_00*integrate.romb(4*np.pi*rTF**(2)*np.abs(psi_TF(rTF))**4, dx=rTF[1]-rTF[0])

print('U_lat =', U*1e-3/tu/hbar, 'kHz')
print('J_lat =', 0.25*(np.max(eigvals[:,0]) - np.min(eigvals[:,0]))/tu/hbar, 'Hz')
print('U_HO =', U_HO/tu/hbar, 'Hz')
print('Frequency difference between bands:', np.mean(eigvals[:,1] - eigvals[:,0])*1e-3/tu/hbar, 'kHz')
print('H.O. frequency:', w_HO/tu, 'Hz')
#print('N <<', hbar/(g*(m/hbar/np.pi/2)**(3/2)*w_HO**(1/2)))
print('omega >>', 2*U_HO*100*np.sqrt(1e3)*1e-3/tu/hbar, 'kHz')


In [None]:
#Gaussian fit to Wannier function

def gauss(x, x0, A, w):
    return A*np.exp(-((x-x0)/w)**(2)/2)

sol_real = optimize.curve_fit(gauss, x, np.real(wan0), p0=[0, 1, lam_eff/20])

#If the Wannier function has a significant imaginary part,do a fit to that as well
if np.abs(np.mean(np.real(wan0))) > 100*np.abs(np.mean(np.imag(wan0))):
    sol_imag = np.array([[0,0,1]])
else:
    sol_imag = optimize.curve_fit(gauss, x, np.imag(wan0), p0=[0, -1, lam_eff/20])

gauss_fit = (gauss(x, sol_real[0][0], sol_real[0][1], sol_real[0][2]) + 
             1j*gauss(x, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2]))
gauss_fit_shift = (gauss(x+lam_eff/4, sol_real[0][0], sol_real[0][1], sol_real[0][2]) + 
                   1j*gauss(x+lam_eff/4, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2]))

print('Analytic overlap between Gaussians of adjacent sites:', np.exp(-(lam_eff/4/sol_real[0][2]/2)**2))
print('Numeric overlap between Gaussians of adjacent sites:', np.abs(integrate.romb(gauss_fit*np.conj(gauss_fit_shift), dx=delta)))
print()
print('Fit lattice trapping frequency:', hbar/m/sol_real[0][2]**(2)*1e-3/tu, 'kHz')
print('Analytic lattice trapping frequency:', np.sqrt(np.abs(2*Vb*(2*np.pi/lam_eff)**(2)/m))*1e-3/tu, 'kHz')

plt.figure()
plt.plot(x*du/1e-6, np.abs(wan0)**2, 'b-', label='Wannier', linewidth=3)
plt.plot(x*du/1e-6, np.abs(gauss_fit)**2, 'r--', label='Gaussian', linewidth=3)
plt.xlabel(r'$x$ $(\mu m)$')
plt.ylabel(r'$P$')
plt.legend()
plt.show()

In [None]:
#Mean-field approximation of Bose-Hubbard model for BEC being driven between both mF = 0 states (|0> and |2>) 
#via a two-photon virtual process through an optical lattice site (|1>) that can either be filled or empty

Na = N                                  #Total number of atoms used to set the experimental parameters
Na2 = Na + np.sqrt(Na)                  #Actual number of atoms in the simulation
frac = 1/10                             #Ratio of the driving fields to the single-photon detuning
f0 = 4/5                                #Fraction of atoms that will be initialized in |0> (the rest will start in |2>)
phi0 = 0                                #Initial phase difference between |0> and |2>
times = np.linspace(0, 72.7931982995749e-3, 4000)/tu #sec, times to evaluate the diffEq at

r_tol, a_tol, maxstep = 1e-8, 1e-8, 1e0 #relative tolerance, absolute tolerance, and max step size for the diffEq solver

#Set up for 3D overlap integrals of the Wannier functions and the TF BEC wavefunction
def Wan_wavefunction(x, y, z):
    return ((gauss(x, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(x, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2]))*
            (gauss(y, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(y, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2]))*
            (gauss(z, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(z, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2])))

xi = np.linspace(x[0], x[-1], 2**5 + 1)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)

intgrnd = psi_TF(np.sqrt(x_grid**2 + y_grid**2 + z_grid**2))*Wan_wavefunction(x_grid, y_grid, z_grid)
overlap = integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])

print('Overlap between Wannier and TF wavefunctions: 1/'+ str(np.abs(1/overlap)))

#Calculate the interaction strengths for each state (and all pairs)

intgrnd = Wan_wavefunction(x_grid, y_grid, z_grid)**4
U11 = np.real(g1_m1m1*integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0]))

intgrl = integrate.romb(integrate.romb(integrate.romb(np.abs(Wan_wavefunction(x_grid, y_grid, z_grid))**(2)*
                                                      np.abs(psi_TF(np.sqrt(x_grid**2 + y_grid**2 + z_grid**2)))**2, 
                                                      dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])
U01 = g1_m10*intgrl
U12 = g12_m10*intgrl

xi = np.linspace(-rTF[-1], rTF[-1], 2**5 + 1)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)
intgrl = integrate.romb(integrate.romb(integrate.romb(np.abs(psi_TF(np.sqrt(x_grid**2 + y_grid**2 + z_grid**2)))**4, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])
U00 = g1_00*intgrl
U22 = g2_00*intgrl
U02 = g12_00*intgrl

det = Na*(U22-U00)/hbar/2               #Hz, two-photon detuning between |0> and |2>
delta = np.real(U11/hbar)           #Hz, one-photon detuning of each field from resonance with |1>

def func(delta):
    return np.abs(0.5/(2*hbar*(delta-det)+Na*(-U00+U01-U02+U12)+n0*(U00-U01-U02+U12)) -
                  1/(2*hbar*(delta-det)+2*U11+Na*(-U00+U01-U02+U12)+n0*(U00-U01-U02+U12)) +
                  0.5/(2*hbar*delta+Na*(U01-U02+U12-U22)+n0*(-U01+U02+U12-U22)) -
                  1/(2*hbar*delta+2*U11+Na*(U01-U02+U12-U22)+n0*(-U01+U02+U12-U22)))

opts = {"maxiter":maxiter, "disp":False} #options for the minimize
out = optimize.minimize(func, delta, method='Nelder-Mead', options=opts)
delta = 1.6*out.x[0]

J0 = frac*hbar*delta               #J, one-photon Rabi frequency connecting |0> -> |1>, and |1> -> |2>

print('One-photon detuning:', delta*1e-3/tu, 'kHz')
print('Free-space Rabi frequencies:', (J0/np.abs(overlap)/hbar)*1e-3/tu, 'kHz')

#Parameters for the double-well approximation

U = np.abs(U00 - 2*U02 + U22)/2                        #Effective double-well interaction strength
J = np.abs(J0**(2)*(1/(-hbar*delta+(Na-1)*(U02-U12)) + 1/(-hbar*delta+(Na-1)*(U22-U12))))  #Effective double-well tunneling energy
n_off = -hbar*det/U/2 + (U22-U00)*(Na-1)/U/4

#Differential equation we'll be solving for the time dependence of n (half the population difference between
#|0> and |2>) and phi (relative phase between |0> and |2>) in the mean field limit

def eqn0(t, vec):
    n, phi = vec
    
    nDot = -np.sqrt(Na2**2 - 4*n**2)*J0**(2)*(-1/(2*hbar*(delta-det)+2*n*(U00-U01-U02+U12) + Na2*(-U00+U01-U02+U12)) - 1/(2*hbar*delta+Na2*(U01-U02+U12-U22)+2*n*(-U01+U02+U12-U22)))*np.sin(phi)/hbar
    
    phiDot = 0.5*(-2*hbar*det + (2*n-Na2)*U00 - 4*n*U02 + (2*n+Na2)*U22 + 
                  4*J0**(2)*(-(2*n-Na2)*(-U00+U01+U02-U12)/(2*hbar*(delta-det) - 2*n*(-U00+U01+U02-U12) + Na2*(-U00+U01-U02+U12))**2 + 
                             1/(2*hbar*delta + 2*n*(U00-U01-U02+U12) + Na2*(-U00+U01-U02+U12)) - 
                             (2*n+Na2)*(U01-U02-U12+U22)/(2*hbar*(delta-det) + Na2*(U01-U02+U12-U22) + 2*n*(-U01+U02+U12-U22))**2 - 
                             1/(2*hbar*delta + Na2*(U01-U02+U12-U22) + 2*n*(-U01+U02+U12-U22))) + 
                  J0**(2)*np.cos(phi)*(8*n*(1/(2*hbar*delta+Na2*(-U00+U01-U02+U12)+2*n*(U00-U01-U02+U12)) + 
                                                1/(2*hbar*(delta-det) + 2*n0*(-U01+U02+U12-U22)+Na2*(U01-U02+U12-U22))) + 
                                  4*(Na2**2 - 4*n**2)*((U00-U01-U02+U12)/(2*hbar*delta+2*n*(U00-U01-U02+U12)+Na2*(-U00+U01-U02+U12))**2 - 
                                            (U01-U02-U12+U22)/(2*hbar*(delta-det)+Na2*(U01-U02+U12-U22)+2*n*(-U01+U02+U12-U22))**2))/np.sqrt(4*(Na2**2 - 4*n**2)))/hbar
    
    return [nDot, phiDot]

def eqn1(t, vec):
    n, phi = vec
    
    nDot = -np.sqrt(Na2**2 - 4*n**2)*J0**(2)*(0.5/(2*hbar*delta+Na2*(-U00+U01-U02+U12)+2*n*(U00-U01-U02+U12)) -
                                             1/(2*hbar*(delta-det)+2*U11+Na2*(-U00+U01-U02+U12)+2*n*(U00-U01-U02+U12)) +
                                             0.5/(2*hbar*delta+Na2*(U01-U02+U12-U22)+2*n*(-U01+U02+U12-U22)) -
                                             1/(2*hbar*(delta-det)+2*U11+Na2*(U01-U02+U12-U22)+2*n*(-U01+U02+U12-U22)))*np.sin(phi)/hbar
    
    phiDot = (-hbar*det + (Na2/2)*(U22-U00) + U12-U01 + n*(U00-2*U02+U22) + 
              4*J0**(2)*(-(hbar*(delta-det)+Na2*(U12-U02))/(2*hbar*(delta-det)+Na2*(-U00+U01-U02+U12)+2*n*(U00-U01-U02+U12))**2
                         +2*(hbar*(delta-det)+U11+Na2*(U12-U02))/(2*hbar*(delta-det)+2*U11+Na2*(-U00+U01-U02+U12)+2*n*(U00-U01-U02+U12))**2
                         +(hbar*delta+Na2*(U01-U02))/(2*hbar*delta+Na2*(U01-U02+U12-U22)+2*n*(-U01+U02+U12-U22))**2
                         -2*(hbar*delta+U11+Na2*(U01-U02))/(2*hbar*delta+2*U11+Na2*(U01-U02+U12-U22)+2*n*(-U01+U02+U12-U22))**2) + 
              J0**(2)*np.cos(phi)*((Na2**(2)*(-U00+U01+U02-U12)-2*n*(2*hbar*(delta-det)+Na2*(-U00+U01-U02+U12)))/
                                   (2*hbar*(delta-det)+2*n*(U00-U01-U02+U12)+Na2*(-U00+U01-U02+U12))**2 + 
                                   (Na2**(2)*(U00-U01-U02+U12) + 2*n*(2*hbar*(delta-det)+2*U11+Na2*(-U00+U01-U02+U12)))/
                                   (2*hbar*(delta-det)+2*U11+Na2*(-U00+U01-U02+U12)+2*n*(U00-U01-U02+U12))**2 +
                                   (Na2**(2)*(-U01+U02+U12-U22)+2*n*(2*hbar*delta+2*U11+Na2*(U01-U02+U12-U22)))/
                                   (2*hbar*delta+2*U11+Na2*(U01-U02+U12-U22)+2*n*(-U01+U02+U12-U22))**2 -
                                   (4*hbar*delta*n+2*Na2*n*(U01-U02+U12-U22)+Na2**(2)*(-U01+U02+U12-U22))/
                                   (2*hbar*delta+Na2*(U01-U02+U12-U22)+2*n*(-U01+U02+U12-U22))**2)/np.sqrt(Na2**2-4*n**2))/hbar
    
    return [nDot, phiDot]

def eqn_dw(t, vec):
    n, phi = vec
    
    nDot = -np.sqrt(Na2**2 - 4*n**2)*J**(2)*np.sin(phi)/hbar
    
    phiDot = (2*U*(n-n_off) + 4*J*n*np.cos(phi)/np.sqrt(Na2**2 - 4*n**2))/hbar
    
    return [nDot, phiDot]


n0 = Na2*(f0 - 1/2)
start = time.time()
sol0 = integrate.solve_ivp(eqn0, (times[0], times[-1]), [n0, phi0], t_eval=times, rtol=r_tol, atol=a_tol, max_step = maxstep)
sol1 = integrate.solve_ivp(eqn1, (times[0], times[-1]), [n0, phi0], t_eval=times, rtol=r_tol, atol=a_tol, max_step = maxstep)
#sol_dw = integrate.solve_ivp(eqn_dw, (times[0], times[-1]), [n0, phi0], t_eval=times, rtol=r_tol, atol=a_tol, max_step = maxstep)
end = time.time()

if sol0.success:
    print('Solver time:', np.round((end-start)/60, 2), 'min.')
else:
    print('Solver error:', sol0.message)

if not sol1.success:
    print('Solver error:', sol1.message)

print()
print('Min n0:', np.min(sol0.y[0]/Na2 + 1/2))
print('Pulse time:', times[np.argmin(sol0.y[0])]*tu/1e-3, 'ms')
print(np.max(np.abs(sol1.y[0] - sol1.y[0][0])/Na2))

font = {'family':'Times New Roman', 'weight':'bold', 'size':25}
plt.rc('font', **font)
fig = plt.figure(figsize=(16,8))

plt.plot(times*tu/1e-3, sol0.y[0]/Na2 + 1/2, 'r.-', linewidth=3, label=r'$\sigma_1 = 0$')
plt.plot(times*tu/1e-3, sol1.y[0]/Na2 + 1/2, 'b.-', linewidth=3, label=r'$\sigma_1 = 1$')
#plt.plot(times*tu/1e-3, sol_dw.y[0]/Na2 + 1/2, 'k--', linewidth=3, label=r'Approximations (9) & (10)')
plt.legend(loc='lower left', fontsize=20)
plt.show()

#fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\BEC_double-well.svg", bbox_inches='tight')


In [None]:
#Uncertainty in qubit energy levels due to interactions with the BEC, and mean-field potential arising from
#interactions with the BEC

#Set up for 3D overlap integrals of the Wannier functions and the TF BEC wavefunction
def Wan_wavefunction(x, y, z):
    return ((gauss(x, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(x, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2]))*
            (gauss(y, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(y, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2]))*
            (gauss(z, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(z, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2])))

xi = np.linspace(x[0], x[-1], 2**6 + 1)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)
r_grid = np.sqrt((x_grid + 0*lam_eff/2)**2 + (y_grid + 0*lam_eff/2)**2 + (z_grid + 0*lam_eff/2)**2)

intgrnd =  np.abs(psi_TF(r_grid))**(2)*np.abs(Wan_wavefunction(x_grid, y_grid, z_grid))**2

U10 = g1_m10*integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])
print('Interaction strength between BEC and lattice atoms:', U10/tu/hbar, 'Hz')
print('Uncertainty in energy levels due to interactions with BEC for single qubit:', np.sqrt(N)*U10/tu/hbar, 'Hz')

intgrnd = psi_HO(r_grid, w_HO, 2)*Wan_wavefunction(x_grid, y_grid, z_grid)
#intgrnd = psi_TF(r_grid)*Wan_wavefunction(x_grid, y_grid, z_grid)

print(integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0]))

In [None]:
#Rabi frequencies for coupling between Wannier states and broad harmonic trap states

n = 100
x0, y0, z0 = -5*lam_eff/2, -5*lam_eff/2, -5*lam_eff/2
x1, y1, z1 = 5*lam_eff/2, 5*lam_eff/2, 5*lam_eff/2
num_print = 10

xi = np.linspace(-5*sol_real[0][2], 5*sol_real[0][2], 2**5 + 1)
yi, zi = np.copy(xi), np.copy(xi)
dx, dy, dz = xi[1] - xi[0], yi[1] - yi[0], zi[1] - zi[0]

x2 = R*np.linspace(-1, 1, len(xi))
dx2 = x2[1] - x2[0]

x_grid, y_grid, z_grid = np.meshgrid(xi, yi, zi, sparse=True)
x2_grid, y2_grid, z2_grid = np.meshgrid(x2, x2, x2, sparse=True)

a0p = np.sqrt(hbar/m/w_HO)

#Distances will be divided by a0p for the HO calculation
x_grid0, y_grid0, z_grid0 = (x_grid - x0)/a0p, (y_grid - y0)/a0p, (z_grid-z0)/a0p
x_grid1, y_grid1, z_grid1 = (x_grid - x1)/a0p, (y_grid - y1)/a0p, (z_grid-z1)/a0p

#Check normalization
intgrnd = np.abs(Wan_wavefunction(x_grid, y_grid, z_grid))**2
print(integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = dz), dx = dy), dx = dx))
print()

levels = int((n+1)*(n+2)*(n+3)/6)

#Generate the Wannier, TF and HO wavefunctions
wan = Wan_wavefunction(x_grid, y_grid, z_grid)
TF_wavefunction = psi_TF(np.sqrt(x2_grid**2 + y2_grid**2 + z2_grid**2))

def psi_HO(xp, w, n):
    """Define a new function to calculate the harmonic oscillator (HO) eigenstates that doesn't
    call on scipy.special.eval_hermit, because the scipy one fails for n > 268. Implements
    the algorithm from "A fast algorithm for evaluation of normalized Hermite functions."
    
    Parameters
    ----------
    xp : array_like
    Points in real-space at which to evaluate the HO eigenstates.
    
    w : float
    Frequency of the HO.
    
    n : int
    All eigenstates up to the nth one will be returned. Must be >= 0.
    
    Returns
    -------
    psi : list
    List of legth n+1 where each element is an array with the same dimensions as xp.
    Each element of psi os one of the HO eigenstates.
    """
    
    a0 = np.sqrt(hbar/m/w) #HO length scale
    x = xp/a0 #dimensionless distance
    psi = [] #we'll return all the wavefunctions up to and including the nth one
    
    #For convenience
    inv_sqrt_a0 = 1/np.sqrt(a0)
    inv_pi_one_fourth = np.pi**(-0.25)
    exponential = np.exp(-0.5*x**2)
    sqrt_2 = np.sqrt(2)
    
    #Calculate the first two wavefunctions, and then use them to build up the others 
    #one-by-one
    
    psi.append(inv_pi_one_fourth*inv_sqrt_a0*exponential)
    
    if n >= 1:
        psi.append(sqrt_2*inv_pi_one_fourth*inv_sqrt_a0*exponential*x)
    
    if n > 1:
        h_i_2 = np.ones_like(x)*inv_pi_one_fourth
        h_i_1 = sqrt_2*inv_pi_one_fourth*x
        for i in range(2, n+1):
            h_i = (sqrt_2*x*h_i_1 - np.sqrt(i-1)*h_i_2) / np.sqrt(i)
            h_i_2, h_i_1 = h_i_1, h_i
            
            psi.append(inv_sqrt_a0 * h_i * exponential)
    
    return psi

"""
#Note: eval_hermite cannot handle any nx, ny, nz > 268!
def Harmonic_wavefunctions(nx, ny, nz, a, x, y, z):
    #Assumes x, y, and z have been normalized by the HO length, sqrt(hbar / m / w_HO)
    norm = np.pi**(-3/4)*a**(-3/2) * np.exp(-0.5*(np.log(2) * (nx+ny+nz) + kate.loggamma(nx+1) + kate.loggamma(ny+1) + kate.loggamma(nz+1)))
    return norm * kate.eval_hermite(nx, x) * kate.eval_hermite(ny, y) * kate.eval_hermite(nz, z) * np.exp(-0.5*(x**2 + y**2 + z**2))
"""

HO_wavefunctions = psi_HO(x2, w_HO, n)
HO_wavefunctions0x = psi_HO(xi - x0, w_HO, n)
HO_wavefunctions0y = psi_HO(xi - y0, w_HO, n)
HO_wavefunctions0z = psi_HO(xi - z0, w_HO, n)
HO_wavefunctions1x = psi_HO(xi - x1, w_HO, n)
HO_wavefunctions1y = psi_HO(xi - y1, w_HO, n)
HO_wavefunctions1z = psi_HO(xi - z1, w_HO, n)

#Field profile
w_beam = 14.1e-6/du
field = (psi_HO(x2, w_HO*(a0p/w_beam)**2, 1)[0][np.newaxis,:,np.newaxis]*
         psi_HO(x2, w_HO*(a0p/w_beam)**2, 1)[0][:,np.newaxis,np.newaxis]*
         np.ones(len(x2))[np.newaxis,np.newaxis,:] )#/ TF_wavefunction)

field *= 1/np.max(np.abs(field))

energy_shifts, omega0, omega1 = [], [], []
U02 = []

ind = 0
total = 0
frac = np.ceil(levels / num_print)
print('Wavefunctions:', levels)
for nx in range(n+1):
    HOx = HO_wavefunctions[nx][np.newaxis,:,np.newaxis]
    HOx0 = HO_wavefunctions0x[nx][np.newaxis,:,np.newaxis]
    HOx1 = HO_wavefunctions1x[nx][np.newaxis,:,np.newaxis]
    for ny in range(n+1-nx):
        HOy = HO_wavefunctions[ny][:,np.newaxis,np.newaxis]
        HOy0 = HO_wavefunctions0y[ny][:,np.newaxis,np.newaxis]
        HOy1 = HO_wavefunctions1y[ny][:,np.newaxis,np.newaxis]
        for nz in range(n+1-nx-ny):
            start = time.time()
            
            energy_shifts.append(integrate.romb(integrate.romb(integrate.romb(HOx * HOy * HO_wavefunctions[nz][np.newaxis,np.newaxis,:] * field * TF_wavefunction, dx=dx2), dx=dx2), dx=dx2))
            
            omega0.append(integrate.romb(integrate.romb(integrate.romb(HOx0 * HOy0 * HO_wavefunctions0z[nz][np.newaxis,np.newaxis,:] * wan, dx=dz), dx=dy), dx=dx))
            omega1.append(integrate.romb(integrate.romb(integrate.romb(HOx1 * HOy1 * HO_wavefunctions1z[nz][np.newaxis,np.newaxis,:] * wan, dx=dz), dx=dy), dx=dx))
            
            #Interactions between BEC and atoms in the HO
            U02.append(integrate.romb(integrate.romb(integrate.romb((HOx * HOy * HO_wavefunctions[nz][np.newaxis,np.newaxis,:] * TF_wavefunction)**2, dx=dx2), dx=dx2), dx=dx2))
            
            end = time.time()
            
            total += end-start
            
            if ind % frac == 0:
                end_time = dtime.datetime.today() + dtime.timedelta(seconds = int((levels - ind - 1)*total/(ind+1)))
                print('Stage', str(int(ind/frac)) + '/' + str(num_print) + ': Elapsed time:', np.round(total/60, 2), 'min, estimated end time:', end_time)
            
            ind += 1

print('Total time:', np.round(total/60, 2), 'min.')
print()

"""
with open('energy_shifts.pickle','wb') as f:
    pickle.dump(energy_shifts,f)
with open('omega0.pickle','wb') as f:
    pickle.dump(omega0,f)
with open('omega1.pickle','wb') as f:
    pickle.dump(omega1,f)
"""

energy_shifts = np.array(energy_shifts)
omega0 = np.array(omega0)
omega1 = np.array(omega1)
U02 = g12_00*np.array(U02)

#Generate the energies for each state
n_vals = np.arange(n+1)
nx, ny, nz = np.meshgrid(n_vals, n_vals, n_vals, sparse=True)
n_vals = nx + ny + nz
n_vals = n_vals[n_vals <= n]

energies = hbar*w_HO*(n_vals + 3/2)

In [None]:
#Rubidium-87 Steck parameters for calculating the dipole potential
w1, gamma1 = 2*np.pi*377.1*10**(12)*tu, 2*np.pi*5.746*10**(6)*tu #D1 line resonance and linewidth
w2, gamma2 = 2*np.pi*384.2*10**(12)*tu, 2*np.pi*6.065*10**(6)*tu #D2 line resonance and linewidth
gJ, F, mF, I, J = 2.00233113, 1, -1, 3/2, 1/2 #Lande factor, atom state quantum numbers
gF = gJ*(F*(F+1) - I*(I+1) + J*(J+1))/(2*F*(F+1)) #Lande g-factor
wF = 2*np.pi*2.563*10**(9)*tu #Frequency of the F = 2 ground state

#Waveguide beam (propagation direction denoted z here)
lam_WG = 532*10**(-9)/du #m, wavelength
P0_WG = 0.2*10**(-3)*tu**(3)/mu/du**2 #W, beam power at the focus
w0x_WG = 14.1e-6/du #m, 1/e^2 radius along x-axis
w0y_WG = w0x_WG #, 1/e^2 radius along y-axis
P = 0 #Polarization (0 = linearly, +/- = LH/RH)

E0_WG = np.sqrt(4*P0_WG*mu0*c/(np.pi*w0x_WG*w0y_WG))
w = 2*np.pi*c/lam_WG #laser frequency
delta1 = w - (w1 - wF)
delta2 = w - (w2 - wF)
A_WG = np.pi*c**(2)*(gamma2*(2+P*gF*mF)/delta2/w2**(3) + gamma1*(1-P*gF*mF)/delta1/w1**3)/2
A_WG2 = np.pi*c**(2)*(gamma2**(2)*2/delta2**(2)/w2**(3) + gamma1**(2)/delta1**(2)/w1**3)/(2*hbar)

energy_shifts2 = (2*A_WG*P0_WG/(np.pi*w0x_WG*w0y_WG)) * N * np.abs(energy_shifts)**2

print('Lifetime:', np.min(tu / ((2*A_WG2*P0_WG/(np.pi*w0x_WG*w0y_WG)) * N * energy_shifts**2)), 's')

print('Rabi frequency:', (2*A_WG*P0_WG/(np.pi*w0x_WG*w0y_WG))*1e-3/tu/hbar, 'kHz')

indcs = np.argsort(energies)

#plt.figure()
#plt.plot(energies[indcs][:int(len(energies))]/hbar/w_HO - 3/2, energy_shifts2[indcs][:int(len(energies))]*1e-3/tu/hbar, 'b.-')
#plt.show()

In [None]:
#Compute the coupling terms and print the results

omega = 1e6*tu
dN = 0*np.sqrt(N) #Uncertainty in the number of BEC atoms

energy_shifts2 = hbar*omega * (N+dN) * np.abs(energy_shifts)**2
energies2 = np.copy(energies) + np.copy(energy_shifts2)

delta = np.sort(energies2)[0]/hbar - 1000e3*tu
#delta = -10000e3*tu

omega = 0.1/np.max([np.max(np.abs(omega0 / (delta - energies2/hbar))), np.max(np.abs(omega1 / (delta - energies2/hbar)))])

print('Free-space Rabi frequency:', omega*1e-3/tu, 'kHz')

#Set up for 3D overlap integrals of the Wannier functions and the TF BEC wavefunction
def Wan_wavefunction(x, y, z):
    return ((gauss(x, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(x, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2]))*
            (gauss(y, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(y, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2]))*
            (gauss(z, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(z, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2])))

xi = np.linspace(x[0], x[-1], 2**5 + 1)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)

U01 = g1_m10*integrate.romb(integrate.romb(integrate.romb(np.abs(Wan_wavefunction(x_grid, y_grid, z_grid))**(2)*
                                                          np.abs(psi_TF(np.sqrt((x_grid-x0)**2 + 
                                                                                (y_grid-y0)**2 + 
                                                                                (z_grid-z0)**2)))**2, 
                                                          dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])

U01p = g1_m10*integrate.romb(integrate.romb(integrate.romb(np.abs(Wan_wavefunction(x_grid, y_grid, z_grid))**(2)*
                                                           np.abs(psi_TF(np.sqrt((x_grid-x1)**2 + 
                                                                                 (y_grid-y1)**2 + 
                                                                                 (z_grid-z1)**2)))**2, 
                                                           dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])

def func(deltap, N):
    return np.abs(delta + U01*N/hbar + np.sum((omega*omega0)**(2) / (delta + U01*N/hbar - energies2/hbar - N*U02/hbar)) - 
                  (deltap + U01p*N/hbar + np.sum((omega*omega1)**(2) / (deltap + U01p*N/hbar - energies2/hbar - N*U02/hbar))))

opts = {"maxiter":maxiter, "disp":False} #options for the minimize
out = optimize.minimize(func, delta, args=(N), method='Nelder-Mead', options=opts)
deltap = out.x

print('Two-photon detuning:', func(deltap, N+dN)*1e-3/tu, 'kHz')

coupling_terms = omega**(2) * omega0 * omega1 * (1/(delta + U01*(N+dN)/hbar - energies2/hbar - U02*(N+dN)/hbar) + 1/(deltap + U01p*(N+dN)/hbar - energies2/hbar - U02*(N+dN)/hbar))

indcs = np.argsort(energies)

print('Sum of coupling terms:', np.sum(coupling_terms)*1e-3/tu, 'kHz')

terms = np.cumsum(np.real(coupling_terms[indcs]))[energies[indcs] == np.max(energies)]

print('Average of the cumulative sum over the terms with the largest energy', 0.5*(np.max(terms) + np.min(terms))*1e-3/tu)

plt.figure()
plt.plot(energies[indcs]/hbar/w_HO - 3/2, np.cumsum(np.real(coupling_terms[indcs]))*1e-3/tu, 'b.')
#plt.plot(energies[indcs][int(len(energies)*0.5):]/hbar/w_HO - 3/2, np.cumsum(np.real(coupling_terms[indcs]))[int(len(energies)*0.5):]*1e-3/tu, 'b.')
#plt.plot(energies[indcs]/hbar/w_HO - 3/2, np.cumsum(np.imag(coupling_terms[indcs]))*1e-3/tu, 'r.')
plt.xlabel('HO energy energy level')
plt.ylabel('Effective Rabi frequency (kHz)')
plt.show()

In [None]:
#Spatial state of the noncondensed atoms in the TF approximation

def psi_exc(r):
    f = (8/3)*np.sqrt(np.abs(psi_TF(r))**(2)*asc**(3)/np.pi) #T=0 noncondensed fraction of atoms
    intgrnd = f*np.abs(psi_TF(r))**2
    f = (1-frac)*f/integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])
    return np.sqrt(f)*psi_TF(r)

In [None]:
#Differential Zeeman shift

B2 = 17 #G, magnetic field

print('Differential Zeeman shift:', 2*np.pi*72*B2**(2)/np.pi/2/1e3, 'kHz')

In [None]:
#Bose-Hubbard for two-level atoms in a spin-dependent lattice and spin-independent harmonic trap
#where a limited number of atoms can populate each lattice site, setup to load one atom from the 
#HO trap into each lattice site
#(uses results from the cells above)

Na = N
dN = np.sqrt(Na)
J = 2*np.pi*hbar*0.1e3*tu/np.sqrt(Na+dN)  #J, two-photon Rabi frequency for transitions between the lattice and HO states
J2 = hbar*0*tu                  #J, Rabi frequency for transitions between adjacent HO states
B = 17                          #G, magnetic field
V0 = hbar*0.7*B*1e6*tu          #J, Zeeman splitting between the lattice sites and HO states
drift = 0.1*lam_eff/2           #m, uncertainty in the lattice position relative to the HO trap
 
use_cubic = False
latt_states = 1                #number of lattice sites (must be an even number raised to the third power if cubic is True)
HO_states = 1                  #number of HO states (must be 1 or 2)
n_latt = 2                     #the Hilbert space will be restricted to Fock states with <= n_latt atoms in each lattice site
dN_exc = 0                     #the Hilbert space will be restricted to Fock states where fewer than dN_exc atoms have entered or left the noncondensed state
gen_eff = True                 #whether to generate the Hilbert space basis efficiently (independent of N)
RWA = True                     #whether to remove the time dependence from the Hamiltonian with the RWA (if False a diffEq solver will be used)
r_tol, a_tol, maxstep = 1e-8, 1e-8, 1e1 #relative tolerance, absolute tolerance, and max step size for the diffEq solver


#Set up for 3D overlap integrals of the Wannier functions and the TF BEC wavefunction
def Wan_wavefunction(x, y, z):
    return ((gauss(x, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(x, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2]))*
            (gauss(y, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(y, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2]))*
            (gauss(z, sol_real[0][0], sol_real[0][1], sol_real[0][2]) +
             1j*gauss(z, sol_imag[0][0], sol_imag[0][1], sol_imag[0][2])))

xi = np.linspace(x[0], x[-1], 2**5 + 1)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)

intgrnd = psi_TF(np.sqrt(x_grid**2 + y_grid**2 + z_grid**2))*Wan_wavefunction(x_grid, y_grid, z_grid)
overlap = integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])

J *= 1/np.abs(overlap)
omega = np.sqrt(Na)*J*np.abs(overlap)/hbar #effective Rabi frequency
tf = 2*np.pi/omega               #s, calculation time (Rabi (pi-pulse) period)

print('hbar*omega/U =', hbar*omega/U)
if hbar*omega/U >= 0.1:
    print('Warning: should have hbar*omega/U =', hbar*omega/U, '<< 1!')

times = np.linspace(0, tf, 4000) #diffEq solver will output solutions at these times

#Calculate the interaction strengths for each lattice site (i.e. U) and HO state. The HO interaction strengths
#will be calculated using the HO eigenstates

U_vals = np.zeros(latt_states + HO_states, dtype=complex)

#Assume all lattice sites correspond to atoms in the state |F=1,mF=-1>
for ind in range(latt_states):
    U_vals[ind] = g1_m1m1*integrate.romb(np.abs(wan0)**4, dx=delta)**3

#Assume the HO corresponds to atoms in the state |F=1,mF=0>

xi = np.linspace(-rTF[-1], rTF[-1], 2**5 + 1)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)

for ind in range(HO_states):
    if ind == 0:
        intgrnd = np.abs(psi_TF(np.sqrt(x_grid**2 + y_grid**2 + z_grid**2)))**4
        U_vals[latt_states+ind] = g1_00*integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])
    else:
        intgrnd = np.abs(psi_exc(np.sqrt(x_grid**2 + y_grid**2 + z_grid**2)))**4
        U_vals[latt_states+ind] = g1_00*integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])

#Calculate the interaction strengths between the lattice sites and the HO states (assume no inter-lattice 
#interactions), and between all pairs of HO states

xi = np.linspace(x[0], x[-1], 2**5 + 1)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)

U12_vals = np.zeros(np.sum(np.arange(latt_states+HO_states)), dtype=complex)
U12_vals2 = np.zeros_like(U12_vals) #values when the lattice has drifted

#Generate indices for the lattice sites such that the lattices are arranged in a cubic pattern
if latt_states > 1 and use_cubic:
    lattice_length = int(np.round(latt_states**(1/3)/2)) #there will be 2*lattice_length sites along each dimension
    indcs = np.zeros((latt_states, 3), dtype=int) #indices for each lattice site

    nums = np.zeros(2*lattice_length, dtype=int)
    nums[:lattice_length] = np.arange(-lattice_length, 0)
    nums[lattice_length:] = np.arange(1, lattice_length+1)

    base = [0]
    for ind in range(indcs.shape[1]):
        base = np.repeat(nums, len(base))
        indcs[:,indcs.shape[1]-ind-1] = np.tile(base, int(np.round(len(indcs)/len(base))))
else:
    indcs = 1*np.repeat(np.arange(1, latt_states+1)[:,np.newaxis], repeats=3, axis=1)

count = 0
for ind in range(latt_states):
    count += latt_states-ind-1
    for ind2 in range(HO_states):
        if ind2 == 0:
            intgrnd =  np.abs(psi_TF(np.sqrt((x_grid + indcs[ind][0]*lam_eff/2)**2 + (y_grid + indcs[ind][1]*lam_eff/2)**2 + (z_grid + indcs[ind][2]*lam_eff/2)**2)))**(2)*np.abs(Wan_wavefunction(x_grid, y_grid, z_grid))**2
            intgrnd2 = np.abs(psi_TF(np.sqrt((x_grid + drift + indcs[ind][0]*lam_eff/2)**2 + (y_grid + indcs[ind][1]*lam_eff/2)**2 + (z_grid + indcs[ind][2]*lam_eff/2)**2)))**(2)*np.abs(Wan_wavefunction(x_grid, y_grid, z_grid))**2
        else:
            intgrnd =  np.abs(psi_exc(np.sqrt((x_grid + indcs[ind][0]*lam_eff/2)**2 + (y_grid + indcs[ind][1]*lam_eff/2)**2 + (z_grid + indcs[ind][2]*lam_eff/2)**2)))**(2)*np.abs(Wan_wavefunction(x_grid, y_grid, z_grid))**2
            intgrnd2 = np.abs(psi_exc(np.sqrt((x_grid + drift + indcs[ind][0]*lam_eff/2)**2 + (y_grid + indcs[ind][1]*lam_eff/2)**2 + (z_grid + indcs[ind][2]*lam_eff/2)**2)))**(2)*np.abs(Wan_wavefunction(x_grid, y_grid, z_grid))**2
        
        U12_vals[count] = g1_m10*integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])
        U12_vals2[count] = g1_m10*integrate.romb(integrate.romb(integrate.romb(intgrnd2, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])
                
        count += 1

xi = np.linspace(-rTF[-1], rTF[-1], 2**5 + 1)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)

for ind in range(HO_states):
    for ind2 in range(ind+1, HO_states):
        intgrnd = np.abs(psi_TF(np.sqrt(x_grid**2 + y_grid**2 + z_grid**2)))**(2)*np.abs(psi_exc(np.sqrt(x_grid**2 + y_grid**2 + z_grid**2)))**2
        U12_vals[count] = g1_00*integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])
        U12_vals2[count] = U12_vals[count]
        count += 1

#Allow tunneling between the lattice sites and the HO states only (i.e. no inter-lattice or inter-HO tunneling) by
#creating a list of arrays where each array stores the tunneling amplitudes for a particular length hop (i.e. from 
#1st lattice site to 1st HO state is the same length as 2nd lattice site to 2nd HO state). For each length hop
#we'll run through all of the lattice sites that have an allowed transition of this particular hop length
#and calculate the tunneling matrix element (start from hop length of 1)

xi = np.linspace(x[0], x[-1], 2**5 + 1)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)
lam_MW = c/(V0/hbar/np.pi/2)
exp_term = np.exp(1j*(2*np.pi/lam_MW)*x_grid)

J_vals = []
J_vals2 = [] #values when the lattice has drifted
for ind in range(latt_states + HO_states - 1):
    J_vals.append(np.zeros(latt_states + HO_states - 1 - ind, dtype=complex))
    J_vals2.append(np.zeros(latt_states + HO_states - 1 - ind, dtype=complex))
    for ind2 in range(max([0, latt_states - ind - 1]), latt_states + min([0, HO_states - ind - 1])):
        if ind+1+ind2 - latt_states == 0:
            intgrnd = psi_TF(np.sqrt((x_grid + indcs[ind2][0]*lam_eff/2)**2 + (y_grid + indcs[ind2][1]*lam_eff/2)**2 + (z_grid + indcs[ind2][2]*lam_eff/2)**2))*exp_term*Wan_wavefunction(x_grid, y_grid, z_grid)
            intgrnd2 = psi_TF(np.sqrt((x_grid + drift + indcs[ind2][0]*lam_eff/2)**2 + (y_grid + indcs[ind2][1]*lam_eff/2)**2 + (z_grid + indcs[ind2][2]*lam_eff/2)**2))*exp_term*Wan_wavefunction(x_grid, y_grid, z_grid)
        else:
            intgrnd = psi_exc(np.sqrt((x_grid + indcs[ind2][0]*lam_eff/2)**2 + (y_grid + indcs[ind2][1]*lam_eff/2)**2 + (z_grid + indcs[ind2][2]*lam_eff/2)**2))*exp_term*Wan_wavefunction(x_grid, y_grid, z_grid)
            intgrnd2 = psi_exc(np.sqrt((x_grid + drift + indcs[ind2][0]*lam_eff/2)**2 + (y_grid + indcs[ind2][1]*lam_eff/2)**2 + (z_grid + indcs[ind2][2]*lam_eff/2)**2))*exp_term*Wan_wavefunction(x_grid, y_grid, z_grid)
            
        J_vals[ind][ind2] = -(J/2)*integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])
        J_vals2[ind][ind2] = -(J/2)*integrate.romb(integrate.romb(integrate.romb(intgrnd2, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])

#Compute the energy difference between the initial and final states when the field is on due to atom interactions
#so we can correct it
indcs = np.arange(latt_states)
indcs = (-1 -indcs**(2)/2 + latt_states + indcs*(latt_states+HO_states-3/2)).astype(int)
V1 = np.real((U_vals[latt_states]/2)*latt_states*(2*Na-latt_states-1) - np.sum(U12_vals[indcs]*(Na-latt_states)))/latt_states #two-photon detuning

#Finally, set the energy offset matrices for when the field is on
Von_vals = np.zeros(len(U_vals), dtype=complex)
Von_vals[latt_states:] = hbar*w_HO*(np.arange(HO_states) + 1/2) - V0
Von_vals[latt_states:] += Von_vals[0] - Von_vals[latt_states] - V1

def gen_basis(N, wells):
    """Generate the number state basis for N atoms in a lattice where wells is the number of sites. Uses 
    the method from the paper 'Exact diagonalization: the Bose–Hubbard model as an example'"""
    i = np.arange(1, wells)
    num_states = int(np.product((N+wells-i)/i))
    
    basis = np.zeros((num_states, wells), dtype=int)
    
    basis[0,0] = N
    for ind in range(1,num_states):
        basis[ind] = np.copy(basis[ind-1])
        k = (np.arange(wells-1)[basis[ind,:-1] > 0])[-1]
        basis[ind,k] += -1
        basis[ind,k+1] = N - np.sum(basis[ind,:k+1])
        basis[ind,k+2:] = 0   
    
    if basis[-1,-1] != N:
        insert = np.zeros(len(basis[0,:]))
        insert[-1] = N
        basis = np.append(basis, insert[np.newaxis,:], axis=0)
    
    return basis

def gen_basis2(N, latt_states, HO_states, n_latt, n_HO, ignore_odd_parity):
    """Efficiently generates the number state basis for N atoms in a lattice. 
    All states will be of the form: 
    
    |latt latt ... latt HO HO ... HO>
    |i0   i1   ... in   j0 j1 ... jm>
    
    where 
    i0...in, all range from 0 to n_latt,
    j0, can range from 0 to N (but each state will have a total of N atoms),
    j1...jm, range from 0 to n_HO (if ignore_odd_parity = True then all j1,j3,... will be 0),
    and there will be a total of N atoms in each state, distributed over the latt and HO levels."""
    
    latt_len = (n_latt+1)**latt_states
    basis = np.zeros((latt_len, latt_states+HO_states), dtype=int)
    nums = np.arange(n_latt, -1, -1)
    
    for ind in range(latt_states):
        basis[:,ind] = np.tile(np.repeat(nums, (n_latt+1)**(latt_states - ind - 1)), (n_latt+1)**ind)
    
    nums = np.arange(n_HO, -1, -1)
    if ignore_odd_parity:
        basis = np.tile(basis, ((n_HO+1)**int((HO_states-1)/2),1))
        
        for ind in range(latt_states+2, latt_states+HO_states, 2):        
            reps = np.repeat(nums, latt_len*(n_HO+1)**(int((HO_states-3)/2) - int((ind-latt_states-2)/2)))
            basis[:,ind] = np.tile(reps, int(len(basis[:,ind])/len(reps)))
    else:
        basis = np.tile(basis, ((n_HO+1)**int(HO_states-1),1))
        
        for ind in range(latt_states+1, latt_states+HO_states):
            reps = np.repeat(nums, latt_len*(n_HO+1)**(latt_states + HO_states - ind - 1))
            basis[:,ind] = np.tile(reps, int(len(basis[:,ind])/len(reps)))
    
    basis[:,latt_states] = N - np.sum(basis, axis=1)
    
    return basis

def tunneling_terms(basis, J_terms, dist):
    """Returns a matrix in the input basis with tunneling elements filled in above the diagonal.
    J_terms should be an array that contains the tunneling matrix element for each pair of lattice sites 
    separated by a distance dist."""
    V = np.zeros((len(basis), len(basis)), dtype = J_terms.dtype)
    indcs = np.arange(len(V))
    
    for ind in range(len(basis)):
        diffs = basis[ind] - basis
        tunnel_left_right = np.logical_and((diffs == 1)[:,0:-1*dist], (diffs == -1)[:,dist:])
        J_vals = np.sum(J_terms*np.sqrt(basis[ind][np.newaxis,:-1*dist].astype(float)*(basis[ind][np.newaxis,dist:]+1))*tunnel_left_right, axis=1)
        
        #Only choose states seperated from the state basis[ind] by a single tunneling event:
        res = np.logical_and(np.sum(tunnel_left_right, axis=1) == 1, np.sum(np.abs(diffs), axis=1) == 2)
        
        V[ind, indcs[res]] = J_vals[res]
    
    return V

def make_Hermitian(H0, upper=True):
    """Returns a new matrix identical to H0, but that has been made Hermitian 
    by taking the conjugate transpose of the lower or upper half of the matrix."""
    H = np.tril(H0)
    if upper:
        H = np.triu(H0)
    H += np.conjugate(np.transpose(H))
    H[np.arange(len(H)), np.arange(len(H))] = H[np.arange(len(H)), np.arange(len(H))]/2
    return H

def solve_Schrodinger(plot):
    #Generate the basis states
    if HO_states == 1:
        if not gen_eff:
            basis = gen_basis(Na, latt_states + HO_states)
            for ind in range(latt_states):
                basis = basis[basis[:,ind] <= n_latt] #keep only states where the lattice sites each have <= n_latt atoms
        else:
            basis = gen_basis2(Na, latt_states, HO_states, n_latt, int(Nt-N+dN_exc), False)
        
        indi = np.arange(len(basis))[basis[:,latt_states] == Na][0]
        indf = np.arange(len(basis))[np.all(basis[:,:latt_states] == 1, axis=1) & (basis[:,latt_states] == Na-latt_states)][0]
    else:
        if not gen_eff:
            basis = gen_basis(Nt, latt_states + HO_states)
            for ind in range(latt_states):
                basis = basis[basis[:,ind] <= n_latt] #keep only states where the lattice sites each have <= n_latt atoms
        else:
            basis = gen_basis2(Nt, latt_states, HO_states, n_latt, int(Nt-N+dN_exc), False)
            basis = basis[np.abs(basis[:,latt_states+1] - (Nt-N)) <= dN_exc] #keep only states where the noncondensed atoms are within dN_exc of the initial value
        
        indi = np.arange(len(basis))[(basis[:,latt_states] == Na) & (basis[:,latt_states+1] == Nt-N)][0]
        indf = np.arange(len(basis))[np.all(basis[:,:latt_states] == 1, axis=1) & (basis[:,latt_states] == Na-latt_states) & (basis[:,latt_states+1] == Nt-Na)][0]
        
    initial = np.zeros(len(basis), dtype=complex)
    initial[indi] = 1
    
    #Interaction terms between atoms in the same states
    U_matrix = np.diagflat(np.sum((U_vals/2)*basis*(basis-1), axis=1)).astype(complex)
    
    #Interaction terms between atoms in different states
    U12_matrix = np.zeros(len(basis), dtype=complex)
    U12_matrix2 = np.zeros_like(U12_matrix)
    count = 0
    for ind in range(len(basis[0,:])-1):
        elems = len(basis[0,:])-ind-1
        U12_matrix += np.sum(U12_vals[count:count+elems]*basis[:,ind][:,np.newaxis]*basis[:,ind+1:], axis=1)
        U12_matrix2 += np.sum(U12_vals2[count:count+elems]*basis[:,ind][:,np.newaxis]*basis[:,ind+1:], axis=1)
        count += elems
    
    U12_matrix = np.diagflat(U12_matrix).astype(complex)
    U12_matrix2 = np.diagflat(U12_matrix2).astype(complex)
    
    #Set tunneling between the lattice and HO states
    J_matrix = np.zeros(U_matrix.shape, dtype=complex)
    J_matrix2 = np.zeros_like(J_matrix)
    for ind in range(len(J_vals)):
        J_matrix += tunneling_terms(basis, J_vals[ind], ind+1)
        J_matrix2 += tunneling_terms(basis, J_vals2[ind], ind+1)
    J_matrix = make_Hermitian(J_matrix)
    J_matrix2 = make_Hermitian(J_matrix2)
    
    #Set tunneling between adjacent HO states
    J2_terms = (-J2/2)*np.ones(len(basis[0,:])-1) #factor of 1/2 because of the RWA
    J2_terms[:latt_states] = 0 #do not allow J2 to cause tunneling between the lattice sites and the HO
    J2_matrix = make_Hermitian(tunneling_terms(basis, J2_terms, 1))
    
    #Voff will contain terms corresponding to the potential difference between all of the states (field off)
    Voff_vals = np.zeros(len(U_vals), dtype=complex)
    Voff_vals[:latt_states] = 0.5*m*w_HO**(2)*(0.5*lam_eff*np.arange(1, latt_states+1))**2
    Voff_vals[latt_states:] = hbar*w_HO*np.arange(HO_states) - V0
    Voff_matrix = make_Hermitian(np.diagflat(np.sum(Voff_vals*basis, axis=1)))
    
    #Von will contain terms corresponding to the potential difference between all of the states (field on)
    
    #Minimize the energy difference between the initial and final states by changing the two-photon detuning, V1
    #Done using the Hamiltonian when the lattice has drifted, so when the dynamics are calculated using the
    #Hamiltonian for the lattice not drifting, the offset V1 will be off by a little bit (corresponding to our
    #uncertainty in the actual position of the lattice)
    def func(V1):
        Von_vals = np.zeros(len(U_vals), dtype=complex)
        Von_vals[:latt_states] = 0.5*m*w_HO**(2)*(0.5*lam_eff*np.arange(1, latt_states+1))**2
        Von_vals[latt_states:] = hbar*w_HO*np.arange(HO_states) - V0
        Von_vals[latt_states:] += Von_vals[0] - Von_vals[latt_states] - V1
        
        Von_matrix = make_Hermitian(np.diagflat(np.sum(Von_vals*basis, axis=1)))
        H = U_matrix + U12_matrix2 + J_matrix2 + J2_matrix + Von_matrix
        return np.abs(H[indi,indi] - H[indf,indf])/tu/hbar
    
    maxiter = 10000
    opts = {"maxiter":maxiter, "disp":False} #options for the minimize
    start = time.time()
    out = optimize.minimize(func, V1, method='Nelder-Mead', options=opts)
    end = time.time()
    
    #Differential Zeeman shift
    #out.x[0] += hbar*20e3*tu
    
    Von_vals = np.zeros(len(U_vals), dtype=complex)
    Von_vals[:latt_states] = 0.5*m*w_HO**(2)*(0.5*lam_eff*np.arange(1, latt_states+1))**2
    Von_vals[latt_states:] = hbar*w_HO*np.arange(HO_states) - V0
    Von_vals[latt_states:] += Von_vals[0] - Von_vals[latt_states] - out.x[0]
    
    Von_matrix = make_Hermitian(np.diagflat(np.sum(Von_vals*basis, axis=1)))
    
    H = U_matrix + U12_matrix + J_matrix + J2_matrix + Von_matrix
    H += -H[0,0]*np.eye(len(H))
    
    print('Detuning:', out.x[0]*1e-3/tu/hbar, 'kHz')
    print('Energy diff between initial and final states:', np.real(H[indi,indi] - H[indf,indf])/tu/hbar, 'Hz')
    
    if not RWA:
        H0 = U_matrix + U12_matrix + Voff_matrix
        wd = np.real(H0[indi,indi] - H[indf,indf])/hbar

        def H_psi(t, psi):
            return np.dot(H0 + J_matrix*np.sin(wd*t) + J2_matrix, psi)/1j/hbar

        start = time.time()
        sol0 = integrate.solve_ivp(H_psi, (times[0], times[-1]), initial, method='RK45', t_eval=times, rtol=r_tol, atol=a_tol, max_step = maxstep)
        end = time.time()

        total = np.sum(np.abs(sol0.y)**2, axis=0)
        if sol0.success:
            print('Solver time:', np.round((end-start)/60, 2), 'min, loss:', str(np.round(np.max(np.abs(1 - total))*100,2)) + '%')
        else:
            print('Solver error:', sol0.message)

        #Store the wavefunction at all times in the Fock basis 
        final = np.copy(sol0.y)
    
    else:

        #Write the initial state in the basis of the eigenvectors
        res0 = np.linalg.eigh(H)
        initial = np.dot(np.linalg.inv(res0[1]), initial)

        #Calculate time evolution of the initial state by applying phase shifts to each eigenvector
        final = initial[:,np.newaxis]*np.exp(-1j*res0[0][:,np.newaxis]*times[np.newaxis,:]/hbar)

        #Go back to the Fock basis
        final = np.dot(res0[1][:,:], final)

        total = np.sum(np.abs(final)**2, axis=0)
    
    #Compute the expected number of lattice sites that have exactly zero, one, or two atoms, as functions of time
    frac_sites_0 = np.sum(np.abs(final)**(2)*np.sum(basis[:,:latt_states] == 0, axis=1)[:,np.newaxis], axis=0)/latt_states
    frac_sites_1 = np.sum(np.abs(final)**(2)*np.sum(basis[:,:latt_states] == 1, axis=1)[:,np.newaxis], axis=0)/latt_states
    frac_sites_2 = np.sum(np.abs(final)**(2)*np.sum(basis[:,:latt_states] == 2, axis=1)[:,np.newaxis], axis=0)/latt_states
    
    indt = np.argmax(frac_sites_1)
    
    print("")
    print('Fraction of sites with 0 atoms:', frac_sites_0[indt])
    print('Fraction of sites with 1 atom: ', frac_sites_1[indt])
    print('Fraction of sites with 2 atoms:', frac_sites_2[indt])
    
    print('Max error:', np.max(total - np.abs(final[indi])**2 - np.abs(final[indf])**2))
    
    print('Pulse time:', times[indt]*tu/1e-3, 'ms')
    
    if plot:
        font = {'family':'Times New Roman', 'weight':'bold', 'size':14}
        plt.rc('font', **font)
        
        plt.figure()
        plt.plot(times*tu/1e-3, frac_sites_0, 'b-', label='n = 0')
        plt.plot(times*tu/1e-3, frac_sites_1, 'r-', label='n = 1')
        plt.plot(times*tu/1e-3, frac_sites_2, 'g-', label='n = 2')
        plt.plot(times*tu/1e-3, total, 'k-', label='Total')
        plt.ylabel('Fraction of sites with n atoms')
        plt.legend()
        plt.show()
        
        def label(array_nums):
            strings = ''
            for num in array_nums:
                strings += str(num) + ' '
            strings = strings[:-1]
            
            return r'$\vert$' + strings + r'$\rangle$'

        colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
        linestyles = ['-', '--', '-.', ':']
        
        plt.figure()
        #for ind in range(len(final[:,0])):
        #    plt.plot(times*tu/1e-6, np.abs(final[ind])**2, c = colors[np.mod(ind, len(colors))], 
        #             ls = linestyles[np.mod(int(ind/len(colors)), len(linestyles))], label=label(basis[ind].astype(int)))
        
        plt.plot(times*tu/1e-3, np.abs(final[indi])**2, 'b-', label=label(basis[indi].astype(int)))
        plt.plot(times*tu/1e-3, np.abs(final[indf])**2, 'r-', label=label(basis[indf].astype(int)))
        plt.plot(times*tu/1e-3, total - np.abs(final[indi])**2 - np.abs(final[indf])**2, 'g-', label=r'Other states')
        plt.plot(times*tu/1e-3, total, 'k-', label='Total')
        plt.legend()
        plt.xlabel(r'$t$ $(ms)$')
        plt.ylabel(r'$P$')
        plt.show()

print("")
start = time.time()
solve_Schrodinger(True)
end = time.time()
print("Calculation time:", np.round((end-start)/60, 2), 'min')


In [None]:
colors = ['b', 'r', 'tab:orange', 'tab:brown', 'tab:gray', 'tab:green', 'tab:purple', 'tab:pink', 'tab:olive', 'tab:cyan']

#BEC wavefunctions

rp = np.linspace(-2*rTF[-1], 2*rTF[-1], 1000)
potp = 0.5*m*w_HO**(2)*rp**(2)*1e-3/tu/hbar

fig = plt.figure(figsize=(12,8))
plt.plot(rp, potp, 'k-', linewidth=6)
plt.plot(rp, 0.5*(np.max(potp)/np.max(np.abs(psi_TF(rp))**2))*np.abs(psi_TF(rp))**2, c = colors[0], linewidth=8)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\HO_0.svg", bbox_inches='tight')

fig = plt.figure(figsize=(12,8))
plt.plot(rp, potp, 'k-', linewidth=6)
plt.plot(rp, 0.5*(np.max(potp)/np.max(np.abs(psi_TF(rp))**2))*np.abs(psi_TF(rp))**2, c = colors[2], linewidth=8)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\HO_2.svg", bbox_inches='tight')


#Wannier function

xp = np.linspace(5*x[0], 3*x[-1], 1000) + 2*lam_eff/2
potp = (m*w_HO**(2)*xp**(2)/2 + Vb*np.sin(2*np.pi*xp/lam_eff)**2)*1e-3/tu/hbar

fig = plt.figure(figsize=(12,8))
plt.plot(xp, potp, 'k-', linewidth=6) 
plt.plot(x, 0.5*(np.max(potp)/np.max(np.abs(wan0)**2))*np.abs(wan0)**2, c = colors[1], linewidth=8)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Wannier.svg", bbox_inches='tight')

#Legend

fig = plt.figure(figsize=(12,8))
plt.plot(rp, 0.5*(np.max(potp)/np.max(np.abs(psi_TF(rp))**2))*np.abs(psi_TF(rp))**2, c = colors[0], linewidth=5, label = r'$| \phi (\mathbf{r}) |^2$')
plt.plot(x, 0.5*(np.max(potp)/np.max(np.abs(wan0)**2))*np.abs(wan0)**2, c = colors[1], linewidth=5, label = r'$| w (\mathbf{r}) |^2$')
plt.plot(rp, 0.5*(np.max(potp)/np.max(np.abs(psi_TF(rp))**2))*np.abs(psi_TF(rp))**2, c = colors[2], linewidth=5, label = r'$| \phi (\mathbf{r}) |^2$')
plt.legend()
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Legend.svg", bbox_inches='tight')

#Lattice

xplot = np.linspace(0, 10*np.pi, 1000)

fig = plt.figure(figsize=(16,4))
plt.plot(xplot, np.cos(xplot)**2, 'k-', linewidth=4)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Lattice.svg", bbox_inches='tight')


#Gaussian potential and energy levels

num = 5 #number of energy levels to plot

fig = plt.figure(figsize=(8,8))

xplot = 0.25*np.linspace(-1, 1, 1000)

pot_plot = -np.exp(-5*xplot**2)

energies = np.linspace(-1, -0.75, 21)
for En in energies[1:]:
    crossings = xplot[np.where(np.diff(np.sign(pot_plot - En)))[0]]
    plt.plot(crossings[0:2], [En, En], c = colors[2], ls = '-', linewidth=4)

plt.plot(xplot, pot_plot, 'k-', linewidth=4)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\HO_levels.svg", bbox_inches='tight')

lw = 4
prob_height = 0.25
energies_min = 0.35
rp = np.linspace(-2*rTF[-1], 2*rTF[-1], 1000)
potp = 0.5*m*w_HO**(2)*rp**(2)*1e-3/tu/hbar
energies = np.max(potp)*np.linspace(energies_min, 1, 21)

fig = plt.figure(figsize=(12,8))

for En in energies:
    crossings = rp[np.where(np.diff(np.sign(potp - En)))[0]]
    plt.plot(crossings[0:2], [En, En], c = colors[0], ls = '-', linewidth=lw)


plt.plot(rp, potp, 'k-', linewidth=lw)
plt.plot(rp, prob_height*(np.max(potp)/np.max(np.abs(psi_TF(rp))**2))*np.abs(psi_TF(rp))**2, c = colors[0], linewidth=lw)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\HO_ground_state_plus_higher0.svg", bbox_inches='tight')

fig = plt.figure(figsize=(12,8))

for En in energies:
    crossings = rp[np.where(np.diff(np.sign(potp - En)))[0]]
    plt.plot(crossings[0:2], [En, En], c = colors[2], ls = '-', linewidth=lw)


plt.plot(rp, potp, 'k-', linewidth=lw)
plt.plot(rp, prob_height*(np.max(potp)/np.max(np.abs(psi_TF(rp))**2))*np.abs(psi_TF(rp))**2, c = colors[2], linewidth=lw)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\HO_ground_state_plus_higher2.svg", bbox_inches='tight')

prob_height = 0.3

xp = np.linspace(5*x[0], 5*x[-1], 1000)
potp = (m*w_HO**(2)*xp**(2)/2 + Vb*np.sin(2*np.pi*xp/lam_eff)**2)*1e-3/tu/hbar
potp2 = (m*w_HO**(2)*xp**(2)/2 + Vb*np.sin(2*np.pi*xp/lam_eff + np.pi/2)**2)*1e-3/tu/hbar
prob_height_factor = prob_height*(np.max(potp)/np.max(np.abs(wan0)**2))
wan_prob = prob_height_factor*np.abs(wan0)**2

fig = plt.figure(figsize=(12,8))
plt.plot(xp, potp, 'k-', linewidth=lw)
plt.plot(xp, potp2, 'k--', linewidth=lw)
plt.plot(x - lam_eff, wan_prob, c = colors[3], linewidth=lw)
plt.plot(x - 0.75*lam_eff, wan_prob, c = colors[3], ls = '--', linewidth=lw)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Other_Zeeman_levels_lattice.svg", bbox_inches='tight')


fig = plt.figure(figsize=(12,8))

energies = np.mean(eigvals, axis=0)*1e-3/tu/hbar
energies = energies[energies < np.max(potp)][1:]

for En in energies:
    crossings = xp[np.where(np.diff(np.sign(potp - En)))[0]]
    plt.plot(crossings[0:2], [En, En], c = colors[1], ls = '-', linewidth=lw)

plt.plot(xp, potp, 'k-', linewidth=lw)
plt.plot(x - lam_eff, wan_prob, c = colors[1], linewidth=lw)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Higher_Bloch_bands_lattice.svg", bbox_inches='tight')

wider = 3
xp_gauss = 1.75*np.linspace(x[0], x[-1], 1000) - lam_eff
gauss_fit = (gauss(xp_gauss, -1*lam_eff, sol_real[0][1], wider*sol_real[0][2]) + 
             1j*gauss(xp_gauss, -1*lam_eff, sol_imag[0][1], wider*sol_imag[0][2]))

fig = plt.figure(figsize=(12,8))

plt.plot(xp, potp, 'k-', linewidth=lw)
plt.plot(xp_gauss, prob_height_factor*np.abs(gauss_fit)**2, c = colors[1], linewidth=lw)
plt.plot()
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Double_occupation_lattice.svg", bbox_inches='tight')


In [None]:
#Double-well: mean-field dynamics for two-qubit gates, population is transferred between the two mF = 0 levels directly

Na = N
dN = 0*np.sqrt(N)
omegap = 1e3*tu     #Hz, what the Rabi frequency should be
frac0 = 0.999999       #fraction of atoms that will initially be in the first state
frac1 = 4/5         #fraction of atoms that should be in the first state at the end
phi0 = 0
V = -hbar*0*tu         #Hz, detuning

times = np.linspace(0, 1e-3, 100000)/tu
r_tol, a_tol, maxstep = 1e-8, 1e-8, 1e1 #relative tolerance, absolute tolerance, and max step size for the diffEq solver

Na += dN

xi = np.linspace(-rTF[-1], rTF[-1], 2**5 + 1)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)
intgrnd = np.abs(psi_TF(np.sqrt(x_grid**2 + y_grid**2 + z_grid**2)))**4
res = integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])
Up = 0.5*(g1_00 - 2*g12_00 + g2_00)*res

J = hbar*omegap
wc, wj = Na*Up/hbar, 2*J/hbar
wp = np.sqrt(wj*(wc+wj))

eta_eq = 2*V/(hbar*wc)

#Due to interactions and the uncertainty in Na, the actual two-photon detuning will be different from the set one
print('Single-photon detuning:', V*1e-3/tu, 'kHz')
print('Actual single-photon detuning:', (V - (g1_00 - g2_00)*res*dN/2)*1e-3/tu/hbar, 'kHz')

print('wj =', wj/tu, 'Hz')
print('wc =', wc/tu, 'Hz')
print()

eta0 = 2*frac0 - 1

def eqn(t, vec, wj, wc):
    eta, phi = vec
    etaDot = -wj*np.sqrt(1-eta**2)*np.sin(phi)
    phiDot = wc*(eta-eta_eq) + wj*eta*np.cos(phi)/np.sqrt(1-eta**2)
    return [etaDot, phiDot]

start = time.time()
sol0 = integrate.solve_ivp(eqn, (times[0], times[-1]), [eta0, phi0], args=(wj,wc), t_eval=times, rtol=r_tol, atol=a_tol, max_step = maxstep)
end = time.time()

if sol0.success:
    print('Solver time:', np.round((end-start)/60, 2), 'min.')
else:
    print('Solver0 error:', sol0.message)

indm = np.argmin(np.abs((1/2)*(1+sol0.y[0]) - frac1))
print('Final population fraction:', str((1/2)*(1+sol0.y[0])[indm]) + ',', 'pulse time:', np.round(times[indm]*tu/1e-3, 4), 'ms')

def model_eta(t):
    return eta_eq*wc*wj/wp**2 + eta0*np.cos(wp*t) - phi0*wj*np.sin(wp*t)/wp

font = {'family':'Times New Roman', 'weight':'bold', 'size':14}
plt.rc('font', **font)

fig = plt.figure(figsize=(14.4,3.9))
plt.plot(times*tu/1e-3, (1/2)*(1+sol0.y[0]), 'b-', label='mean-field')
plt.plot(times*tu/1e-3, (1/2)*(1+model_eta(times)), 'k--', label='model')
plt.xlabel('Time (ms)')
plt.ylabel('Fraction of atoms in the first state')
plt.legend()
plt.show()

In [None]:
#Calculate the spin healing length, which must be larger than the condensate size to prevent
#spatial structures from arising, and the spin mixing rate, which must be smaller than the
#quadratic Zeeman splitting to prevent spin mixing from occuring

B2 = B*tu**(2)/mu #Gauss, magnetic field
l = R #m, length scale of the condensate

heal_length = 2*np.pi*hbar/np.sqrt(2*m*np.abs(c1_1)*N/l**3)

print('Ratio of spin healing length to condensate size:', heal_length/l)

delta_B = 2*np.pi*hbar*72*B2**(2)*mu**(2)/tu**3
c_eff = c1_1*N*integrate.romb(4*np.pi*rTF**(2)*np.abs(psi_TF(rTF))**4, dx=np.abs(rTF[1]-rTF[0]))

print("")
print('Effective spin-mixing rate:', np.abs(c_eff)/tu/hbar, 'Hz')
print('Quadratic Zeeman energy splitting:', np.abs(delta_B)/tu/1e9, 'GHz')
print("")

In [None]:
#Uncertainty in qubit energy levels due to interactions with the BEC, and mean-field potential arising from
#interactions with the BEC

xi = np.linspace(x[0], x[-1], 2**5 + 1)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)

intgrnd =  np.abs(psi_TF(np.sqrt((x_grid + 0*lam_eff/2)**2 + (y_grid + 0*lam_eff/2)**2 + (z_grid + 0*lam_eff/2)**2)))**(2)*np.abs(Wan_wavefunction(x_grid, y_grid, z_grid))**2
print('Uncertainty in energy levels due to interactions with BEC for single qubit:', 
      g1_m10*np.sqrt(N)*integrate.romb(integrate.romb(integrate.romb(intgrnd, dx = xi[1]-xi[0]), dx = xi[1]-xi[0]), dx=xi[1]-xi[0])
      /tu/hbar, 'Hz')

plt.figure()
plt.plot(rTF, g1_00*np.abs(psi_TF(rTF))**(2)*N*1e-3/tu/hbar, 'b-')
plt.xlabel(r'$r$ $(\mu m)$')
plt.ylabel(r'$Energy$ $(kHz)$')
plt.title('Mean field interaction potential due to Bose gas')
plt.show()

In [None]:
#Probability that a qubit atom will be lost after k atom losses

n = 1000 #number of filled lattice sites
k = np.arange(5e3) #atom losses

i = np.arange(n)

frac_lost = np.zeros(len(k))
for ind in range(len(k)):
    #i = np.arange(k[ind])
    #frac[ind] = n*np.product(1 - n/(N-i))/(N - n - k[ind] + 1)
    #frac[ind] = 1 - np.product(1 - (n + 2)/(N - i))
    
    frac_lost[ind] = 1 - np.product(1 - k[ind]/(N-i)) 

indcs0 = n*k/N <= 1


plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.serif"] = "STIX"
plt.rcParams["mathtext.fontset"] = "stix"
plt.rcParams["font.size"] = 28

fig = plt.figure(figsize=(12,8))
plt.plot(k, frac_lost, 'r-', linewidth=3, label=r'full')
plt.plot(k, 1 - (1-k/N)**n, 'k--', linewidth=3, label=r'approximate')
plt.plot(k[indcs0], n*k[indcs0]/N, 'b--', linewidth=3, label=r'$m k / N$')
plt.xlabel(r'$k$')
plt.ylabel(r'$P_k$')
plt.legend()
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Atom_losses.eps", bbox_inches='tight')

In [None]:
#Lifetime due to collisions (from Collisional Blockade in Microscopic Optical Dipole Traps paper)

gamma = 0.2*tu                        #Hz, rate of collisions with the background gas
width = lam_eff/2                     #m, width of the trap
t = np.linspace(0, 1.5e-3, 1000)/tu   #sec, times over which the atom population will be calculated
n0 = 2                                #initial number of atoms in the trap


#These two parameters are used for the volume calculation to get betap to agree with the paper
wavelength = 772e-9/du #m, don't change!
eta = 0.4 #don't change!

def vol(w0):
    return np.pi**(2)*w0**(4)*np.log(1/(1-eta))*np.sqrt(eta/(1-eta))/wavelength

#Analytic solution to the differential equation given in the paper for the atom population over time
def pop(t, n0, gamma, betap):
    return n0*(gamma-betap)/(-betap*n0 + np.exp(t*(gamma-betap))*((n0-1)*betap + gamma))

betap = 1000*tu*vol(0.7e-6/du)/vol(width)

print(betap/tu)

font = {'family':'Times New Roman', 'weight':'bold', 'size':15}
plt.rc('font', **font)
fig = plt.figure()
plt.plot(t*tu/1e-3, pop(t, n0, gamma, betap), 'b-')
plt.xlabel('t (ms)')
plt.ylabel('N(t)')
plt.show()

#fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Lifetime_collisions.pdf", bbox_inches='tight')

In [None]:
rp = np.linspace(-2*rTF[-1], 2*rTF[-1], 1000)
potp = 0.5*m*w_HO**(2)*rp**(2)*1e-3/tu/hbar

plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.serif"] = "STIX"
plt.rcParams["mathtext.fontset"] = "stix"
plt.rcParams["font.size"] = 30

fig = plt.figure(figsize=(12,8))
plt.plot(rp, potp, 'k-', linewidth=3, label='Harmonic potential')
plt.plot(rp, 0.5*(np.max(potp)/np.max(np.abs(psi_TF(rp))**2))*np.abs(psi_TF(rp))**2, 'b-', linewidth=3, label=r'$\vert \psi_{m_F=0} \vert ^ 2$')
plt.xlabel(r'$x$ $(\mu m)$')
plt.ylabel(r'$V$ $(kHz)$')
plt.legend(loc='upper right', fontsize=36)
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\HO_potential.tiff", bbox_inches='tight')

xp = np.linspace(5*x[0], 5*x[-1], 1000)
potp = (m*w_HO**(2)*xp**(2)/2 + Vb*np.sin(2*np.pi*xp/lam_eff)**2)*1e-3/tu/hbar

fig = plt.figure(figsize=(12,8))
plt.plot(xp, potp, 'k-', linewidth=3, label='Lattice potential') 
plt.plot(x, 0.5*(np.max(potp)/np.max(np.abs(wan0)**2))*np.abs(wan0)**2, 'r-', linewidth=3, label=r'$\vert \psi_{m_F \neq 0} \vert ^ 2$')
plt.xlabel(r'$x$ $(\mu m)$')
plt.ylabel(r'$V$ $(kHz)$')
plt.legend(loc='upper right', fontsize=36)
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Lattice_potential.tiff", bbox_inches='tight')

xl = np.arange(-1, 2)*lam_eff/2
xp, yp, zp = np.meshgrid(xl, xl, xl, sparse=False)


fig = plt.figure(figsize=(12,12))
ax = fig.add_subplot(projection='3d')

indcs = [0, -1]

#Lines in x-direction
for ind0 in indcs:
    for ind1 in indcs:
        ax.plot(xp[0,:,0], yp[ind0,:,ind0], zp[ind1,:,ind1], c='k')

#Lines in y-direction
for ind0 in indcs:
    for ind1 in indcs:
        ax.plot(xp[:,ind0,ind0], yp[:,0,0], zp[:,ind1,ind1], c='k')

#Lines in z-direction
for ind0 in indcs:
    for ind1 in indcs:
        ax.plot(xp[ind0,ind0,:], yp[ind1,ind1,:], zp[0,0,:], c='k')


for ind in range(len(zp[0,0,:])):
    ax.scatter(xp[:,:,ind], yp[:,:,ind], zp[:,:,ind], s=800, marker='o', c='red', alpha=1)


plt.axis('off')
plt.show()

In [None]:
xp = np.linspace(-0.5*lam_eff/2, 3.5*lam_eff/2, 1000)
potp = (m*w_HO**(2)*xp**(2)/2 + Vb*np.sin(2*np.pi*xp/lam_eff)**2)*1e-3/tu/hbar

fig = plt.figure(figsize=(12,8))
plt.plot(xp, potp, 'k-', linewidth=6) 
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Lattice.svg", bbox_inches='tight')


fig = plt.figure(figsize=(12,8))
plt.plot(xp, potp, 'k-', linewidth=6) 
plt.plot(x, 0.5*(np.max(potp)/np.max(np.abs(wan0)**2))*np.abs(wan0)**2, 'r-', linewidth=8)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Wannier.svg", bbox_inches='tight')

fig = plt.figure(figsize=(12,8))
plt.plot([0], [0], 'b-', label=r'$\vert \phi(\mathbf{r}) \vert ^ 2$')
plt.plot([0], [0], 'r-', label=r'$\vert w(\mathbf{r}) \vert ^ 2$')
plt.plot([0], [0], 'y-', label=r'$\vert \phi(\mathbf{r}) \vert ^ 2$')
plt.legend()
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\Fig1_legend.svg", bbox_inches='tight')

rp = np.linspace(-2*rTF[-1], 2*rTF[-1], 1000)
potp = 0.5*m*w_HO**(2)*rp**(2)*1e-3/tu/hbar

fig = plt.figure(figsize=(12,8))
plt.plot(rp, potp, 'k-', linewidth=6)
plt.plot(rp, 0.5*(np.max(potp)/np.max(np.abs(psi_TF(rp))**2))*np.abs(psi_TF(rp))**2, 'b-', linewidth=8)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\HO_blue.svg", bbox_inches='tight')

fig = plt.figure(figsize=(12,8))
plt.plot(rp, potp, 'k-', linewidth=6)
plt.plot(rp, 0.5*(np.max(potp)/np.max(np.abs(psi_TF(rp))**2))*np.abs(psi_TF(rp))**2, 'y-', linewidth=8)
plt.axis('off')
plt.show()

fig.savefig(r"C:\Users\ehabe\OneDrive\Desktop\School\CAT_lab\Poshua\Papers\Rydberg_qubit_alternative\Figures\HO_gold.svg", bbox_inches='tight')



In [None]:
#3D spinor condensate using 4th order Runge-Kutta (find the ground state)

num_spatial = 201
Rplot = 2*15e-6/du
t_relax = 8000e-6/tu
t_step = 2e-6/tu
num_evals = 1
plot = True
spins = 2
eps = 1

g = np.array([[g1_00, g12_00], [0, g2_00]])
g = make_Hermitian(g)

N_vals = N*np.array([4/5, 1/5])

colors = [(0,0,1), (1,0,0)]

xi = np.linspace(-Rplot, Rplot, num_spatial)
x_grid, y_grid, z_grid = np.meshgrid(xi, xi, xi, sparse=True)
rTF = np.sqrt(x_grid**2 + y_grid**2 + z_grid**2)
dx = xi[1] - xi[0]

#Create a list of arrays for the initial wavefunction where the first axis is the spin and the last three are spatial
initial = []
for i in range(spins):
    initial.append(np.zeros(rTF.shape, dtype=complex))

#Use TF approximation to create an initial guess for each spin's ground state
"""
r0 = 1 - g[0,1]**(2) / (g[0,0]*g[1,1])
r1 = g[0,1] / g[0,0] - g[1,1] / g[0,1]
r2 = g[0,1] / g[1,1] - g[0,0] / g[0,1]

w1_2 = np.abs((w_HO**2) / r0 + (w_HO**2) / r1)
w2_2 = np.abs((w_HO**2) / r0 + (w_HO**2) / r2)

A1 = (15*N_vals[0]**(2)*g[0,0] / (8*np.pi))**(2/5) * (0.5*m*w1_2)**(3/5)
A2 = (15*N_vals[1]**(2)*g[1,1] / (8*np.pi))**(2/5) * (0.5*m*w2_2)**(3/5)

mew1 = r2*A2 / (1 - r1*r2/r0**2) + r0*A1 / (1 - r0**(2) / (r1*r2))
mew2 = r1*A1 / (1 - r1*r2/r0**2) + r0*A2 / (1 - r0**(2) / (r1*r2))

initial[0] = np.sqrt((mew1/r0 + mew2/r1 - (1/2)*m*w1_2*(x_grid**2 + eps*y_grid**2 + z_grid**2)) / (g[0,0]*N_vals[0])) + 0*1j
initial[1] = np.sqrt((mew2/r0 + mew1/r2 - (1/2)*m*w2_2*(x_grid**2 + eps*y_grid**2 + z_grid**2)) / (g[1,1]*N_vals[1])) + 0*1j
"""
asc = g1_00*m/(4*np.pi*hbar**2)
mew = 0.5*(15*N*asc/np.sqrt(hbar/m/w_HO))**(2/5)*hbar*w_HO 
R = np.sqrt(2*mew/m/w_HO**2)

def psi_TF(r):
    prob_TF = (1/g1_00/N)*(mew - 0.5*m*w_HO**(2)*r**2)
    prob_TF[prob_TF < 0] = 0
    return np.sqrt(prob_TF)

for i in range(len(initial)):
    initial[i] = psi_TF(rTF).astype(complex)

#Normalize

norm = 1
for i in range(len(initial)):
    initial[i] *= np.sqrt(norm / integrate.trapz(integrate.trapz(integrate.trapz(np.abs(initial[i])**2, dx=dx), dx=dx), dx=dx))

kin_factor = (hbar**(2)/m/2)/hbar
pot_factor = -1/hbar

#Potential
pot = 0.5*m*w_HO**(2)*(x_grid**2 + eps*y_grid**2 + z_grid**2)

def laplacian(y, dx):
    return ( filters.laplace(np.real(y)) + 1j*filters.laplace(np.imag(y)) )/dx**3

def H_psi(t, psi):
    dpsi = []
    for i in range(len(psi)):
        dpsi_i = kin_factor*laplacian(psi[i], dx) + pot_factor*pot*psi[i]
        
        for j in range(len(psi)):
            dpsi_i += pot_factor*N_vals[j]*g[i,j]*np.abs(psi[j])**(2)*psi[i]
        
        dpsi.append(dpsi_i)
    
    return dpsi

def Runge_Kutta(func, initial, t_eval, num_evals, t_step):
    temp = []
    for i in range(len(initial)):
        temp.append(np.copy(initial[i]))
    
    times = np.linspace(t_eval[0], t_eval[-1], int((t_eval[-1] - t_eval[0])/t_step))
    
    if num_evals > len(times) - 1:
        num_evals = len(times) - 1
    
    t_mid = (times[:-1] + times[1:])/2
    delta_ind = int((len(times)-1)/num_evals)
    t_step_2 = t_step/2
    t_step_6 = t_step/6
    
    energies = []
    ind = 0
    for ind0 in range(num_evals):
        indf = ind + delta_ind
        
        for ind1 in range(ind, indf):
            k1 = func(times[ind1], temp)
            
            for i in range(len(temp)):
                temp[i] += k1[i]*t_step_2
                        
            k2 = func(t_mid[ind1], temp)
            
            for i in range(len(temp)):
                temp[i] += k2[i]*t_step_2
                        
            k3 = func(t_mid[ind1], temp)
            
            for i in range(len(temp)):
                temp[i] += k3[i]*t_step
                        
            k4 = func(times[ind1+1], temp)
            
            energies.append(np.zeros(len(temp)))
            for i in range(len(temp)):
                temp[i] += (k1[i] + 2*k2[i] + 2*k3[i] + k4[i])*t_step_6
            
                #Boundary
                temp[i][:,:,0] = temp[i][:,:,-1]
                temp[i][:,0,:] = temp[i][:,-1,:]
                temp[i][0,:,:] = temp[i][-1,:,:]
                                
                #Normalize
                temp[i] *= np.sqrt(norm / integrate.trapz(integrate.trapz(integrate.trapz(np.abs(temp[i])**2, dx=dx), dx=dx), dx=dx))
                
                #Calculate the current energies of the spinor BEC
                H_psi_prod = laplacian(temp[i], dx) + pot*temp[i]
                for j in range(len(temp)):
                    H_psi_prod += N_vals[j]*g[i,j]*np.abs(temp[j])**(2)*temp[i]
                energies[-1][i] = np.real(integrate.trapz(integrate.trapz(integrate.trapz(np.conj(temp[i])*H_psi_prod, dx=dx), dx=dx), dx=dx))
                
        ind = indf
                
    return temp, energies, times

start = time.time()
initial, energies, times = Runge_Kutta(H_psi, initial, (0, t_relax), num_evals, t_step)

energies = np.transpose(np.array(energies))

end = time.time()

print('Solver time:', np.round((end-start)/60, 2), 'min.')

#Plot the results

def colorplot(xplot, yplot, zplot, axis0, axis1, ax):
    zplot = zplot[:-1, :-1]
    levels = MaxNLocator(nbins=15).tick_values(zplot.min(), zplot.max())
    cmap = plt.get_cmap('PiYG')
    norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)

    im = ax.pcolormesh(xplot, yplot, zplot, norm=norm)
    ax.set_xlabel(r'$' + str(axis0) + '$ $(\mu m)$')
    ax.set_ylabel(r'$' + str(axis1) + '$ $(\mu m)$')
    ax.set_aspect('equal')

if plot:
    plot_spin = 0
    
    ind0x = np.argmin(np.abs(xi - 0))
    ind0y = np.argmin(np.abs(xi - 0))
    ind0z = np.argmin(np.abs(xi - 0))
    
    font = {'family':'Times New Roman', 'weight':'bold', 'size':12}
    plt.rc('font', **font)

    fig, ax = plt.subplots(nrows=3,ncols=2,figsize=(18,18))

    colorplot(x_grid[:,:,0], y_grid[:,:,0], np.abs(initial[plot_spin][:,:,ind0z])**2, 'x', 'y', ax[0,0])
    colorplot(x_grid[0,:,0], z_grid[0,0,:], (np.abs(initial[plot_spin][ind0y,:,:]).transpose())**2, 'x', 'z', ax[1,0])
    colorplot(y_grid[:,0,0], z_grid[0,0,:], (np.abs(initial[plot_spin][:,ind0x,:]).transpose())**2, 'y', 'z', ax[2,0])

    ax[0,1].plot(xi, np.abs(initial[plot_spin][ind0y,:,ind0z])**2, 'b.-')
    ax[0,1].set_xlabel(r'$x$ $(\mu m)$')

    ax[1,1].plot(xi, np.abs(initial[plot_spin][:,ind0x,ind0z])**2, 'r.-')
    ax[1,1].set_xlabel(r'$y$ $(\mu m)$')

    ax[2,1].plot(xi, np.abs(initial[plot_spin][ind0y,ind0x,:])**2, 'g.-')
    ax[2,1].set_xlabel(r'$z$ $(\mu m)$')

colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']

for i in range(len(initial)):
    plt.figure()
    plt.plot(times[:len(energies[i])]*1e-6/tu, np.abs(energies[i])*1e-3/hbar/tu, c = colors[i])
    plt.xlabel('t (us)')
    plt.ylabel('Energy of spin ' + str(i) + ' (kHz)')

plt.show()

In [None]:
with open('initial.pickle','rb') as f:
    initial_scalar = pickle.load(f)

initial = [initial_scalar, initial_scalar]

In [None]:
#3D spinor condensate using 4th order Runge-Kutta

tf = 4000e-6/tu
tf2 = 10e-6/tu
x0, y0, z0 = 0*lam_eff/2, 0*lam_eff/2, 0*lam_eff/2
reverse = False
t_step = 0.4*0.25e-6/tu
num_evals = 1
z_neighbors = True
spins = 2

g = np.array([[g1_00, g12_00], [0, g2_00]])
g = make_Hermitian(g)

N_vals = N*np.array([4/5, 1/5])

colors = [(0,0,1), (1,0,0)]

#At each time step, we will record the smallest ratio of the BEC population density
#at the center to all the other qubit locations
ind0x = np.argmin(np.abs(xi - x0))
ind0y = np.argmin(np.abs(xi - y0))
ind0z = np.argmin(np.abs(xi - z0))

def find_indcs_neighbors(x0, amp):
    neighbors = amp*np.arange(-5, 6)
    neighbors += x0

    return np.argmin(np.abs(xi[:,np.newaxis] - neighbors[np.newaxis,:]), axis=0)

indcsx = find_indcs_neighbors(x0, 0.5*lam_eff)
indcsy = find_indcs_neighbors(y0, 0.5*lam_eff)
if z_neighbors:
    indcsz = find_indcs_neighbors(z0, 0.5*lam_eff)
else:
    indcsz = find_indcs_neighbors(z0, 0.0)

indcsx, indcsy, indcsz = np.meshgrid(indcsx, indcsy, indcsz)

ind0 = (indcsx == ind0x) & (indcsy == ind0y) & (indcsz == ind0z)

indcsx = indcsx[~ind0]
indcsy = indcsy[~ind0]
indcsz = indcsz[~ind0]

#Create a list of arrays for the initial wavefunction where the first axis is the spin and the last three are spatial
initial2 = np.copy(initial)

#Phase masks

phase = np.zeros(rTF.shape)

"For vortex"

#phase =  (np.arctan2(y_grid-y0, x_grid-x0) + np.pi + 0*z_grid)
#initial2[0] *= np.exp(-1j*phase)

#if spins > 1:
#    phase = (np.arctan2(y_grid-y0, z_grid-z0) + np.pi + 0*x_grid)
#    initial2[1] *= np.exp(-1j*phase)



#Normalize
for i in range(len(initial2)):
    initial2[i] *= np.sqrt(norm / integrate.trapz(integrate.trapz(integrate.trapz(np.abs(initial2[i])**2, dx=dx), dx=dx), dx=dx))

kin_factor = -(hbar**(2)/m/2)/1j/hbar
pot_factor = 1/1j/hbar

#Generate labels for each of the images, and clean up the directory we'll be saving them in
labels = np.arange(np.ceil(tf/t_step), dtype=int).astype(str)
labels = np.char.zfill(labels, len(labels[-1]))

file_path = r'C:/Users/ehabe/OneDrive/Desktop/School/CAT_lab/Poshua/Papers/Rydberg_qubit_alternative/GPE_simulations/'
for f in glob.glob(file_path + '*.png'):
    os.remove(f)

#Initialize the mlab engine
engine = Engine()
engine.start()

fig = mlab.figure(size=(1000, 1000), bgcolor=None, fgcolor=(1,1,1), engine=engine)

for i in range(len(initial2)):
    src = mlab.pipeline.scalar_field(np.abs(initial2[i])**2, figure=fig)
    mlab.pipeline.iso_surface(src, contours=[1e-4, ], opacity=0.3, figure=fig, color=colors[i])

mlab.orientation_axes()
mlab.title('t = ' + str(0) + ' us', size=0.3, height=0.9)

scene = engine.scenes[0]
scene.scene.camera.view_up = [-0.2571155775605064, -0.6073453530583819, 0.7516802524305908]
scene.scene.render()
scene.scene.disable_render = True

mlab.savefig(filename=file_path + 't' + labels[0] + '.png')
mlab.clf(fig)

def H_psi(t, psi):
    
    dpsi = []
    for i in range(len(psi)):
        dpsi_i = kin_factor*laplacian(psi[i], dx) + pot_factor*pot*psi[i]
        
        for j in range(len(psi)):
            dpsi_i += pot_factor*(N_vals[j])*g[i,j]*np.abs(psi[j])**(2)*psi[i]
        
        dpsi.append(dpsi_i)
    
    return dpsi

def Runge_Kutta(func, initial, t_eval, num_evals, t_step):
    temp = []
    for i in range(len(initial)):
        temp.append(np.copy(initial[i]))
    
    times = np.linspace(t_eval[0], t_eval[-1], int((t_eval[-1] - t_eval[0])/t_step))
    
    if num_evals > len(times) - 1:
        num_evals = len(times) - 1
    
    t_mid = (times[:-1] + times[1:])/2
    delta_ind = int((len(times)-1)/num_evals)
    t_step_2 = t_step/2
    t_step_6 = t_step/6
    
    diffs = []
    ind = 0
    for ind0 in range(num_evals):
        indf = ind + delta_ind
        
        for ind1 in range(ind, indf):
            k1 = func(times[ind1], temp)
            
            for i in range(len(temp)):
                temp[i] += k1[i]*t_step_2
                        
            k2 = func(t_mid[ind1], temp)
            
            for i in range(len(temp)):
                temp[i] += k2[i]*t_step_2
                        
            k3 = func(t_mid[ind1], temp)
            
            for i in range(len(temp)):
                temp[i] += k3[i]*t_step
                        
            k4 = func(times[ind1+1], temp)
            
            #diff = 0
            for i in range(len(temp)):
                temp[i] += (k1[i] + 2*k2[i] + 2*k3[i] + k4[i])*t_step_6
            
                #Boundary
                temp[i][:,:,0] = temp[i][:,:,-1]
                temp[i][:,0,:] = temp[i][:,-1,:]
                temp[i][0,:,:] = temp[i][-1,:,:]
                                
                #Normalize
                temp[i] *= np.sqrt(norm / integrate.trapz(integrate.trapz(integrate.trapz(np.abs(temp[i])**2, dx=dx), dx=dx), dx=dx))
                
                
                #diff += g[i,i]*(np.abs(temp[i][indcsy,indcsx,indcsz])**2 - np.abs(temp[i][ind0y,ind0x,ind0z])**2)
                
                if ((ind0+ind1) % 8 == 0):
                    #Create a new file with an isosurface of the current probability densities
                    src = mlab.pipeline.scalar_field(np.abs(temp[i])**2, figure=fig)
                    mlab.pipeline.iso_surface(src, contours=[1e-4, ], opacity=0.3, figure=fig, color=colors[i])
            
            if ((ind0+ind1) % 8 == 0):
                mlab.orientation_axes()
                mlab.title('t = ' + str(np.round(times[ind1+1], 2)) + ' us', size=0.3, height=0.9)
                
                mlab.savefig(filename=file_path + 't' + labels[ind0+ind1+1] + '.png')
                mlab.clf(fig)
                
                #diffs.append(np.min(np.abs(diff)))
                diffs.append(temp[0][ind0y,ind0x,ind0z])
                
                with open('diffs.pickle','wb') as f:
                    pickle.dump(diffs,f)
                
            
        ind = indf
    
    mlab.close()
            
    return temp, diffs, times

start = time.time()
psif, diffs, times = Runge_Kutta(H_psi, initial2, (0, tf), num_evals, t_step)
end = time.time()

#diffs = N*np.array(diffs)
diffs = np.array(diffs)

print('Solver time:', np.round((end-start)/60, 2), 'min.')

# Create a gif of the dynamics
frames = []
imgs = glob.glob(file_path + '*.png')
for i in imgs:
    new_frame = Image.open(i)
    frames.append(new_frame)

#Save into a GIF file
frames[0].save(file_path + 'Spinor_BEC.gif', format='GIF', append_images=frames[1:], save_all=True, duration=100, loop=0)

#Clean up the directory
for f in glob.glob(file_path + '*.png'):
    os.remove(f)

#Plot the results

plot_spin = 0

def colorplot(xplot, yplot, zplot, axis0, axis1, ax):
    zplot = zplot[:-1, :-1]
    levels = MaxNLocator(nbins=15).tick_values(zplot.min(), zplot.max())
    cmap = plt.get_cmap('PiYG')
    norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)

    im = ax.pcolormesh(xplot, yplot, zplot, norm=norm)
    ax.set_xlabel(r'$' + str(axis0) + '$ $(\mu m)$')
    ax.set_ylabel(r'$' + str(axis1) + '$ $(\mu m)$')
    ax.set_aspect('equal')

font = {'family':'Times New Roman', 'weight':'bold', 'size':12}
plt.rc('font', **font)

fig, ax = plt.subplots(nrows=4,ncols=2,figsize=(18,18))

colorplot(x_grid[:,:,0], y_grid[:,:,0], np.abs(psif[plot_spin][:,:,ind0z])**2, 'x', 'y', ax[0,0])
colorplot(x_grid[0,:,0], z_grid[0,0,:], (np.abs(psif[plot_spin][ind0y,:,:]).transpose())**2, 'x', 'z', ax[1,0])
colorplot(y_grid[:,0,0], z_grid[0,0,:], (np.abs(psif[plot_spin][:,ind0x,:]).transpose())**2, 'y', 'z', ax[2,0])

ax[0,1].axvline(x = x0 + lam_eff/2, c='black')
ax[0,1].axvline(x = x0 - lam_eff/2, c='black')
ax[0,1].axvline(x = x0 + 5*lam_eff/2, c='black')
ax[0,1].axvline(x = x0 - 5*lam_eff/2, c='black')
ax[0,1].axvline(x = x0, c='black', ls='--')
ax[0,1].plot(xi, np.abs(psif[plot_spin][ind0y,:,ind0z])**2, 'b.-')
ax[0,1].set_xlabel(r'$x$ $(\mu m)$')

ax[1,1].axvline(x = y0 + lam_eff/2, c='black')
ax[1,1].axvline(x = y0 - lam_eff/2, c='black')
ax[1,1].axvline(x = y0 + 5*lam_eff/2, c='black')
ax[1,1].axvline(x = y0 - 5*lam_eff/2, c='black')
ax[1,1].axvline(x = y0, c='black', ls='--')
ax[1,1].plot(xi, np.abs(psif[plot_spin][:,ind0x,ind0z])**2, 'r.-')
ax[1,1].set_xlabel(r'$y$ $(\mu m)$')

ax[2,1].axvline(x = z0 + lam_eff/2, c='black')
ax[2,1].axvline(x = z0 - lam_eff/2, c='black')
ax[2,1].axvline(x = z0 + 5*lam_eff/2, c='black')
ax[2,1].axvline(x = z0 - 5*lam_eff/2, c='black')
ax[2,1].axvline(x = z0, c='black', ls='--')
ax[2,1].plot(xi, np.abs(psif[plot_spin][ind0y,ind0x,:])**2, 'g.-')
ax[2,1].set_xlabel(r'$z$ $(\mu m)$')

#ax[3,0].plot(times[1:]*tu/1e-6, diffs/hbar/(2*np.pi)/tu, 'b.-')
#ax[3,0].set_xlabel('t (us)')
#ax[3,0].set_ylabel('Smallest differential shift (Hz)')

#ax[3,1].plot(times[1:]*tu/1e-6, diffs/hbar/(2*np.pi)/tu, 'b.-')
#ax[3,1].set_ylim(0, 100)
#ax[3,1].set_xlabel('t (us)')
#ax[3,1].set_ylabel('Smallest differential shift (Hz)')

ax[3,0].plot(np.linspace(times[0], times[-1], len(diffs))*tu/1e-6, np.abs(diffs)**2, 'b.-')
ax[3,0].set_xlabel('t (us)')
ax[3,0].set_ylabel('Probability density at center')

plt.show()

In [None]:
#Uncertainty in the effective Rabi frequency due to the wavefunction dynamics during a CNOT gate

with open('diffs.pickle','rb') as f:
    diffs = pickle.load(f)

indm = np.argmax(np.abs(np.diff(diffs)))
print('Relative uncertainty in the effective Rabi frequency:', (1.807e-6/tu)*np.abs(np.diff(diffs))[indm]/(8*t_step))

In [None]:
#Initialize the mlab engine
engine = Engine()
engine.start()

fig = mlab.figure(size=(1000, 1000), bgcolor=None, fgcolor=(1,1,1), engine=engine)
src = mlab.pipeline.scalar_field(np.abs(psif)**2, figure=fig)
mlab.pipeline.iso_surface(src, contours=[1e-4, ], opacity=0.2, figure=fig, color = (0,0,1))
src = mlab.pipeline.scalar_field(np.abs(initial2)**2, figure=fig)
mlab.pipeline.iso_surface(src, contours=[1e-4, ], opacity=0.2, figure=fig, color = (1,0,0))
mlab.orientation_axes()

scene = engine.scenes[0]
scene.scene.camera.view_up = [-0.2571155775605064, -0.6073453530583819, 0.7516802524305908]
scene.scene.render()
scene.scene.disable_render = True


#mlab.title('t = ' + str(0) + ' us', size=0.3, height=0.9)

file_path = r'C:/Users/ehabe/OneDrive/Desktop/School/CAT_lab/Poshua/Conferences_and_presentations/Prospective_weekend/'

#mlab.savefig(filename=file_path + 'test.png')

mlab.show()