# Solutions

In [1]:
import jax.numpy as jnp
import chex

## Array Operations

Please solve the following using `Jax` library calls only (no loops!):
1. Write a function that, given $N$, returns $\sum\limits_{i=0}^{N - 1} \log(i + 1)$.
2. Write a function that, given $N$ and $M$, returns $\sum\limits_{i=0}^{N - 1} \sum\limits_{j=0}^{M - 1} \log(i \cdot j + 1)$. *Hint: one way to do this is with `jnp.tile` and `jnp.transpose`.*

In [2]:
def array_operations_q1(N):
    return jnp.sum(jnp.log(jnp.arange(N) + 1.0))

array_operations_q1(10.0)

Array(15.104412, dtype=float32)

In [6]:
def array_operations_q2(N, M):
    i = jnp.tile(jnp.arange(N)[..., None], M)
    j = jnp.tile(jnp.arange(M)[..., None], N).T

    return jnp.sum(jnp.log(i * j + 1.0))

array_operations_q2(5, 10)

Array(84.905975, dtype=float32)

## Slicing

Please solve the following using `Jax` library calls only (no loops!):
1. Write a function that, given an array `a` of shape $(N,)$, returns a new array `b` of shape $(N - 1,)$ in which index $i$ contains `b[i] = a[i + 1] - a[i]`.
2. Extend the function you wrote to multi-dimensional arrays, where the operation is only performed on the last dimension. That is, if an array `a` of shape $(N, M, K)$ is given, return an array `b` of shape $(N, M, K-1)$.

In [8]:
def array_slicing_q1(a):
    return a[1:] - a[:-1]

array_slicing_q1(jnp.array([0.0, 1.0, 5.0, 10.0, 20.0]))

Array([ 1.,  4.,  5., 10.], dtype=float32)

In [11]:
def array_slicing_q2(a):
    return a[..., 1:] - a[..., :-1]

array_slicing_q2(jnp.array([
    [0.0, 1.0, 5.0, 10.0, 20.0],
    [2.0, 5.0, 6.0, 20.0, 30.0]
]))

Array([[ 1.,  4.,  5., 10.],
       [ 3.,  1., 14., 10.]], dtype=float32)

## Indexing with Boolean Arrays

In [13]:
def boolean_indexing_q1(a):
    chex.assert_rank(a, 2)    
    positive_sum = a.sum(axis=-1) > 0.0
    return a[positive_sum]

boolean_indexing_q1(jnp.array([
    [1.0, 2.0, 3.0],
    [-1.0, -2.0, 3.0],
    [-1.0, -2.0, 4.0],
    [-1.0, -2.0, -3.0],
]))

Array([[ 1.,  2.,  3.],
       [-1., -2.,  4.]], dtype=float32)

In [20]:
def boolean_indexing_q2(N, x, y, r):
    assert(0 < N)
    assert(0 <= x and x < N)
    assert(0 <= y and y < N)
    assert(0 < r)
    
    rows = jnp.tile(jnp.arange(N)[..., None], N)
    cols = rows.T
    result = ((rows - x) ** 2.0 + (cols - y) ** 2.0 <= r ** 2.0).astype('int32')

    chex.assert_shape(result, (N, N))
    return result

print(boolean_indexing_q2(10, 3, 4, 2))
print('')
print(boolean_indexing_q2(10, 9, 8, 2))

[[0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0]
 [0 0 0 1 1 1 0 0 0 0]
 [0 0 1 1 1 1 1 0 0 0]
 [0 0 0 1 1 1 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]]

[[0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 1 1 1]
 [0 0 0 0 0 0 1 1 1 1]]


## Indexing with Arrays of Indices

In [26]:
def integer_indexing_q1(a):
    chex.assert_rank(a, 2)
    chex.assert_size(a, a.shape[0] ** 2)
    
    i = jnp.arange(a.shape[0])
    result = a[(i, i)]

    chex.assert_shape(result, (a.shape[0],))
    return result

a = jnp.arange(6 * 6).reshape(6, 6)
print(a)
print('')
print(integer_indexing_q1(a))

[[ 0  1  2  3  4  5]
 [ 6  7  8  9 10 11]
 [12 13 14 15 16 17]
 [18 19 20 21 22 23]
 [24 25 26 27 28 29]
 [30 31 32 33 34 35]]

[ 0  7 14 21 28 35]


In [29]:
def integer_indexing_q2(a, offset=0):
    chex.assert_rank(a, 2)
    chex.assert_size(a, a.shape[0] ** 2)
    
    i = jnp.arange(a.shape[0] - abs(offset)) - offset * (offset <= 0)
    j = jnp.arange(a.shape[1] - abs(offset)) + offset * (offset > 0)
    result = a[(i, j)]

    chex.assert_shape(result, (a.shape[0] - abs(offset),))
    return result

a = jnp.arange(6 * 6).reshape(6, 6)
print(a)
print('')

for offset in [-2, -1, 0, 1, 2]:
    print(f'Offset {offset}:', integer_indexing_q2(a, offset=offset))

[[ 0  1  2  3  4  5]
 [ 6  7  8  9 10 11]
 [12 13 14 15 16 17]
 [18 19 20 21 22 23]
 [24 25 26 27 28 29]
 [30 31 32 33 34 35]]

Offset -2: [12 19 26 33]
Offset -1: [ 6 13 20 27 34]
Offset 0: [ 0  7 14 21 28 35]
Offset 1: [ 1  8 15 22 29]
Offset 2: [ 2  9 16 23]


## Broadcasting

In [30]:
def broadcasting_q1(a, b):
    chex.assert_rank((a, b), 2)
    chex.assert_equal_shape_suffix((a, b), 1)
    
    result = jnp.sum((b[None, ...] - a[:, None, ...]) ** 2.0, axis=-1)

    chex.assert_shape(result, (a.shape[0], b.shape[0]))
    return result

a = jnp.array([
    [1, 2],
    [2, 4],
    [5, 6],
])

b = jnp.array([
    [5, 3],
    [4, 1],
    [6, 6],
    [7, 1],
])

broadcasting_q1(a, b)

Array([[17., 10., 41., 37.],
       [10., 13., 20., 34.],
       [ 9., 26.,  1., 29.]], dtype=float32)