<table class="tfo-notebook-buttons" align="left">
  <td>
    <a href="https://colab.research.google.com/github/martin-fabbri/colab-notebooks/blob/master/deeplearning.ai/nlp/c3_w2_jax_numpy_calculating_perplexity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>    
  </td>
  <td>
    <a href="https://github.com/martin-fabbri/colab-notebooks/blob/master/deeplearning.ai/nlp/c3_w2_jax_numpy_calculating_perplexity.ipynb" target="_parent"><img src="https://raw.githubusercontent.com/martin-fabbri/colab-notebooks/master/assets/github.svg" alt="View On Github"/></a>  </td>
</table>

# Working with JAX numpy and calculating perplexity: Ungraded Lecture Notebook

Normally you would import `numpy` and rename it as `np`. 

However in this week's assignment you will notice that this convention has been changed. 

Now standard `numpy` is not renamed and `trax.fastmath.numpy` is renamed as `np`. 

The rationale behind this change is that you will be using Trax's numpy (which is compatible with JAX) far more often. Trax's numpy supports most of the same functions as the regular numpy so the change won't be noticeable in most cases.

In [1]:
%%capture
!pip install trax==1.3.1

In [8]:
%%capture
import numpy
import trax
import trax.fastmath.numpy as jnp

from trax import layers as tl

trax.supervised.trainer_lib.init_random_number_generators(32)
numpy.random.seed(32)

In [3]:
!pip list | grep 'trax\|numpy\|jax'

jax                           0.2.7                
jaxlib                        0.1.57+cuda101       
numpy                         1.19.5               
trax                          1.3.1                


One important change to take into consideration is that the types of the resulting objects will be different depending on the version of numpy. With regular numpy you get `numpy.ndarray` but with Trax's numpy you will get `jax.interpreters.xla.DeviceArray`. These two types map to each other. So if you find some error logs mentioning DeviceArray type, don't worry about it, treat it like you would treat an ndarray and march ahead.

You can get a randomized numpy array by using the `numpy.random.random()` function.

This is one of the functionalities that Trax's numpy does not currently support in the same way as the regular numpy. 

In [4]:
numpy_array = numpy.random.random((5, 10))
print(f"The regular numpy array looks like this:\n\n {numpy_array}\n")
print(f"It is of type: {type(numpy_array)}")

The regular numpy array looks like this:

 [[0.85888927 0.37271115 0.55512878 0.95565655 0.7366696  0.81620514
  0.10108656 0.92848807 0.60910917 0.59655344]
 [0.09178413 0.34518624 0.66275252 0.44171349 0.55148779 0.70371249
  0.58940123 0.04993276 0.56179184 0.76635847]
 [0.91090833 0.09290995 0.90252139 0.46096041 0.45201847 0.99942549
  0.16242374 0.70937058 0.16062408 0.81077677]
 [0.03514717 0.53488673 0.16650012 0.30841038 0.04506241 0.23857613
  0.67483453 0.78238275 0.69520163 0.32895445]
 [0.49403187 0.52412136 0.29854125 0.46310814 0.98478429 0.50113492
  0.39807245 0.72790532 0.86333097 0.02616954]]

It is of type: <class 'numpy.ndarray'>


You can easily cast regular numpy arrays or lists into trax numpy arrays using the trax.fastmath.numpy.array() function:

In [6]:
trax_numpy_array = jnp.array(numpy_array)
print(f"The trax numpy array looks like this:\n\n {trax_numpy_array}\n")
print(f"It is of type: {type(trax_numpy_array)}")

The trax numpy array looks like this:

 [[0.8588893  0.37271115 0.55512875 0.9556565  0.7366696  0.81620514
  0.10108656 0.9284881  0.60910916 0.59655344]
 [0.09178413 0.34518623 0.6627525  0.44171348 0.5514878  0.70371246
  0.58940125 0.04993276 0.56179184 0.7663585 ]
 [0.91090834 0.09290995 0.9025214  0.46096042 0.45201847 0.9994255
  0.16242374 0.7093706  0.16062407 0.81077677]
 [0.03514718 0.5348867  0.16650012 0.30841038 0.04506241 0.23857613
  0.67483455 0.7823827  0.69520164 0.32895446]
 [0.49403188 0.52412134 0.29854125 0.46310815 0.9847843  0.50113493
  0.39807245 0.72790533 0.86333096 0.02616954]]

It is of type: <class 'jax.interpreters.xla._DeviceArray'>


Hope you now understand the differences (and similarities) between these two versions and numpy. **Great!**

The previous section was a quick look at Trax's numpy. However this notebook also aims to teach you how you can calculate the perplexity of a trained model.

## Calculating Perplexity

The perplexity is a metric that measures how well a probability model predicts a sample and it is commonly used to evaluate language models. It is defined as: 

$$P(W) = \sqrt[N]{\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{n-1})}}$$

As an implementation hack, you would usually take the log of that formula (to enable us to use the log probabilities we get as output of our `RNN`, convert exponents to products, and products into sums which makes computations less complicated and computationally more efficient). You should also take care of the padding, since you do not want to include the padding when calculating the perplexity (because we do not want to have a perplexity measure artificially good). The algebra behind this process is explained next:


$$log P(W) = {log\big(\sqrt[N]{\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{n-1})}}\big)}$$

$$ = {log\big({\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{n-1})}}\big)^{\frac{1}{N}}}$$ 

$$ = {log\big({\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\big)^{-\frac{1}{N}}} $$
$$ = -\frac{1}{N}{log\big({\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\big)} $$
$$ = -\frac{1}{N}{\big({\sum_{i=1}^{N}{logP(w_i| w_1,...,w_{n-1})}}\big)} $$

You will be working with a real example from this week's assignment. The example is made up of:
   - `predictions` : batch of tensors corresponding to lines of text predicted by the model.
   - `targets` : batch of actual tensors corresponding to lines of text.

In [7]:
predictions = n

