-
Notifications
You must be signed in to change notification settings - Fork 98
/
equations.py
195 lines (165 loc) · 6.63 KB
/
equations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Examples of defining equations."""
import functools
from typing import Callable, Optional
import jax
import jax.numpy as jnp
from jax_cfd.base import advection
from jax_cfd.base import diffusion
from jax_cfd.base import grids
from jax_cfd.base import pressure
from jax_cfd.base import time_stepping
import tree_math
# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
ConvectFn = Callable[[GridVariableVector], GridArrayVector]
DiffuseFn = Callable[[GridVariable, float], GridArray]
ForcingFn = Callable[[GridVariableVector], GridArrayVector]
def sum_fields(*args):
return jax.tree.map(lambda *a: sum(a), *args)
def stable_time_step(
max_velocity: float,
max_courant_number: float,
viscosity: float,
grid: grids.Grid,
implicit_diffusion: bool = False,
) -> float:
"""Calculate a stable time step for Navier-Stokes."""
dt = advection.stable_time_step(max_velocity, max_courant_number, grid)
if not implicit_diffusion:
diffusion_dt = diffusion.stable_time_step(viscosity, grid)
if diffusion_dt < dt:
raise ValueError(f'stable time step for diffusion is smaller than '
f'the chosen timestep: {diffusion_dt} vs {dt}')
return dt
def dynamic_time_step(v: GridVariableVector,
max_courant_number: float,
viscosity: float,
grid: grids.Grid,
implicit_diffusion: bool = False) -> float:
"""Pick a dynamic time-step for Navier-Stokes based on stable advection."""
v_max = jnp.sqrt(jnp.max(sum(u.data ** 2 for u in v)))
return stable_time_step( # pytype: disable=wrong-arg-types # jax-types
v_max, max_courant_number, viscosity, grid, implicit_diffusion)
def _wrap_term_as_vector(fun, *, name):
return tree_math.unwrap(jax.named_call(fun, name=name), vector_argnums=0)
def navier_stokes_explicit_terms(
density: float,
viscosity: float,
dt: float,
grid: grids.Grid,
convect: Optional[ConvectFn] = None,
diffuse: DiffuseFn = diffusion.diffuse,
forcing: Optional[ForcingFn] = None,
) -> Callable[[GridVariableVector], GridVariableVector]:
"""Returns a function that performs a time step of Navier Stokes."""
del grid # unused
if convect is None:
def convect(v): # pylint: disable=function-redefined
return tuple(
advection.advect_van_leer_using_limiters(u, v, dt) for u in v)
def diffuse_velocity(v, *args):
return tuple(diffuse(u, *args) for u in v)
convection = _wrap_term_as_vector(convect, name='convection')
diffusion_ = _wrap_term_as_vector(diffuse_velocity, name='diffusion')
if forcing is not None:
forcing = _wrap_term_as_vector(forcing, name='forcing')
@tree_math.wrap
@functools.partial(jax.named_call, name='navier_stokes_momentum')
def _explicit_terms(v):
dv_dt = convection(v)
if viscosity is not None:
dv_dt += diffusion_(v, viscosity / density)
if forcing is not None:
dv_dt += forcing(v) / density
return dv_dt
def explicit_terms_with_same_bcs(v):
dv_dt = _explicit_terms(v)
return tuple(grids.GridVariable(a, u.bc) for a, u in zip(dv_dt, v))
return explicit_terms_with_same_bcs
# TODO(shoyer): rename this to explicit_diffusion_navier_stokes
def semi_implicit_navier_stokes(
density: float,
viscosity: float,
dt: float,
grid: grids.Grid,
convect: Optional[ConvectFn] = None,
diffuse: DiffuseFn = diffusion.diffuse,
pressure_solve: Callable = pressure.solve_fast_diag,
forcing: Optional[ForcingFn] = None,
time_stepper: Callable = time_stepping.forward_euler,
) -> Callable[[GridVariableVector], GridVariableVector]:
"""Returns a function that performs a time step of Navier Stokes."""
explicit_terms = navier_stokes_explicit_terms(
density=density,
viscosity=viscosity,
dt=dt,
grid=grid,
convect=convect,
diffuse=diffuse,
forcing=forcing)
pressure_projection = jax.named_call(pressure.projection, name='pressure')
# TODO(jamieas): Consider a scheme where pressure calculations and
# advection/diffusion are staggered in time.
ode = time_stepping.ExplicitNavierStokesODE(
explicit_terms,
lambda v: pressure_projection(v, pressure_solve)
)
step_fn = time_stepper(ode, dt)
return step_fn
def implicit_diffusion_navier_stokes(
density: float,
viscosity: float,
dt: float,
grid: grids.Grid,
convect: Optional[ConvectFn] = None,
diffusion_solve: Callable = diffusion.solve_fast_diag,
pressure_solve: Callable = pressure.solve_fast_diag,
forcing: Optional[ForcingFn] = None,
) -> Callable[[GridVariableVector], GridVariableVector]:
"""Returns a function that performs a time step of Navier Stokes."""
del grid # unused
if convect is None:
def convect(v): # pylint: disable=function-redefined
return tuple(
advection.advect_van_leer_using_limiters(u, v, dt) for u in v)
convect = jax.named_call(convect, name='convection')
pressure_projection = jax.named_call(pressure.projection, name='pressure')
diffusion_solve = jax.named_call(diffusion_solve, name='diffusion')
# TODO(shoyer): refactor to support optional higher-order time integators
@jax.named_call
def navier_stokes_step(v: GridVariableVector) -> GridVariableVector:
"""Computes state at time `t + dt` using first order time integration."""
convection = convect(v)
accelerations = [convection]
if forcing is not None:
# TODO(shoyer): include time in state?
f = forcing(v)
accelerations.append(tuple(f / density for f in f))
dvdt = sum_fields(*accelerations)
# Update v by taking a time step
v = tuple(
grids.GridVariable(u.array + dudt * dt, u.bc)
for u, dudt in zip(v, dvdt))
# Pressure projection to incompressible velocity field
v = pressure_projection(v, pressure_solve)
# Solve for implicit diffusion
v = diffusion_solve(v, viscosity, dt)
return v
return navier_stokes_step