In [2]:
import jax
from jax import jit, vmap, make_jaxpr
from jax import lax
import jax.numpy as jnp
from jax import random

In [100]:
block_size = 0.5

n_blocks_x = 3
n_blocks_y = 3
inlet_idx = 3
outlet_idx = [2, 5, 8]
A = jnp.arange(1,10).reshape(3,3)

def get_block_idx(x,y, n_blocks_x):
    x_idx = jnp.floor_divide(x,block_size).astype(int)
    y_idx = jnp.floor_divide(y,block_size).astype(int)

    return 
    return x_idx, y_idx

def permeablity(a, x,y):
    x_idx, y_idx = get_block_idx(x,y)
    return a[x_idx,y_idx]

def is_inlet(x,y):
    x_idx, y_idx = get_block_idx(x,y)
    c1 = lax.cond(x_idx*y_idx == inlet_idx, lambda _: 1, lambda _: 0,None)
    c2 = lax.cond(x == 0.0, lambda _: 1, lambda _: 0, None)
    return c1*c2 != 0

def is_outlet(x,y):
    x_idx, y_idx = get_block_idx(x,y)
    c1 = lax.cond(x_idx*y_idx == inlet_idx, lambda _: 1, lambda _: 0, (x_idx,y_idx))
    c2 = lax.cond(x == n_blocks_x*block_size, lambda _: 1, lambda _: 0, (x_idx,y_idx))
    return c1*c2 != 0

In [112]:
get_block_idx(0.0,0.75)
jnp.floor_divide(0.75,0.5)
#is_inlet(0.0,0.75)

DeviceArray(1., dtype=float32)

In [108]:
key = random.PRNGKey(0)

x = random.normal(key,(100000,))
y = random.normal(key,(100000,))

f1 = permeablity
f2 = vmap(permeablity,(None,0,0))
f3 = jit(f2)

g1 = is_inlet
g2 = vmap(is_inlet,(0,0))
g3 = jit(g2)

#%timeit f1(A,x,y).block_until_ready()
#%timeit f2(A,x,y).block_until_ready()
#%timeit f3(A,x,y).block_until_ready()

# %timeit g1(inlet_idx,x,y).block_until_ready()
%timeit g2(x,y).block_until_ready()
%timeit g3(x,y).block_until_ready()

6.8 ms ± 165 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
220 µs ± 4.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


DeviceArray(True, dtype=bool, weak_type=True)

In [106]:
inlet_velocity = 1.0

def pred(a, x,y):
    u = x*y*a
    v = 1.2
    w = 1.0
    vx = 1.0
    wy = 1.0

    return u,v,w, vx,wy

def loss_inlet(v):
    return (v-inlet_velocity)**2

def loss_outlet(u):
    return (u-0.0)**2

def loss_equations(vx,wy):
    return (vx+wy)**2

def loss(x,y):
    a = permeablity(A,x,y)
    u,v,w, vx,wy = pred(a,x,y)

    loss = 0.0
    # if inlet calculate loss
    loss = lax.cond(is_inlet(x,y),lambda loss: loss + loss_inlet(v), lambda _: loss, loss)
    # if outlet calculate loss
    loss = lax.cond(is_outlet(x,y),lambda loss: loss + loss_outlet(v), lambda _: loss, loss)
    # always check divergence
    loss += loss_equations(vx,wy)

    return loss

l1 = loss
l2 = vmap(l1,(0,0))
l3 = jit(l2)

#%timeit l1(inlet_idx,x,y).block_until_ready()
%timeit l2(x,y).block_until_ready()
%timeit l3(x,y).block_until_ready()

20.5 ms ± 560 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
212 µs ± 2.53 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [91]:
key = random.PRNGKey(0)
x = random.normal(key,(100000,))
y = random.normal(key,(100000,))

def p1(x,y):
    return x > 10 and y > 10

def p2(x,y):
    c1 = lax.cond(x > 10 and y > 10, lambda _: True, lambda _: False, None)
    return c1

def p3(x,y):
    c1 = lax.cond(x > 10 * y > 10, lambda _: True, lambda _: False, None)
    return c1

def p4(x,y):
    c1 = lax.cond(x > 10, lambda _: True, lambda _: False, None)
    c2 = lax.cond(y > 10, lambda _: True, lambda _: False, None)
    return c1 and c2

def p5(x,y):
    c1 = lax.cond(x > 10, lambda _: True, lambda _: False, None)
    c2 = lax.cond(y > 10, lambda _: True, lambda _: False, None)
    return c1*c2

def p6(x,y):
    c1 = lax.cond(x > 10, lambda _: 1, lambda _: 0, None)
    c2 = lax.cond(y > 10, lambda _: 1, lambda _: 0, None)
    return (c1*c2).astype(bool)
    # return c1*c2 != 0

def p7(x,y):
    
    p = lambda x,y : x > 10 and y > 10

    return lax.cond(p(x,y),lambda _: True, lambda _: False, None)

#pv1 = jit(vmap(p1,(0,0)))
#pv2 = jit(vmap(p2,(0,0)))
#pv3 = jit(vmap(p3,(0,0)))
#pv4 = jit(vmap(p4,(0,0)))
pv5 = jit(vmap(p5,(0,0)))
pv6 = jit(vmap(p6,(0,0)))
#pv7 = jit(vmap(p7,(0,0)))

#%timeit pv1(x,y).block_until_ready()
#%timeit pv2(x,y).block_until_ready()
#%timeit pv3(x,y).block_until_ready()
#%timeit pv4(x,y).block_until_ready()
%timeit pv5(x,y).block_until_ready()
%timeit pv6(x,y).block_until_ready()
#%timeit pv7(x,y).block_until_ready()

13.4 µs ± 34.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
67.1 µs ± 578 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
