# Trax Intro

[Trax](https://trax-ml.readthedocs.io/en/latest/) is an end-to-end library for deep learning that focuses on clear code and speed.

  1. **[Run a pre-trained Transformer](#1)**: create a translator in a few lines of code
  1. **[Walkthrough](#2)**: how Trax works, how to make new models and train on your own data


In [1]:
#@title Install dependencies
#@markdown - Trax
%%capture
!pip install -Uqq trax

In [2]:
#@title Import packages
import os

import numpy as np
import trax

from trax import layers as tl

print("numpy", np.__version__)

numpy 1.19.5


<a name='1'></a>
## Run a pre-trained Transformer


Here is how you create an Engligh-German translator in a few lines of code:

* create a Transformer model in Trax with [trax.models.Transformer](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.Transformer)
* initialize it from a file with pre-trained weights with [model.init_from_file](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.base.Layer.init_from_file)
* tokenize your input sentence to input into the model with [trax.data.tokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.tokenize)
* decode from the Transformer with [trax.supervised.decoding.autoregressive_sample](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.decoding.autoregressive_sample)
* de-tokenize the decoded result to get the translation with [trax.data.detokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.detokenize)

In [3]:
# create a transformer model
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512,
    d_ff=2048,
    n_heads=8,
    n_encoder_layers=6,
    max_len=2048,
    mode="predict",
)

# initialize using pre-trained weights
model.init_from_file(
    "gs://trax-ml/models/translation/ende_wmt32k.pkl.gz", 
    weights_only=True
)

In [4]:
%%time
# tokenize a sentence
sentence = "It is nice to learn new things today!"
tokenized = list(
    trax.data.tokenize(
        iter([sentence]),
        vocab_dir="gs://trax-ml/vocabs/",
        vocab_file="ende_32k.subword",
    )
)[0]

# decode from the transformer
tokenized = tokenized[None, :]
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0
)  # Higher temperature: more diverse results.

# de-tokenize
tokenized_translation = tokenized_translation[0][:-1]  # remove batch and EOS
translation = trax.data.detokenize(
    tokenized_translation,
    vocab_dir="gs://trax-ml/vocabs/",
    vocab_file="ende_32k.subword",
)
print(translation)

Es ist schön, heute neue Dinge zu lernen!
CPU times: user 12.5 s, sys: 939 ms, total: 13.5 s
Wall time: 17.4 s


Expected result:

```
Es ist schön, heute neue Dinge zu lernen!
[GPU time] 
1 loop, best of 3: 4.22 s per loop
CPU times: user 12.7 s, sys: 227 ms, total: 12.9 s
Wall time: 13.3 s

[TPU time] 
1 loop, best of 3: 13.8 s per loop ???
CPU times: user 49.5 s, sys: 234 ms, total: 49.7 s
Wall time: 49.7 s

[CPU time] 
1 loop, best of 3: 14.4 s per loop
CPU times: user 57.5 s, sys: 250 ms, total: 57.7 s
Wall time: 57.8 s
```

<a name='2'></a>
## Walkthrough

### Tensors and Fast Math

The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`.  Trax also uses the numpy API.


In [5]:
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend("jax")

matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f"matrix =\n{matrix}")
vector = fastnp.ones(3)
print(f"vector = {vector}")
product = fastnp.dot(vector, matrix)
print(f"product = {product}")
tanh = fastnp.tanh(product)
print(f"tanh(product) = {tanh}")

matrix =
[[1 2 3]
 [4 5 6]
 [7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]


Gradients can be calculated using trax.fastmath.grad.

In [6]:
def f(x):
    return 2.0 * x * x

grad_f = trax.fastmath.grad(f)

print(f"grad(2x^2) at  1 = {grad_f(1.0)}")
print(f"grad(2x^2) at -2 = {grad_f(-2.0)}")

grad(2x^2) at  1 = 4.0
grad(2x^2) at -2 = -8.0


## Layers

Layers with trainable weights like `Embedding` need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.

In [7]:
x = np.arange(15)
print(f"x = {x}")

embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))

y = embedding(x)
print(f"shape of y= {y.shape}")

x = [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
shape of y= (15, 32)


### Models

Models in Trax are built from layers most often using the `Serial` and `Branch` combinators. You can read more about those combinators in the [layers intro](https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html) and
see the code for many models in `trax/models/`, e.g., this is how the [Transformer Language Model](https://github.com/google/trax/blob/master/trax/models/transformer.py#L167) is implemented. Below is an example of how to build a sentiment classification model.

In [8]:
model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),
    tl.Dense(2),
)

print(model)

Serial[
  Embedding_8192_256
  Mean
  Dense_2
]


### Data

To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call `next(data_stream)` and get a tuple, e.g., `(inputs, targets)`. Trax allows you to use [TensorFlow Datasets](https://www.tensorflow.org/datasets) easily and you can also get an iterator from your own text file using the standard `open('my_file.txt')`.

In [9]:
train_stream = trax.data.TFDS(
    "imdb_reviews", keys=("text", "label"), train=True
)()
eval_stream = trax.data.TFDS(
    "imdb_reviews", keys=("text", "label"), train=False
)()
print(next(train_stream))


(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", 0)


Using the `trax.data` module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using `trax.data.Serial` and they are functions that you apply to streams to create processed streams.

In [10]:
data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file="en_8k.subword", keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(
        boundaries=[32, 128, 512, 2048],
        batch_sizes=[512, 128, 32, 8, 1],
        length_keys=[0],
    ),
    trax.data.AddLossWeights(),
)
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f"shapes = {[x.shape for x in example_batch]}")


shapes = [(8, 2048), (8,), (8,)]


### Supervised training

When you have the model and the data, use `trax.supervised.training` to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you.

In [13]:
from trax.supervised import training

train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=20
)

output_dir = os.path.expanduser("~/output_dir/")
!rm -rf {output_dir}
training_loop = training.Loop(
    model,
    train_task,
    eval_tasks=[eval_task],
    output_dir=output_dir
)

training_loop.run(2000)


Step      1: Total number of trainable weights: 2097666
Step      1: Ran 1 train steps in 1.60 secs
Step      1: train WeightedCategoryCrossEntropy |  0.69452739
Step      1: eval  WeightedCategoryCrossEntropy |  0.70902301
Step      1: eval      WeightedCategoryAccuracy |  0.47500000

Step    500: Ran 499 train steps in 15.23 secs
Step    500: train WeightedCategoryCrossEntropy |  0.50166368
Step    500: eval  WeightedCategoryCrossEntropy |  0.42592406
Step    500: eval      WeightedCategoryAccuracy |  0.80000000

Step   1000: Ran 500 train steps in 13.28 secs
Step   1000: train WeightedCategoryCrossEntropy |  0.37094992
Step   1000: eval  WeightedCategoryCrossEntropy |  0.33120547
Step   1000: eval      WeightedCategoryAccuracy |  0.84375000

Step   1500: Ran 500 train steps in 13.04 secs
Step   1500: train WeightedCategoryCrossEntropy |  0.34561551
Step   1500: eval  WeightedCategoryCrossEntropy |  0.48605159
Step   1500: eval      WeightedCategoryAccuracy |  0.78437500

Step   200

After training the model, run it like any layer to get results.

In [17]:
example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(
    example_input, vocab_file="en_8k.subword"
)
print(f"sample review: {example_input_str}")
sentiment_log_probs = model(example_input[None, :])
print(f"probabilities: {np.exp(sentiment_log_probs)}")

sample review: I was a schoolboy when I watched this film for the first time. The next day I knew that all pupils of our form watched it and all were fascinated by the film as I was. I think the same situation was in all forms of our school and in the whole Soviet Union. Later I watched it every time it was shown on TV and want to watch more. I think that comparison with "Back to the Future" or other Sci-Fi films is not appropriate. "Gost'ya iz budushchego" is unique in many ways, once you have watched it, you never forget it.<br /><br />This film is full of belief in peaceful science achievements, full of belief in the beautiful future of our world. It's not only the film, but also a forecast of many scientific inventions and achievements. The time shown in the film is the year 1984 (the year of its creation) and the year 2084 (where a schoolboy Kolya Gerasimov has traveled for some time and where his friend Alisa Seleznyova was from). The year now is 2005, many inventions and achieve

In [24]:
example_input_str = "It is a moving film, I got touched for the story."
example_input_tokenized = list(
    trax.data.tokenize(
        iter([example_input_str]),
        vocab_dir="gs://trax-ml/vocabs/",
        vocab_file="ende_32k.subword",
    )
)[0]
print(f"sample review: {example_input_str}")
sentiment_log_probs = model(example_input_tokenized[None, :])
print(f"probabilities: {np.exp(sentiment_log_probs)}")

sample review: It is a moving film, I got touched for the story.
probabilities: [[17.71551     0.05883902]]


In [27]:
example_input_str = "Completely dissapointing. I felt sleep. Simply awful."
example_input_tokenized = list(
    trax.data.tokenize(
        iter([example_input_str]),
        vocab_dir="gs://trax-ml/vocabs/",
        vocab_file="ende_32k.subword",
    )
)[0]
print(f"sample review: {example_input_str}")
sentiment_log_probs = model(example_input_tokenized[None, :])
print(f"probabilities: {np.exp(sentiment_log_probs)}")

sample review: Completely dissapointing. I felt sleep. Simply awful.
probabilities: [[1.9859998 0.520353 ]]
