# Numerically Solving the Time-Dependent Schrödinger Equation

#### [Visscher's Method](https://www.asc.tuwien.ac.at/~juengel/simulations/schroedinger/Visscher_method.html)

- We consider the following dimensionless representation of the time-dependent Schrödinger equation (TDSE), where the relevant physical scales are expressed in [Hartree units](https://en.wikipedia.org/wiki/Hartree_atomic_units) ($\hbar = m_{e} = a_{0} = 1$).
$$
    i\frac{\partial \Psi}{\partial t} = \left(\frac{1}{2}{\nabla}^{2} + \mathbb{V}\right)\Psi 
$$

- At any of the points indexed within the discrete space and time domains, the solution to the TDSE is denoted 
$$
    \Psi(x_i, y_j, t_n) = \Psi_{i,j}^{n}
$$

- The numerical method we consider here is that developed by Visscher in 1991, wherein the wavefunction is separated into real and imaginary components and discretized over staggered times differing by a half-step. For $\Psi = u + iv$:

$$
    i\frac{\partial \Psi}{\partial t} = \mathbb{H}\Psi = -\frac{1}{2}\nabla^{2}\Psi + \mathbb{V}\Psi
$$

$$
    \Psi = u + i v \quad \implies \quad
    \begin{cases} 
        \dfrac{\partial u}{\partial t} = \mathbb{H}v\\
        \dfrac{\partial v}{\partial t} = -\mathbb{H}u
    \end{cases}
$$

$$
\begin{align}
    &u\left(x_{i}, y_{j}, t_{n}\right) &\longrightarrow\quad u^{n}_{i,j}\\
    &v\left(x_{i}, y_{j}, t_{n}-{\Delta t}/{2}\right) &\longrightarrow\quad v^{n}_{i,j}
\end{align}
$$

$$
\begin{align}
    v^{n+1}_{i,j} &= 
    v^{n}_{i,j} 
    + \frac{\Delta t}{2\Delta x^{2}}\left(u^{n}_{i,j+1} + u^{n}_{i,j-1} + u^{n}_{i+1,j} + u^{n}_{i-1,j} - 4 u^{n}_{i,j}\right)
    - \Delta t \, V_{i,j} u^{n}_{i,j}
    \\
    u^{n+1}_{i,j} &= 
    u^{n}_{i,j} 
    - \frac{\Delta t}{2\Delta x^{2}}\left(v^{n}_{i,j+1} + v^{n}_{i,j-1} + v^{n}_{i+1,j} + v^{n}_{i-1,j} - 4 v^{n}_{i,j}\right)
    + \Delta t \, V_{i,j} v^{n+1}_{i,j}
\end{align}
$$

- To ensure that the method is unitary, the probability density $|\Psi|^2$ is defined as:

$$
    P^{n} = u^{n} u^{n} + v^{n}v^{n+1}
$$

- The stability condition is 
$$
    -\frac{2}{\Delta t} \leq V_{0} \leq \frac{2}{\Delta t} - \frac{2}{\Delta x^{2}} \quad\Longrightarrow\quad \Delta t_{\mathrm{max}} = \frac{\Delta x^2}{1 + |V_0| \Delta x^2 / 2}
$$

In [5]:
import time
import math as mt
import numpy as np
import matplotlib as mpl
import scipy.constants as spc
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation
from scipy.ndimage import convolve, generate_binary_structure, gaussian_filter
from matplotlib.colors import hsv_to_rgb
from datetime import datetime
from tqdm import tqdm
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
plt.rcParams.update({
    'figure.constrained_layout.use' : True          , 'figure.dpi'            : 50     ,
    'figure.figsize'        : [9, 9]                , 'figure.titlesize'      : 22     , 'figure.facecolor'     : 'white' ,
    'animation.ffmpeg_path' : '/Users/pranav/ffmpeg', 'animation.embed_limit' : 2**128 , 'animation.html'       : 'jshtml', 
    'font.sans-serif'       : 'Times New Roman'     , 'font.family'           : 'serif', 'mathtext.fontset'     : 'cm'    ,
    'xtick.direction'       : 'inout'               , 'xtick.top'             : False  ,
    'ytick.direction'       : 'inout'               , 'ytick.right'           : False  ,
    'xtick.labelsize'       : 14                    , 'xtick.major.size'      : 6      , 'xtick.major.width'    : 1.25    ,
    'ytick.labelsize'       : 14                    , 'ytick.major.size'      : 6      , 'ytick.major.width'    : 1.25    ,
    'axes.titlesize'        : 20                    , 'axes.titley'           : 1.0    , 'axes.titlepad'        : 5.0     ,
    'axes.labelsize'        : 18                    , 'axes.edgecolor'        : 'black', 'axes.grid'            : False   ,
    'legend.title_fontsize' : 14                    , 'legend.fontsize'       : 14     , 'legend.framealpha'    : 1       ,
    'legend.handleheight'   : 0.5                   , 'legend.handlelength'   : 2.0    , 'legend.labelspacing'  : 0.25    ,
    'legend.borderaxespad'  : 1                     , 'legend.borderpad'      : 0.5    , 'legend.handletextpad' : 0.5     ,
    'legend.fancybox'       : False                 , 'legend.frameon'        : True   , 'legend.edgecolor'     : '0'     ,
    'legend.markerscale'    : 1.25                  , 'agg.path.chunksize'    : 1048576
})

def getHUE(u, v):
    """
    Obtain RGB values for the complex phase of the wavefunction psi = u + i*v
    """
    z         = (u + 1j*v).T 
    hue       = np.angle(z)/(2*pi)
    mask      = (hue < 0.0)
    hue[mask] = 1.0 + hue[mask]
    mag       = np.abs(z)
    z_max     = np.max(mag)
    z_min     = np.min(mag)
    val       = (mag - z_min)/(z_max - z_min)
    hsv_im    = (np.asarray([hue, np.full_like(hue, 1), val])).T
    return hsv_to_rgb(hsv_im)

# Parameters ———————————————————————————————————————————————————————————————————— #
# Hartree Units: (mₑ = ℏ = e = a₀ = 1) —————————————————————————————————————————— #
pi     = spc.pi
L      = 8.             # length of spatial axes    (x = y in this program)
N      = 512            # spatial resolution        (no. of grid points)
Tsim   = 1.2            # set simulation time       (t-scale = (a₀)²mₑ/ℏ) 
stable = 2/7            # stable <= 1
R1, R2 = .35*L, .45*L
x0, y0 = .50*L, .50*L

φ0     = 0

# Spatial Grid —————————————————————————————————————————————————————————————————— #
dx     = L/N                        # Δy = Δx
x      = np.arange(N + 1) * dx
y      = np.arange(N + 1) * dx
x, y   = np.meshgrid(x, y, indexing = 'xy')

# Temporal Grid ————————————————————————————————————————————————————————————————— #
max_dt = dx**2
dt     = stable * max_dt            # set time-step with stability factor
steps  = int(Tsim/dt)               # no. of time steps
t      = np.arange(steps + 1) * dt

# Define the Annulus as a Mask —————————————————————————————————————————————————— #
annulus = gaussian_filter(
                    ((x-x0)**2 + (y-y0)**2 < R1**2) \
                  + ((x-x0)**2 + (y-y0)**2 > R2**2), 
                    sigma = .5
).astype(bool)

def confine(Psi):
    Psi *= ~annulus
    return Psi

# Initial Wavefunction —————————————————————————————————————————————————————————— #
def gaussian(x, y, s = .02*L, kx = 0, ky = -25*pi):
    realGauss = np.exp(-(1/(2 * s * s)) * (x*x + y*y))
    imagGauss = np.exp(1j * (kx*x + ky*y))
    return  realGauss * imagGauss
R    = (R1 + R2)/2
psi0 = gaussian(
                x - x0 - R*np.cos(φ0), 
                y - y0 - R*np.sin(φ0),
                s  = .025*L          , 
                kx = 0               ,
                ky = -30*pi
)

# Solution functions ———————————————————————————————————————————————————————————— #
def normalize(f):
    return f/np.sqrt(np.sum(np.abs(f)**2 * dx**2))

def Hamiltonian(Psi):
    """
    Use a convolution to calculuate the result of operating on a state 
    with the finite-difference Hamiltonian
    """
    Laplacian = (1/dx**2) * np.array([[0,  1,  0], [1, -4,  1], [0,  1,  0]])
    return -(1/2) * convolve(confine(Psi), Laplacian, mode = 'wrap')

def solveSchrodinger(psi0):
    """
    Use the Visscher Method to solve for the real and imaginary parts 
    of the wavefunction.
    """
    Ψ0             = normalize(confine(psi0))
    ReΨ   , ImΨ    = np.zeros((2,) + t.shape + Ψ0.shape)
    ReΨ[0], ImΨ[0] = Ψ0.real, Ψ0.imag
    for n in tqdm(range(steps)):
        ImΨ[n+1] = ImΨ[n] - dt * Hamiltonian(ReΨ[n])
        ReΨ[n+1] = ReΨ[n] + dt * Hamiltonian(ImΨ[n+1])
    return (ReΨ, ImΨ)

# Plot Initial State ———————————————————————————————————————————————————————————— #
fig, ax = plt.subplots(1, 1, figsize = (9, 9), dpi = 50, facecolor = 'k')
ax.imshow((getHUE(psi0.real, psi0.imag)), interpolation = 'bilinear')
ax.axis(False)
plt.show()

# Compute Solution —————————————————————————————————————————————————————————————— #
u, v = solveSchrodinger(psi0)

# Plot Final State —————————————————————————————————————————————————————————————— #
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
fig, ax = plt.subplots(1, 1, figsize = (9,9), dpi = 50, facecolor = 'k')
ax.imshow((getHUE(u[steps], v[steps])), interpolation = 'bilinear')
ax.axis(False)
fig.savefig(f'annulus_{steps}.jpg', dpi = 360)
plt.close()

# Create Animation —————————————————————————————————————————————————————————————— #
def showTime(i, STEPS, FRAMES, t0, t1):
    """ 
    Updates animation rendering time remaining and time elapsed since
    the first frame.
    """
    if (i == 0):
        return
    else:
        j            = int(np.ceil(i/STEPS * FRAMES))
        SPF          = (t1 - t0)/j
        MINS , SECS  = int((SPF*(FRAMES - j))//60), int((SPF*(FRAMES - j)) - int((SPF*(FRAMES - j))//60)*60)
        ELMIN, ELSEC = int((t1 - t0)//60), int((t1 - t0) - int((t1 - t0)//60)*60)
        zELM , zELS  = '0'*(ELMIN < 10), '0'*(ELSEC < 10)
        zM   , zS    = '0'*(MINS < 10), '0'*(SECS < 10)
        print('\r', f'{j}/{FRAMES} ({round((i/STEPS)*100, 1)}%) | {zELM}{ELMIN}:{zELS}{ELSEC} ⟸ {zM}{MINS}:{zS}{SECS} ', end = '\r')
    return

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

def animatePHASE(TIME = 12, FPS = 60, DPI = 360):
    global phase_image
    startTime = time.perf_counter()
    STEPS     = int(steps-1)
    DURATION  = min(TIME, int(STEPS/FPS))
    FRAMES    = int(FPS * DURATION)
    WRITER    = animation.FFMpegWriter(fps = FPS, extra_args = ['-vcodec', 'libx264'])    
    
    fig, ax = plt.subplots(1, 1, figsize = (3840/DPI, 2160/DPI), dpi = DPI, facecolor = 'k')
    phase_image = ax.imshow((getHUE(u[0], v[0])), interpolation = 'bilinear')
    ax.axis(False)
    def animate(i):
        global phase_image
        phase_image.remove()
        phase_image = ax.imshow((getHUE(u[i], v[i])), interpolation = 'bilinear')
        ETA         = showTime(i, STEPS, FRAMES, startTime, time.perf_counter())
        return phase_image,
    plt.close()
    
    print(f'Animation Details : {round(DURATION,1)} s, {FPS} fps, {9*DPI}p', end = '\n\n')
    anim = FuncAnimation(
                        fig,
                        animate,
                        interval = int(1e3/FPS),
                        frames = np.arange(0, STEPS, int(STEPS/FRAMES))
    )
    stringDT = (datetime.now()).strftime('%m.%d @ %H-%M')
    filename = f'annulus ({stringDT} | {round(DURATION,1)} s, {FPS} fps, {9*DPI}p).mp4'
    anim.save(filename, writer = WRITER, dpi = DPI)
    print('Animation saved as:  {}'.format(filename))
    return anim

phase_animation = animatePHASE(TIME = 60)