In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, vjp
from jax import random
import jax

In [237]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

110 ms ± 8.06 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

107 ms ± 8.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

548 µs ± 91 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

157 µs ± 9.56 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

In [8]:
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [9]:
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


In [10]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


In [14]:
from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

In [36]:
def f(x, y):
  return jnp.append(  )

In [78]:
def f(x):
  return jnp.array( [  [ x[0,0]*x[0,0] ], [x[0,0]*x[0,0]*x[0,0]]  ] )

In [90]:
primals, f_vjp = vjp(f, jnp.array([[3.22]]))

In [91]:
primals

Array([[10.368401],
       [33.38625 ]], dtype=float32)

In [92]:
f_vjp

Partial(_HashableCallableShim(functools.partial(<function _vjp_pullback_wrapper at 0x7fc0132f4860>, 'f', [dtype('float32')], [(2, 1)], (PyTreeDef(*), PyTreeDef((*,))))), Partial(_HashableCallableShim(functools.partial(<function vjp.<locals>.unbound_vjp at 0x7fc0001c4a40>, [(ShapedArray(float32[2,1]), None)], { lambda a:i32[2] b:i32[2] c:f32[] d:f32[] e:i32[2] f:i32[2] g:f32[] h:f32[] i:i32[2]
    j:f32[] k:f32[]; l:f32[1,1]. let
    m:f32[] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=True
    ] l a
    n:f32[] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      u

In [93]:
f_vjp(jnp.array([[1.0],[0.0]]))

(Array([[6.44]], dtype=float32),)

In [97]:
x = 3.22
3*x*x

31.105200000000004

In [98]:
2*x

6.44

In [99]:
f_vjp(jnp.array([[0.0],[1.0]]))

(Array([[31.105202]], dtype=float32),)

In [100]:
primals, tangents = vjp(f, jnp.array([[3.22]]))

In [101]:
primals

Array([[10.368401],
       [33.38625 ]], dtype=float32)

In [103]:
tangents(jnp.array([[0.0],[1.0]]))

(Array([[31.105202]], dtype=float32),)

In [105]:
grad( jnp.tanh )(2.0)

Array(0.07065082, dtype=float32, weak_type=True)

In [108]:
grad( jnp.linalg.norm )( jnp.array([[1.0, 2.0],[3.0, 5.0]]) )

Array([[0.16012816, 0.32025632],
       [0.48038447, 0.8006408 ]], dtype=float32)

In [111]:
grad( jnp.linalg.inv )( jnp.array([[1.0, 2.0],[3.0, 5.0]]) )

TypeError: Gradient only defined for scalar-output functions. Output had shape: (2, 2).

In [115]:
def inverse_first(x):
    return jnp.linalg.inv(x)[0,0]

In [125]:
grad_inverse = grad( inverse_first )

In [123]:
%timeit grad_inverse( jnp.array([[1.0, 2.0],[3.0, 5.0]]) )

94.2 µs ± 3.81 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [126]:
%timeit grad_inverse( jnp.array([[1.0, 2.0],[3.0, 5.0]]) )

2.03 ms ± 44 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [121]:
jnp.linalg.inv(jnp.array([[1.0, 2.0],[3.0, 5.0]]))

Array([[-5.0000014,  2.0000005],
       [ 3.0000007, -1.0000002]], dtype=float32)

In [127]:
def loop_example(x):
    for i in range(1000):
        x = x + 1
    return x

In [128]:
%timeit loop_example(2)

19.1 µs ± 806 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [129]:
loop_example_jit = jit( loop_example )

In [132]:
loop_example_jit(2)

Array(1002, dtype=int32, weak_type=True)

In [131]:
%timeit loop_example_jit(2)

2.92 µs ± 91.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [133]:
def if_example(x):
    for i in range(1000):
        if x<10:
            x = x + 1
        else:
            x = x - 0.5
    return x

In [134]:
%timeit if_example(2)

38.7 µs ± 979 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [137]:
gradd = grad(if_example)

In [140]:
gradd(2.0)

Array(1., dtype=float32, weak_type=True)

In [141]:
if_example_jit = jit( if_example )
if_example_jit(2)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function. 
The error occurred while tracing the function if_example at /tmp/ipykernel_19980/625940139.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [156]:
def func_1(x):
    return jnp.sin(x)

def func_main(x):
    for i in range(1000):
        x = x + func_1(x)

In [157]:
func_main_jit = jit(func_main)

In [158]:
func_main_jit(2)

In [159]:
%timeit func_main(2)

6.28 ms ± 88.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [160]:
%timeit func_main_jit(2)

122 µs ± 1.83 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [161]:
%timeit func_main_jit(2)

121 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [162]:
def generate_psd_params():
    n = 4
    N = 50
    
    diag = np.random.rand(n) + n
    off_diag = np.random.rand(int( (n**2-n)/2.0 ))
    params = np.append(diag, off_diag, axis = 0).reshape(1,-1)
    
    for i in range(1,50):
        # Diagonal elements
        params_temp = np.random.rand( int(n + (n**2 -n)/2.0) ).reshape(1,-1)
        
        # ## lower Off-diagonal
        # off_diag = np.random.rand(int( (n**2-n)/2.0 ))
        
        # params_temp = np.append(diag, off_diag, axis = 0).reshape(1,-1)
        params = np.append( params, params_temp, axis = 0 )
    
    return params

In [163]:
sigma = generate_psd_params()

In [165]:
sigma.shape

(50, 10)

In [166]:
from jax import random
key = random.PRNGKey(0)

In [169]:
random.uniform(key)

Array(0.41845703, dtype=float32)

In [170]:
random.uniform(key)

Array(0.41845703, dtype=float32)

In [171]:
random.uniform(key, shape=(3,2))

Array([[0.57450044, 0.09968603],
       [0.7419659 , 0.8941783 ],
       [0.59656656, 0.45325184]], dtype=float32)

In [176]:
random.uniform(key, shape=(3,1))[:,0]

Array([0.9653214 , 0.31468165, 0.63302994], dtype=float32)

In [183]:
a = jnp.array([[ 1, 1, 1 ],[1, 1, 1]])
b = jnp.array([[1,2,3]])

In [184]:
a

Array([[1, 1, 1],
       [1, 1, 1]], dtype=int32)

In [185]:
b

Array([[1, 2, 3]], dtype=int32)

In [186]:
a * b

Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)

In [190]:
jnp.sum(a*b, axis=1)

Array([6, 6], dtype=int32)

In [195]:
x = jnp.array([[1],[1]])

In [196]:
jnp.repeat( x, 3, axis=1 )

Array([[1, 1, 1],
       [1, 1, 1]], dtype=int32)

In [197]:
x.shape

(2, 1)

In [198]:
x

Array([[1],
       [1]], dtype=int32)

In [199]:
a

Array([[1, 1, 1],
       [1, 1, 1]], dtype=int32)

In [201]:
a[:,1]

Array([1, 1], dtype=int32)

In [250]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [251]:
def f1(x):
    return jax.scipy.linalg.sqrtm(x)[0,0]

In [252]:
f1grad = grad(f1)

In [253]:
f1_jit = jit(f1)
f1grad_jit = jit(f1grad)

In [254]:
f1(jnp.array([[1.0]]))

Array(1.+0.j, dtype=complex64)

In [255]:
f1grad(jnp.array([[1.0]]))

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "<frozen runpy>", line 173, in _run_module_as_main
  File "<frozen runpy>", line 65, in _run_code
  File "/home/hardik/Desktop/Research/FORESEE/venv311/lib/python3.11/site-packages/ipykernel_launcher.py", line 0, in <module>
  File "/home/hardik/Desktop/Research/FORESEE/venv311/lib/python3.11/site-packages/traitlets/config/application.py", line 1035, in launch_instance
    @classmethod
    
  File "/home/hardik/Desktop/Research/FORESEE/venv311/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 706, in start
    def start(self):
    
  File "/home/hardik/Desktop/Research/FORESEE/venv311/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 208, in start
    def start(self) -> None:
    
  File "/usr/lib/python3.11/asyncio/base_events.py", line 593, in run_forever
    def run_forever(self):
    
  File "/usr/lib/python3.11/asyncio/base_events.py", line 1845, in _run_once
    def _run_once(self):
    
  File "/usr/lib/python3.11/

In [256]:
jnp.diag( jnp.array([1,2]) )

Array([[1, 0],
       [0, 2]], dtype=int32)

### Scan vs For

In [53]:
a = np.linspace(1,10,10)
a

Array([ 1.       ,  2.       ,  3.       ,  4.       ,  5.       ,
        6.0000005,  7.0000005,  8.       ,  9.       , 10.       ],      dtype=float32)

In [67]:
import jax.numpy as np
import time

# a = np.array([1.0, 2, 3, 5, 7, 11, 13, 17])
a = np.linspace(1,1000,1000)


def fun1(a):
    result = []
    res = 0
    for el in a:
        res += 2*el
        result.append(res)
    return result[-1]
np.array(fun1(a))

fun1_jit=jit(fun1)
t0 = time.time()
fun1_jit(a)
print(f"time jit: {time.time()-t0}")

print(fun1_jit(a))
%timeit fun1_jit(a)

fun1_grad_jit=jit(grad(fun1))
t0 = time.time()
fun1_grad_jit(a)
print(f"time grad jit: {time.time()-t0}")

# print(fun1_grad_jit(a))
%timeit fun1_grad_jit(a)



time jit: 1.3443591594696045
1001000.0
3.58 µs ± 392 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
time grad jit: 2.8755059242248535
3.38 µs ± 897 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [69]:
from jax import lax


def cumsum(res, el):
    """
    - `res`: The result from the previous loop.
    - `el`: The current array element.
    """
    res = res + 2*el
    return res, res  # ("carryover", "accumulated")

def cum_sum_scanned( result_init, a ):
    return lax.scan(cumsum, result_init, a)[0]

result_init = 0.0
t0 = time.time()
final, result = lax.scan(cumsum, result_init, a)
print(f"time scan: {time.time()-t0}")
%timeit lax.scan(cumsum, result_init, a)

cum_sum_scanned_jit = jit(cum_sum_scanned)
t0 = time.time()
cum_sum_scanned_jit( result_init, a )
print(f"time scan jit: {time.time()-t0}")
%timeit cum_sum_scanned_jit( result_init, a )


cum_sum_scanned_grad = jit(grad( cum_sum_scanned ))
t0 = time.time()
cum_sum_scanned_grad( result_init, a )
print(f"time grad jit : {time.time()-t0}")

%timeit cum_sum_scanned_grad( result_init, a )


result[-1]

time scan: 0.049346208572387695
209 µs ± 90.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
time scan jit: 0.019811630249023438
6.91 µs ± 284 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
time grad jit : 0.017626047134399414
2.5 µs ± 95 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Array(1001000., dtype=float32)

In [10]:
num_timesteps = 100
timesteps = np.arange(num_timesteps)

Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
       68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
       85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],      dtype=int32)

In [15]:
fun1_jit(a)

[Array(1, dtype=int32),
 Array(3, dtype=int32),
 Array(6, dtype=int32),
 Array(11, dtype=int32),
 Array(18, dtype=int32),
 Array(29, dtype=int32),
 Array(42, dtype=int32),
 Array(59, dtype=int32)]

In [26]:
fun1_jit_grad = grad(fun1_jit)

In [27]:
fun1_jit_grad(a)

Array([2., 2., 2., 2., 2., 2., 2., 2.], dtype=float32)

In [28]:
a

Array([ 1.,  2.,  3.,  5.,  7., 11., 13., 17.], dtype=float32)

In [37]:
b = np.ones((3,1))

In [38]:
b[:,0]

Array([1., 1., 1.], dtype=float32)

In [40]:
b[0:3,0]

Array([1., 1., 1.], dtype=float32)

In [70]:
a = np.array([ [1,2],[3,4] ])

In [71]:
a[1]

Array([3, 4], dtype=int32)

In [72]:
a[0]

Array([1, 2], dtype=int32)

In [73]:
6/0.02

300.0

In [74]:
300*10

3000

In [75]:
3/0.02

150.0

In [92]:
def custom(a):
    x = 1.0
    def body(t, inputs):
        inputs = inputs + inputs*a
        return inputs
    return lax.fori_loop(0, 2, body, x)

In [93]:
custom(2)

Array(9., dtype=float32, weak_type=True)

In [110]:
jit(custom)(2)

Array(9., dtype=float32, weak_type=True)

In [94]:
custom_grad = grad(custom)

In [105]:
custom_grad(3.0)

Array(8., dtype=float32, weak_type=True)

In [111]:
jit(custom_grad)(3.0)

Array(8., dtype=float32, weak_type=True)

In [127]:
def new_custom(a):
    x = 1.0
    state = a*x
    def body(t, inputs):
        x, state = inputs
        y = x + state
        x = y + 0.0
        new_state = a*x
        return x, new_state
    return lax.fori_loop(0, 2, body, (x, state))[0]
new_custom_jit = jit(new_custom)

In [128]:
new_custom(5.0)

Array(36., dtype=float32, weak_type=True)

In [129]:
new_custom_jit(5.0)

Array(36., dtype=float32, weak_type=True)

In [130]:
new_custom_grad = grad(new_custom)
new_custom_grad_jit = jit(new_custom_grad)

In [131]:
new_custom_grad(4.0)

Array(10., dtype=float32, weak_type=True)

In [132]:
new_custom_grad_jit(4.0)

Array(10., dtype=float32, weak_type=True)

In [133]:
a = np.array(9)

In [134]:
a

Array(9, dtype=int32, weak_type=True)

In [126]:
a.reshape(-1,1)

Array([[9]], dtype=int32, weak_type=True)