In [5]:
import jax
import jaxlib
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp

## PyHEP 2020 -- Tutorial on Automatic Differentiation


Welcome to this tutorial on automatic differentiation





## Short Preface on Linear Transformations

Before we start, let's first look at *linear transformations** from ℝᵐ → ℝⁿ:
$$y(x) = Ax$$

With a given basis, this is representable as a (rectangular0 matrix: 
$$y_i(x) = A_{ij}x_j$$


For a given linear problem, there are few ways we can run this computation


1. **full matrix computation**

   i.e. we store the full (dense) $nm$ elements of the rectangular matrix and 
   compute an explicit matrix multiplication.
   
   The computation can be fully generic for any matrix
   
```python
    def result(matrix, vector):
        return np.matmul(matrix,vector)
```
<br/>

2. **sparse matrix computation**

   If many $A_ij=0$, it might be wasteful to expend memory on them. We can just 
   create a sparse matrix, by
   
   * storing only the non-zerro elements 
   * storing a look-up table, where those elements are in the matrix
   
   The computation can be kept general

```python
    def result(sparse_matrix, vector):
        return sparse_matmul(sparse_matrix,vector)
```

<br/>
   
3. **matrix-free computation**

    In many cases a linear program is not explicitly given by a Matrix, but it's
    given as *code* / a "black-box" function. As long as the computation in the body of 
    keeps to (hard-coded) linear transformation the program is linear. The matrix elements
    are no longer explicitly enumerated and stored in a data structure
    but implicitly defined in the source code.
    
    This is not anymore a generic computation, but each linear transformation is its own
    program. At the same time this is also the most memory efficient  representation. No
    lookup table is needed since all constants are hard-coded.
    
    
```python
    def linear_program(vector):
        z1,z2 = 0,0
        z1 += A_11*x1
        z2 += A_12*x2
        z2 += A_22*x2
        return [z1,z2]
```






### Recovering Matrix Elements from matrix-free computations


#### Matrix-vector products

In the matrix-free setting, the program does not give access to the matrix elements,
but only computes "matrix-vector" products (MVP)

We can use basis vectors to recover the matrix **one column at a time**

<img src="./assets/mvp.png" alt="A Matrix Vector Product" width="600"/>



In [6]:
def matrix_vector_product(x):
    x1,x2,x3 = x
    z1,z2 = 0,0
    z1 += 2*x1  #MVP statement 1
    z2 += 1*x2  #MVP statement 2
    z2 += 3*x3  #MVP statement 3
    return np.asarray([z1,z2])

M = np.concatenate([
    matrix_vector_product(np.asarray([1,0,0])).reshape(-1,1),
    matrix_vector_product(np.asarray([0,1,0])).reshape(-1,1),
    matrix_vector_product(np.asarray([0,0,1])).reshape(-1,1),
],axis=1)
print(f'M derived from matrix-vector products:\n{M}')

M derived from matrix-vector products:
[[2 0 0]
 [0 1 3]]


#### Vector Matrix product (VMP)

The same matrix induces a "dual" linear map: ℝⁿ → ℝᵐ 
$$ x_k = y_i A_{ik}$$

i.e. instead of a Matrix-vector product it's now a *vector-Matrix* product (VMP)

If one has access to a "vector-Matrix" program corresponding to a matrix $A$ one
can again -- as in the MVP-case -- recover the matrix elements, by feeding in basis vectors.

This time the matrix is built **one row at a time**


<img src="./assets/vmp.png" alt="A Matrix Vector Product" width="600"/>

In [9]:
def vector_matrix_product(z):
    x1,x2,x3 = 0,0,0
    z1,z2 = z

    x3 += z2*3 #VMP version of statement 3
    x2 += z2*1 #VMP version of statement 2
    x1 += z1*2 #VMP version of statement 1

    return np.asarray([x1,x2,x3])


M = np.concatenate([
    vector_matrix_product(np.asarray([1,0])).reshape(1,-1),
    vector_matrix_product(np.asarray([0,1])).reshape(1,-1),
],axis=0)
print(f'M derived from vector-matix products:\n{M}')

M derived from vector-matix products:
[[2 0 0]
 [0 1 3]]


#### Short Recap:

For a given linear transformation, characterized by a matrix $A_{ij}$ we have a forward (matrix-vector) and backward (vector-matrix) map $$y_i = A_{ij}x_k$$ $$x_j = y_i A_{ij}$$

and we can use either to recover the full matrix $A_{ij}$

## Wide versus Tall Transformation

If you look at the code above, you'll notice that the number of calls necessary to the MVP or VMP program
is related to the dimensions of matrix itself.

For a $n\times m$ matrix (for a map: ℝᵐ → ℝⁿ), you need as $m$ calls to the "Matrix-vector" program to 
built the full matrix one-column-at-a-time. Likewise you need $n$ calls to the "vector-Matrix" program
to build the matrix one-row-at-a-time.

This becomes relevant for very asymmetric maps: e.g. scalar maps from very high-dimensional spaces
$\mathbb{R}^{10000} \to \mathbb{R}$ the "vector-Matrix" appraoch is *vastly* more efficient than the
"Matrix-vector one.

Similarly, functions mapping few variables into very high dimensional spaces $\mathbb{R} \to \mathbb{R}^{10000}$
it's the opposite: the "Matrix-vector" approach is much better suited than the "vector-Matrix" one.




In [10]:
A = np.random.randint(0,10,size = (2,3))
A

array([[2, 4, 9],
       [6, 6, 9]])