# Solving the Schroedinger Equations Using the Finite Difference Itme Domain Method

10.1088/1751-8113/40/8/013

In [1]:
import dataclasses
import math

import numpy as np

from IPython import display
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from scipy.io import savemat
from tqdm import tqdm

In [2]:
@dataclasses.dataclass(frozen=True)
class Boundary:
    start: float
    end: float
    step: float

In [3]:
class DirichletSystem:
    def __init__(self, pot, time_dom, spatial_doms, bd_cond=None, max_iter=None, stop_tol=1.0e-5, prior_vecs=None, **kwargs) -> None:
        self.pbc = False
        if prior_vecs is None:
            self.priors = []
        else:
            self.priors = list(prior_vecs)
        self.pot = pot
        # 0: time, 1: X, 2: Y, 3: Z
        self.time_dom = time_dom
        self.spat_dom = list(spatial_doms)
        self.vol = np.prod([dom.step for dom in self.spat_dom])
        self.time_len = int((time_dom.end - time_dom.start) / time_dom.step) + 1
        self._init_mesh()
        # self._initialize()
        self.bd_cond = bd_cond
        self.counter = 0
        self.stop_tol = stop_tol
        self.max_iter = max_iter
        self.logs = []
    
    def _init_mesh(self):
        self.spaces = [np.linspace(dom.start, dom.end, int((dom.end - dom.start) / dom.step) + 1) for dom in self.spat_dom]
        space_meshs = [x for x in np.meshgrid(*self.spaces, indexing='ij')]
        print(space_meshs[0].shape)
        self.sol_mesh = np.zeros([*space_meshs[0].shape])
        self.sol_shape = tuple(self.sol_mesh.shape)
        self.space_meshs = np.asarray([x.flatten() for x in space_meshs])
        self.ndim = len(self.sol_shape)

    def normalize(self, psi):
        nn = np.sqrt(np.sum(psi * psi)*self.vol)
        nn = 1.0 / nn
        out = psi * nn
        return out

    def _initialize(self, psi0_grid):
        self.sol_mesh = psi0_grid
        self._apply_boundary()

        pot_grid = np.zeros(np.prod(self.sol_shape))
        for i, xi in enumerate(zip(*self.space_meshs)):
            pot_grid[i] = self.pot(*xi)
        self.pot_grid = pot_grid.reshape(self.sol_shape)

#     def set_psi0_by_func(self, func):
#         psi0_grid = np.zeros(np.prod(self.sol_shape))
#         for i, xi in enumerate(zip(*self.space_meshs)):
#             psi0_grid[i] = func(*xi)
#         psi0_grid = psi0_grid.reshape(self.sol_shape)
#         self._initialize(psi0_grid)

    def set_psi0_by_grid(self, psi0_grid):
        self._initialize(psi0_grid)

    def _apply_boundary(self):
        if (self.bd_cond is None) and (self.pbc == False):
            sx = slice(None, None, None)
            for i, l in enumerate(self.sol_shape):
                ind = tuple([0 if i == j else sx for j in range(self.ndim)])
                self.sol_mesh[ind] = 0.0
                ind = tuple([l - 1 if i == j else sx for j in range(self.ndim)])
                self.sol_mesh[ind] = 0.0
                
    def _prepare(self, *args, **kwargs):
        psi = self.sol_mesh
        for vec in self.priors:
            inner = np.sum(psi * self.normalize(vec) * self.vol)
            self.sol_mesh -= inner * vec

    def solve(self):
        dt = self.time_dom.step
        self.energy_series = np.zeros(self.time_len)
        sx = slice(None, None, None)

        beta = 1.0 / (1.0 + 0.5 * dt * self.pot_grid)
        alpha = (1.0 - 0.5 * dt * self.pot_grid) * beta
        self._prepare()
        
        if self.max_iter is not None and self.max_iter < self.time_len:
            t_len = self.max_iter
        else:
            t_len = self.time_len
    
        kine = np.zeros(self.sol_shape)
        term = np.zeros(self.sol_shape)
        for n in tqdm(range(0, t_len)):
            self.counter = n
            kine.fill(0)
            px = self.sol_mesh

            self._prepare()
            
            for i in range(self.ndim):
                if self.pbc:
                    ind0 = tuple([0 if i == j else sx for j in range(self.ndim)])
                    ind1 = tuple([1 if i == j else sx for j in range(self.ndim)])
                    ind2 = tuple([-1 if i == j else sx for j in range(self.ndim)])
                    ind3 = tuple([-2 if i == j else sx for j in range(self.ndim)])
                    
                    px[ind2] = px[ind0] + (1/3) * px[ind3] - (1/3) * px[ind1]
                
                ind1 = tuple([slice(1, -1, None) if i == j else sx for j in range(self.ndim)])
                ind2 = tuple([slice(0, -2, None) if i == j else sx for j in range(self.ndim)])
                ind3 = tuple([slice(2, None, None) if i == j else sx for j in range(self.ndim)])

                term = px[ind2] + px[ind3] - 2.0 * px[ind1]

                if self.pbc:
                    ind0 = tuple([0 if i == j else sx for j in range(self.ndim)])
                    ind1 = tuple([1 if i == j else sx for j in range(self.ndim)])
                    ind2 = tuple([-1 if i == j else sx for j in range(self.ndim)])
                    ind3 = tuple([-2 if i == j else sx for j in range(self.ndim)])
                    
                    term += px[ind3]
                    term += px[ind1]
                    term -= px[ind0]
                    term -= px[ind2]
                
                ind1 = tuple([slice(1, -1, None) if i == j else sx for j in range(self.ndim)])
                term /= (2.0 * self.spat_dom[i].step ** 2)
                kine[ind1] += term

            energy_before = np.sum(self.pot_grid * px * px) - (kine * px).sum()
            energy_before /= (px * px).sum()

            self.energy_series[n] = energy_before # * self.vol ** 2
            
            if n + 1 < self.time_len:
                self.sol_mesh = alpha * self.sol_mesh
                self.sol_mesh += dt * beta * kine
                # self._apply_boundary(n + 1)
                
                psi = self.sol_mesh
                nn2 = np.sum(psi * psi) * self.vol
                self.logs.append(np.sqrt(nn2))
                
                self.sol_mesh = self.normalize(self.sol_mesh)

            if n > 0 and self.stop_tol is not None:
                error = np.abs(1 - self.energy_series[n - 1] / energy_before)
                if error < self.stop_tol:
                    break
            
            if n % 100 == 0:
                np.save('psi', self.sol_mesh)
    def get_solution(self):
        return self.spaces, self.sol_mesh
    
    def get_full_mesh(self):
        return self.spaces, self.sol_mesh

    def get_energy(self):
        return self.energy_series

In [4]:
x0, x1, dx = -10.0, 10.0, 0.07
ndim = 3
t0, t1, dt = 0.0, 1.5, dx*dx*0.25



dom_t = Boundary(t0, t1, dt)
dom_x = Boundary(x0, x1, dx)
space_dom = [dom_x] * ndim

def pot(x, y, z, eps=1.0e-05, cut=1.0/dx):
    r = np.sqrt(x*x+y*y+z*z)
    return 0.0 if r < 1.0/cut else cut - 1.0/r
#     return 0.0

In [5]:
# sys1 = DirichletSystem(pot=pot, time_dom=dom_t, spatial_doms=space_dom, stop_tol=1.0e-5, max_iter=None)
sys1 = DirichletSystem(pot=pot, time_dom=dom_t, spatial_doms=space_dom, stop_tol=None, max_iter=None)

X, Y, Z = sys1.spaces
X, Y, Z = np.meshgrid(X, Y, Z, indexing='ij')
R = np.sqrt(X*X+Y*Y+Z*Z)

psi0 = np.load('psi.npy')
# psi0 = np.exp(-R) + 0.1*np.random.randn(*sys1.sol_shape)

sys1.set_psi0_by_grid(psi0)
sys1.solve()

_, u = sys1.get_solution()
v1 = u.copy()

(286, 286, 286)


  0%|                                                                                            | 1/1225 [00:05<2:02:09,  5.99s/it]


KeyboardInterrupt: 

In [None]:
(X, Y, Z), V = sys1.get_solution()
(X1, Y1, Z1), V1 = sys1.get_solution()
fig1, ax1 = plt.subplots()
ax1.plot(X1, X1*X1*V1[100, 100, :] ** 2, 'rx', label='1s orbital')
ax1.legend()
ax1.grid()
ax1.set_title("1s orbital")
# plt.savefig("1s columb potential.png", dpi = 300)
# ax2.set_ylim(-0.1, 2.0)
plt.show(fig1)


In [None]:
f= open("energy.txt", 'a')
enr = sys1.get_energy()[-1]-1.0/dx
enr = str(enr) + "\n"
print(enr)
f.write(enr)

In [None]:
fig2, ax2 = plt.subplots()
skip = 1
# print(V**2)
ax2.plot(X, X*X*V[100, 100, :] ** 2, 'rx', label='2s orbital')
ax2.legend()
ax2.grid()
ax2.set_title("2s orbital")
plt.savefig("1s columb potential.png", dpi = 300)
# ax2.set_ylim(-0.1, 2.0)
plt.show(fig2)

np.save('psi.npy', V)
X, Y, Z = np.meshgrid(X, Y, Z, indexing='ij')
density = (X*X+Y*Y+Z*Z)*V**2
savemat('dens1.mat', {'x': X.flat, 'y': Y.flat, 'z': Z.flat, 'density': density.flat})


In [None]:
np.save('psi',v1)

In [None]:
sys1.get_energy()

In [None]:
sys2 = DirichletSystem(pot=pot, time_dom=dom_t, spatial_doms=space_dom, prior_vecs=[v1], stop_tol=None, max_iter=None)
sys2.set_psi0_by_grid(psi0)
sys2.solve()

_, u = sys2.get_solution()
v2 = u.copy()

sys3 = DirichletSystem(pot=pot, time_dom=dom_t, spatial_doms=space_dom, prior_vecs=[v1, v2], stop_tol=None, max_iter=None)
sys3.set_psi0_by_grid(psi0)
sys3.solve()

sys2 = DirichletSystem(pot=pot, time_dom=dom_t, spatial_doms=space_dom, prior_vecs=[v1], stop_tol=1.0e-4, max_iter=None)
sys2.set_psi0_by_grid(psi0)
sys2.solve()

_, u = sys2.get_solution()
v2 = u.copy()

(X, Y, Z), V = sys2.get_solution()

fig3, ax3 = plt.subplots()
skip = 1
# print(V**2)
ax3.plot(X, X*X*V[100, 100, :] ** 2, 'rx', label='$\psi_1$')
ax3.legend()
ax3.grid()
# ax2.set_ylim(-0.1, 2.0)
plt.show(fig3)

np.save('psi2.npy', V)
X, Y, Z = np.meshgrid(X, Y, Z, indexing='ij')
density = (X*X+Y*Y+Z*Z)*V**2
savemat('dens2.mat', {'x': X.flat, 'y': Y.flat, 'z': Z.flat, 'density': density.flat})
print(sys1.get_energy()[-1])

x = X[:,0, 0]

plt.plot(x, (-x**5 + 25*x**4 - 200 * x**3 + 600*x + 120)/120)
# plt.xlim(0,8)
# plt.ylim(-1,5)