In [33]:
from navier_stokes_2D import progress_timestep, progress_timestep_with_particles
import jax.numpy as jnp
import matplotlib.pyplot as plt
import math
import jax.config as config

import imageio

In [34]:
"""Determine Constants."""

TIMESTEPS = 5000000
NUM_POINTS = 64    # How many points to split x and y axis into. This should be odd to make parameter comparisons later.
dt = 0.0001        # Thie size of the timesteps.

IMAGE_SCALAR = math.ceil(NUM_POINTS / 32)     # This is used to ensure the output images aren't too dense, so they can actually be made out. 

element_length = 1 / (NUM_POINTS - 1)   # The distance between points on the x and y axis.
    
"""Define the Initial and Boundary Conditions."""

# Make a meshgrid to create our domain.
x = jnp.linspace(0, 1, NUM_POINTS)      
y = jnp.linspace(0, 1, NUM_POINTS)

X, Y = jnp.meshgrid(x, y)

# Define the boundary conditions.
u_bound = jnp.zeros_like(X)
v_bound = jnp.zeros_like(X)
u_bound = u_bound.at[-1, :].set(1)

# Define the initial conditions. 
u_prev = jnp.zeros_like(X)
v_prev = jnp.zeros_like(X)
p_prev = jnp.zeros_like(X)

x_prev = jnp.array([0.25, 0.75])
y_prev = jnp.array([0.25, 0.75])

dx_dt_prev = jnp.zeros_like(x_prev)
dy_dt_prev = jnp.zeros_like(x_prev)

In [35]:
"""Run training loop and save images."""
for i in range(TIMESTEPS):
    # update the velocities and pressure.
    u_prev, v_prev, p_prev, x_prev, y_prev, dx_dt_prev, dy_dt_prev = progress_timestep_with_particles(u_prev, v_prev, p_prev, x_prev, 
                                                                                                      y_prev, dx_dt_prev, dy_dt_prev, 
                                                                                                      u_bound, v_bound, element_length,    
                                                                                                      drag_constant=99999999, dt=dt, density=1., 
                                                                                                      viscosity=0.1, jacobi_iterations=50)
    
    # Print an image of the current state every 10000 timesteps. 
    if i % 100 == 0:

        plt.figure()
        plt.contourf(X[::IMAGE_SCALAR, ::IMAGE_SCALAR], Y[::IMAGE_SCALAR, ::IMAGE_SCALAR], p_prev[::IMAGE_SCALAR, ::IMAGE_SCALAR], 100, cmap="coolwarm")
        plt.scatter(x_prev, y_prev)
        plt.colorbar()

        plt.quiver(X[::IMAGE_SCALAR, ::IMAGE_SCALAR], Y[::IMAGE_SCALAR, ::IMAGE_SCALAR], u_prev[::IMAGE_SCALAR, ::IMAGE_SCALAR], v_prev[::IMAGE_SCALAR, ::IMAGE_SCALAR])

        plt.savefig(f'img_{int(i/100)}.png',
                  transparent = False,  
                  facecolor = 'white'
                )
        plt.close()

KeyboardInterrupt: 

In [20]:
"""Make images into a gif."""
frames = []
for i in range(100):
  image = imageio.v2.imread(f'img_{int(i)}.png')
  frames.append(image)

imageio.mimsave('./example.gif', # output gif
                frames,          # array of input frames
                duration=0.4)         # optional: frames per second """