## Navier Stokes Example with Flow Parallel to Top Axis

This example is a situation in which a force is making all the points on the top boundary have a velocity of 1 in the x direction. 

                    -> -> -> -> -> -> -> -> -> -> -> -> -> ->
                    |                                       |
                    |                                       |
                    |                                       |
                    |                                       |
                    |                                       |
                    |                                       |
                    |                                       |
                    |                                       |
                    |                                       |
                    |                                       |
                    |                                       |
                    |                                       |
                    |                                       |
                    |_______________________________________|

In [108]:
from navier_stokes_2D import progress_timestep
import jax.numpy as jnp
import matplotlib.pyplot as plt
import math

In [109]:
"""Determine Constants."""

TIMESTEPS = 5000
NUM_POINTS = 255    # How many points to split x and y axis into. This should be odd to make parameter comparisons later.
dt = 0.00001        # Thie size of the timesteps.

IMAGE_SCALAR = math.ceil(NUM_POINTS / 32)     # This is used to ensure the output images aren't too dense, so they can actually be made out. 

element_length = 1 / (NUM_POINTS - 1)   # The distance between points on the x and y axis.

In [110]:
"""Define the Initial and Boundary Conditions."""

# Make a meshgrid to create our domain.
x = jnp.linspace(0, 1, NUM_POINTS)      
y = jnp.linspace(0, 1, NUM_POINTS)

X, Y = jnp.meshgrid(x, y)

# Define the boundary conditions.
u_bound = jnp.zeros_like(X)
v_bound = jnp.zeros_like(X)
u_bound = u_bound.at[-1, :].set(1)

# Define the initial conditions. 
u_prev = jnp.zeros_like(X)
v_prev = jnp.zeros_like(X)
p_prev = jnp.zeros_like(X)

In [None]:
"""Run the update loop."""

for i in range(TIMESTEPS):
    # update the velocities and pressure.
    u_prev, v_prev, p_prev = progress_timestep(u_prev, v_prev, p_prev, u_bound, v_bound, element_length, dt=dt)
    
    # Print an image of the current state every 1000 timesteps. 
    if i % 1000 == 0:
        print(f"timestep: {i}, time: {i*dt}.")
        plt.figure()
        plt.contourf(X[::IMAGE_SCALAR, ::IMAGE_SCALAR], Y[::IMAGE_SCALAR, ::IMAGE_SCALAR], p_prev[::IMAGE_SCALAR, ::IMAGE_SCALAR], 100, cmap="coolwarm")
        plt.colorbar()

        plt.quiver(X[::IMAGE_SCALAR, ::IMAGE_SCALAR], Y[::IMAGE_SCALAR, ::IMAGE_SCALAR], u_prev[::IMAGE_SCALAR, ::IMAGE_SCALAR], v_prev[::IMAGE_SCALAR, ::IMAGE_SCALAR])

        plt.show()
        plt.close()

## Evaluation
To check that the iterative method is working, run it again with:

1.  The same parameters,
2.  dt / 10 and TIMESTEPS * 10,
3.  NUM_POINTS decreased (but with the same positions).

Images of the system after this has been run can then be printed, and the mean squared error of the system calculated. If the method is working, the results should be very similar. 

In [112]:
"""Define the new initial conditions."""

# Redefine the starting points for each of the 3 methods.
u_prev_standard = jnp.zeros_like(X)
v_prev_standard = jnp.zeros_like(X)
p_prev_standard = jnp.zeros_like(X)

# The notation below selects every even point. New initial values must be created for the points_halved values, since these arrays will be a different size
# to the original initial arrays.

u_prev_points_decreased = jnp.zeros_like(X[::2, ::2])
v_prev_points_decreased = jnp.zeros_like(X[::2, ::2])
p_prev_points_decreased = jnp.zeros_like(X[::2, ::2])

# redefine the starting points for the timestep varied method.
u_prev_dt_decreased = jnp.zeros_like(X)
v_prev_dt_decreased = jnp.zeros_like(X)
p_prev_dt_decreased = jnp.zeros_like(X)

In [113]:
"""Run the update loop."""

# Loop for the standard and points varied systems, since they require the same number of timesteps.
for i in range(TIMESTEPS):
    u_prev_standard, v_prev_standard, p_prev_standard = progress_timestep(u_prev_standard, v_prev_standard, 
                                                                          p_prev_standard, u_bound, v_bound, 
                                                                          element_length, dt=dt)
    
    # Note that in this loop, the boundary condition arrays also have to be decreased in size using [::2, ::2].
    u_prev_points_decreased, v_prev_points_decreased, p_prev_points_decreased = progress_timestep(u_prev_points_decreased, 
                                                                                        v_prev_points_decreased, 
                                                                                         p_prev_points_decreased, 
                                                                                         u_bound[::2, ::2],
                                                                                         v_bound[::2, ::2],                 
                                                                                         element_length, dt=dt)

# Loop for the system with decreased dt. 
for i in range(TIMESTEPS * 10):
    u_prev_dt_decreased, v_prev_dt_decreased, p_prev_dt_decreased = progress_timestep(u_prev_dt_decreased, 
                                                                                v_prev_dt_decreased, 
                                                                                p_prev_dt_decreased, 
                                                                                u_bound, v_bound, 
                                                                                element_length, 
                                                                                dt=dt/10)

In [None]:
"""Plot graphs of the systems, and print the mean squared errors of the varied systems copared to the standard ones."""

# Standard system.
plt.figure()
plt.contourf(X[::IMAGE_SCALAR, ::IMAGE_SCALAR], 
             Y[::IMAGE_SCALAR, ::IMAGE_SCALAR], 
             p_prev_standard[::IMAGE_SCALAR, ::IMAGE_SCALAR], 
             100, cmap="coolwarm")
plt.colorbar()
plt.title(f"Standard system after {TIMESTEPS*dt} seconds.")
plt.quiver(X[::IMAGE_SCALAR, ::IMAGE_SCALAR], 
           Y[::IMAGE_SCALAR, ::IMAGE_SCALAR], 
           u_prev_standard[::IMAGE_SCALAR, ::IMAGE_SCALAR], 
           v_prev_standard[::IMAGE_SCALAR, ::IMAGE_SCALAR])
plt.show()
plt.close()

# Varied number of points. 
plt.figure()
plt.contourf(X[::IMAGE_SCALAR, ::IMAGE_SCALAR], 
             Y[::IMAGE_SCALAR, ::IMAGE_SCALAR], 
             p_prev_points_decreased[::int(IMAGE_SCALAR/2), ::int(IMAGE_SCALAR/2)], 
             100, cmap="coolwarm")
plt.colorbar()
plt.title(f"system with varied number of points after {TIMESTEPS*dt} seconds.")
plt.quiver(X[::IMAGE_SCALAR, ::IMAGE_SCALAR], 
           Y[::IMAGE_SCALAR, ::IMAGE_SCALAR], 
           u_prev_points_decreased[::int(IMAGE_SCALAR/2), ::int(IMAGE_SCALAR/2)], 
           v_prev_points_decreased[::int(IMAGE_SCALAR/2), ::int(IMAGE_SCALAR/2)])
plt.show()
plt.close()

# Varied number of timesteps.
plt.figure()
plt.contourf(X[::IMAGE_SCALAR, ::IMAGE_SCALAR], Y[::IMAGE_SCALAR, ::IMAGE_SCALAR], p_prev_dt_decreased[::IMAGE_SCALAR, ::IMAGE_SCALAR], 100, cmap="coolwarm")
plt.colorbar()
plt.title(f"System with varied number of timesteps after {TIMESTEPS*dt} seconds.")

plt.quiver(X[::IMAGE_SCALAR, ::IMAGE_SCALAR], Y[::IMAGE_SCALAR, ::IMAGE_SCALAR], u_prev_dt_decreased[::IMAGE_SCALAR, ::IMAGE_SCALAR], v_prev_dt_decreased[::IMAGE_SCALAR, ::IMAGE_SCALAR])

plt.show()
plt.close()

def mse(A, B):
  return ((A - B)**2).mean()

print("Mean Squared Error of u from decreasing number of points: ", mse(u_prev_standard[::2, ::2], u_prev_points_decreased))
print("Mean Squared Error of v from decreasing number of points: ", mse(v_prev_standard[::2, ::2], v_prev_points_decreased))
print("Mean Squared Error of p from decreasing number of points: ", mse(p_prev_standard[::2, ::2], p_prev_points_decreased))

print("Mean Squared Error of u from varying timestep size: ", mse(u_prev_standard, u_prev_dt_decreased))
print("Mean Squared Error of v from varying timestep size: ", mse(v_prev_standard, v_prev_dt_decreased))
print("Mean Squared Error of p from varying timestep size: ", mse(p_prev_standard, p_prev_dt_decreased))
