In [None]:
! pip install git+https://github.com/endolith/complex_colormap

In [None]:
from time import time

import numpy as np
from scipy.integrate import solve_ivp

import matplotlib.pyplot as plt
# from complex_colormap.cplot import cplot

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

In [None]:
xmin = -10
xmax = 10
dx = 0.1
xs = np.arange(xmin, xmax, dx)

In [None]:
psi0 = np.exp(1j * 2 * xs - (xs+4)**2 / (2 * 0.5**2))
# psi0 = np.exp(1j * 2 * xs - xs**2 / 2)
# psi0 = np.exp(-xs**2/2)

psi0 = psi0.astype('D')
psi0 /= np.sum(np.abs(psi0)**2 * dx, axis=0) ** 0.5

# psi0[0] = psi0[-1] = 0

In [None]:
def d2x(ys):
    return (np.roll(ys, 1) - 2*ys + np.roll(ys, -1)) / dx**2
#     return (ys[2:] - 2*ys[1:-1] + ys[:-2])/(dx ** 2)


@(lambda f: np.vectorize(f, otypes='D'))
def v(x):
#     return x**2 / 2
    return x**4 / 25 - x**2
#     return 0
#     return np.where(np.abs(x) < 4, -4, 0)
#     return 10 * np.exp(-x**2 / (2 * 0.1**2))


def fun(t, psi):
    dpsi_dt = -1j * (-(1 / 2) * d2x(psi) + v(xs) * psi) 
    return dpsi_dt

In [None]:
fig, ax = plt.subplots(figsize=(10, 2))
ax.plot(xs, v(xs), xs, psi0)
fig.show()

In [None]:
tic = time()
sol = solve_ivp(fun, (0, 5), psi0, t_eval=np.linspace(0, 5, 101))
toc = time()
print(toc - tic)
# sol

In [None]:
# Check normalization
fig, ax = plt.subplots(figsize=(7, 2))
ax.plot(sol.t, np.sum(np.abs(sol.y)**2 * dx, axis=0))
ax.set_ylim((0, 2))
plt.show()

In [None]:
tg, xg = np.meshgrid(sol.t, xs)
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
pcm0 = axs[0].pcolormesh(xg, tg, np.real(sol.y),
              shading='gouraud', cmap='bwr')
pcm1 = axs[1].pcolormesh(xg, tg, np.abs(sol.y)**2,
                 shading='gouraud', cmap='hot')
# pcm2 = axs[2].pcolormesh(xg, tg, np.angle(sol.y),
#                  shading='gouraud', cmap='hsv',
#                  vmin=-np.pi, vmax=np.pi)

In [None]:
@interact(frame=widgets.IntSlider(min=0, max=100, continuous_update=False))
def answer(frame):
    ax = plt.gca()
    ax.plot(xs, np.real(sol.y[:, 0]),
            xs, np.real(sol.y[:, frame]))
    ax.set_title(sol.t[frame])
    ax.set_ylim((-1, 1))