In [1]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
!pip install -q -U trax

[K     |████████████████████████████████| 471kB 5.4MB/s 
[K     |████████████████████████████████| 174kB 13.4MB/s 
[K     |████████████████████████████████| 2.6MB 15.4MB/s 
[K     |████████████████████████████████| 71kB 8.5MB/s 
[K     |████████████████████████████████| 1.3MB 52.3MB/s 
[K     |████████████████████████████████| 348kB 52.7MB/s 
[K     |████████████████████████████████| 1.1MB 47.0MB/s 
[K     |████████████████████████████████| 3.6MB 48.2MB/s 
[K     |████████████████████████████████| 890kB 51.4MB/s 
[K     |████████████████████████████████| 2.9MB 47.6MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [None]:
import trax
# Use the tensorflow-numpy backend.
trax.fastmath.set_backend('tensorflow-numpy')
print(trax.fastmath.backend_name())

tensorflow-numpy


In [None]:
# https://www.tensorflow.org/datasets/catalog/fashion_mnist
train_stream = trax.data.TFDS('fashion_mnist', keys=('image', 'label'), train=True)()
eval_stream = trax.data.TFDS('fashion_mnist', keys=('image', 'label'), train=False)()

[1mDownloading and preparing dataset fashion_mnist/3.0.1 (download: 29.45 MiB, generated: 36.42 MiB, total: 65.87 MiB) to /root/tensorflow_datasets/fashion_mnist/3.0.1...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…









HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=60000.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/fashion_mnist/3.0.1.incompleteG7JCTM/fashion_mnist-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=60000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/fashion_mnist/3.0.1.incompleteG7JCTM/fashion_mnist-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

[1mDataset fashion_mnist downloaded and prepared to /root/tensorflow_datasets/fashion_mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [None]:
train_data_pipeline = trax.data.Serial(
    trax.data.Shuffle(),
    trax.data.Batch(8),
    trax.data.AddLossWeights(),
)

train_batches_stream = train_data_pipeline(train_stream)

eval_data_pipeline = trax.data.Serial(
    trax.data.Batch(8),
    trax.data.AddLossWeights(),
)

eval_batches_stream = eval_data_pipeline(eval_stream)

In [None]:
example_batch = next(train_batches_stream)
print(f'batch shape (image, label) = {[x.shape for x in example_batch]}')

batch shape (image, label) = [(8, 28, 28, 1), (8,), (8,)]


In [None]:
from trax import layers as tl
from trax.models.resnet import Resnet50

def get_model(n_output_classes=10):
  model = tl.Serial(
      tl.ToFloat(),

      tl.Conv(32, (3, 3), (1, 1), 'SAME'),
      tl.LayerNorm(),
      tl.Relu(),
      tl.MaxPool(),

      tl.Conv(64, (3, 3), (1, 1), 'SAME'),
      tl.LayerNorm(),
      tl.Relu(),
      tl.MaxPool(),

      tl.Flatten(),
      tl.Dense(n_output_classes),
      tl.LogSoftmax(),
  )
  return model

In [None]:
from trax.supervised import training

train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=100,
)

eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
    n_eval_batches=20,
)

In [None]:
import os

model = get_model()

training_loop = training.Loop(model, 
                              train_task, 
                              eval_tasks=[eval_task], 
                              output_dir='./cnn_model')

training_loop.run(1000)


Step      1: Total number of trainable weights: 451658
Step      1: Ran 1 train steps in 1.56 secs
Step      1: train CrossEntropyLoss |  2.06045055
Step      1: eval  CrossEntropyLoss |  177.95621605
Step      1: eval          Accuracy |  0.23125000

Step    100: Ran 99 train steps in 2.99 secs
Step    100: train CrossEntropyLoss |  30.65670776
Step    100: eval  CrossEntropyLoss |  5.53605672
Step    100: eval          Accuracy |  0.54375000

Step    200: Ran 100 train steps in 2.82 secs
Step    200: train CrossEntropyLoss |  1.93918920
Step    200: eval  CrossEntropyLoss |  0.73415423
Step    200: eval          Accuracy |  0.71250000

Step    300: Ran 100 train steps in 1.68 secs
Step    300: train CrossEntropyLoss |  0.64221692
Step    300: eval  CrossEntropyLoss |  0.67144935
Step    300: eval          Accuracy |  0.75000000

Step    400: Ran 100 train steps in 1.81 secs
Step    400: train CrossEntropyLoss |  0.60595691
Step    400: eval  CrossEntropyLoss |  0.49500983
Step    40