In [None]:
import numpy as np
from scipy.signal import convolve2d
import matplotlib.pyplot as plt
from matplotlib.colors import SymLogNorm
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from tqdm import tqdm

plt.rcParams['animation.embed_limit'] = 2**128

In [None]:
fps = 60

pos_min = -1
pos_max = 1
N = 1000
dx = (pos_max - pos_min) / (N - 1)

dt = 0.1 * dx
T_max = 1.5
frames = int(T_max / dt)

stretch = 5
frame_select = np.rint(1 / (stretch * dt * fps)).astype(np.int64)

d = dt**2 / dx**2

source_x = int(N / 2)
source_y = 800
source_A = 1
source_omega = 2 * np.pi * 10

kernel = np.array([[0, d, 0],
                   [d, 2 - 4 * d, d],
                   [0, d, 0]])

frames, frame_select

In [None]:
prev_array = np.zeros((N, N))
curr_array = np.zeros((N, N))

solutions = [curr_array]

for i in tqdm(range(frames)):
    # Wave eq solution

    next_array = convolve2d(curr_array, kernel, mode = 'same') - prev_array
    
    prev_array = curr_array
    curr_array = next_array

    # Source setup

    t = (i + 1) * dt

    #curr_array[source_y, source_x] = source_A * np.sin(source_omega * t)
    curr_array[source_y, 400:600] = source_A * np.sin(source_omega * t)
    curr_array[source_y - 200, 0:430] = 0
    curr_array[source_y - 200, 450:550] = 0
    curr_array[source_y - 200, 570:1000] = 0

    if (i + 1) % frame_select == 0:
        solutions.append(curr_array)

len(solutions)

In [None]:
scale = 0.1

min_val = np.min(solutions)
max_val = np.max(solutions)
edge_val = np.max((np.abs(min_val), np.abs(max_val)))
edge_val_rescaled = edge_val * scale
print(min_val, max_val, edge_val)

In [None]:
fig, ax = plt.subplots(figsize = (6, 6))
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
fig.subplots_adjust(left = 0, bottom = 0, right = 1, top = 1, wspace = None, hspace = None)

#im = ax.imshow(solutions[0], vmin = np.sign(min_val) * edge_val_rescaled, vmax = np.sign(max_val) * edge_val_rescaled, cmap = 'twilight')
im = ax.imshow(solutions[0], cmap = 'twilight', norm = SymLogNorm(linthresh = edge_val_rescaled, vmin = np.sign(min_val) * edge_val, vmax = np.sign(max_val) * edge_val))
plt.close()

def frame(i):
    im.set_array(solutions[i])
    return im,

anim = FuncAnimation(fig, frame, frames = len(solutions), interval = 1000 / fps, blit = True)
HTML(anim.to_jshtml())