In [1]:
import jax.numpy as jnp

In [2]:
def selu(x, alpha = 1.67, lmbda = 1.05):
    return lmbda * jnp.where(x >0, x, alpha * jnp.exp(x) - alpha)

In [3]:
x = jnp.arange(5.0)

In [5]:
selu(x)

Array([0.       , 1.05     , 2.1      , 3.1499999, 4.2      ], dtype=float32)

In [10]:
from jax import random

In [11]:
key = random.key(1701)
x = random.normal(key, (1_000_000))
%timeit selu(x).block_until_ready()

8.22 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
from jax import jit
selu_jit = jit(selu)
_ = selu_jit(x)
%timeit selu_jit(x).block_until_ready()

2.13 ms ± 346 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

In [19]:
def apply_matrix(x):
    return jnp.dot(mat,x)

In [20]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
4.91 ms ± 477 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
import jax

In [22]:
x = jnp.arange(5)
w = jnp.array([2.,3.,4.])

In [23]:
def convolve(x,w):
    output = []
    for i in range(1, len(x) -1):
        output.append(jnp.dot(x[i-1:i+2], w))
    return jnp.array(output)

In [24]:
convolve(x, w)

Array([11., 20., 29.], dtype=float32)

In [26]:
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

In [37]:
def manually_batch_convolve(xs, ws):
    output = []
    for i in range(xs.shape[0]):
        out = convolve(xs[i], ws[i])
        output.append(out)
        
    return jnp.stack(output)
        

In [38]:
manually_batch_convolve(xs, ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

In [34]:
auto_batch_convolve = jax.vmap(convolve)

In [None]:
auto_batch_convolve(xs, ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

In [39]:
import jax
from jax import numpy as jnp
from jax import grad

In [40]:
grad_tnh = grad(jnp.tanh)

In [42]:
grad_tnh(2.0)

Array(0.07065082, dtype=float32, weak_type=True)

In [43]:
key = jax.random.key(0)
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

def predict(w, b, inputs):
    return sigmoid(jnp.dot(inputs, w) + b)

In [47]:
inputs =  jax.random.normal(key = key, shape = (4,3))
targets = jnp.array([True, True, False, True])

def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

In [49]:
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())

In [50]:
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')

# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}')

W_grad=Array([-1.2433283 ,  1.3099571 , -0.96137655], dtype=float32)
W_grad=Array([-1.2433283 ,  1.3099571 , -0.96137655], dtype=float32)
b_grad=Array(-1.7784319, dtype=float32)
W_grad=Array([-1.2433283 ,  1.3099571 , -0.96137655], dtype=float32)
b_grad=Array(-1.7784319, dtype=float32)


In [51]:
loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))


loss value 4.0585675
loss value 4.0585675


In [52]:
import jax
from jax import numpy as jnp

In [53]:
@jax.jit
def f(x):
    print("print(x ->)", x)
    y = jnp.sin(x)
    print("print(y ->)", y)
    return y

result = f(2.)

print(x ->) Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
print(y ->) Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [59]:
@jax.jit
def f(x):
    jax.debug.print("jax debug print -> {x}", x = x)
    y = jnp.sin(x)
    jax.debug.print("jax debug print -> {y}", y =y)
    return y
    
result = f(2.)

jax debug print -> 2.0
jax debug print -> 0.9092974066734314


In [67]:
key = jax.random.key(10)

In [68]:
for i in range(3):
    new_key, subkey = random.split(key)
    del key
    
    val = random.normal(subkey)
    del subkey
    
    print(f"draw {i}: {val}")
    key = new_key

draw 0: 0.7978776097297668
draw 1: 0.2311430275440216
draw 2: 0.8755434155464172


In [69]:
example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

In [70]:
example_trees

[[1, 'a', <object at 0x725632185260>],
 (1, (2, 3), ()),
 [1, {'k1': 2, 'k2': (3, 4)}, 5],
 {'a': 2, 'b': (2, 3)},
 Array([1, 2, 3], dtype=int32)]

In [73]:
for pytree in example_trees:
    leaves = jax.tree.leaves(pytree)
    jax.debug.print('leaves is {a}', a = leaves)

leaves is [Array(1, dtype=int32, weak_type=True), 'a', <object object at 0x725632185260>]
leaves is [Array(1, dtype=int32, weak_type=True), Array(2, dtype=int32, weak_type=True), Array(3, dtype=int32, weak_type=True)]
leaves is [Array(1, dtype=int32, weak_type=True), Array(2, dtype=int32, weak_type=True), Array(3, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True), Array(5, dtype=int32, weak_type=True)]
leaves is [Array(2, dtype=int32, weak_type=True), Array(2, dtype=int32, weak_type=True), Array(3, dtype=int32, weak_type=True)]
leaves is [Array([1, 2, 3], dtype=int32)]


In [74]:
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

In [75]:
jax.tree.map(lambda x: x**2, list_of_lists)

[[1, 4, 9], [1, 4], [1, 4, 9, 16]]

In [76]:
import numpy as np

In [81]:
def init_mlp_params(layer_widths):
    params = []
    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        res = dict(
            weights = np.random.normal(size = (n_in, n_out)) * np.sqrt(2/n_in),
            biases = np.ones(shape = (n_out))
            
        )
        params.append(res)
    return params

In [82]:
params = init_mlp_params([1, 128, 128,1])

In [83]:
*hidden, last = params

In [87]:
def forward(params, x):
    *hidden, last = params
    for layer in hidden:
        x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
    return last['weights'] @ x + last['biases']

In [88]:
def loss_fn(params, x, y):
    return jnp.mean((forward(params, x) -y)**2)

learning_rate = 0.0001

In [89]:
@jax.jit
def update(params, x, y):
    grads = jax.grad(loss_fn)(params,x,y)
    return jax.tree.map(
        lambda p,g: p - learning_rate * g, params, grads
    )