In [None]:
import numpy as np
import matplotlib.pyplot as plt
import dedalus.public as d3
import logging
import os
from matplotlib.animation import FuncAnimation
logger = logging.getLogger(__name__)


# Parameters
Lx = 2 * np.pi  # size of domain
Nx = 128  # no of points in mesh
Ly = 2 * np.pi
Ny = 128
beta = 10      # planetary vorticity 
sigma = 1
#dealias = 3 / 2
dealias = 1
stop_sim_time = 1
timestepper = d3.SBDF2
timestep = 0.001
dtype = np.float64

# Bases
xcoord = d3.Coordinate('x')
ycoord = d3.Coordinate('y')

dist = d3.Distributor([xcoord, ycoord], dtype=dtype)
xbasis = d3.RealFourier(xcoord, size=Nx, bounds=(-np.pi, np.pi), dealias=dealias)
ybasis = d3.RealFourier(ycoord, size=Ny, bounds=(-np.pi, np.pi), dealias=dealias)

# Fields
u = dist.Field(name='u', bases=(xbasis,ybasis))
v = dist.Field(name='v', bases=(xbasis,ybasis))
psi = dist.Field(name='psi', bases=(xbasis, ybasis))
zeta = dist.Field(name='zeta', bases=(xbasis, ybasis))
tau = dist.Field(name='tau')

# Substitutions
dx = lambda A: d3.Differentiate(A, xcoord)
dy = lambda A: d3.Differentiate(A, ycoord)
delta = lambda A: dx(dx(A)) + dy(dy(A))

# Problem
problem = d3.IVP([psi,zeta, tau], namespace=locals())
#problem.add_equation("u + dy(psi) = 0")
#problem.add_equation("v - dx(psi) = 0")
problem.add_equation("zeta - delta(psi) + tau = 0")
problem.add_equation("dt(zeta) + beta*dx(psi) = dy(psi)*dx(zeta) - dx(psi)*dy(zeta)")
problem.add_equation("integ(psi) = 0")

# Initial conditions
x_0 = 0
y_0 = 0
sigma = 1
x = dist.local_grid(xbasis)
y = dist.local_grid(ybasis)
XX, YY = np.meshgrid(x, y, indexing='ij')
#u['g'] = np.zeros_like(dist.local_grid(xbasis))
#v['g'] = np.zeros_like(dist.local_grid(xbasis))
psi['g'] = 0.1 * XX * np.exp(-((XX-x_0)**2 + (YY-y_0)**2)/(2*(sigma**2)))
zeta['g'] = delta(psi).evaluate()['g']


# Solver
solver = problem.build_solver(timestepper)
solver.stop_sim_time = stop_sim_time

# Main loop
u.change_scales(1)
v.change_scales(1)
psi.change_scales(1)
zeta.change_scales(1)
zeta_list = [np.copy(zeta['g'])]
psi_list = [np.copy(psi['g'])]
t_list = [solver.sim_time]
while solver.proceed:
    
    solver.step(timestep)
    if solver.iteration % 100 == 0:
        logger.info('Iteration=%i, Time=%e, dt=%e' %(solver.iteration, solver.sim_time, timestep))
    if solver.iteration % 25 == 0:
        u.change_scales(1)
        v.change_scales(1)
        psi.change_scales(1)
        zeta.change_scales(1)
        zeta_list.append(np.copy(zeta['g']))
        psi_list.append(np.copy(psi['g']))
        t_list.append(solver.sim_time)


In [None]:
fig, ax = plt.subplots(figsize=(6,6))

def animate(i):
    """
    Animation function that updates the plot at each time step.
    
    Args:
        i (int): The current frame number (time step).
    """
    
    ax.clear()
    

    ax.set_title(f'Time: {t_list[i]:.2f}')
    ax.set_xlabel(r'$x$')
    ax.set_ylabel(r'$y$')
    

    im = ax.imshow(zeta_list[i].T, cmap='coolwarm', aspect='auto',
                   extent=[x.min(), x.max(), y.min(), y.max()],
                   origin='lower', interpolation='bicubic')
    
    return im,


ani = FuncAnimation(fig, animate, frames=10, blit=True)

ani.save('wave_animation.gif', writer='imagemagick', fps=10)

plt.show()