In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import rc

'''constants'''
N = 500  # number of grid points
hbar = 1  # reduced planck's constant
dt = 0.5  # time step
m = 1  # mass of particle in au
omega = 0.02  # angular frequency of the harmonic potential

'''spatial grid'''
xMin = 0
xMax = 500
x = np.linspace(xMin, xMax, N)

'''momentum grid'''
p = np.arange(-N/2, N/2) * ((2 * np.pi * hbar) / (N * ((xMax - xMin) / (N - 1))))

'''wavefunction: gaussian wavepacket'''
k0 = m / 2  # initial momentum
x0 = xMax / 2  # initial position (center of the grid)
sigma = (xMax - xMin) / 30  # width of the Gaussian wavepacket
psi = (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-.5 * ((x - x0) / sigma) ** 2) * np.exp(1j * k0 * x)
psiC = (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-.5 * ((x - x0) / sigma) ** 2) * np.exp(-1j * k0 * x)

waveP = np.conj(psi) * psi  # probability distribution of the wavefunction

plt.plot(waveP, label='Initial Wavepacket')

'''potential operator + propagator'''
# Harmonic potential: V(x) = (1/2) * m * omega^2 * (x - x0)^2
Vhat = 0.5 * m * omega**2 * (x - x0)**2

# Create the potential propagator for half a time step
Vprop = np.exp(-1j * dt * Vhat / (2 * hbar))

plt.plot(Vhat, label='Harmonic Potential')
plt.ylim(0, 0.001)
plt.legend()


'''kinetic operator + propagator'''
That = p**2 / (2 * m)  # Kinetic energy operator in momentum space
Tprop = np.exp(-1j * dt * That / hbar)  # Kinetic propagator for the full time step

'''function for one step of the split operator method'''
def splitOperator(psi, Vprop=Vprop, Tprop=Tprop):
    psi *= Vprop  # 1/2 step potential propagator
    psi_p = np.fft.fft(psi)  # shift to momentum space
    psi_p *= Tprop  # full step kinetic propagator
    psi = np.fft.ifft(psi_p)  # shift back to spatial space
    psi *= Vprop  # 1/2 step potential propagator
    return psi

'''animate'''
# First set up the figure, the axis, and the plot element we want to animate
fig, ax = plt.subplots()
plt.close()

### Parameters for plot
ax.set_xlim((0, 500))
ax.set_ylim((-0.0005, 0.0025))
line, = ax.plot([], [], lw=2)
lineV, = ax.plot([], [], lw=2)

# Initialization function: plot the background of each frame
def init():
    line.set_data([], [])
    lineV.set_data([], [])
    return (line, lineV,)

# Animation function: called sequentially
def animate(i):
    global psi
    psi = splitOperator(psi)
    line.set_data(x, np.abs(psi)**2)  # plot probability density |psi|^2
    lineV.set_data(x, Vhat / np.max(Vhat) * 0.001)  # scale the potential for visualization
    return (line, lineV,)

# Call the animator
anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=500, interval=10, blit=True)

# Display the animation (this works in Jupyter or Colab)
rc('animation', html='jshtml')
anim