In [1]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [1]:
#!pip install -q -U trax
import sys
import os

# For example, if trax is inside a 'src' directory
project_root = os.environ.get('TRAX_PROJECT_ROOT', '')
sys.path.insert(0, project_root)

# Option to verify the import path
print(f"Python will look for packages in: {sys.path[0]}")

# Import trax
import trax

# Verify the source of the imported package
print(f"Imported trax from: {trax.__file__}")

Python will look for packages in: /raid/mmironczuk/projects/trax-upgrade
Imported trax from: /raid/mmironczuk/projects/trax-upgrade/trax/__init__.py


In [2]:
from trax import fastmath
from trax.fastmath.jax import jax

# Use the tensorflow-numpy backend.
fastmath.set_backend(fastmath.Backend.JAX.value)
print(trax.fastmath.backend_name())
print(jax.devices())

2025-04-10 11:54:19.998674: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-10 11:54:20.027312: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-10 11:54:20.036206: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-10 11:54:20.056125: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


jax
[CudaDevice(id=0)]


In [3]:
# https://www.tensorflow.org/datasets/catalog/fashion_mnist
from trax.data.preprocessing import inputs as preprocessing
from trax.data.loader.tf import base as dataset

train_stream = dataset.TFDS('fashion_mnist', keys=('image', 'label'), train=True)()
eval_stream = dataset.TFDS('fashion_mnist', keys=('image', 'label'), train=False)()

2025-04-10 11:54:28.637868: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30801 MB memory:  -> device: 0, name: Tesla V100-DGXS-32GB, pci bus id: 0000:0f:00.0, compute capability: 7.0


In [4]:
train_data_pipeline = preprocessing.Serial(
    preprocessing.Shuffle(),
    preprocessing.Batch(8),
)

train_batches_stream = train_data_pipeline(train_stream)

eval_data_pipeline = preprocessing.Batch(8)
eval_batches_stream = eval_data_pipeline(eval_stream)

In [5]:
example_batch = next(train_batches_stream)
print(f'batch shape (image, label) = {[x.shape for x in example_batch]}')

batch shape (image, label) = [(8, 28, 28, 1), (8,)]


In [6]:
from trax import layers as tl


def get_model(n_output_classes=10):
    model = tl.Serial(
        tl.ToFloat(),

        tl.Conv(32, (3, 3), (1, 1), 'SAME'),
        tl.LayerNorm(),
        tl.Relu(),
        tl.MaxPool(),

        tl.Conv(64, (3, 3), (1, 1), 'SAME'),
        tl.LayerNorm(),
        tl.Relu(),
        tl.MaxPool(),

        tl.Flatten(),
        tl.Dense(n_output_classes),
    )
    return model

In [7]:
from trax.learning.supervised import training
from trax import optimizers as optimizers

train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.CategoryCrossEntropy(),
    optimizer=optimizers.Adam(0.01),
    n_steps_per_checkpoint=100,
)

eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.CategoryCrossEntropy(), tl.CategoryAccuracy()],
    n_eval_batches=20,
)

In [8]:
model = get_model()

training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir='./cnn_model')

training_loop.run(1000)




Step      1: Total number of trainable weights: 451658
Step      1: Ran 1 train steps in 1.11 secs
Step      1: train CategoryCrossEntropy |  3.80618763


  with gzip.GzipFile(fileobj=f, compresslevel=compresslevel) as gzipf:
  with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf:


Step      1: eval  CategoryCrossEntropy |  247.44056854
Step      1: eval      CategoryAccuracy |  0.11875000

Step    100: Ran 99 train steps in 4.19 secs
Step    100: train CategoryCrossEntropy |  40.71861649
Step    100: eval  CategoryCrossEntropy |  3.85425627
Step    100: eval      CategoryAccuracy |  0.59375000

Step    200: Ran 100 train steps in 3.44 secs
Step    200: train CategoryCrossEntropy |  1.70089340
Step    200: eval  CategoryCrossEntropy |  0.73954285
Step    200: eval      CategoryAccuracy |  0.75000000

Step    300: Ran 100 train steps in 3.45 secs
Step    300: train CategoryCrossEntropy |  0.92829901
Step    300: eval  CategoryCrossEntropy |  0.86131594
Step    300: eval      CategoryAccuracy |  0.71875000

Step    400: Ran 100 train steps in 3.41 secs
Step    400: train CategoryCrossEntropy |  0.61508113
Step    400: eval  CategoryCrossEntropy |  0.55881782
Step    400: eval      CategoryAccuracy |  0.82500000

Step    500: Ran 100 train steps in 3.42 secs
Step   

In [10]:
import shutil

shutil.rmtree(training_loop.output_dir)

FileNotFoundError: [Errno 2] No such file or directory: './cnn_model'