In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import random
from jaxdf.geometry import Domain
import jax
from jax import numpy as jnp

seeds = random.split(random.PRNGKey(42), 20)

domain = Domain((32,35), (.5,.6))
x = jnp.array([1., 2.])

In [3]:
from jaxdf.core import operator, Field

In [4]:
from jaxdf.discretization import Arbitrary
from jax import numpy as jnp
from jax.experimental import stax

init_random_params, predict = stax.serial(
    stax.Dense(32), stax.Relu,
    stax.Dense(1))

init_params = lambda seed, domain: init_random_params(seed, (len(domain.N),))[1]

def get_fun(params, x):
    return predict(params, x)

arbitrary_discr = Arbitrary(domain, get_fun, init_params)

u_arb_params = arbitrary_discr.random_field(seeds[0])
u_arbitrary = Field(arbitrary_discr, params=u_arb_params, name='u')

v_arb_params = arbitrary_discr.random_field(seeds[1])
v_arbitrary = Field(arbitrary_discr, params=v_arb_params, name='v')

In [5]:
from jaxdf.discretization import RealFourierSeries
fourier_discr = RealFourierSeries(domain)

u_fourier_params = fourier_discr.random_field(seeds[0])
u_fourier = Field(fourier_discr, params=u_fourier_params, name='u')

v_fourier_params = fourier_discr.random_field(seeds[1])
v_fourier = Field(fourier_discr, params=v_fourier_params, name='v')

# `__call__` (non-jittable)

In [6]:
u_arbitrary(x)

DeviceArray([-0.91762376], dtype=float32)

In [7]:
u_fourier(x)

DeviceArray([-0.2871864], dtype=float32)

# `get_field` (jittable)

In [8]:
u_arbitrary.get_field()(u_arb_params, x)

DeviceArray([-0.91762376], dtype=float32)

In [9]:
u_fourier.get_field()(u_fourier_params, x)

DeviceArray([-0.2871864], dtype=float32)

# `add`

In [10]:
v_arbitrary.get_field()(v_arb_params, x)

DeviceArray([-0.3396132], dtype=float32)

In [11]:
@operator()
def new_op(u, v):
    return u + v

In [12]:
out_field = new_op(u=u_arbitrary, v=v_arbitrary)
global_params = out_field.get_global_params()
out_field.get_field(0)(
    global_params, 
    {"u": u_arb_params, "v": v_arb_params}, 
    x
)

DeviceArray([-1.257237], dtype=float32)

In [13]:
f = out_field.get_field(0)
jax.make_jaxpr(f)(
    global_params, 
    {"u": u_arb_params, "v": v_arb_params}, 
    x
)

{ lambda  ; a b c d e f g h i.
  let j = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] i a
      k = add j b
      l = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda  ; a.
                                             let b = max a 0.0
                                             in (b,) }
                                 jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f0514335820>
                                 num_consts=0 ] k
      m = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] l c
      n = add m d
      o = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] i e
      p = add o f
      q = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda  ; a.
                                           

In [14]:
out_field = new_op(u=u_fourier, v=u_fourier)
global_params = out_field.get_global_params()
out_field.get_field(0)(
    global_params, 
    {"u": u_fourier_params, "v": v_fourier_params}, 
    x
)

DeviceArray([0.25046578], dtype=float32)

In [15]:
f = out_field.get_field_on_grid(0)
jax.make_jaxpr(f)(
    global_params, 
    {"u": u_fourier_params, "v": v_fourier_params}
)

{ lambda  ; a b.
  let c = add a b
  in (c,) }

In [16]:
f = out_field.get_field(0)
jax.make_jaxpr(f)(
    global_params, 
    {"u": u_fourier_params, "v": v_fourier_params},
    x
)

{ lambda a b c d e ; f g h.
  let i = add f g
      j = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(32,) ] 0.0
      k = convert_element_type[ new_dtype=float32
                                weak_type=False ] a
      l = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] 0
      m = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))
                   indices_are_sorted=True
                   unique_indices=True
                   update_consts=(  )
                   update_jaxpr={ lambda  ; a b.
                                  let 
                                  in (b,) } ] j l k
      n = convert_element_type[ new_dtype=float32
                                weak_type=False ] b
      o = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] 16
      p = scatter[ dimension_numbers=ScatterDimen

# `add_scalar`

In [25]:
@operator()
def add_two(u):
    return u + 2

In [26]:
out_field = add_two(u=u_arbitrary)
global_params = out_field.get_global_params()
out_field.get_field(0)(
    global_params, 
    {"u": u_arb_params}, 
    x
)

DeviceArray([1.0823762], dtype=float32)

In [27]:
f = out_field.get_field(0)
global_params = out_field.get_global_params()

f(global_params, {"u": u_arb_params}, x)

DeviceArray([1.0823762], dtype=float32)

In [28]:
from jax import make_jaxpr

print(make_jaxpr(f)(global_params, {"u": u_arb_params}, x))

{ lambda  ; a b c d e f.
  let g = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] f b
      h = add g c
      i = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda  ; a.
                                             let b = max a 0.0
                                             in (b,) }
                                 jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f051428d0d0>
                                 num_consts=0 ] h
      j = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] i d
      k = add j e
      l = convert_element_type[ new_dtype=float32
                                weak_type=False ] a
      m = add k l
  in (m,) }


In [29]:
out_field = add_two(u=u_fourier)

global_params = out_field.get_global_params()
out_field.get_field(0)(global_params, {"u": u_fourier_params}, x)

DeviceArray([1.712813], dtype=float32)

In [30]:
f = out_field.get_field(0)
global_params = out_field.get_global_params()

f(global_params, {"u": u_fourier_params}, x)

DeviceArray([1.712813], dtype=float32)

In [31]:
print(make_jaxpr(f)(global_params, {"u": u_fourier_params}, x))

{ lambda a b c d e ; f g h.
  let i = convert_element_type[ new_dtype=float32
                                weak_type=False ] f
      j = add g i
      k = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(32,) ] 0.0
      l = convert_element_type[ new_dtype=float32
                                weak_type=False ] a
      m = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] 0
      n = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))
                   indices_are_sorted=True
                   unique_indices=True
                   update_consts=(  )
                   update_jaxpr={ lambda  ; a b.
                                  let 
                                  in (b,) } ] k m l
      o = convert_element_type[ new_dtype=float32
                                weak_type=False ] b
      p = broadcast_in_dim[ broadcast_dimensi

# elementwise

In [32]:
from jaxdf.operators import elementwise
from jax import numpy as jnp

Tanh = elementwise(jnp.tanh)

@operator()
def custom(u):
    return Tanh(u) + 2

In [33]:
out_field = custom(u=u_arbitrary)
global_params = out_field.get_global_params()
out_field.get_field(0)(global_params, {"u": u_arb_params}, x)

DeviceArray([1.2752286], dtype=float32)

In [34]:
print(out_field)

DiscretizedOperator :: [Arbitrary], ['_mO'] 

 Input fields: ('u',)

Globals: {'shared': {}, 'independent': {'AddScalar_l5': {'scalar': 2}}}

Operations:
- _k5: Arbitrary <-- Elementwise ('u',) | (none) Elementwise
- _mO: Arbitrary <-- AddScalar ('_k5',) | (independent) AddScalar_l5



In [35]:
from jax import make_jaxpr

f = out_field.get_field(0)
print(make_jaxpr(f)(global_params, {"u": u_arb_params}, x))

{ lambda  ; a b c d e f.
  let g = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] f b
      h = add g c
      i = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda  ; a.
                                             let b = max a 0.0
                                             in (b,) }
                                 jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f0514238160>
                                 num_consts=0 ] h
      j = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] i d
      k = add j e
      l = tanh k
      m = convert_element_type[ new_dtype=float32
                                weak_type=False ] a
      n = add l m
  in (n,) }


In [36]:
out_field = custom(u=u_fourier)
global_params = out_field.get_global_params()
out_field.get_field(0)(global_params, {"u": u_fourier_params}, x)

DeviceArray([1.718648], dtype=float32)

In [37]:
from jax import make_jaxpr

f = out_field.get_field(0)
print(make_jaxpr(f)(global_params, {"u": u_fourier_params}, x))

{ lambda a b c d e ; f g h.
  let i = tanh g
      j = convert_element_type[ new_dtype=float32
                                weak_type=False ] f
      k = add i j
      l = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(32,) ] 0.0
      m = convert_element_type[ new_dtype=float32
                                weak_type=False ] a
      n = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] 0
      o = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))
                   indices_are_sorted=True
                   unique_indices=True
                   update_consts=(  )
                   update_jaxpr={ lambda  ; a b.
                                  let 
                                  in (b,) } ] l n m
      p = convert_element_type[ new_dtype=float32
                                weak_type=False ] b
      q = broadcast_in_dim[ 

In [38]:
from jax import make_jaxpr

f = out_field.get_field_on_grid(0)
print(make_jaxpr(f)(global_params, {"u": u_fourier_params}))

{ lambda  ; a b.
  let c = tanh b
      d = convert_element_type[ new_dtype=float32
                                weak_type=False ] a
      e = add c d
  in (e,) }
