# Schrodinger Equation in 2D

Jonny Hyman's implimentation: https://github.com/jonnyhyman/QuantumWaves

Original implementation: https://github.com/Azercoco/Python-2D-Simulation-of-Schrodinger-Equation

In [None]:
import os
import toml, sys
import numpy as np
from PIL import Image
from time import time
from time import sleep
import scipy.linalg
import scipy as sp
import scipy.sparse
import scipy.sparse.linalg
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
from colorsys import hls_to_rgb
from numba import njit, prange
from IPython.display import display, clear_output
from mpl_toolkits.mplot3d import Axes3D 

c16 = np.complex128
%config InlineBackend.figure_format = 'retina'
plt.rc('font', family='serif')
plt.rc('xtick',labelsize=18)
plt.rc('ytick',labelsize=18)
plt.rcParams['animation.html'] = 'jshtml'
plt.rcParams['animation.embed_limit'] = 2**128

# Field

In [None]:
class Field:
    
    def __init__(self):
        self.potential_expr = None
        self.obstacle_expr = None

        
    def setPotential(self, expr):
        self.potential_expr = expr
        self.test_pot_expr()

        
    def setObstacle(self, expr):
        self.obstacle_expr = expr
        self.test_obs_expr()


    def test_pot_expr(self):
        # required for eval()
        x = 0
        y = 0

        try:
            a = eval(self.potential_expr)
        except:
            print(self.potential_expr)
            print('Potential calculation error: set to 0 by default')
            self.potential_expr = '0'

            
    def test_obs_expr(self):
        # required for eval()
        x = 0
        y = 0

        try:
            a = eval(self.obstacle_expr)
        except:
            print('Error setting obstacle: Set to False by default')
            self.obstacle_expr = 'False'

            
    def isObstacle(self, x, y):
        a = False

        try:
            a = eval(self.obstacle_expr)
        except:
            print(f'Invalid obstacle: {self.obstacle_expr}')

        return a

    
    def getPotential(self, x, y):
        a = 0 + 0j

        try:
            a = eval(self.potential_expr)
        except:
            print(f'Invalid potential: {self.potential_expr}')

        return a


# Utils

In [None]:
do_parralel = False

def colorize(z):
    r = np.abs(z)
    arg = np.angle(z) 

    h = (arg + np.pi)  / (2 * np.pi) + 0.5
    l = 1.0 - 1.0/(1.0 + 2*r**1.2)
    s = 0.8

    c = np.vectorize(hls_to_rgb) (h,l,s) # --> tuple
    c = np.array(c)  # -->  array of (3,n,m) shape, but need (n,m,3)
    c = c.swapaxes(0,2) 
    c = c.swapaxes(0,1) 
    return c


@njit(cache=True, parallel=do_parralel)
def x_concatenate(MM, N):
    result = np.zeros((N*N),dtype=c16)
    for j in prange(N):
        for i in prange(N):
            index = i + N*j
            result[index] = (MM[i][j])
    return result


@njit(cache=True)
def x_deconcatenate(vector, N):
    result = np.zeros((N, N), dtype=c16)
    for j in range(N):
        for i in range(N):
            result[i][j] = vector[N*j + i]

    return result


@njit(cache=True, parallel=do_parralel)
def y_concatenate(MM, N):
    result = np.zeros((N*N),dtype=c16)
    for i in prange(N):
        for j in prange(N):
            index = j + N*i
            result[index] = (MM[i][j])
    return result


@njit(cache=True)
def y_deconcatenate(vector, N):
    result = np.zeros((N, N), dtype=c16)
    for i in prange(N):
        for j in prange(N):
            result[i][j] = vector[N*i + j]
    return result


@njit(cache=True, parallel=do_parralel)
def dx_square(MM, N, Δx):
    result = np.zeros((N, N), dtype=c16)
    for j in prange(N):
        result[0][j] = MM[1][j] - 2*MM[0][j]

        for i in prange(1, N-1):
            result[i][j] = MM[i+1][j] + MM[i-1][j] - 2*MM[i][j]

        result[N-1][j] = MM[N-2][j] - 2*MM[N-1][j]

    return result / (Δx**2)


@njit(cache=True, parallel=do_parralel)
def dy_square(MM, N, Δx):
    result = np.zeros((N, N), dtype=c16)
    for j in prange(N):

        result[j][0] = MM[j][1] - 2*MM[j][0]

        for i in prange(1, N-1):
            result[j][i] = MM[j][i+1] + MM[j][i-1] - 2*MM[j][i]

        result[j][N-1] = MM[j][N-2] - 2*MM[j][N-1]
        
    return result / (Δx**2)



def apply_obstacle(MM, N, meshX, meshY):
    for i in range(N):
        for j in range(N):
            if isObstacle(meshX[i][j], meshY[i][j]):
                MM[i][j] = 0 + 0j
    return MM


def getAdjPos(x, y, N):
    res = []
    res.append((x-1,y))
    res.append((x+1,y))
    res.append((x, y - 1))
    res.append((x,y+1))
    res.append((x - 1,y+1))
    res.append((x - 1,y-1))
    res.append((x + 1,y+1))
    res.append((x+1, y+1))
    return res
    

def clear():
    os.system('cls')

def launch(filename):
    os.system(filename)

    
@njit(cache=True)  
def integrate(MM, N, Δx):
    a = 0
    air = Δx*Δx/2
    for i in range(N-1):
        for j in range(N-1):
            AA, AB, BA, BB = MM[i][j], MM[i][j+1], MM[i+1][j], MM[i+1][j+1]
            a += air*(AA+AB+BA)/3
            a += air*(BB+AB+BA)/3
    return a

# Main

In [None]:
class Simulation:
    
    def __init__(self, N, size, delta_t, V, N_iter):

        self.N = N # dimension in number of points of the simulation
        self.SIZE = size
        self.Δt = delta_t
        self.Δx = self.SIZE/self.N
        self.N_iter = N_iter

        
        #Potential as a function of x and y
        self.field = Field()
        self.field.setPotential(V) # Ex: x**2+y**2"
        self.potential_boudnary = []


        #Obstacle: boolean expression in fct of x and y (set to False if you do not want an obstacle)
        obstacles = "False"
        self.field.setObstacle(obstacles)            
        print("{:.4f} GB of memory".format(16*self.N_iter*self.N*self.N*1e-9))
            
            
    def simulation_initialize(self, x0=[0], y0=[0], k_x=[5.0], k_y=[5.0], a_x=[0.2], a_y=[0.2]):
        wall_potential = 1e10
        
        N = self.N
        SIZE = self.SIZE
        Δx = self.SIZE/self.N
        Δt = self.Δt
        
        self.counter = 0
        
        
        ######## Create points at all xy coordinates in meshgrid ######## 
        self.x_axis = np.linspace(-SIZE/2, SIZE/2, N)
        self.y_axis = np.linspace(-SIZE/2, SIZE/2, N)
        X, Y = np.meshgrid(self.x_axis, self.y_axis)
        
        
        ######## Initialize Wavepackets ######## 
        n = 0
        phase = np.exp( 1j*(X*k_x[n] + Y*k_y[n]))
        px = np.exp( - ((x0[n] - X)**2)/(4*a_x[n]**2))
        py = np.exp( - ((y0[n] - Y)**2)/(4*a_y[n]**2))
        
        Ψ = phase*px*py
        norm = np.sqrt(integrate(np.abs(Ψ)**2, N, Δx))
        self.Ψ = Ψ/norm
        
        # TODO: figure out best way to add wavefunctions (antisymmetric / symmetric combination?)
        for n in range(1,len(x0)):
            phase = np.exp( 1j*(X*k_x[n] + Y*k_y[n]))
            px = np.exp( - ((x0[n] - X)**2)/(4*a_x[n]**2))
            py = np.exp( - ((y0[n] - Y)**2)/(4*a_y[n]**2))

            Ψn = phase*px*py
            norm = np.sqrt(integrate(np.abs(Ψn)**2, N, Δx))
            
            self.Ψ += Ψn/norm
        
        NORM = np.sqrt(integrate(np.abs(self.Ψ)**2, N, Δx))
        self.Ψ = self.Ψ/NORM
        
        
        
        LAPLACE_MATRIX = sp.sparse.lil_matrix(-2*sp.sparse.identity(N*N))
        for i in range(N):
            for j in range(N-1):
                k = i*N + j
                LAPLACE_MATRIX[k,k+1] = 1
                LAPLACE_MATRIX[k+1,k] = 1

            
        ######## Create Potential ######## 
        self.V_x = np.zeros(N*N, dtype='c16')

        for j in range(N):
            for i in range(N):
                xx = i
                yy = N*j
                if self.field.isObstacle(self.x_axis[j], self.y_axis[i]):
                    self.V_x[xx+yy] = wall_potential
                else:
                    self.V_x[xx+yy] = self.field.getPotential(self.x_axis[j], self.y_axis[i])


        self.V_y = np.zeros(N*N, dtype='c16')

        for j in range(N):
            for i in range(N):
                xx = j*N
                yy = i
                if self.field.isObstacle(self.x_axis[i], self.y_axis[j]):
                    self.V_y[xx+yy] = wall_potential
                else:
                    self.V_y[xx+yy] = self.field.getPotential(self.x_axis[i], self.y_axis[j])


        self.V_x_matrix = sp.sparse.diags([self.V_x], [0])
        self.V_y_matrix = sp.sparse.diags([self.V_y], [0])

        ######## Create Hamiltonian ########
        LAPLACE_MATRIX = LAPLACE_MATRIX/(Δx ** 2)

        self.H1 = (1*sp.sparse.identity(N*N) - 1j*(Δt/2)*(LAPLACE_MATRIX))
        self.H1 = sp.sparse.dia_matrix(self.H1)

        self.HX = (1*sp.sparse.identity(N*N) - 1j*(Δt/2)*(LAPLACE_MATRIX - self.V_x_matrix))
        self.HX = sp.sparse.dia_matrix(self.HX)

        self.HY = (1*sp.sparse.identity(N*N) - 1j*(Δt/2)*(LAPLACE_MATRIX - self.V_y_matrix))
        self.HY = sp.sparse.dia_matrix(self.HY)
        
        ######## Place Obstacles ########  
        for i in range(0, N):
            for j in range(0, N):
                if self.field.isObstacle(self.x_axis[j], self.y_axis[i]):
                    adj = getAdjPos(i, j, N)
                    for xx, yy in adj:
                        coord_check = xx >= 0 and yy >= 0 and xx < N and yy <N 
                        if coord_check and not self.field.isObstacle(self.x_axis[yy], self.y_axis[xx]):
                            self.potential_boudnary.append((i, j))
        

        self.start_time = time()
        self.i_time = time()
        
        
        
    def evolve(self):
        
        Ψ = self.Ψ
        
        vector_wrt_x = x_concatenate(Ψ, self.N)
        vector_deriv_y_wrt_x = x_concatenate(dy_square(Ψ, self.N, self.Δx), self.N)
        U_wrt_x = vector_wrt_x + (1j*self.Δt/2 )*(vector_deriv_y_wrt_x - self.V_x*vector_wrt_x)
        U_wrt_x_plus = scipy.sparse.linalg.spsolve(self.HX, U_wrt_x)
        Ψ = x_deconcatenate(U_wrt_x_plus, self.N)

        
        vector_wrt_y = y_concatenate(Ψ, self.N)
        vector_deriv_x_wrt_y = y_concatenate(dx_square(Ψ, self.N, self.Δx), self.N)
        U_wrt_y = vector_wrt_y  + (1j*self.Δt/2 )*(vector_deriv_x_wrt_y - self.V_y *vector_wrt_y)
        U_wrt_y_plus = scipy.sparse.linalg.spsolve(self.HY, U_wrt_y)
        self.Ψ = y_deconcatenate(U_wrt_y_plus, self.N)   
        
        self.counter += 1
    

    def print_update(self):

        NORM = np.sqrt(integrate(np.abs(self.Ψ)**2, self.N, self.Δx))
        report = self.counter/(self.N_iter)

        #clear()
        M = 40
        k = int(report*M)
        l = M - k
        to_print = '[' + k*'#' + l*'-'+ ']   {0:.3f} %'

        d_time = time() - self.start_time

        ETA = (time()-self.i_time) * (self.N_iter-self.counter) # (time / frame) * frames remaining
        ETA = (ETA / 60) # sec to min ... / 60 # seconds to hours
        ETA = np.modf(ETA)
        ETA = int(ETA[1]), int(round(ETA[0]*60))
        ETA = str(ETA[0]) + ":" + str(ETA[1]).zfill(2)

        self.i_time = time()
        clear_output(wait=True)
        print('--- Simulation in progress ---')
        print(to_print.format(report*100))
        print('Time elapsed : {0:.1f} s'.format(d_time))
        print(f'Estimated time remaining : {ETA}')
        print('Function norm : {0:.5f} '.format(NORM))


In [None]:
def map_to_rgb(sim, intensity):
    Ψ = sim.Ψ
    
    rgb_map = None

    if intensity:
        cmap = plt.cm.viridis
        data = np.abs(Ψ)**2
        norm = plt.Normalize(data.min(), data.max())
        rgb_map = cmap(norm(data))
        rgb_map = rgb_map[:, :, :3]
    else:
        rgb_map = colorize(Ψ)

    for i, j in sim.potential_boudnary:
        rgb_map[i][j] = 1, 1, 1

    return rgb_map

In [None]:
sim = Simulation(N=1024, size=10, delta_t=0.003, V="2*x**2 + 2*y**2 + 2*x*y", N_iter=2000)

sim.simulation_initialize(x0=[-2.3, 2.0, 2.1],  y0=[2.2,-2.6, 2.0], 
                          k_x=[5.5,-6.5,-1.0], k_y=[1.0, 1.0,-6.0], 
                          a_x=[0.4, 0.3, 0.5], a_y=[0.3, 0.5, 0.4])


def generate_rgb_data(sim):
    ΨP_arr = np.zeros((sim.N_iter, sim.N, sim.N, 3))
    for i in range(sim.N_iter):
        sim.evolve()
        sim.print_update()
        rgb_map = map_to_rgb(sim=sim, intensity=True)
        ΨP_arr[i] = rgb_map
        return sim, rgb_map
        
ΨP_arr = np.zeros((sim.N_iter, sim.N, sim.N))

for i in range(sim.N_iter):
    sim.evolve()
    sim.print_update()
    ΨP_arr[i] =  np.abs(sim.Ψ)**2

In [None]:
# X, Y = np.meshgrid(sim.x_axis, sim.y_axis)
# cmap = plt.cm.viridis
# psi = 30


# start = time() 
# for i in range(sim.N_iter):
#     i_start = time()
#     psi = psi = psi + 0.1
#     fig = plt.figure(figsize=(16,16), dpi=200)
#     ax = fig.add_subplot(111, projection='3d')
#     ax.plot_surface(X, Y, ΨP_arr[i,:,:], cmap=cmap, rstride=2, cstride=2)
#     ax.view_init(20,psi )
#     ax.set_zlim([0,0.8])
#     ax.set_xlim(-5,5)
#     ax.set_ylim(-5,5)
#     ax.get_xaxis().set_visible(False)
#     ax.get_yaxis().set_visible(False)
#     plt.axis('off')
#     clear_output(wait=True)
#     display(plt.gcf())
#     plt.close()
#     i_end = time()
#     print("\n{:.4f} seconds".format(i_end-i_start))
    

# end = time()
# print("\n{:.4f} seconds".format(end-start))

In [None]:
N_iter = sim.N_iter
psi = 30

X, Y = np.meshgrid(sim.x_axis, sim.y_axis)
cmap = plt.cm.viridis

def update_plot(frame_number, zarray, plot):
    global psi
    psi = psi + 0.1
    plot[0].remove()
    plot[0] = ax.plot_surface(X, Y, zarray[frame_number,:,:], cmap=plt.cm.viridis, rstride=1, cstride=1)
    ax.view_init(20,psi)

    
fig = plt.figure(figsize=(16,16), dpi=250)
ax = fig.add_subplot(111, projection='3d')

plot = [ax.plot_surface(X, Y, ΨP_arr[0,:,:], cmap=cmap, rstride=1, cstride=1)]
ax.view_init(20,30)
ax.set_zlim([0,0.8])
ax.set_xlim(-5,5)
ax.set_ylim(-5,5)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.axis('off')

start = time() 

anim = animation.FuncAnimation(fig, update_plot, frames=N_iter, fargs=(ΨP_arr, plot), interval=10)
anim.save('animations/2D_wavepacket_sim_long_v5.mp4', fps=30)
plt.close()

end = time()
print("\n{:.4f} seconds".format(end-start))