In [1]:
import jax
from jax import jit, vmap, make_jaxpr
from jax import lax
import jax.numpy as jnp
from jax import random

In [15]:
block_size = 0.5

n_blocks_x = 3
n_blocks_y = 3
inlet_idx = 3
A = jnp.arange(1,10).reshape(3,3)

def get_block_idx(x,y):
    x_idx = jnp.floor_divide(x,block_size).astype(int)
    y_idx = jnp.floor_divide(y,block_size).astype(int)
    return x_idx, y_idx

def permeablity(a, x,y):
    x_idx, y_idx = get_block_idx(x,y)
    return a[x_idx,y_idx]

def is_inlet(x,y):
    x_idx, y_idx = get_block_idx(x,y)
    c1 = lax.cond(x_idx*y_idx == inlet_idx, lambda _: True, lambda _: False, None)
    return c1

def is_outlet(x,y):
    pass

In [10]:
key = random.PRNGKey(0)

x = random.normal(key,(100000,))
y = random.normal(key,(100000,))

f1 = permeablity
f2 = vmap(permeablity,(None,0,0))
f3 = jit(f2)

g1 = is_inlet
g2 = vmap(is_inlet,(0,0))
g3 = jit(g2)

#%timeit f1(A,x,y).block_until_ready()
#%timeit f2(A,x,y).block_until_ready()
#%timeit f3(A,x,y).block_until_ready()

# %timeit g1(inlet_idx,x,y).block_until_ready()
%timeit g2(x,y).block_until_ready()
%timeit g3(x,y).block_until_ready()

2.91 ms ± 182 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
208 µs ± 5.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [23]:
inlet_velocity = 1.0

def pred(a, x,y):
    u = x*y*a
    v = 1.2
    w = 1.0
    vx = 1.0
    wy = 1.0

    return u,v,w, vx,wy

def loss_inlet(v):
    return (v-inlet_velocity)**2

def loss_equations(vx,wy):
    return (vx+wy)**2

def loss(x,y):
    a = permeablity(A,x,y)
    u,v,w, vx,wy = pred(a,x,y)

    loss = 0

    # if inlet calculate loss
    loss = lax.cond(is_inlet(x,y),lambda loss: loss + loss_inlet(v), lambda _: 0.0, loss)
    # if outlet calculate loss

    # always check divergence
    loss += loss_equations(vx,wy)


    return loss

l1 = loss
l2 = vmap(l1,(0,0))
l3 = jit(l2)

#%timeit l1(inlet_idx,x,y).block_until_ready()
%timeit l2(x,y).block_until_ready()
%timeit l3(x,y).block_until_ready()

9.31 ms ± 313 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
213 µs ± 4.22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [16]:
is_inlet

<function __main__.is_inlet(x, y)>