Most of the codes are directly borrowed from JAX official repo.
- E1 Dot product
- E2 @jit: JIT Supporting a sequence of operations
- E3 grad(): Taking derivatives
- E4 vmap(): Automatic batch computation
- E5 pmap(): Parallel programming for multiple devices