In [1]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
!pip install -q -U trax




In [3]:
import trax
# Use the tensorflow-numpy backend.
trax.fastmath.set_backend('tensorflow-numpy')
print(trax.fastmath.backend_name())

tensorflow-numpy


In [4]:
# https://www.tensorflow.org/datasets/catalog/fashion_mnist
train_stream = trax.data.TFDS('fashion_mnist', keys=('image', 'label'), train=True)()
eval_stream = trax.data.TFDS('fashion_mnist', keys=('image', 'label'), train=False)()

In [5]:
train_data_pipeline = trax.data.Serial(
    trax.data.Shuffle(),
    trax.data.Batch(8),
)

train_batches_stream = train_data_pipeline(train_stream)

eval_data_pipeline = trax.data.Batch(8)

eval_batches_stream = eval_data_pipeline(eval_stream)

In [6]:
example_batch = next(train_batches_stream)
print(f'batch shape (image, label) = {[x.shape for x in example_batch]}')

batch shape (image, label) = [(8, 28, 28, 1), (8,)]


In [7]:
from trax import layers as tl
from trax.models.resnet import Resnet50

def get_model(n_output_classes=10):
  model = tl.Serial(
      tl.ToFloat(),

      tl.Conv(32, (3, 3), (1, 1), 'SAME'),
      tl.LayerNorm(),
      tl.Relu(),
      tl.MaxPool(),

      tl.Conv(64, (3, 3), (1, 1), 'SAME'),
      tl.LayerNorm(),
      tl.Relu(),
      tl.MaxPool(),

      tl.Flatten(),
      tl.Dense(n_output_classes),
  )
  return model

In [8]:
from trax.supervised import training

train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.CategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=100,
)

eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.CategoryCrossEntropy(), tl.CategoryAccuracy()],
    n_eval_batches=20,
)

In [9]:
import os

model = get_model()

training_loop = training.Loop(model, 
                              train_task, 
                              eval_tasks=[eval_task], 
                              output_dir='./cnn_model')

training_loop.run(1000)


Step      1: Total number of trainable weights: 451658
Step      1: Ran 1 train steps in 1.20 secs
Step      1: train CategoryCrossEntropy |  2.94750214
Step      1: eval  CategoryCrossEntropy |  211.32588081
Step      1: eval      CategoryAccuracy |  0.12500000

Step    100: Ran 99 train steps in 1.60 secs
Step    100: train CategoryCrossEntropy |  33.01021576
Step    100: eval  CategoryCrossEntropy |  4.50655540
Step    100: eval      CategoryAccuracy |  0.61250000

Step    200: Ran 100 train steps in 1.53 secs
Step    200: train CategoryCrossEntropy |  1.78586197
Step    200: eval  CategoryCrossEntropy |  0.89368055
Step    200: eval      CategoryAccuracy |  0.76250000

Step    300: Ran 100 train steps in 0.98 secs
Step    300: train CategoryCrossEntropy |  0.81385994
Step    300: eval  CategoryCrossEntropy |  0.64747319
Step    300: eval      CategoryAccuracy |  0.77500000

Step    400: Ran 100 train steps in 0.95 secs
Step    400: train CategoryCrossEntropy |  0.59235722
Step    