**Update - 23rd Dec, 2021**

We have completed the TF-JAX tutorials series. 10 notebooks that covers every fundamental aspect of both TensorFlow and JAX. Here are the links to the notebooks along with the Github repo details:

### TensorFlow Notebooks:

* [TF_JAX_Tutorials - Part 1](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part1)
* [TF_JAX_Tutorials - Part 2](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part2)
* [TF_JAX_Tutorials - Part 3](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part3)

### JAX Notebooks:

* [TF_JAX_Tutorials - Part 4 (JAX and DeviceArray)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-4-jax-and-devicearray)
* [TF_JAX_Tutorials - Part 5 (Pure Functions in JAX)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-5-pure-functions-in-jax/)
* [TF_JAX_Tutorials - Part 6 (PRNG in JAX)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-6-prng-in-jax/)
* [TF_JAX_Tutorials - Part 7 (JIT in JAX)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-7-jit-in-jax)
* [TF_JAX_Tutorials - Part 8 (Vmap and Pmap)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-8-vmap-pmap)
* [TF_JAX_Tutorials - Part 9 (Autodiff in JAX)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-9-autodiff-in-jax)
* [TF_JAX_Tutorials - Part 10 (Pytrees in JAX)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-10-pytrees-in-jax)

### Github Repo with all notebooks in one place
https://github.com/AakashKumarNain/TF_JAX_tutorials

---

<img src="https://i.ytimg.com/vi/yjprpOoH5c8/maxresdefault.jpg" width="300" height="300" align="center"/>

Welcome to another TensorFlow/JAX tutorial. This is the third tutorial in this series. If you haven't looked at the previous tutorials,
I would highly recommend checking them out.

1. [TF-JAX Tutorials - Part 1](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part1)
2. [TF-JAX Tutorials - Part 2](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part2)

**Note** The tutorials are in the following format:
1. TF Fundamentals (2-3 notebooks)
2. JAX Fundamentals (2-3 notebooks)
3. Advanced TF (2-3 notebooks)
4. Advanced JAX (2-3 notebooks)


Today we will be taking a deep dive into a very important topic: **`Gradients`**

`Automatic Differentiation` and `Gradients` are so important concepts that they deserve a few dedicated chapters. Understanding every bit of it isn't necessary but the more you dive into it, the more you will appreciate the beauty of it. I am planning to do an `advanced` tutorial on these topics if there is enough interest from the readers. Do let me know in the comments section what you think

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

seed=1234
np.random.seed(seed)
tf.random.set_seed(seed)
%config IPCompleter.use_jedi = False

## Automatic Differentiation and Gradients

Let's say you apply a sequence of operations on an input in a *forward* pass. To differentiate automatically, you need some sort of mechanism to 
figure out:
1. What operations were applied in the forward pass?
2. What was the order in which the operations were applied?

For autodiff, you need to remember the above two. Different frameworks can implement the same idea in different ways but the fundamentals remain the same.


### Gradients in TensorFlow

TensorFlow provides the `tf.GradientTape` API for automatic differentiation. Any relevant operation executed inside the context of `GradientTape` gets recorded for gradients computation. To compute gradients, you need to do the following:

1. Record operations inside the `tf.GradientTape` context
2. Compute the gradients using `GradientTape.gradient(target, sources)`

Let's code up a few examples for this.

In [2]:
# We will initialize a few Variables here

x = tf.Variable(3.0)
y = tf.Variable(4.0)

print(f"Variable x: {x}")
print(f"Is x trainable?: {x.trainable}")
print(f"\nVariable y: {y}")
print(f"Is y trainable?: {y.trainable}")

Variable x: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>
Is x trainable?: True

Variable y: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>
Is y trainable?: True


We will do a simple operation here: **`z = x * y`**. And we will calculate the gradients of `z` wrt `x` and `y` . We are taking a simple example so that the readers can easily verify that things are working as expected. We will work on complex examples in a bit

In [3]:
# Remember we need to execute the operations inside the context
# of GradientTape so that we can record them

with tf.GradientTape() as tape:
    z = x * y
    
dx, dy = tape.gradient(z, [x, y])

print(f"Input Variable x: {x.numpy()}")
print(f"Input Variable y: {y.numpy()}")
print(f"Output z: {z}\n")

# dz / dx
print(f"Gradient of z wrt x: {dx}")

# dz / dy
print(f"Gradient of z wrt y: {dy}")

Input Variable x: 3.0
Input Variable y: 4.0
Output z: 12.0

Gradient of z wrt x: 4.0
Gradient of z wrt y: 3.0


In [4]:
# Remember we need to execute the operations inside the context
# of GradientTape so that we can record them

with tf.GradientTape() as tape:
    z = 2*tf.pow(x,2) +5*tf.pow(y,3)
    
dx, dy = tape.gradient(z, [x, y])

print(f"Input Variable x: {x.numpy()}")
print(f"Input Variable y: {y.numpy()}")
print(f"Output z: {z}\n")

# dz / dx = 4*x=12
print(f"Gradient of z wrt x: {dx}")

# dz / dy =15*y^2=240
print(f"Gradient of z wrt y: {dy}")

Input Variable x: 3.0
Input Variable y: 4.0
Output z: 338.0

Gradient of z wrt x: 12.0
Gradient of z wrt y: 240.0


Easy enough! Similarly, you can calculate gradients of many many variables wrt to some computation say `loss` by just passing all the trainable variables involved in that computation in a nested way (can be a list or dictionary for example). The returned gradients will follow the same nested structure in which the inputs are passed to the tape.

What happens if we calculate the gradients in the above code wrt x and y separately?

In [5]:
with tf.GradientTape() as tape:
    z = x * y

try:
    dx = tape.gradient(z, x)
    dy = tape.gradient(z, y)

    print(f"Gradient of z wrt x: {dx}")
    print(f"Gradient of z wrt y: {dy}")
except Exception as ex:
    print("ERROR! ERROR! ERROR!\n")
    print(type(ex).__name__, ex)

ERROR! ERROR! ERROR!

RuntimeError A non-persistent GradientTape can only be used tocompute one set of gradients (or jacobians)


**What just happened?**<br>
As soon as the `GradientTape.gradient(...)` is called, all the resources held by a `GradientTape` are released. So, if you computed the `gradient` once, then you won't be able to call it again.

**What's the solution then?**<br>
The solution is to use set the `persistent` argument to `True`. This allows multiple calls to the gradient() method as resources are released when the tape object is garbage collected. Let's try the above example again

In [6]:
# Set the persistent argument
with tf.GradientTape(persistent=True) as tape:
    z = x * y

try:
    dx = tape.gradient(z, x)
    dy = tape.gradient(z, y)

    print(f"Gradient of z wrt x: {dx}")
    print(f"Gradient of z wrt y: {dy}")
except Exception as ex:
    print("ERROR! ERROR! ERROR!\n")
    print(type(ex).__name__, ex)

Gradient of z wrt x: 4.0
Gradient of z wrt y: 3.0


In [7]:
# What if one of the Variables is non-trainable?
# Let's make y non-trainable in the above example and run
# the computation again

x = tf.Variable(3.0)
y = tf.Variable(4.0, trainable=False)

with tf.GradientTape() as tape:
    z = x * y
    
dx, dy = tape.gradient(z, [x, y])

print(f"Variable x: {x}")
print(f"Is x trainable?: {x.trainable}")
print(f"\nVariable y: {y}")
print(f"Is y trainable?: {y.trainable}\n")

print(f"Gradient of z wrt x: {dx}")
print(f"Gradient of z wrt y: {dy}")

Variable x: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>
Is x trainable?: True

Variable y: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>
Is y trainable?: False

Gradient of z wrt x: 4.0
Gradient of z wrt y: None


**Note:** An important point to remember is that you should never mix the `topology` of the `dtypes` for AD and computing the graidents. When I say `topology`, I mean don't mix `float`, `int`, `string` types. In fact, you can't take a gradient for any op that has a dtype of `int` or `string`. Let us take an example to make this clear

In [8]:
# Note the dtypes

x = tf.Variable(3.0, dtype=tf.float32)
y = tf.Variable(4, dtype=tf.int32)

with tf.GradientTape() as tape:
    z = x * tf.cast(y, x.dtype)
    
dx, dy = tape.gradient(z, [x, y])

print(f"Input Variable x: {x}")
print(f"Input Variable y: {y}")
print(f"Output z: {z}\n")

print(f"Gradient of z wrt x: {dx}")
print(f"Gradient of z wrt y: {dy}")

Input Variable x: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>
Input Variable y: <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=4>
Output z: 12.0

Gradient of z wrt x: 4.0
Gradient of z wrt y: None


In [9]:
# There is no gradient flow defined for int and string types

x = tf.Variable(3, dtype=tf.int32)
y = tf.Variable(4, dtype=tf.int32)

with tf.GradientTape() as tape:
    z = x * y
    
dx, dy = tape.gradient(z, [x, y])

print(f"Input Variable x: {x}")
print(f"Input Variable y: {y}")
print(f"Output z: {z}\n")

print(f"Gradient of z wrt x: {dx}")
print(f"Gradient of z wrt y: {dy}")

Input Variable x: <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=3>
Input Variable y: <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=4>
Output z: 12

Gradient of z wrt x: None
Gradient of z wrt y: None


You would say that your code is correct, which in some sense is, but you will get `None` as the gradient value for the variables where the data types of the source and the target are different. Before we move to a very important concept, let's summarize all the things we have learned so far:

1. `tf.GradientTape` is the API for doing AD in TensorFlow
2. For computing gradients using Tape, we need:
   * Record the relevant operations inside the context of Tape
   * Compute the gradients by calling the `GradientTape.gradient(...)` method
3. If you wish to call the `gradient(...)` method multiple times, make sure to set the `persistent` argument to `GradientTape`
4. If any `non-trainable` variable is involved in the computation, then the gradient wrt to that variable would be `None`
5. Mixing `dtype` topology is a blunder. Your code will run but will fail silently!


### Fine-gain control

A few questions that comes to mind naturally after seeing the above examples:

1. How to access all the objects that are being watched?
2. How to stop the flow of a gradient through a specific Variable/path?
3. What if you don't want to watch all the variables inside the `GradientTape` context?
4. What if you want to watch something that isn't inside the context?

We will take a few examples for each of the above case to understand it in a better way.

#### 1. Accessing all the watched objects

In [10]:
x = tf.Variable(3.0, name="x")
y = tf.Variable(4.0, name="y")
t = tf.Variable(tf.random.normal(shape=(2, 2)), name="t")
print(t)
with tf.GradientTape() as tape:
    z = x * y +t

print("Tape is watching all of these:")
for var in tape.watched_variables():
    print(f"{var.name} and it's value is {var.numpy()}")

<tf.Variable 't:0' shape=(2, 2) dtype=float32, numpy=
array([[ 0.8369314 , -0.7342977 ],
       [ 1.0402943 ,  0.04035992]], dtype=float32)>
Tape is watching all of these:
x:0 and it's value is 3.0
y:0 and it's value is 4.0
t:0 and it's value is [[ 0.8369314  -0.7342977 ]
 [ 1.0402943   0.04035992]]


#### 2. Stopping the gradients

In [11]:
# The ugly way

x = tf.Variable(3.0, name="x")
y = tf.Variable(4.0, name="y")

with tf.GradientTape(persistent=True) as tape:
    z = x * y
    
    # Stop the grasdient flow
    with tape.stop_recording():
        zz = x*x + y*y

dz_dx, dz_dy = tape.gradient(z, [x, y])
dzz_dx, dzz_dy = tape.gradient(zz, [x, y])

print(f"Gradient of z wrt x: {dz_dx}")
print(f"Gradient of z wrt y: {dz_dy}\n")
print(f"Gradient of zz wrt x: {dzz_dx}")
print(f"Gradient of zz wrt y: {dzz_dy}")

Gradient of z wrt x: 4.0
Gradient of z wrt y: 3.0

Gradient of zz wrt x: None
Gradient of zz wrt y: None


A better way to stop gradient flow is to use `tf.stop_gradient(...)`. Why?
1. Doesn't require access to tape
2. Clean with much better semantics 

In [12]:
# The better way!

x = tf.Variable(3.0, name="x")
y = tf.Variable(4.0, name="y")

with tf.GradientTape() as tape:
    z = x * tf.stop_gradient(y)

dz_dx, dz_dy = tape.gradient(z, [x, y])
print(f"Gradient of z wrt x: {dz_dx}")
print(f"Gradient of z wrt y: {dz_dy}")

Gradient of z wrt x: 4.0
Gradient of z wrt y: None


#### 3. Select what to watch inside the context?

By default `GradientTape` will automatically watch any trainable variables that are accessed inside the context but if you want gradients for selected variables only, then you can disable automatic tracking by passing `watch_accessed_variables=False` to the tape constructor

In [13]:
# Both variables are trainable
x = tf.Variable(3.0, name="x")
y = tf.Variable(4.0, name="y")

# Telling the tape: Hey! I will tell you what to record.
# Don't start recording automatically!
with tf.GradientTape(watch_accessed_variables=False) as tape:
    # Watch x but not y
    tape.watch(y)
    z = x * y
dz_dx, dz_dy = tape.gradient(z, [x, y])


print(f"Gradient of z wrt x: {dz_dx}")
print(f"Gradient of z wrt y: {dz_dy}")

Gradient of z wrt x: None
Gradient of z wrt y: 3.0


In [14]:
# What if something that you wanted to watch,
# wasn't present in the computation done inside 
# the context?

x = tf.Variable(3.0, name="x")
y = tf.Variable(4.0, name="y")
t = tf.Variable(5.0, name="t")

# Telling the tape: Hey! I will tell you what to record.
# Don't start recording automatically!
with tf.GradientTape(watch_accessed_variables=False) as tape:
    # Watch x but not y
    tape.watch(x)
    z = x * y
    
    # `t` isn't involved in any computation here
    # but what if we want to record it as well
    tape.watch(t)

print("Tape watching only these objects that you asked it to watch")
for var in tape.watched_variables():
    print(f"{var.name} and it's value is {var.numpy()}")

Tape watching only these objects that you asked it to watch
x:0 and it's value is 3.0
t:0 and it's value is 5.0


#### Multiple Tapes
You can use more than one `GradientTape` for recording different objects. Tapes interact seamlessly

In [15]:
x = tf.Variable(3.0, name="x")
y = tf.Variable(4.0, name="y")

with tf.GradientTape() as tape_for_x, tf.GradientTape() as tape_for_y:
    # Watching different variables with different tapes
    tape_for_x.watch(x)
    tape_for_y.watch(y)
    
    z = x * y

dz_dx = tape_for_x.gradient(z, x)
dz_dy = tape_for_y.gradient(z, y)
print(f"Gradient of z wrt x: {dz_dx}")
print(f"Gradient of z wrt y: {dz_dy}")

Gradient of z wrt x: 4.0
Gradient of z wrt y: 3.0


#### Higher order gradients

Any computation done nside the `GradientTape` context gets recorded. If the computations involves gradient calculations, it gets recorded as well. This makes it easy to compute the `higher-order` gradients using the same API. Check this out

In [16]:
x = tf.Variable(3.0, name="x")

with tf.GradientTape() as tape1:
    with tf.GradientTape() as tape0:
        y = x * x * x
    first_order_grad = tape0.gradient(y, x)
second_order_grad = tape1.gradient(first_order_grad, x)

print(f"Variable x: {x.numpy()}")
print("\nEquation is y = x^3")
print(f"First Order Gradient wrt x (3 * x^2): {first_order_grad}")
print(f"Second Order Gradient wrt x (6^x): {second_order_grad}")

Variable x: 3.0

Equation is y = x^3
First Order Gradient wrt x (3 * x^2): 27.0
Second Order Gradient wrt x (6^x): 18.0


#### Gotchas

Let's look at a few things that you **should** be aware of so that your code doesn't fail silently!

We already looked at that gradients for `int` or `string` dtypes isn't defined. Here are a few other things

In [17]:
# What happens when you tries to take gradient wrt a Tensor?
x = tf.constant(3.0)

with tf.GradientTape() as tape:
    y = x * x
    
dy_dx = tape.gradient(y, x)

print(x)
print("\nGradient of y wrt x: ", dy_dx)

tf.Tensor(3.0, shape=(), dtype=float32)

Gradient of y wrt x:  None


In [18]:
# Let's modify the above code a bit

x = tf.constant(3.0)

with tf.GradientTape() as tape:
    tape.watch(x)
    y = x * x
    
dy_dx = tape.gradient(y, x)

print(x)
print("\nGradient of y wrt x: ", dy_dx)

tf.Tensor(3.0, shape=(), dtype=float32)

Gradient of y wrt x:  tf.Tensor(6.0, shape=(), dtype=float32)


Woah! What just happened? Don't look further down but pause for a minute and think for a while about what just happened and why such a behavior

#### States and Gradients

`GradientTape` can only read from the current state, not from the history that ead to it. State blocks gradient calculations from going farther back. Let's look at an example to make it more clear

In [19]:
x = tf.Variable(3.0)
y = tf.Variable(4.0)

with tf.GradientTape() as tape:
    # Change the state of x by making x = x + y
    
    x.assign_add(y)
    
    # Let's do some computation e.g z = x * x 
    # This is equivalent to z = (x + y) * (x + y) because of above assign_add
    # z= x^2 + y^2 +2x*y   dz/dx= 2x+2y=6+8=14
    z = x * x
    
dx = tape.gradient(z, x)
print("Gradients of z wrt y: ", dx)

Gradients of z wrt y:  tf.Tensor(14.0, shape=(), dtype=float32)


In [20]:
x = tf.Variable(3.0)
y = tf.Variable(4.0)

with tf.GradientTape() as tape:
    # Change the state of x by making x = x + y
    
    x.assign_add(y)
    
    # Let's do some computation e.g z = x * x 
    # This is equivalent to z = (x + y) * (x + y) because of above assign_add
    # z= x^2 + y^2 +2x*y   dz/dx= 2x+2y=6+8=14
    z = x * x
    
dx,dy = tape.gradient(z, [x,y])
print("Gradients of z wrt x: ", dx)
print("Gradients of z wrt y: ", dy) # error no gradient

Gradients of z wrt x:  tf.Tensor(14.0, shape=(), dtype=float32)
Gradients of z wrt y:  None


In [21]:
x = tf.Variable(3.0)
y = tf.Variable(4.0)

with tf.GradientTape() as tape:
    # Change the state of x by making x = x + y
    
   
    # Let's do some computation e.g z = x * x 
    # This is equivalent to z = (x + y) * (x + y) because of above assign_add
    # z= x^2 + y^2 +2x*y   dz/dx= 2x+2y=6+8=14
    z = (x + y) * (x + y)
    
dx,dy = tape.gradient(z, [x,y])
print("Gradients of z wrt x: ", dx)
print("Gradients of z wrt y: ", dy) 

Gradients of z wrt x:  tf.Tensor(14.0, shape=(), dtype=float32)
Gradients of z wrt y:  tf.Tensor(14.0, shape=(), dtype=float32)


That's it for part 3! I hope you liked the content and I am also hoping that it would have given you a much clear picture of Automatic Differentiation and Gradients calculation. We will be looking at other things in the next tutorial!<br>


**References**:
1. https://www.tensorflow.org/guide/autodiff
2. https://keras.io/getting_started/intro_to_keras_for_researchers/