Set of exercises generated by Gemini 2.5 with the prompt:
> give me a set of 10-15 exercises for a beginner in Jax willing to learn about Automatic Vectorization. Try to cover the important topics and concepts.

In [1]:
import jax
import jax.numpy as jnp

def squared(x):
    return x**2

## Section 1 - basics of jax.

Ex 61. Use jax.vmap to create a function `batched_square` that can apply `square` to a JAX array element-wise

In [6]:
batched_square = jax.vmap(squared)

# Test the function
x = jnp.array([1, 2, 3])
print(batched_square(x))  # Output: [1 4 9]

x2 = jnp.array([[1., 2., 3], [4.5, 5, 6]])
print(batched_square(x2))  # Output: [[1 4 9], [20.25 25 36]]

[1 4 9]
[[ 1.    4.    9.  ]
 [20.25 25.   36.  ]]


Ex 62: Function with multiple scalar arguments:
Write a function `g(x, y) = x + y`; Use `jax.vmap`to apply g to two arrays of th same shape, element-wise.
What happens if arrays have different shape?


In [9]:
def g(x, y):
    return x + y

x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])

batched_g = jax.vmap(g)
print(batched_g(x, y))  # Output: [5 7 9]

x2 = jnp.array([1., 2.])
print(batched_g(x2, y))  # error.


[5 7 9]


ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 2: axis 0 of argument x of type float32[2];
  * one axis had size 3: axis 0 of argument y of type int32[3]

Ex 63: Mapping over specific axes. 
Consider function h(scalar, vector) = scalar * vector;
Array of scalars: s = jnp.array([1., 2., 3.]) and a single vector, v = jnp.array([10., 20.])
Use jax.vmapw with in_axes to apply h such that each scalar in s is multiplied by entire vector v. 

Hint: one of `in_axes` should be `0` and the other should be `None`, to indicate vector is not mapped over.


In [10]:
def h(scalar, vector):
    return scalar * vector

s = jnp.array([1., 2., 3.])
v = jnp.array([10., 20.])

batched_h = jax.vmap(h, in_axes=(0, None))
print(batched_h(s, v))  # Output: [10. 20. 30.]





[[10. 20.]
 [20. 40.]
 [30. 60.]]


Ex. 64: Broadcasting on different axis:
* modify prev exercise; you now have a single scalar s_single = 5, and an array of vectors `vs = jnp.array([[1., 2.], [3., 4.], [5., 6.]])

In [15]:
s_single = 5.
vs = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
batched_h2 = jax.vmap(h, in_axes=(None, 0))
print(batched_h2(s_single, vs))  # Output: [[5. 10.], [15. 20.], [25. 30.]]


s_vec1 = jnp.array([1., 2., 3.])
s_vec2 = jnp.array([4., 5., 6.])
batched_h3 = jax.vmap(h, in_axes=(None, 0))
print(batched_h3(s_vec1, s_vec2))  # Output: [4. 10. 18.]



[[ 5. 10.]
 [15. 20.]
 [25. 30.]]
[[ 4.  8. 12.]
 [ 5. 10. 15.]
 [ 6. 12. 18.]]


Ex 65: Controlling output structure:
* write a funciton `k(x) = (x, x*2, x*3)` that returns a tuple of scalars
* use jax.vmap to apply this to an array `inputs = jnp.array([1, 2, 3])
* by default vmap stacks outputs along axis 0; observe this
* experiment with out_axes to change output structure

In [24]:
def k(x):
    return (x, x*2, x*3)

inputs = jnp.array([1, 2, 3])

batched_k = jax.vmap(k)
print(batched_k(inputs))  #Output: [(1, 2, 3), (2, 4, 6), (3, 6, 9)]

batched_k2 = jax.vmap(k, out_axes=(0, 0, 0))
print(batched_k2(inputs))  #Output: [1, 2, 3, 2, 4, 6, 3, 6, 9]

# TODO: revisit this.

(Array([1, 2, 3], dtype=int32), Array([2, 4, 6], dtype=int32), Array([3, 6, 9], dtype=int32))
(Array([1, 2, 3], dtype=int32), Array([2, 4, 6], dtype=int32), Array([3, 6, 9], dtype=int32))
