# Trax : Ungraded Lecture Notebook

In this notebook you'll get to know about the Trax framework and learn about some of its basic building blocks.



## Background

### Why Trax and not TensorFlow or PyTorch?

TensorFlow and PyTorch are both extensive frameworks that can do almost anything in deep learning. They offer a lot of flexibility, but that often means verbosity of syntax and extra time to code.

Trax is much more concise. It runs on a TensorFlow backend but allows you to train models with 1 line commands. Trax also runs end to end, allowing you to get data, model and train all with a single terse statements. This means you can focus on learning, instead of spending hours on the idiosyncrasies of big framework implementation.

### Why not Keras then?

Keras is now part of Tensorflow itself from 2.0 onwards. Also, trax is good for implementing new state of the art algorithms like Transformers, Reformers, BERT because it is actively maintained by Google Brain Team for advanced deep learning tasks. It runs smoothly on CPUs,GPUs and TPUs as well with comparatively lesser modifications in code.

### How to Code in Trax
Building models in Trax relies on 2 key concepts:- **layers** and **combinators**.
Trax layers are simple objects that process data and perform computations. They can be chained together into composite layers using Trax combinators, allowing you to build layers and models of any complexity.

### Trax, JAX, TensorFlow and Tensor2Tensor

You already know that Trax uses Tensorflow as a backend, but it also uses the JAX library to speed up computation too. You can view JAX as an enhanced and optimized version of numpy. 

**Watch out for assignments which import `import trax.fastmath.numpy as np`. If you see this line, remember that when calling `np` you are really calling Trax’s version of numpy that is compatible with JAX.**

As a result of this, where you used to encounter the type `numpy.ndarray` now you will find the type `jax.interpreters.xla.DeviceArray`.

Tensor2Tensor is another name you might have heard. It started as an end to end solution much like how Trax is designed, but it grew unwieldy and complicated. So you can view Trax as the new improved version that operates much faster and simpler.

### Resources

- Trax source code can be found on Github: [Trax](https://github.com/google/trax)
- JAX library: [JAX](https://jax.readthedocs.io/en/latest/index.html)


## Installing Trax

Trax has dependencies on JAX and some libraries like JAX which are yet to be supported in [Windows](https://github.com/google/jax/blob/1bc5896ee4eab5d7bb4ec6f161d8b2abb30557be/README.md#installation) but work well in Ubuntu and MacOS. We would suggest that if you are working on Windows, try to install Trax on WSL2. 

Official maintained documentation - [trax-ml](https://trax-ml.readthedocs.io/en/latest/) not to be confused with this [TraX](https://trax.readthedocs.io/en/latest/index.html)

In [6]:
!pip install trax==1.3.9 

Collecting trax==1.3.9
  Using cached trax-1.3.9-py2.py3-none-any.whl (629 kB)
Collecting t5 (from trax==1.3.9)
  Using cached t5-0.9.4-py2.py3-none-any.whl (164 kB)
INFO: pip is looking at multiple versions of trax to determine which version is compatible with other requirements. This could take a while.
[31mERROR: Could not find a version that satisfies the requirement tensorflow-text (from trax) (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for tensorflow-text[0m[31m
[0m


## Imports

In [7]:
import numpy as np  # regular ol' numpy

from trax import layers as tl  # core building block
from trax import shapes  # data signatures: dimensionality and type
from trax import fastmath  # uses jax, offers numpy on steroids

ValueError: no signature found for builtin function <built-in function asarray>