# Introduction to `Jax` Array Operations

(This tutorial has been adapted from [this](https://numpy.org/doc/stable/user/quickstart.html) Numpy tutorial).

`Jax` is a Python library for high-performance numerical computing and array operations. Unlike in previous classes you may have taken, where you were encouraged to write loops, if/else-expressions, in Jax we'll take a different paradigm: we will avoid writing loops and if/else-expressions, instead relying on `Jax` library calls.
This is because, in contrast to vanilla Python code (which is slow), `Jax` is very, very fast (and has some other features we will rely on later).

The main part of `Jax` you'll be interacting with (for now) comes from:

In [1]:
import jax.numpy as jnp

This portion of `Jax` introduces its main building block multidimensional arrays. You may be wondering, why do we need arrays for ML? This is because, as it turns out, all of the code we will write can be conveniently expressed in terms of array operations. `Jax` will make these fast for us so we can spend more time thinking and less time waiting for the model to fit to data.

## Properties of Arrays

So what does an array look like? Consider the following example:

In [2]:
a = jnp.array(
    [[1., 0., 0.],
     [0., 1., 2.]]
)

print(a)

[[1. 0. 0.]
 [0. 1. 2.]]


The above array has 2 dimensions. The first axis has a length of 2, the second axis has a length of 3. You can determine the shape/size/dimensionality of the area as follows:

In [3]:
print('Number of dimensions:', a.ndim)
print('Shape (length along each axis/dimension):', a.shape)
print('Size (total number of elements):', a.size)

Number of dimensions: 2
Shape (length along each axis/dimension): (2, 3)
Size (total number of elements): 6


Arrays can also store different types of elements (e.g. ints, floats, etc.). You can determine what type of elements they store as follows:

In [4]:
array_of_ints = jnp.array([1, 2, 3])
print('An array of ints has dtype:', array_of_ints.dtype)

array_of_floats = jnp.array([1.0, 2.0, 3.0])
print('An array of floats has dtype:', array_of_floats.dtype)

An array of ints has dtype: int32
An array of floats has dtype: float32


Lastly, you can convert between array types as follows:

In [5]:
print('Array of ints converted to an array of floats:', array_of_ints.astype('float32'))

Array of ints converted to an array of floats: [1. 2. 3.]


## Array Creation

There are several ways to create arrays. For example, you can create an array from a regular Python list or tuple using the `jnp.array` function. The type of the resulting array is deduced from the type of the elements in the sequences:

In [6]:
array_from_list = jnp.array([1.2, 3.5, 5.1])

The function `jnp.zeros` creates an array full of zeros, the function `jnp.ones` creates an array full of ones, and the function `jnp.empty` creates an array whose initial content is random and depends on the state of the memory. These functions all require the shape of the desired array. (Note that by default, the `dtype` of the created array is `float64`, but it can be specified via the keyword argument `dtype`).

In [7]:
array_of_zeros = jnp.zeros((3, 4))
print('Array of zeros:')
print(array_of_zeros)
print('')

array_of_ones = jnp.ones((2, 3, 4))
print('Array of ones:')
print(array_of_ones)
print('')

empty_array = jnp.empty((2, 3))
print('Empty array:')
print(empty_array)

Array of zeros:
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]

Array of ones:
[[[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]

 [[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]]

Empty array:
[[0. 0. 0.]
 [0. 0. 0.]]


To create sequences of numbers, we can use the `jnp.arange` function, which is analogous to the Python built-in range, but returns an array---by default, it counts from 0 up to the given number. A similar function you may want to check out is `jnp.linspace`.

In [8]:
print('From 0 to 10 in increments of 1:', jnp.arange(10.0))
print('From 4 to 20 in increments of 5:', jnp.arange(4.0, 20.0, 5.0))

From 0 to 10 in increments of 1: [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
From 4 to 20 in increments of 5: [ 4.  9. 14. 19.]


## Array Operations

Arithmetic operators on arrays apply elementwise. By this, we mean that the operation is applied independently to each element of the original array. A new array is then created and filled with the result.

In [9]:
a = jnp.array([20, 30, 40, 50]).astype('float32')
b = jnp.arange(4.0)

print('a           =', a)
print('b           =', b)
print('')
print('a - b       =', a - b)
print('b^2         =', b ** 2.0) # the double-star means power
print('10 * sin(a) =', 10.0 * jnp.sin(a))
print('a < 35      =', a < 35)
print('a * b       =', a * b)

a           = [20. 30. 40. 50.]
b           = [0. 1. 2. 3.]

a - b       = [20. 29. 38. 47.]
b^2         = [0. 1. 4. 9.]
10 * sin(a) = [ 9.129453  -9.880316   7.4511313 -2.6237485]
a < 35      = [ True  True False False]
a * b       = [  0.  30.  80. 150.]


## Shape Manipulation

## Indexing