# Trax Intro

[Trax](https://trax-ml.readthedocs.io/en/latest/) is an end-to-end library for deep learning that focuses on clear code and speed.

  1. **[Run a pre-trained Transformer](#1)**: create a translator in a few lines of code
  1. **[Walkthrough](#2)**: how Trax works, how to make new models and train on your own data


In [9]:
#@title Install dependencies
#@markdown - Trax
%%capture
!pip install -Uqq trax

In [21]:
#@title Import packages
import os

import numpy as np
import trax

from trax import layers as tl

print("numpy", np.__version__)

numpy 1.19.5


<a name='1'></a>
## Run a pre-trained Transformer


Here is how you create an Engligh-German translator in a few lines of code:

* create a Transformer model in Trax with [trax.models.Transformer](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.Transformer)
* initialize it from a file with pre-trained weights with [model.init_from_file](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.base.Layer.init_from_file)
* tokenize your input sentence to input into the model with [trax.data.tokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.tokenize)
* decode from the Transformer with [trax.supervised.decoding.autoregressive_sample](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.decoding.autoregressive_sample)
* de-tokenize the decoded result to get the translation with [trax.data.detokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.detokenize)

In [11]:
# create a transformer model
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512,
    d_ff=2048,
    n_heads=8,
    n_encoder_layers=6,
    max_len=2048,
    mode="predict",
)

# initialize using pre-trained weights
model.init_from_file(
    "gs://trax-ml/models/translation/ende_wmt32k.pkl.gz", 
    weights_only=True
)

In [12]:
%%time
# tokenize a sentence
sentence = "It is nice to learn new things today!"
tokenized = list(
    trax.data.tokenize(
        iter([sentence]),
        vocab_dir="gs://trax-ml/vocabs/",
        vocab_file="ende_32k.subword",
    )
)[0]

# decode from the transformer
tokenized = tokenized[None, :]
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0
)  # Higher temperature: more diverse results.

# de-tokenize
tokenized_translation = tokenized_translation[0][:-1]  # remove batch and EOS
translation = trax.data.detokenize(
    tokenized_translation,
    vocab_dir="gs://trax-ml/vocabs/",
    vocab_file="ende_32k.subword",
)
print(translation)

Es ist schön, heute neue Dinge zu lernen!
CPU times: user 12.7 s, sys: 225 ms, total: 13 s
Wall time: 13.4 s


Expected result:

```
Es ist schön, heute neue Dinge zu lernen!
[GPU time] 
1 loop, best of 3: 4.22 s per loop
CPU times: user 12.7 s, sys: 227 ms, total: 12.9 s
Wall time: 13.3 s

[TPU time] 
1 loop, best of 3: 13.8 s per loop ???
CPU times: user 49.5 s, sys: 234 ms, total: 49.7 s
Wall time: 49.7 s

[CPU time] 
1 loop, best of 3: 14.4 s per loop
CPU times: user 57.5 s, sys: 250 ms, total: 57.7 s
Wall time: 57.8 s
```

<a name='2'></a>
## Walkthrough

### Tensors and Fast Math

The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`.  Trax also uses the numpy API.


In [17]:
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend("jax")

matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f"matrix =\n{matrix}")
vector = fastnp.ones(3)
print(f"vector = {vector}")
product = fastnp.dot(vector, matrix)
print(f"product = {product}")
tanh = fastnp.tanh(product)
print(f"tanh(product) = {tanh}")

matrix =
[[1 2 3]
 [4 5 6]
 [7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]


Gradients can be calculated using trax.fastmath.grad.

In [20]:
def f(x):
    return 2.0 * x * x

grad_f = trax.fastmath.grad(f)

print(f"grad(2x^2) at  1 = {grad_f(1.0)}")
print(f"grad(2x^2) at -2 = {grad_f(-2.0)}")

grad(2x^2) at  1 = 4.0
grad(2x^2) at -2 = -8.0


## Layers

Layers with trainable weights like `Embedding` need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.

In [24]:
x = np.arange(15)
print(f"x = {x}")

embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))

y = embedding(x)
print(f"shape of y= {y.shape}")

x = [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
shape of y= (15, 32)


### Models

Models in Trax are built from layers most often using the `Serial` and `Branch` combinators. You can read more about those combinators in the [layers intro](https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html) and
see the code for many models in `trax/models/`, e.g., this is how the [Transformer Language Model](https://github.com/google/trax/blob/master/trax/models/transformer.py#L167) is implemented. Below is an example of how to build a sentiment classification model.

In [27]:
model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),
    tl.Dense(2),
)

print(model)

Serial[
  Embedding_8192_256
  Mean
  Dense_2
]


### Data

To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call `next(data_stream)` and get a tuple, e.g., `(inputs, targets)`. Trax allows you to use [TensorFlow Datasets](https://www.tensorflow.org/datasets) easily and you can also get an iterator from your own text file using the standard `open('my_file.txt')`.

In [28]:
train_stream = trax.data.TFDS(
    "imdb_reviews", keys=("text", "label"), train=True
)()
eval_stream = trax.data.TFDS(
    "imdb_reviews", keys=("text", "label"), train=False
)()
print(next(train_stream))


[1mDownloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…







HBox(children=(FloatProgress(value=0.0, description='Generating splits...', max=3.0, style=ProgressStyle(descr…

HBox(children=(FloatProgress(value=0.0, description='Generating train examples...', max=25000.0, style=Progres…

HBox(children=(FloatProgress(value=0.0, description='Shuffling imdb_reviews-train.tfrecord...', max=25000.0, s…

HBox(children=(FloatProgress(value=0.0, description='Generating test examples...', max=25000.0, style=Progress…

HBox(children=(FloatProgress(value=0.0, description='Shuffling imdb_reviews-test.tfrecord...', max=25000.0, st…

HBox(children=(FloatProgress(value=0.0, description='Generating unsupervised examples...', max=50000.0, style=…

HBox(children=(FloatProgress(value=0.0, description='Shuffling imdb_reviews-unsupervised.tfrecord...', max=500…

[1mDataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.[0m
(b'I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Constantly slow and boring. Things seemed to happen, but with no explanation of what was causing them or why. I admit, I may have missed part of the film, but i watched the majority of it and everything just seemed to happen of its own accord without any real concern for anything else. I cant recommend this film at all.', 0)


Using the `trax.data` module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using `trax.data.Serial` and they are functions that you apply to streams to create processed streams.

In [48]:
data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file="en_8k.subword", keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(
        boundaries=[32, 128, 512, 2048],
        batch_sizes=[512, 128, 32, 8, 1],
        length_keys=[0],
    ),
    trax.data.AddLossWeights(),
)
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f"shapes = {[x.shape for x in example_batch]}")


shapes = [(8, 1024), (8,), (8,)]


### Supervised training

When you have the model and the data, use `trax.supervised.training` to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you.

In [49]:
example_batch[1]

array([1, 0, 1, 1, 1, 0, 0, 1])