# Rayleigh–Taylor instability

In this setup, we simulate Rayleigh–Taylor instability

## Installing dependencies and Importing modules

In [None]:
# Installing dependencies for this notebook
%pip install einops moviepy proglog scikit-image matplotlib

In [None]:
import jax.numpy as jnp
from einops import rearrange
import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
from cglbm.lbm import grid_eq_dist
from cglbm.simulation import multi_step_simulation
from cglbm.environment import State, System

In [None]:
# from jax import config

# config.update("jax_debug_nans", True)
# config.update("jax_enable_x64", True)
# config.update("jax_platforms", "cpu")

## Simulation Setup

### Loading pre-defined environment

In [None]:
from cglbm.config import load_sandbox_config

system = load_sandbox_config("RT-instability-config.ini")

In [None]:
# to change a system parameter, use system.replace
# system = system.replace(LX=250, LY=1500)

### Initial conditions of simulation

In [None]:
LX = system.LX
LY = system.LY
X, Y = jnp.meshgrid(jnp.arange(LX), jnp.arange(LY))

grid_shape = X.shape # this is taken from meshgrid, can also be Y.shape
phase_field = jnp.zeros(grid_shape)
center = (grid_shape[0]//2, grid_shape[1]//2)

radius = system.drop_radius

### Initializing drop

In [None]:
coordinates = rearrange(jnp.stack([Y,X]), "v y x -> y x v")

dist = 2*(LX-2) + 0.1*(LX-2)* jnp.cos(2*jnp.pi*(LY - X)/(LX-2))

phase_field = jnp.where(dist < Y , jnp.full(grid_shape, 1.0), jnp.zeros(grid_shape))
# phase_field = 0.5 * (1.0 + jnp.tanh((dist - Y) * 2.0 / 20.0))

### Initializing Density, Velocity, Pressure

In [None]:
rho = system.density_one * phase_field + system.density_two * (1.0 - phase_field)
pressure = jnp.full(grid_shape, system.ref_pressure)

u_x = jnp.zeros(grid_shape)
u_y = jnp.zeros(grid_shape)
u = rearrange(jnp.stack([u_x, u_y]), "x i j -> i j x")

### Defining Obstacle

In [None]:
obs = jnp.zeros(grid_shape, dtype=bool)
obs_velX = jnp.zeros(grid_shape)
obs_velY = jnp.zeros(grid_shape)

obs = obs.at[[0, 1, -2, -1], :].set(True)

obs_vel = rearrange(jnp.stack([obs_velX, obs_velY]), "x i j -> i j x")

obs_indices = jnp.argwhere(obs)

### Plotting obstacle and phase_field

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6, 4))

im0 = axs[0].imshow(phase_field, cmap='RdBu')
axs[0].set_title("Initial phase field")

divider = make_axes_locatable(axs[0])
cax0 = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im0, cax=cax0, orientation='vertical')

im1 = axs[1].imshow(obs, cmap='binary')
axs[1].set_title("Obstacles")

divider = make_axes_locatable(axs[1])
cax1 = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im1, cax=cax1, orientation='vertical')

### Initialising f and N

In [None]:
N = grid_eq_dist(system.cXYs, system.weights, system.phi_weights, pressure, jnp.zeros(coordinates.shape))
f = jnp.einsum("ij,kij->kij",phase_field,N)

### Initialising state of the simulation

In [None]:
state = State(
    rho=rho,
    pressure=pressure,
    u=u,
    phase_field=phase_field,
    obs=obs,
    obs_velocity=obs_vel,
    obs_indices=obs_indices,
    f=f,
    N=N
)

## Running the Simulation

In [None]:
nr_iter = 1000
nr_snapshots = 10
final_results, _ = multi_step_simulation(system, state, nr_iter, nr_snapshots)

### Benchmark

In [None]:
# %timeit multi_step_simulation(system, state, 1000, 100)

## Visualizations

### Getting data from all iterations

In [None]:
ux_final, uy_final = rearrange(final_results["u"], "t i j x -> x t i j")
# ux_final[:,obstacle] = 0
# uy_final[:,obstacle] = 0
phase_field_final = final_results["phase_field"]

### Calculating vorticity

In [None]:
vorticity = (
    np.roll(ux_final, -1, axis=1) - np.roll(ux_final, 1, axis=1)
) - (
    np.roll(uy_final, -1, axis=2) - np.roll(uy_final, 1, axis=2))

### Phase field visualisation

In [None]:
# set iteration appropriately
iteration = 10

plt.imshow(phase_field_final[iteration,:,:])
plt.colorbar()

### Stream Plot

#### One per iteration

In [None]:
iteration = 1
X, Y = np.meshgrid(jnp.arange(LX), jnp.arange(LY))
plt.streamplot(X, Y, ux_final[iteration], uy_final[iteration])

#### Plotting multiple iterations

In [None]:
n = 3
fig, axes = plt.subplots(n, 1, figsize=(6, 6))

for i in range(min(n, len(ux_final))):
    axes[i].streamplot(X, Y, ux_final[i], uy_final[i])


### Quiver Plot

In [None]:
szx = LX
szy = LY
timestep = 1

# Plotting all the arrows will be messy so we are sampling one in every 16 points
skipx = 4
skipy = 4
x = np.arange(0,szx-1,skipx)
y = np.arange(0,szy-1,skipy)

xx, yy = jnp.meshgrid(x, y)

u = ux_final[timestep][:szy:skipy,:szx:skipx]
v = uy_final[timestep][:szy:skipy,:szx:skipx]

# disabled as it may fail if grid size is odd
# plt.quiver(xx, yy, u, v)

### Vorticity Plot

#### One per iteration

In [None]:
iteration = 1

plt.imshow(vorticity[iteration], cmap='RdBu')
plt.colorbar()
plt.clim(-.0001, .0001)

#### Plotting multiple iterations

In [None]:
n = 3
fig, axes = plt.subplots(n, 1, figsize=(3, 12))

for i in range(min(n, len(vorticity))):
    axes[i].imshow(vorticity[i+2], cmap='RdBu', vmin=-0.0001, vmax=0.0001)

## Creating video out of snapshots

In [None]:
import matplotlib.cm
import matplotlib.colors
from PIL import Image

def make_images(data, cmap='RdBu', vmax=None):
    images = []
    for frame in data:
        if vmax is None:
            this_vmax = np.max(abs(frame))
        else:
            this_vmax = vmax
        norm = matplotlib.colors.Normalize(vmin=-this_vmax, vmax=this_vmax)
        mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
        rgba = mappable.to_rgba(frame, bytes=True)
        image = Image.fromarray(rgba, mode='RGBA')
        images.append(image)
    return images

def save_movie(images, path, duration=100, loop=0, **kwargs):
    images[0].save(path, save_all=True, append_images=images[1:],
                 duration=duration, loop=loop, **kwargs)


In [None]:
from functools import partial
import proglog
from moviepy.editor import ImageSequenceClip

# Show Movie
proglog.default_bar_logger = partial(proglog.default_bar_logger, None)

In [None]:
# Video from vorticity plots
vorticity_images = make_images(vorticity)
ImageSequenceClip([np.array(im) for im in vorticity_images], fps=10).ipython_display()

In [None]:
# save_movie(vorticity_images,'vorticity.gif', duration=[2000]+[200]*(len(vorticity_images)-2)+[2000])

In [None]:
# Video from phase field 
phase_images = make_images(phase_field_final, cmap='RdBu')
ImageSequenceClip([np.array(im) for im in phase_images], fps=10).ipython_display()

In [None]:
# save_movie(phase_images,'phase_field.gif', duration=[2000]+[200]*(len(phase_images)-2)+[2000])