# Chapter 2: Preliminaries

### Dhuvi Karthikeyan

##### 1/09/2023

**Updated: 01/11/23**: This notebook follows Chapter 2: Preliminaries from the Dive into Deep Learning Text (D2L) and serves as a space for notes for sections that I wanted to keep for future review. It doesn't contain all of the sections and while the names of the sections are slightly altered, I have tried to preserve the original hierarchical structure of the format and have deviated only where more intuitive for my mental model of the course material. This chapter serves as a getting your feet wet with deep learning frameworks a la creating tensors and familiarizing ourselves with their operations. I used this as a chance to test out JAX and compare it against to torch and tensorflow APIs for the same functions and test out their functionalities.

## 2.1 Data Manipulation

In [1]:
# Import the Frameworks

import torch
import tensorflow as tf
import jax
from jax import numpy as jnp

2023-01-11 16:43:36.398663: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-11 16:43:37.013429: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-01-11 16:43:37.817330: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.7/lib64:
2023-01-11 16:43:37.817399: W tensorflow/compiler/xla/stream_executor/

### 2.1.1 Creating Tensors in the Three Frameworks

In [2]:
# PyTorch Instantiate a Tensor
torchx = torch.arange(12, dtype=torch.float32)
torchx

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.])

In [3]:
# TF Instantiate a Tensor
tfx = tf.range(12, dtype=tf.float32)
tfx

2023-01-11 16:43:39.487589: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error


<tf.Tensor: shape=(12,), dtype=float32, numpy=
array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.],
      dtype=float32)>

In [4]:
# JAX Instantiate a Tensor
jaxx = jnp.arange(12)
jaxx

2023-01-11 16:43:39.571701: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=int32)

In [5]:
jnp.zeros((2, 3, 4)) #Also there's a JNP Ones

Array([[[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]], dtype=float32)

In [6]:
tf.ones((2, 3, 4))

<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]],

       [[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float32)>

In [7]:
torch.zeros((2,3,4))

tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]])

In [8]:
tf.random.normal(shape=[3, 4])

<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[ 1.2134943 , -1.1871845 ,  0.18352774, -0.24870522],
       [ 0.85458815,  0.2468478 ,  0.526479  ,  1.4140105 ],
       [-0.14359604, -0.9495299 , -1.1702207 , -1.142411  ]],
      dtype=float32)>

In [9]:
torch.randn(3,4) # Randn is not uniform! It is standard normal w/ mu = 0 and sd = 1

tensor([[-0.8232,  1.1569, -0.2483,  1.8967],
        [-0.7940,  2.3601, -0.5117, -1.4068],
        [ 1.3473,  0.2974,  0.1927,  0.3635]])

In [10]:
# Any call of a random function in JAX requires a key to be
# specified, feeding the same key to a random function will
# always result in the same sample being generated
jax.random.normal(jax.random.PRNGKey(0), (3, 4))

Array([[ 1.1901639 , -1.0996888 ,  0.44367844,  0.5984697 ],
       [-0.39189556,  0.69261974,  0.46018356, -2.068578  ],
       [-0.21438177, -0.9898306 , -0.6789304 ,  0.27362573]],      dtype=float32)

In [11]:
torchY = torch.tensor([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
torchY

tensor([[2, 1, 4, 3],
        [1, 2, 3, 4],
        [4, 3, 2, 1]])

In [12]:
jaxY = jnp.array([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
jaxY

Array([[2, 1, 4, 3],
       [1, 2, 3, 4],
       [4, 3, 2, 1]], dtype=int32)

In [13]:
tfY = tf.constant([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
tfY

<tf.Tensor: shape=(3, 4), dtype=int32, numpy=
array([[2, 1, 4, 3],
       [1, 2, 3, 4],
       [4, 3, 2, 1]], dtype=int32)>

### 2.1.2 Indexing and Slicing

In [14]:
torchx.numel() == jaxx.size  # Number of Elements total (int)

True

In [15]:
tf.size(tfx) #Interesting output

<tf.Tensor: shape=(), dtype=int32, numpy=12>

In [16]:
print(torchx.shape, jaxx.shape, tfx.shape) #All the same function call to get shape

torch.Size([12]) (12,) (12,)


#### Tensor Manipulation

In [17]:
torchX = torchx.reshape(3,4)
torchX

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])

In [1]:
### We can also infer a missing dimension by putting -1

torchx.reshape(-1, 4)

NameError: name 'torchx' is not defined

In [18]:
jaxX = jaxx.reshape(3,4)
jaxX

Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)

In [19]:
tfX = tf.reshape(tfx, (3, 4))
tfX

<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[ 0.,  1.,  2.,  3.],
       [ 4.,  5.,  6.,  7.],
       [ 8.,  9., 10., 11.]], dtype=float32)>

In [20]:
torchX[-1], torchX[1:3] # Indexing works much like numpy for all three frameworks

(tensor([ 8.,  9., 10., 11.]),
 tensor([[ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]]))

In [21]:
jaxX[-1], jaxX[1:3, 0:2] # Notable when only one slice is specified it takes it from dim=0

(Array([ 8,  9, 10, 11], dtype=int32),
 Array([[4, 5],
        [8, 9]], dtype=int32))

**Tensor Mutability varies across the frameworks**

In [22]:
torchX[:2, :] = 12
torchX

tensor([[12., 12., 12., 12.],
        [12., 12., 12., 12.],
        [ 8.,  9., 10., 11.]])

In [23]:
tfX[1:3, :] = 12 #Tensors are immutable in TF and JAX
tfX

TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment

In [24]:
X_var = tf.Variable(tfX)
X_var[1, 2].assign(9)
X_var

<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[ 0.,  1.,  2.,  3.],
       [ 4.,  5.,  9.,  7.],
       [ 8.,  9., 10., 11.]], dtype=float32)>

In [25]:
# JAX arrays are immutable. `jax.numpy.ndarray.at` index
# update operators create a new array with the corresponding
# modifications made
X_new_1 = jaxX.at[1, 2].set(17)
X_new_1

Array([[ 0,  1,  2,  3],
       [ 4,  5, 17,  7],
       [ 8,  9, 10, 11]], dtype=int32)

### 2.1.3 Tensor Ops

In [26]:
torch.exp(torchx) #Unary vs (binary) operators take map R -> R also analogous in JNP.exp and tf.exp

tensor([162754.7969, 162754.7969, 162754.7969, 162754.7969, 162754.7969,
        162754.7969, 162754.7969, 162754.7969,   2980.9580,   8103.0840,
         22026.4648,  59874.1406])

In [27]:
torch.cat((torchX, torchY), dim=0), torch.cat((torchX, torchY), dim=1)

(tensor([[12., 12., 12., 12.],
         [12., 12., 12., 12.],
         [ 8.,  9., 10., 11.],
         [ 2.,  1.,  4.,  3.],
         [ 1.,  2.,  3.,  4.],
         [ 4.,  3.,  2.,  1.]]),
 tensor([[12., 12., 12., 12.,  2.,  1.,  4.,  3.],
         [12., 12., 12., 12.,  1.,  2.,  3.,  4.],
         [ 8.,  9., 10., 11.,  4.,  3.,  2.,  1.]]))

In [28]:
jnp.concatenate((jaxX, jaxY), axis=0), jnp.concatenate((jaxX, jaxY), axis=1)

(Array([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [ 2,  1,  4,  3],
        [ 1,  2,  3,  4],
        [ 4,  3,  2,  1]], dtype=int32),
 Array([[ 0,  1,  2,  3,  2,  1,  4,  3],
        [ 4,  5,  6,  7,  1,  2,  3,  4],
        [ 8,  9, 10, 11,  4,  3,  2,  1]], dtype=int32))

In [29]:
tf.concat([tfX, tfX], axis=0), tf.concat([tfX, tfX], axis=1) #tfY was dtype int32 and refused to concat

(<tf.Tensor: shape=(6, 4), dtype=float32, numpy=
 array([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]], dtype=float32)>,
 <tf.Tensor: shape=(3, 8), dtype=float32, numpy=
 array([[ 0.,  1.,  2.,  3.,  0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.,  8.,  9., 10., 11.]], dtype=float32)>)

In [30]:
torchX, torchX.sum(dim=0), torchX.sum(dim=1), torchX.sum() #Reduce Sum

(tensor([[12., 12., 12., 12.],
         [12., 12., 12., 12.],
         [ 8.,  9., 10., 11.]]),
 tensor([32., 33., 34., 35.]),
 tensor([48., 48., 38.]),
 tensor(134.))

In [31]:
jaxX, jaxX.sum(axis=0), jaxX.sum(axis=1), jaxX.sum()

(Array([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]], dtype=int32),
 Array([12, 15, 18, 21], dtype=int32),
 Array([ 6, 22, 38], dtype=int32),
 Array(66, dtype=int32))

In [32]:
tfX, tf.reduce_sum(tfX, axis=0), tf.reduce_sum(tfX, axis=1), tf.reduce_sum(tfX)

(<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
 array([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]], dtype=float32)>,
 <tf.Tensor: shape=(4,), dtype=float32, numpy=array([12., 15., 18., 21.], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 6., 22., 38.], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=66.0>)

### 2.1.4 Broadcasting

Broadcasting works the same way as in NumPy where if two arrays are of different dimension, try expanding one or both arrays by duplicating elements along axes with length 1 so they have the same shape at the end.

In [33]:
a = jnp.arange(3).reshape((3, 1))
b = jnp.arange(2).reshape((1, 2))
a, b

(Array([[0],
        [1],
        [2]], dtype=int32),
 Array([[0, 1]], dtype=int32))

In [34]:
a + b

Array([[0, 1],
       [1, 2],
       [2, 3]], dtype=int32)

### 2.1.5 Saving Memory (In-Place Operations)

In [35]:
print(id(torchY))   # Can be a nuisance if updates don't happen in place and old params are used
torchY = torchY + torchX
print("After the addition", id(torchY))

140173843992464
After the addition 140173445156592


In [36]:
torchZ = torch.zeros_like(torchY)
print('id(Z):', id(torchZ))
torchZ[:] = torchX + torchY #TF has the tfT.assign(operation)
print('id(Z):', id(torchZ))

id(Z): 140173445160832
id(Z): 140173445160832


In [37]:
# JAX arrays do not allow in-place operations

### 2.1.6 Conversions to Other Python Objects

This is handled inplace conveniently for us by the backend

In [38]:
torchX.numpy() == tfX.numpy()

array([[False, False, False, False],
       [False, False, False, False],
       [ True,  True,  True,  True]])

In [39]:
type(jax.device_get(jaxX)) #The opposite is device_put

numpy.ndarray

## 2.3 Linear Algebra

### 2.3.3 Matrices

In [40]:
torchY.T #Transpose

tensor([[14., 13., 12.],
        [13., 14., 12.],
        [16., 15., 12.],
        [15., 16., 12.]])

In [41]:
jaxY.T

Array([[2, 1, 4],
       [1, 2, 3],
       [4, 3, 2],
       [3, 4, 1]], dtype=int32)

In [42]:
tf.transpose(tfY)

<tf.Tensor: shape=(4, 3), dtype=int32, numpy=
array([[2, 1, 4],
       [1, 2, 3],
       [4, 3, 2],
       [3, 4, 1]], dtype=int32)>

### 2.3.4 Tensor

In [43]:
torch3d = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4)
torch3d

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

In [44]:
jax3d = jnp.arange(24).reshape(2, 3, 4)
jax3d

Array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]], dtype=int32)

In [45]:
tf3d = tf.reshape(tf.range(24, dtype=tf.float32), (2,3,4))
tf3d

<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]],

       [[12., 13., 14., 15.],
        [16., 17., 18., 19.],
        [20., 21., 22., 23.]]], dtype=float32)>

### 2.3.5 Basic Properties of Tensor Arithmetic

In [46]:
torch3d * torch3d # Hadamard Product or Element-wise product (*) is used across frameworks

tensor([[[  0.,   1.,   4.,   9.],
         [ 16.,  25.,  36.,  49.],
         [ 64.,  81., 100., 121.]],

        [[144., 169., 196., 225.],
         [256., 289., 324., 361.],
         [400., 441., 484., 529.]]])

In [47]:
jax3d + jax3d  # Elementwise operator (+) is the same across frameworks

Array([[[ 0,  2,  4,  6],
        [ 8, 10, 12, 14],
        [16, 18, 20, 22]],

       [[24, 26, 28, 30],
        [32, 34, 36, 38],
        [40, 42, 44, 46]]], dtype=int32)

### 2.3.6 Reduction

In [48]:
torch3d.mean(), jax3d.mean(axis=0), tf.reduce_mean(tf3d, axis=1) #Finding the mean

(tensor(11.5000),
 Array([[ 6.,  7.,  8.,  9.],
        [10., 11., 12., 13.],
        [14., 15., 16., 17.]], dtype=float32),
 <tf.Tensor: shape=(2, 4), dtype=float32, numpy=
 array([[ 4.,  5.,  6.,  7.],
        [16., 17., 18., 19.]], dtype=float32)>)

### 2.3.7 Non-Reduction Sum

In [49]:
torch3d.sum(dim=0, keepdims=True), torch3d.sum(dim=1, keepdims=True).shape 

(tensor([[[12., 14., 16., 18.],
          [20., 22., 24., 26.],
          [28., 30., 32., 34.]]]),
 torch.Size([2, 1, 4]))

In [50]:
torch3d

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

In [51]:
torch3d.sum(dim=0), torch3d.sum(dim=1), torch3d.sum(dim=2) #Adding across dim0, dim1, dim2

(tensor([[12., 14., 16., 18.],
         [20., 22., 24., 26.],
         [28., 30., 32., 34.]]),
 tensor([[12., 15., 18., 21.],
         [48., 51., 54., 57.]]),
 tensor([[ 6., 22., 38.],
         [54., 70., 86.]]))

In [52]:
torch3d.cumsum(dim=0), torch3d.cumsum(dim=1), torch3d.cumsum(dim=2) #CUMSUM across dim0, dim1, dim2 same in JAX, tf.cumsum 

(tensor([[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.]],
 
         [[12., 14., 16., 18.],
          [20., 22., 24., 26.],
          [28., 30., 32., 34.]]]),
 tensor([[[ 0.,  1.,  2.,  3.],
          [ 4.,  6.,  8., 10.],
          [12., 15., 18., 21.]],
 
         [[12., 13., 14., 15.],
          [28., 30., 32., 34.],
          [48., 51., 54., 57.]]]),
 tensor([[[ 0.,  1.,  3.,  6.],
          [ 4.,  9., 15., 22.],
          [ 8., 17., 27., 38.]],
 
         [[12., 25., 39., 54.],
          [16., 33., 51., 70.],
          [20., 41., 63., 86.]]]))

### 2.3.8 Dot Products

torch.dot(torch.ones(3, dtype=torch.float32), torch.zeros(3, dtype=torch.float32))

jnp.dot(jnp.ones(3, dtype=jnp.float32), jnp.zeros(3, dtype=jnp.float32))

#tf.tensordot(tf.ones(3, dtype=tf.float32), tf.zeros(3, dtype=tf.float32), axes=1) #CUBLAS error

### 2.3.9 Matrix-Vector Products

In [54]:
# Torch is A@x
# Jax is jnp.matmul(Z, x)
# TF is separated tf.matvec and tf.matmul

torchX@torch.ones(4, dtype=torch.float32)

tensor([48., 48., 38.])

### 2.3.10 Matrix-Matrix Products

In [None]:
torchX@torch.ones(3,4, dtype=torch.float32)

### 2.3.11 Norms

Norms all have the property of mapping vectors to scalars and have the following properties:

1. $ ||ax|| = |a|||x|| $
2. $ || x + y || \leq ||x|| + ||y||$
3. $ ||x|| >= 0 $ 

Lp Norm:

$$ ||x||_p = (\sum_{i=1}^n|x_i|^p)^{1/p} $$

Frobenius Norm:

$$||X||_F = \sqrt{\sum_i \sum_j x_{ij}^2} $$

In [55]:
### Norms

#L2 Norm = Euclidean Norm = Vector version of the Frobenius Norm
#L1 Norm = Manhattan Distance = Sum of the Absolute Values of Vector Elements

#Norms are subjectively easier to compute in JAX due to specification of order to wrapping func

In [56]:
# Torch and TF Norms

#torch.norm(v) and tf.norm(u) are the Frobenius Norm also applies to the vectors (l2)
#torch.abs(v).sum() and tf.reduce_sum(tf.abs(u)) is the L2 norm or the manhattan distance


In [57]:
# JAX 

jnp.linalg.norm(jnp.ones((4,9)), ord=1)

Array(4., dtype=float32)

In [58]:
jnp.linalg.norm(jnp.ones((4,9)), ord=2), jnp.linalg.norm(jnp.ones((4,9)))

(Array(5.9999995, dtype=float32), Array(6., dtype=float32))

## 2.4 Calculus Refresher

### 2.4.3 Partial Derivatives and Gradients

* Gradient is the vector of partial derivatives
* For all $A \in \mathbb{R^{m x n}} \nabla Ax = A^T$
* For all $A \in \mathbb{R^{m x n}} \nabla x^TA = A$
* For all square $A \in \mathbb{R^{n x n}} \nabla x^TAx = (A+A^T)x$
* $\nabla ||x||^2 = \nabla x^Tx = 2x$, also holds for Frobenius norm of a vector

### 2.4.4 Chain Rule

$$ y = f(\mathbf{u}) \text{ where} u \in \mathbb{R}^{m}$$
$$ \mathbf{u} = g(\mathbf{x}) \text{ where} x \in \mathbb{R}^{n}$$ 
$$ \frac{\partial y}{\partial x_i} = \frac{\partial y}{\partial u_1}\frac{\partial u_1}{\partial x_i} + \frac{\partial y}{\partial u_2}\frac{\partial u_2}{\partial x_i} ... + \frac{\partial y}{\partial u_m}\frac{\partial u_m}{\partial x_i}$$

This is the same as stating succinctly that there exists an $\mathbf{A} \in \mathbb{R}^{n x m}$:

$$ \nabla_x y = \mathbf{A}\nabla_u y$$

## 2.5 Automatic Differentiation

### 2.5.1 A Simple Function

Auto calculate the gradients for the simple function: $y = 2x^2$

In [59]:
# Torch

torchz = torch.arange(4.0, dtype=torch.float32, requires_grad=True)
torchz.grad #is None because backwards hasn't been called
torchu = 2 * torch.dot(torchz, torchz)
torchu

tensor(28., grad_fn=<MulBackward0>)

In [60]:
torchu.backward()
torchz.grad

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


tensor([ 0.,  4.,  8., 12.])

In [61]:
# To calculate another function of x we need to first clear the gradients of torchz or clear the gradients

torchz.grad.zero_()
torchu = torchz.sum()
torchu.backward()
torchz.grad

tensor([1., 1., 1., 1.])

In [62]:
# JAX

jaxz = jnp.arange(4.0) # Defaults to gradients turned on?
jaxu = lambda x: 2 * jnp.dot(x,x) # Jax gradients are passed through the function
jaxu(jaxz)

Array(28., dtype=float32)

In [63]:
from jax import grad

jaxz_grad = grad(jaxu)(jaxz)
jaxz_grad

Array([ 0.,  4.,  8., 12.], dtype=float32)

In [64]:
# Unlike Torch the gradients don't need to be cleared so running the following will give out the correct values

jaxu = lambda x: x.sum()
jaxz_grad = grad(jaxu)(jaxz)
jaxz_grad

Array([1., 1., 1., 1.], dtype=float32)

In [65]:
# Tensorflow

tfz = tf.range(4, dtype=tf.float32)
tfz = tf.Variable(tfz)                      #This code is necessary to store the grads as they are computed Container class
with tf.GradientTape() as t:
    tfu = 2 * tf.tensordot(tfz, tfz, axes=1) # Gives the tensor with scalar value 28.0
t.gradient(tfu, tfz)

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 0.,  4.,  8., 12.], dtype=float32)>

### 2.5.2 Backward for Non-Scalar Variables

https://github.com/yang-zhang/yang-zhang.github.io/blob/master/ds_code/pytorch-backward-gradient-examples.ipynb

The jacobian is a generalization of the gradient. The gradient is the vector of derivatives of a scalar valued output over a vector valued input. So for computing the loss, the gradient is what is being calculated. The jacobian on the other hand is what is used when an input of vectors maps to an output of vectors. It is the matrix of partial derivatives where the number of columns is equal to the dimension of the output and the number of rows is equal to the dimension of the input. Jacobians are seldom used in ML, instead using the sum of gradients w.r.t to a partical dimension of y.

In [66]:
#Torch
torchz.grad.zero_()
torchu = torchz * torchz
torchu.backward(gradient=torch.ones(len(torchu)))  # Faster: y.sum().backward()
torchz.grad

tensor([0., 2., 4., 6.])

In [67]:
#JAX
ajax = lambda x: x * x
# `grad` is only defined for scalar output functions
grad(lambda x: ajax(x).sum())(jaxz)

Array([0., 2., 4., 6.], dtype=float32)

In [68]:
#TF
with tf.GradientTape() as t:
    tfu = tfz * tfz
t.gradient(tfu, tfz)

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([0., 2., 4., 6.], dtype=float32)>

### 2.5.3 Detaching from Computational Graph

In [69]:
{
    'Torch': torchY.detach(),
    'TF': tf.stop_gradient(tfY),
    'JAX':jax.lax.stop_gradient(jaxY)
}

{'Torch': tensor([[14., 13., 16., 15.],
         [13., 14., 15., 16.],
         [12., 12., 12., 12.]]),
 'TF': <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
 array([[2, 1, 4, 3],
        [1, 2, 3, 4],
        [4, 3, 2, 1]], dtype=int32)>,
 'JAX': Array([[2, 1, 4, 3],
        [1, 2, 3, 4],
        [4, 3, 2, 1]], dtype=int32)}

### 2.5.4 Gradients and Python Control Flow

Nontrivial point but gradients are able to be calculated various if conditions as long as there is a linear (smooth) flow of I/O through function transformations.

## 2.6 Probability and Statistics

Conditional Independence: $P(A,B|C) = P(A|C) P(B|C)$

$$ E[X+Y] = E[X] + E[Y] $$
$$ Var[X] = E[(X-E[X])^2] = E[X^2] - E[X]^2 $$
$$ Var[aX] = a^2Var[X] $$
$$ Var[aX + bY] = a^2Var[X] + b^2Var[Y] + 2abCov(X,Y)$$

Covariance matrix is calculated as: $E[(x-\mu)(x-\mu)^T]$

* Aleatoric Uncertainty: Uncertainty intrinsic to the problem
* Epistemic Uncertainty: Uncertainty over model params

Chebyshev Inequality: 

$$ P(|X-\mu| \geq k\sigma) \leq \frac{1}{k^2}$$