<a href="https://colab.research.google.com/github/mkaramib/Trax/blob/main/learn_trax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Trax is a deep learning library implemented by Google brain team. In this page, I will show some experiments with Trax.

In [None]:
import os 
import numpy as np

In this stage we need to install Trax. Following lines of code, will install it for us.


# Install Trax

In [None]:
#@title
# Import Trax

!pip install -q -U trax
import trax

# Pre-trained Model

In this section, I will experiment a machine translation provided by Trax team. 


In [None]:
# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='predict')

# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                     weights_only=True)

# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
sentence = 'I love to learn Trax.'
tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                    vocab_dir='gs://trax-ml/vocabs/',
                                    vocab_file='ende_32k.subword'))[0]

# Decode from the Transformer.
tokenized = tokenized[None, :]  # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
                                   vocab_dir='gs://trax-ml/vocabs/',
                                   vocab_file='ende_32k.subword')
print(translation)

Ich liebe Trax zu lernen.


# Trax and Numpy

One of the key feature of Trax is speed which is achieved by using a fast version of numpy using JAX. 

In [None]:
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.

matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix =\n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')