## Data Manipulation
In deep learning we will build a $y = f(x)$ function to predict the real work cases. It requires all our input, output and internal state should be able represented as numbers. Specificly, the numbers should be value, vector, matrix. In engineering side, the NumPy called it as generic concept `ndarry` or `tensor` by PyTorch or TensorFlow framework.  This notebook will demo how to represent those numbers and their basic operations with jax.numpy API.

### Getting Started 

In [1]:
import jax
import jax.numpy as jnp

In [2]:
x = jnp.arange(12)
x

Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=int32)

In [3]:
x.size

12

We can access a tensor's shape (the length along each axis) by inspecting its shape attribute. Because we are dealing with a vector here, the shape contains just a single element and is identical to the size.

In [4]:
x.shape

(12,)

We can change the shape of a tensor without altering its size or values, by invoking `reshape`.

In [5]:
X = x.reshape(2, 6)
X

Array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11]], dtype=int32)

Practitioners often need to work with tensors initialized to contain all 0s or 1s. We can construct a tensor with all elements set to 0 and a shape of (2, 3, 4) via the zeros function.

In [6]:
jnp.zeros((2, 3, 4))

Array([[[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]], dtype=float32)

Similarly, we can create a tensor with all 1s by invoking ones.

In [7]:
jnp.ones((3, 4, 5))

Array([[[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]],

       [[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]],

       [[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]]], dtype=float32)

We often wish to sample each element randomly (and independently) from a given probability distribution. For example, the parameters of neural networks are often initialized randomly. The following snippet creates a tensor with elements drawn from a standard Gaussian (normal) distribution with mean 0 and standard deviation 1.

In [8]:
jax.random.normal(jax.random.PRNGKey(0), (5, 6))

Array([[-0.28371066,  0.9368162 , -1.0050073 ,  1.4165013 ,  1.0543301 ,
         0.9108127 ],
       [-0.42656708,  0.986188  , -0.5575324 ,  0.01532502, -2.078568  ,
         0.5548371 ],
       [ 0.91423655,  0.5744596 ,  0.7227863 ,  0.12106175, -0.3237354 ,
         1.6234998 ],
       [ 0.24500391, -1.3809781 , -0.6111237 ,  0.14037248,  0.84100425,
        -1.0943578 ],
       [-1.077502  , -1.1396457 , -0.593338  , -0.15576515, -0.38321444,
        -1.1144515 ]], dtype=float32)

Finally, we can construct tensors by supplying the exact values for each element by supplying (possibly nested) Python list(s) containing numerical literals. Here, we construct a matrix with a list of lists, where the outermost list corresponds to axis 0, and the inner list corresponds to axis 1.

In [9]:
jnp.array([[1, 6, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])

Array([[1, 6, 4, 3],
       [1, 2, 3, 4],
       [4, 3, 2, 1]], dtype=int32)

### Indexing and Slicing

In [10]:
X[-1], X[1:3]

(Array([ 6,  7,  8,  9, 10, 11], dtype=int32),
 Array([[ 6,  7,  8,  9, 10, 11]], dtype=int32))

In [11]:
# JAX arrays are immutable. jax.numpy.ndarray.at index
# update operators create a new array with the corresponding
# modifications made
X_new_1 = X.at[1, 2].set(17)
X_new_1

Array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7, 17,  9, 10, 11]], dtype=int32)

If we want to assign multiple elements the same value, we apply the indexing on the left-hand side of the assignment operation. For instance, [:2, :] accesses the first and second rows, where : takes all the elements along axis 1 (column). While we discussed indexing for matrices, this also works for vectors and for tensors of more than two dimensions.

In [12]:
X_new_2 = X_new_1.at[:2, :].set(12)
X_new_2

Array([[12, 12, 12, 12, 12, 12],
       [12, 12, 12, 12, 12, 12]], dtype=int32)

### Operations

Now that we know how to construct tensors and how to read from and write to their elements, we can begin to manipulate them with various mathematical operations. Among the most useful of these are the elementwise operations. These apply a standard scalar operation to each element of a tensor. For functions that take two tensors as inputs, elementwise operations apply some standard binary operator on each pair of corresponding elements. We can create an elementwise function from any function that maps from a scalar to a scalar.

In mathematical notation, we denote such `unary` scalar operators (taking one input) by the signature $f: R->R $
. This just means that the function maps from any real number onto some other real number. Most standard operators, including unary ones like $e^x$ , can be applied elementwise.

In [13]:
jnp.exp(x)

Array([1.0000000e+00, 2.7182817e+00, 7.3890562e+00, 2.0085537e+01,
       5.4598148e+01, 1.4841316e+02, 4.0342880e+02, 1.0966332e+03,
       2.9809580e+03, 8.1030840e+03, 2.2026467e+04, 5.9874145e+04],      dtype=float32)

The common standard arithmetic operators for addition (+), subtraction (-), multiplication (*), division (/), and exponentiation (**) have all been lifted to elementwise operations for identically-shaped tensors of arbitrary shape

In [14]:
x = jnp.array([1, 2, 4, 8])
y = jnp.array([2, 3, 4, 2])

x + y, x - y, x * y, x / y, x ** y

(Array([ 3,  5,  8, 10], dtype=int32),
 Array([-1, -1,  0,  6], dtype=int32),
 Array([ 2,  6, 16, 16], dtype=int32),
 Array([0.5      , 0.6666667, 1.       , 4.       ], dtype=float32),
 Array([  1,   8, 256,  64], dtype=int32))

We can also concatenate multiple tensors, stacking them end-to-end to form a larger one.

We just need to provide a list of tensors and tell the system along which axis to concatenate. The example below shows what happens when we concatenate two matrices along rows (axis 0) instead of columns (axis 1). We can see that the first output’s axis-0 length (6) is the sum of the two input tensors’ axis-0 lengths (3+3); while the second output’s axis-1 length (8) is the sum of the two input tensors’ axis-1 lengths (4 + 4).

In [15]:
X = jnp.arange(12, dtype=jnp.float16).reshape((3, 4))
Y = jax.random.normal(jax.random.PRNGKey(0), (3, 4))

jnp.concatenate((X, Y), axis=0), jnp.concatenate((X, Y), axis=1)

(Array([[ 0.        ,  1.        ,  2.        ,  3.        ],
        [ 4.        ,  5.        ,  6.        ,  7.        ],
        [ 8.        ,  9.        , 10.        , 11.        ],
        [ 1.1901639 , -1.0996888 ,  0.44367844,  0.5984697 ],
        [-0.39189556,  0.69261974,  0.46018356, -2.068578  ],
        [-0.21438177, -0.9898306 , -0.6789304 ,  0.27362573]],      dtype=float32),
 Array([[ 0.        ,  1.        ,  2.        ,  3.        ,  1.1901639 ,
         -1.0996888 ,  0.44367844,  0.5984697 ],
        [ 4.        ,  5.        ,  6.        ,  7.        , -0.39189556,
          0.69261974,  0.46018356, -2.068578  ],
        [ 8.        ,  9.        , 10.        , 11.        , -0.21438177,
         -0.9898306 , -0.6789304 ,  0.27362573]], dtype=float32))

Sometimes, we want to construct a binary tensor via logical statements. Take X == Y as an example. For each position i, j, if X[i, j] and Y[i, j] are equal, then the corresponding entry in the result takes value 1, otherwise it takes value 0.

In [16]:
X == Y

Array([[False, False, False, False],
       [False, False, False, False],
       [False, False, False, False]], dtype=bool)

1. Summing all the elements in the tensor yields a tensor with only one element.
2. Mean all the lements in the tensor.
3. Cumulatively summing the elements in the tensor. 

In [17]:
X.sum()

Array(66., dtype=float16)

In [18]:
X.mean()

Array(5.5, dtype=float16)

In [19]:
X.cumsum()

Array([ 0.,  1.,  3.,  6., 10., 15., 21., 28., 36., 45., 55., 66.],      dtype=float16)

In [20]:
# additional operations
X.std(), X.max(), X.min()

(Array(3.451, dtype=float16),
 Array(11., dtype=float16),
 Array(0., dtype=float16))

### Broadcasting

By now, you know how to perform elementwise binary operations on two tensors of the same shape. Under certain conditions, even when shapes differ, we can still perform elementwise binary operations by invoking the broadcasting mechanism. Broadcasting works according to the following two-step procedure: (i) expand one or both arrays by copying elements along axes with length 1 so that after this transformation, the two tensors have the same shape; (ii) perform an elementwise operation on the resulting arrays.

In [21]:
a = jnp.arange(3).reshape((3, 1))
b = jnp.arange(2).reshape((1, 2))

a, b

(Array([[0],
        [1],
        [2]], dtype=int32),
 Array([[0, 1]], dtype=int32))

Since a and b are $3 * 1$ and $1 * 2$ matrices, respectively, their shapes do not match up. Broadcasting produces a larger $3 * 2$ matrix by replicating matrix a along the columns and matrix b along the rows before adding them elementwise.

In [22]:
a + b

Array([[0, 1],
       [1, 2],
       [2, 3]], dtype=int32)

In [23]:
# or use numpy API
jnp.add(a, b)

Array([[0, 1],
       [1, 2],
       [2, 3]], dtype=int32)

### Saving Memory

Running operations can cause new memory to be allocated to host results. For example, if we write Y = X + Y, we dereference the tensor that Y used to point to and instead point Y at the newly allocated memory. We can demonstrate this issue with Python’s id() function, which gives us the exact address of the referenced object in memory. Note that after we run Y = Y + X, id(Y) points to a different location. That is because Python first evaluates Y + X, allocating new memory for the result and then points Y to this new location in memory.

In [24]:
before = id(Y)
Y = Y + X
id(y) == before

False

This might be undesirable for two reasons. First, we do not want to run around allocating memory unnecessarily all the time. In machine learning, we often have hundreds of megabytes of parameters and update all of them multiple times per second. Whenever possible, we want to perform these updates in place. Second, we might point at the same parameters from multiple variables. If we do not update in place, we must be careful to update all of these references, lest we spring a memory leak or inadvertently refer to stale parameters.

### Conversion to Other Python Objects

In [25]:
A = jax.device_get(X)
B = jax.device_put(A)
type(A), type(B)

(numpy.ndarray, jaxlib.xla_extension.ArrayImpl)

To convert a size-1 tensor to a Python scalar, we can invoke the item function or Python’s built-in functions.

In [26]:
a = jnp.array([3.5])
a, a.item(), jnp.float16(a), jnp.int16(a)

(Array([3.5], dtype=float32),
 3.5,
 Array([3.5], dtype=float16),
 Array([3], dtype=int16))

### Additional Excercises
1. Run the code in this section. Change the conditional statement X == Y to X < Y or X > Y, and then see what kind of tensor you can get.
2. Replace the two tensors that operate by element in the broadcasting mechanism with other shapes, e.g., 3-dimensional tensors. Is the result the same as expected?

In [27]:
# 1
X < Y, X > Y

(Array([[ True, False,  True,  True],
        [False,  True,  True, False],
        [False, False, False,  True]], dtype=bool),
 Array([[False,  True, False, False],
        [ True, False, False,  True],
        [ True,  True,  True, False]], dtype=bool))

In [28]:
# 2
a = jnp.arange(27).reshape((3, 3, 3))
b = jnp.arange(3).reshape((3, 1, 1))
a, b

(Array([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],
 
        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],
 
        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]], dtype=int32),
 Array([[[0]],
 
        [[1]],
 
        [[2]]], dtype=int32))

In [29]:
a + b

Array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[10, 11, 12],
        [13, 14, 15],
        [16, 17, 18]],

       [[20, 21, 22],
        [23, 24, 25],
        [26, 27, 28]]], dtype=int32)