# TP Numerical Algebra and Calculus tools in python

In this TP, we will learn how to use (specific but not captive) linear algebra and calculus libraries in python. The objectives are:
- To be able to translate a set of equations into python code
- To be able to "unroll" for loops by using high order tensor operations
- To be able to optimize a differentiable function by performing gradient descent using automatic differenciation tools
- To be able to spot where the bottleneck is when translating a full algorithm into code


<div class="alert alert-success"> 
    <b>Questions are in green boxes.</b>
The maximum time you should spend on each question is given as indication only. If you take more time than that, then you should come see me.
</div>
<div class="alert alert-info" role="alert"><b>Analyzes are in blue boxes.</b> You should comment on your results in theses boxes (Is it good? Is it expected? Why do we get such result? Why is it different from the previous one? etc)
</div>

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

## 1. Euclidean norm of a vector

Let $\mathbf{x} \in \mathbb{R}^d$ be a vector, we want to compute its squared $\ell_2$ norm $\|\mathbf{x}\|^2 = \sum_i \mathbf{x}[i]^2$.

Let us first create $\mathbf{x}$, with 1024 dimensions

In [None]:
x = np.random.randn(1024)


<div class="alert alert-success"> 
    <b>Q1.</b> Compute $\|\mathbf{x}\|^2$ by converting the sum into a for loop.
</div>

In [None]:
def l2_sqnorm_for(x):
    # your code
    return 

In [None]:
# benchmark the time it takes
%timeit l2_sqnorm_for(x)

Now, let us note that there exists a _sum_ function in numpy (or in jax) and that the _*_ operator on arrays defaults to the element wise multiplication.

In [None]:
xs = jnp.sum(x)

<div class="alert alert-success"> 
    <b>Q2.</b> Write a new $\ell_2$ norm function that performs element wise multiplication followed by using the sum function to avoid any for loop.
</div>

In [None]:
def l2_sqnorm_sum(x):
    # your code
    return 

In [None]:
# benchmark it
%timeit l2_sqnorm_sum(x)

We can also note that the norm of $x$ is also equal to the dot product between $x$ and itself: $\|x\|^2 = \langle x, x\rangle$ which is accessible via the jnp.dot function.

<div class="alert alert-success"> 
    <b>Q3.</b> write a new function using jnp.dot.
</div>

In [None]:
def l2_sqnorm_dot(x):
    # your code
    return 

In [None]:
%timeit l2_sqnorm_dot(x)

## 2. Distance matrix

Let us now consider sets of $n$ vectors of dimension $d$ arranged in a matrix $n \times d$

In [None]:
x = np.random.randn(64, 1024) # n = 64, d = 1024

We want to compute the matrix of size $n \times n$ that contains the squared euclidean distance between every pair of samples: $D_{i,j} = \|x_i - x_j\|^2$.

<div class="alert alert-success"> 
    <b>Q4.</b> Write a baseline function that computes the distance matrix between 2 sets of vector using for loops.
</div>

In [None]:
def sqdist_loop(x1, x2):
    n1 = len(x1)
    n2 = len(x2)
    # your code
    return 

In [None]:
%timeit sqdist_loop(x, x)

Most linear algebra packages in python that emulate numpy support broadcasting. Brodcasting consist in manipulating arrays with different shapes during arithmetic operations. Subject to certain constraints, the smaller array is “broadcast” across the larger array so that they have compatible shapes. 

In our case, $x_1$ and $x_2$ have both size $n \times d$. If we could extend the dimension of $x_1$ such that it has dimension $n\times n \times d$ and is repeated across the second dimension, and similarly extend $x_2$ such that it has size $n \times n \times d$ and is repeated across dimension 1, then $x_1 - x_2$ would be a 3 dimensional array where the position $[i,j,:]$ contains $x_1[i,:]  - x_2[j,:]$

Fortunately, broadcasting avoids us the pain of manually replicating $x_1$ (resp $x_2$) across a new dimension. All we have to do is add a dimension of size 1 and the broadcast will do the replication. We can add a dimension by slicing None, for example:

In [None]:
x[:,None,:].shape

<div class="alert alert-success"> 
    <b>Q5.</b> Write a function that has no for loops and instead uses broadcasting.
</div>

In [None]:
def sqdist_bc(x1, x2):
    # your code
    return 

In [None]:
%timeit sqdist_bc(x,x)

The main problem with brodcasting is that it may lead to gigantic memory consumption. In our case, we need to create the $n \times n \times d$ array in memory which does not scale with $n$ and takes time to allocate.

Instead, we can recall that $\|x_i - x_j\|^2 = \|x_i\|^2 + \|x_j\|^2 - 2\langle x_i, x_j \rangle$. Combining that with broacasting, we can create a matrix $n\times 1$ that contains all the square norms of $x_1$, a matrix $1\times n$ containing the norms of $x_2$ and a matrix $n\times n$ containing the dot product between all possible pairs and just add them all.

<div class="alert alert-success"> 
    <b>Q6.</b>  Write a function using only dot products an no loop.
</div>

In [None]:
def sqdist_dot(x1, x2):
    # your code
    return 

In [None]:
%timeit sqdist_dot(x,x)

## 3. Einsum

There is a tricky but practical function called ``jax.numpy.einsum`` that allows he user to perform summation over abritrary indices provided they have a matching number of dimension. For example, the regular matrix product between matrices A and B is defined as 
$$
C_{ij} = \sum_k A_{ik}B_{kj}
$$
and the corresponding einsum notation is then
```
C = einsum("ik, kj -> ij", A, B)
```
which means that $A$ is indexed (in order) by $i$ and $k$, while $B$ is indexed by $k$ and $j$, and since the index $k$ is common between the 2 but absent from the output, the element wise product followed by the sum is performed on it.

This can be extended to arbitrary number of indices (ex: `einsum("ijkl, mknj -> ilmn",A, B)` will multiply then sum over the common indices $j$ and $k$ and then rearrange dimensions to match the output semantic).

<div class="alert alert-success"> 
    <b>Q7.</b>  Write a function using only calls to einsum.
</div>

In [None]:
def sqdist_einsum(x1, x2):
    # your code
    return 

In [None]:
%timeit sqdist_einsum(x,x)

## 4. Selection and partial updates

In order to benefit also to the maximum of parallel computation, we also have to get rid of all conditional operations (if).

Let us consider an example where we want to set to 0 all elements of an array that are above a specific threshold: $x[i] \leftarrow x[i] \text{ if } x[i] \leq \theta, 0$ else.

<div class="alert alert-success"> 
    <b>Q8.</b> Write a baseline function that uses loop and if.
</div>

In [None]:
def thresh_loop(x, theta):
    n, d = x.shape
    # your code
    return 

In [None]:
%timeit thresh_loop(x, 0.7)

Now, instead of looping, we can use parallel operations to get the same result. Notice that the operation can be performed as the product of 2 arguments: $x[i] \leftarrow x[i]\mathbb{1}_{x[i]\leq \theta}$, with $\mathbb{1}$ the indicator function.

<div class="alert alert-success"> 
    <b>Q9.</b>  Write a function without loop that only uses binary operation and products instead of if.
</div>

In [None]:
def thresh_bin(x, theta):
    # your code
    return 

In [None]:
%timeit thresh_bin(x, 0.7)

## 5. Optimization using gradient descent

Machine learning relies a lot upon numerical optimization and gradient descent is one of the workhorse in that context. Python offers several toolkit with an autograd that allows us to compute the gradient of a function automatically, such as jax.

As an example, we will optimize the following problem: $\max_w w^\top A w/|w\|^2$, with $A$ p.s.d. The solution should correspond to an eigenvector of $A$ (rayleigh quotient).

Let us define our objective function:

In [None]:
def loss(w, A):
    return jnp.sum(w[None,:]@(A@w)/(w[None,:]@w))

And define a value for $A$ and an initial for $w$:

In [None]:
A = x.T@x
w0 = np.random.randn(1024)

<div class="alert alert-success"> 
    <b>Q10.</b> Write a function that returns the gradient of the objective function with respect to its first argument.
</div>

In [None]:
def manual_grad(w, A):
    # your code
    return 

In [None]:
%%time
loss_value = []
w=w0
for ite in range(100):
    w = w + 0.1*manual_grad(w, A)
    loss_value.append(loss(w,A))
plt.plot(loss_value)

Instead, we can use the jax autograd function to make sure we do not do any error in our derivation of the objective function.


In [None]:
jax_grad = jax.grad(loss, argnums=0)

In [None]:
%%time
loss_value = []
w = w0
for ite in range(100):
    w = w + 0.1*jax_grad(w, A)
    loss_value.append(loss(w,A))
plt.plot(loss_value)

We can even use jax to return the value and the gradient:

In [None]:
value_grad = jax.value_and_grad(loss, argnums=0)

In [None]:
%%time
loss_value = []
w = w0
for ite in range(100):
    v, g = value_grad(w, A)
    w = w + 0.1*g
    loss_value.append(v)
plt.plot(loss_value)

But why is it slower than our manual implementation? Well, everytime we call the function, it need to computes the gradient. Instead, we can ask jax to compile only once using annotations:

In [None]:
@jax.jit
def loss(w, A):
    return jnp.sum(w[None,:]@(A@w)/jnp.dot(w,w))

value_grad = jax.value_and_grad(loss, argnums=0)

In [None]:
%%time
loss_value = []
w = w0
for ite in range(100):
    v, g = value_grad(w, A)
    w = w + 0.1*g
    loss_value.append(v)
plt.plot(loss_value)