In [None]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import os
#!pip install -q -U trax
import sys

# For example, if trax is inside a 'src' directory
project_root = os.environ.get('TRAX_PROJECT_ROOT', '')
sys.path.insert(0, project_root)

# Option to verify the import path
print(f"Python will look for packages in: {sys.path[0]}")

# Import trax
import trax

# Verify the source of the imported package
print(f"Imported trax from: {trax.__file__}")

In [None]:
from trax import fastmath
from trax.fastmath.jax import jax

# Use the tensorflow-numpy backend.
fastmath.set_backend(fastmath.Backend.JAX.value)
print(trax.fastmath.backend_name())
print(jax.devices())

In [None]:
# https://www.tensorflow.org/datasets/catalog/fashion_mnist
from trax.data.preprocessing import inputs as preprocessing
from trax.data.loader.tf import base as dataset

train_stream = dataset.TFDS('fashion_mnist', keys=('image', 'label'), train=True)()
eval_stream = dataset.TFDS('fashion_mnist', keys=('image', 'label'), train=False)()

In [None]:
train_data_pipeline = preprocessing.Serial(
    preprocessing.Shuffle(),
    preprocessing.Batch(8),
)

train_batches_stream = train_data_pipeline(train_stream)

eval_data_pipeline = preprocessing.Batch(8)
eval_batches_stream = eval_data_pipeline(eval_stream)

In [None]:
example_batch = next(train_batches_stream)
print(f'batch shape (image, label) = {[x.shape for x in example_batch]}')

In [None]:
from trax import layers as tl


def get_model(n_output_classes=10):
    model = tl.Serial(
        tl.ToFloat(),

        tl.Conv(32, (3, 3), (1, 1), 'SAME'),
        tl.LayerNorm(),
        tl.Relu(),
        tl.MaxPool(),

        tl.Conv(64, (3, 3), (1, 1), 'SAME'),
        tl.LayerNorm(),
        tl.Relu(),
        tl.MaxPool(),

        tl.Flatten(),
        tl.Dense(n_output_classes),
    )
    return model

In [None]:
from trax.learning.supervised import training
from trax import optimizers as optimizers

train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.CategoryCrossEntropy(),
    optimizer=optimizers.Adam(0.01),
    n_steps_per_checkpoint=100,
)

eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.CategoryCrossEntropy(), tl.CategoryAccuracy()],
    n_eval_batches=20,
)

In [None]:
model = get_model()

training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir='./cnn_model')

training_loop.run(100)

In [None]:
import shutil

shutil.rmtree(training_loop.output_dir, ignore_errors=True)