In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import tkinter as tk
from matplotlib.animation import FuncAnimation

x = np.linspace(-10, 10, 5000)
deltax = x[1] - x[0]

def norm(phi):
    norm = np.sum(np.square(np.abs(phi))) * deltax
    return phi / np.sqrt(norm)

def complex_plot(x, y, prob=True, **kwargs):
    real = np.real(y)
    imag = np.imag(y)
    a, *_ = plt.plot(x, real, label='Re', **kwargs)
    b, *_ = plt.plot(x, imag, label='Im', **kwargs)
    plt.xlim(-2, 2)
    if prob:
        p, *_ = plt.plot(x, np.abs(y), label='$\sqrt{P}$')
        return a, b, p
    else:
        return a, b

def wave_packet(pos=0, mom=0, sigma=0.2):
    return norm(np.exp(-1j * mom * x) * np.exp(-np.square(x - pos) / sigma / sigma, dtype=complex))

def d_dxdx(phi, x=x):
    dphi_dxdx = -2 * phi
    dphi_dxdx[:-1] += phi[1:]
    dphi_dxdx[1:] += phi[:-1]
    return dphi_dxdx / deltax

def d_dt(phi, h=1, m=100, V=0):
    return 1j * h / 2 / m * d_dxdx(phi) - 1j * V * phi / h

def rk4(phi, dt, **kwargs):
    k1 = d_dt(phi, **kwargs)
    k2 = d_dt(phi + dt / 2 * k1, **kwargs)
    k3 = d_dt(phi + dt / 2 * k2, **kwargs)
    k4 = d_dt(phi + dt * k3, **kwargs)
    return phi + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)

def simulate(phi_sim, method='rk4', V=0, steps=100000, dt=1e-1, condition=None, normalize=True, save_every=100):
    simulation_steps = [np.copy(phi_sim)]
    for i in range(steps):
        if method == 'rk4':
            phi_sim = rk4(phi_sim, dt, V=V)
        else:
            raise Exception(f'Unknown method {method}')
        if condition:
            phi_sim = condition(phi_sim)
        if normalize:
            phi_sim = norm(phi_sim)
        if save_every is not None and (i + 1) % save_every == 0:
            simulation_steps.append(np.copy(phi_sim))
    return simulation_steps

sim_zero_mom = simulate(wave_packet(), V=np.zeros_like(x), steps=100000, save_every=500)
*
class SimulationApp:
    def __init__(self, master):
        self.master = master
        master.title("Zero Momentum Wave Simulation")

        self.fig, self.ax = plt.subplots()
        self.canvas = FigureCanvasTkAgg(self.fig, master=master)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)

        self.sim_zero_mom = sim_zero_mom
        self.anim = self.create_animation()

        self.toolbar_frame = tk.Frame(master)
        self.toolbar_frame.pack(side=tk.BOTTOM)

        self.play_button = tk.Button(self.toolbar_frame, text="Play", command=self.play_animation)
        self.play_button.pack(side=tk.LEFT)

        self.pause_button = tk.Button(self.toolbar_frame, text="Pause", command=self.pause_animation)
        self.pause_button.pack(side=tk.LEFT)

        self.reset_button = tk.Button(self.toolbar_frame, text="Reset", command=self.reset_animation)
        self.reset_button.pack(side=tk.LEFT)

    def create_animation(self):
        def update(frame):
            prob.set_data((x, np.abs(self.sim_zero_mom[frame])))
            re.set_data((x, np.real(self.sim_zero_mom[frame])))
            im.set_data((x, np.imag(self.sim_zero_mom[frame])))
            return prob, re, im

        re, im, prob = complex_plot(x, self.sim_zero_mom[0])
        plt.xlim(-2, 2)
        plt.ylim(-2, 2)
        plt.legend()
        anim = FuncAnimation(self.fig, update, frames=len(self.sim_zero_mom), interval=50)
        return anim

    def play_animation(self):
        self.anim.event_source.start()

    /def pause_animation(self):+
        self.anim.event_source.stop()

    def reset_animation(self):
        self.anim.event_source.stop()
        self.anim = self.create_animation()

if __name__ == "__main__":
    root = tk.Tk()
    app = SimulationApp(root)
    root.mainloop()
