Derived from the docs' tutorial <a href="https://trax-ml.readthedocs.io/en/latest/notebooks/trax_intro.html">here</a>.

In [1]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [14]:
import numpy as np
import trax
from   trax import layers as tl
from   trax.fastmath import numpy as fastnp
from   trax.supervised import training

trax.fastmath.use_backend('jax'); # or 'tensorflow-numpy'

### Fast Math

In [3]:
M = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'Matrix:\n{M}')

v = fastnp.ones(3)
print(f'Vector: {v}')

dot_prod = fastnp.dot(v, M)
print(f'Dot product: {dot_prod}')

tanh = fastnp.tanh(dot_prod)
print(f'tanh(prod): {tanh}')

Matrix:
[[1 2 3]
 [4 5 6]
 [7 8 9]]
Vector: [1. 1. 1.]
Dot product: [12. 15. 18.]
tanh(prod): [1. 1. 1.]




In [4]:
def f(x):
    return 2. * x * x

In [5]:
grad_f = trax.fastmath.grad(f)

print(f'grad(2x^2) at 1: {grad_f(1.)}')
print(f'grad(2x^2) at -2: {grad_f(-2.)}')

grad(2x^2) at 1: 4.0
grad(2x^2) at -2: -8.0


In [6]:
x = np.arange(15)
print(f'x: {x}')
      
# Create embedding layer
embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))

# Run the layer -- y = embedding(x)
y = embedding(x)
print(f'y.shape: {y.shape}')

x: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
y.shape: (15, 32)


### Models

In [7]:
mod = tl.Serial(tl.Embedding(vocab_size=8192, d_feature=256),
                tl.Mean(axis=1), # mean sent length
                tl.Dense(2),     # classify 2 classes
                tl.LogSoftmax()) # log probs
mod

Serial[
  Embedding_8192_256
  Mean
  Dense_2
  LogSoftmax
]

### Data 

In [10]:
train_stream = trax.data.TFDS(
    'imdb_reviews', keys=('text', 'label'), train=True
)()
eval_stream = trax.data.TFDS(
    'imdb_reviews', keys=('text', 'label'), train=False
)()

In [11]:
print(next(train_stream))

(b'Mann photographs the Alberta Rocky Mountains in a superb fashion, and Jimmy Stewart and Walter Brennan give enjoyable performances as they always seem to do. <br /><br />But come on Hollywood - a Mountie telling the people of Dawson City, Yukon to elect themselves a marshal (yes a marshal!) and to enforce the law themselves, then gunfighters battling it out on the streets for control of the town? <br /><br />Nothing even remotely resembling that happened on the Canadian side of the border during the Klondike gold rush. Mr. Mann and company appear to have mistaken Dawson City for Deadwood, the Canadian North for the American Wild West.<br /><br />Canadian viewers be prepared for a Reefer Madness type of enjoyable howl with this ludicrous plot, or, to shake your head in disgust.', 0)


In [12]:
data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[32, 128, 512, 2048], 
                             batch_sizes=[512, 128, 32, 8, 1], 
                             length_keys=[0]),
    trax.data.AddLossWeights())

In [13]:
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}')

shapes = [(8, 2048), (8,), (8,)]


### Supervised Training

In [16]:
train_task = training.TrainTask(labeled_data=train_batches_stream,
                                loss_layer=tl.CrossEntropyLoss(),
                                optimizer=trax.optimizers.Adam(0.01),
                                n_steps_per_checkpoint=500)
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
    n_eval_batches=20)

output_dir = './output'
!rm -rf $output_dir

training_loop = training.Loop(
    mod, train_task, eval_tasks=[eval_task], output_dir=output_dir)
training_loop.run(500)


Step      1: Ran 1 train steps in 0.96 secs
Step      1: train CrossEntropyLoss |  0.69560391
Step      1: eval  CrossEntropyLoss |  0.70317779
Step      1: eval          Accuracy |  0.47343750

Step    500: Ran 499 train steps in 14.22 secs
Step    500: train CrossEntropyLoss |  0.51409441
Step    500: eval  CrossEntropyLoss |  0.53382829
Step    500: eval          Accuracy |  0.77968750


In [17]:
example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, 
                                         vocab_file='en_8k.subword')
print(f'Input: {example_input_str}')
sentiment_log_probs = mod(example_input[None, :])
print(f'Sentiment probs: {np.exp(sentiment_log_probs)}')

Sentiment probs: [[0.02387173 0.9761283 ]]
