In [7]:
import numpy as np
from scipy import sparse as sp
from scipy.sparse import linalg as spla
import matplotlib.pyplot as plt

#Define function that takes in some wavefunction, potential, minx, maxx, dx, mint, maxt, dt
def sch_solve(psi_init,  
              minx = 0, maxx = 10, nx = 10, x0 = 5, sigmax=0.1, vx=0,
              miny = 0, maxy = 10, ny = 10, y0 = 5, sigmay=0.1, vy=0,
              mint = 0, maxt = 10, N = 200, m1 = 1, m2 = 1, lam = 1):

    #Define X and Y matrices
    x = np.linspace(minx, maxx, nx+2)[1:nx+1]
    y = np.linspace(miny, maxy, ny+2)[1:ny+1]
    hx = x[1] - x[0]
    hy = y[1] - y[0]

    #Set up initial data
    X,Y = np.meshgrid(x,y)
    Z = psi_init(X,Y,x0,y0,vx,vy,sigmax,sigmay)

    #Set up zeros for psi
    psi = np.zeros((N+1,nx * ny), dtype = 'complex')
    psi[0,:] = np.reshape(Z,nx*ny,order='C')

    #Normalize
    norm = hy*hx*np.vdot(psi[0,:], psi[0,:])
    psi[0,:] = norm**(-0.5)*psi[0,:]

    #Define x and y propely
    Px2 = (1/(2*m1))*sp.diags([1/hx**2, -2/hx**2, 1/hx**2], [-1, 0, 1], shape=(nx, nx))
    Py2 = (1/(2*m2))*sp.diags([1/hy**2, -2/hy**2, 1/hy**2], [-1, 0, 1], shape=(ny, ny))
    Y2 = sp.diags(np.vectorize(lambda x : x**2)(y))
    PotXmY = sp.diags(np.vectorize(lambda x,y,lam : lam*np.exp(x-y))(x,y,lam))

    #Build the official A matrix 
    fnd = 1j*(sp.kron(sp.identity(ny),Px2) - sp.kron(Py2, sp.identity(nx)))
    fnd = fnd + 1j*((1/(2*m2))*Y2)
    fnd = fnd + 1j * PotXmY
    A = sps.identity(nx*ny) - (fnd)*(maxt/(2*N))
    B = sps.identity(nx*ny) + (fnd)*(maxt/(2*N))

    #Run the CN scheme to solve
    for t in range(1,N):
        b = B.dot(psi[t-1,:])
        psi[i,:] = spla.spsolve(A,b)

    #Return results
    return(psi)


#Set up primary wavefunction defining function
def wavefunction(x, y, x0, y0, vx, vy, sigmax, sigmay):
    return(np.exp(-vx*1j*x)*np.exp(-vy*1j*y)*np.exp(-(x-x0)**2/4/sigmax**2)*np.exp(-(y-y0)**2/4/sigmay**2))

#def plotCNContour(data, mint, maxt, dt, minx, maxx, dx, phi):
#    M = np.trunc((maxx-minx)/dx).astype(int)
#    N = np.trunc((maxt-mint)/dt).astype(int)#
#
#    X = np.linspace(minx, stop = maxx-dx, num = M)*2*np.pi/maxx
#    T = np.linspace(mint, stop = maxt-dx, num = N)
#    absdat = np.square(np.abs(data))
#    absphi = -np.real(phi)
#   
#    maxPsi = np.max(absdat)
#    if(maxPsi > 2):
#        maxPsi = 2
#    maxPhi = np.max(absphi)
#    
#    ratio = maxPsi/maxPhi
#
#    X, T = np.meshgrid(X,T)
#    fig, ax = plt.subplots()
#    cs = ax.contour(X,T,absdat, levels = 15)
#    ax.set_title(r"$|\Psi(x,t)|^2$")
#    ax.set_xlabel(r"Position")
#    ax.set_ylabel(r"Time (t)")

In [8]:
#ORDER OF VARIABLES
#psi_init,  
#minx = 0, maxx = 10, nx = 10, x0 = 5, sigmax=0.1, vx=0,
#miny = 0, maxy = 10, xy = 10, y0 = 5, sigmay=0.1, vy=0,
#mint = 0, maxt = 10, N = 200, m1 = 1, m2 = 1, lam = 1
psi = np.abs(sch_solve(wavefunction))**2

ValueError: inconsistent shapes