Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extra field variables as output for the step function #224

Closed
scaomath opened this issue Apr 11, 2024 · 1 comment
Closed

Add extra field variables as output for the step function #224

scaomath opened this issue Apr 11, 2024 · 1 comment

Comments

@scaomath
Copy link

Current

@jax.named_call
def navier_stokes_step(v: GridVariableVector) -> GridVariableVector:
"""Computes state at time `t + dt` using first order time integration."""
# Collect the acceleration terms
convection = convect(v)
accelerations = [convection]
if viscosity is not None:
diffusion_ = tuple(diffuse(u, viscosity / density) for u in v)
accelerations.append(diffusion_)
if forcing is not None:
# TODO(shoyer): include time in state?
force = forcing(v)
accelerations.append(tuple(f / density for f in force))
dvdt = sum_fields(*accelerations)
# Update v by taking a time step
v = tuple(
grids.GridVariable(u.array + dudt * dt, u.bc)
for u, dudt in zip(v, dvdt))
# Pressure projection to incompressible velocity field
v = pressure_projection(v, pressure_solve)
return v
return navier_stokes_step

For example, for this function, only the velocity variable is the output, in the function call, it is like this:

step_fn = funcutils.repeated(
      collocated.equations.semi_implicit_navier_stokes(
          density=density, viscosity=viscosity, dt=dt, grid=grid),
      steps=inner_steps)
  rollout_fn = jax.jit(funcutils.trajectory(step_fn, outer_steps))
  _, trajectory = jax.device_get(rollout_fn(v0))

where v0 is a GridVariableVector.

My temp hack

Let us say if we want the time derivative to be the output as well, this is my ugly way to do it:
basically it concats a dummy tensor after the velocity

@jax.named_call 
 def navier_stokes_step(v: GridVariableVector) -> GridVariableVector: 
   v, _ = (v[0], v[1]), (v[2], v[3])
   ...
   dvdt = sum_fields(*accelerations) 
   ...
   v = pressure_projection(v, pressure_solve)
   dvdt = tuple(grids.GridVariable(dudt, u.bc) for u, dudt in zip(v, dvdt)) 
   return v+dvdt 
 return navier_stokes_step 

then append a dummy GridVariable after v0, and call the same step function in the rollout,.

vt0 = tuple(grids.GridVariable(grids.GridArray(jnp.zeros_like(u.data), u.offset, grid),
                         u.bc) for u in v0)
v0 += vt0

Question

is there any template in jax to do this?

@scaomath
Copy link
Author

scaomath commented Apr 20, 2024

@pnorgaard @kochkov92 Is outputting extra field variables something the team is considering adding?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant