In [43]:
import dataclasses

from joshpyutil import mpl
import numpy as np
import scipy.interpolate
import progressbar


t_max = 15
t_rain_stop = 4
lax_friedrichs_dt = 0.005
n = 200
terrain_r = 0.2
terrain_h = 0.3
gamma = 5
R = 0.1
g = 10
water_wave_speed = np.sqrt(g)

x, dx = np.linspace(0, 1, n, retstep=True)


def grad(field):
    spline = scipy.interpolate.Akima1DInterpolator(x, field)
    return spline.derivative()(x)


terrain = (
    terrain_h * np.exp(-(x - 1/2)**2 / terrain_r**2)
    + 0.1 * terrain_h * np.sin(14 * np.pi * (x - 1/2) + 1)
)
grad_terrain = grad(terrain)
water = np.zeros(n)
moment = np.zeros(n)


@dataclasses.dataclass
class Fields:
    height: np.ndarray[float]
    momentum: np.ndarray[float]

    def pack(self) -> np.ndarray[float]:
        stacked = np.stack([self.height, self.momentum], axis=1)
        return np.ravel(stacked)

    @classmethod
    def unpack(cls, packed: np.ndarray[float]) -> 'Fields':
        stacked = packed.reshape((n, 2))
        return cls(height=stacked[:, 0], momentum=stacked[:, 1])

    @classmethod
    def zeros(cls) -> 'Fields':
        return cls(height=np.zeros(n), momentum=np.zeros(n))


def lax_friedrichs_diff(field: np.ndarray[float]) -> np.ndarray[float]:
    return 1 / 2 / lax_friedrichs_dt * (field[2:] + field[:-2] - 2 * field[1:-1])


def time_deriv_func(t, packed) -> np.ndarray:
    fields = Fields.unpack(packed)
    
    grad_height = grad(fields.height)
    accel_grav = -g / (1 + (grad_height + grad_terrain)**2) * (grad_height + grad_terrain)
    
    height_flux = fields.momentum.copy()
    
    velocity = fields.momentum.copy()
    tiny_mask = fields.height < 0.001
    velocity[tiny_mask] = 0
    velocity[~tiny_mask] /= fields.height[~tiny_mask]
    momentum_flux = fields.momentum * velocity
    
    fields_time_deriv = Fields.zeros()
    
    if t < t_rain_stop:
        rain_rate = R
    else:
        rain_rate = 0
    
    fields_time_deriv.height[1:-1] = (
        -1/2/dx * (height_flux[2:] - height_flux[:-2])
        + lax_friedrichs_diff(fields.height)
        + rain_rate * np.exp(-((x[1:-1] - 1/2) / terrain_r)**6)
    )
    # fields_time_deriv.height[0] = fields_time_deriv.height[1]
    # fields_time_deriv.height[-1] = fields_time_deriv.height[-2]
    # fields_time_deriv.height[0] = water_wave_speed * grad_height[0]
    # fields_time_deriv.height[-1] = -water_wave_speed * grad_height[-1]
    
    negative_mask = fields.height < -0.0001
    fields_time_deriv.height[negative_mask] = -fields.height[negative_mask]
    
    fields_time_deriv.momentum[1:-1] = (
        -1/2/dx * (momentum_flux[2:] - momentum_flux[:-2])
        + accel_grav[1:-1] * fields.height[1:-1]
        - gamma * fields.momentum[1:-1]
        + lax_friedrichs_diff(fields.momentum)
    )
    # fields_time_deriv.momentum[0] = fields_time_deriv.momentum[1]
    # fields_time_deriv.momentum[-1] = fields_time_deriv.momentum[-2]
    
    return fields_time_deriv.pack()


print('Solving PDE... ')
sol = scipy.integrate.solve_ivp(
    time_deriv_func,
    (0, t_max),
    Fields.zeros().pack(),
    # method='LSODA',
    # method='BDF',
    # method='RK23',
    # method='DOP853',
)
print(f'Done, num steps: {sol.t.shape[0]}', flush=True)

sol_interp = scipy.interpolate.Akima1DInterpolator(sol.t, sol.y.T)

frame_rate_hz = 10
num_frames = 100
with mpl.autovideo('video_v2.mp4', 3, frame_rate_hz=frame_rate_hz, size_inches=(6, 6), sharex=True) as av:
    for i in progressbar.progressbar(range(num_frames)):
        t = i / num_frames * t_max
        fields = Fields.unpack(sol_interp(t))
        fields_time_deriv = Fields.unpack(time_deriv_func(t, fields.pack()))
        with av.next_frame() as ap:
            # ap.plot(x, fields.height)
            min_terrain = terrain.min()
            _ = ap.ax.fill_between(x, min_terrain, terrain, label='terrain', color='sienna')
            _ = ap.ax.fill_between(x, terrain, terrain + fields.height, label='water', color='dodgerblue')
            if t < t_rain_stop:
                _ = ap.ax.fill_between(x, terrain + fields.height, terrain.max() + 0.1, label='water', color='lightgray')
            ap.set(xlim=[0, 1])

            ap = ap.next()
            ap.plot(x, fields.momentum, label='p')
            # ap.plot(x, fields.momentum / (fields.height + 1e-9), label='v')
            ap.legend()
            
            ap = ap.next()
            ap.plot(x, fields_time_deriv.momentum, label='dp/dt')
            ap.legend()

Solving PDE... 
Done, num steps: 2204


100% (100 of 100) |######################| Elapsed Time: 0:00:20 Time:  0:00:20
