In [None]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [1]:
import warnings

warnings.filterwarnings('ignore')

In [2]:
import collections
import functools
import os
import sys
import time

import numpy as np
import psutil
from absl import logging

In [3]:
# For example, if trax is inside a 'src' directory
project_root = os.environ.get('TRAX_PROJECT_ROOT', '')
sys.path.insert(0, project_root)

# Option to verify the import path
print(f"Python will look for packages in: {sys.path[0]}")

# Import trax
import trax
from trax.data.encoder import encoder
from trax import fastmath
from trax import layers as tl
from trax.fastmath import numpy as jnp
from trax.learning.supervised import training

# Verify the source of the imported package
print(f"Imported trax from: {trax.__file__}")

Python will look for packages in: /raid/mmironczuk/projects/trax-upgrade


2025-04-10 12:13:44.765221: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-10 12:13:44.786231: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-10 12:13:44.792704: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-10 12:13:44.809811: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Imported trax from: /raid/mmironczuk/projects/trax-upgrade/trax/__init__.py


In [4]:
class MyLoop(training.Loop):
    def __init__(
            self,
            *args, **kwargs
    ):
        super().__init__(
            *args, **kwargs
        )
        self._stop_training = False

    def run(self, n_steps=1):
        """Just add a logic to break the loop to ``training.Loop.run`` when
            the early stopping condition is satisfied.
        """

        with self._open_summary_writers() as (
                train_summary_writers,
                eval_summary_writers,
        ):
            process = psutil.Process(os.getpid())
            loss_acc, step_acc = 0.0, 0
            start_time = time.time()
            optimizer_metrics_acc = collections.defaultdict(float)
            for i in range(n_steps):
                prev_task_index = self._which_task(self._step)
                self._step += 1
                task_index = self._which_task(self._step)
                task_changed = task_index != prev_task_index

                if task_changed:
                    loss_acc, step_acc = 0.0, 0

                loss, optimizer_metrics = self._run_one_step(task_index, task_changed)

                optimizer_metrics, loss = fastmath.nested_map(
                    functools.partial(tl.mean, self._n_devices),
                    (optimizer_metrics, loss),
                )

                loss_acc += loss
                # Log loss every 50 steps, every step in memory-efficient trainers.
                if self._step % 50 == 0 or self._use_memory_efficient_trainer:
                    self._log_step("Loss: %.4f" % loss, stdout=False)
                step_acc += 1
                for metric_name, value in optimizer_metrics.items():
                    optimizer_metrics_acc[metric_name] += value

                if self._checkpoint_at(self.step):
                    self.save_checkpoint("model")
                if self._permanent_checkpoint_at(self.step):
                    self.save_checkpoint(f"model_{self.step}")
                if self._eval_at(self.step):
                    logging.info(
                        "cpu memory use (MB): %.2f",
                        process.memory_info().rss / float(1024 * 1024),
                    )
                    elapsed_time = time.time() - start_time
                    self._log_training_progress(
                        task=self._tasks[task_index],
                        total_loss=loss_acc,
                        n_steps=step_acc,
                        elapsed_time=elapsed_time,
                        optimizer_metrics=optimizer_metrics_acc,
                        summary_writer=train_summary_writers[task_index],
                    )
                    self.run_evals(eval_summary_writers)
                    loss_acc, step_acc = 0.0, 0
                    start_time = time.time()
                    optimizer_metrics_acc = collections.defaultdict(float)

                if self._checkpoint_at(self.step):
                    if self._checkpoint_low_metric is not None and self._at_lowest():
                        self.save_checkpoint(f"lowest_{self._checkpoint_low_metric}")
                    if self._checkpoint_high_metric is not None and self._at_highest():
                        self.save_checkpoint(f"highest_{self._checkpoint_high_metric}")

                for callback in self._callbacks:
                    if callback.call_at(self.step):
                        if callback.__class__.__name__ == 'EarlyStopping':
                            #added to check for earlystopping callback after
                            # history was updated.
                            #callback.on_step_end execute before history was
                            #updated.
                            best_step = callback.on_step_begin_with_history(self.step)

                            if not self._stop_training and self.step == n_steps:
                                self._log_step("Did not meet early stopping condition.")

                if self._stop_training:
                    # added to stop the training.
                    self._log_step(f"Early stopping... "
                                   f" the best step at {best_step}")
                    break

        self._eval_model.weights = self._model.weights

In [5]:
def callback_earlystopper(
        monitor=None,
        min_delta=0,
        patience=0,
        mode="auto",
        restore_best_checkpoint=True
):
    """Wrap the EarlyStopping class into a callable.

    Returns an early stopping.

    Args:
    monitor: Quantity to be monitored.

    min_delta: Minimum change in the monitored quantity
        to qualify as an improvement, i.e. an absolute
        change of less than min_delta, will count as no
        improvement.

    patience: ``patience`` times ``n_steps_per_checkpoint`` will be
        the total number of steps without improvement
        after which training will be stopped.

    mode: One of ``{"auto", "min", "max"}``. In ``min``(``max``) mode,
        training will stop when the quantity monitored has stopped
        decreasing(increasing) during the number of steps assigned
        in ``patience``; in ``"auto"``
        mode, the direction is automatically inferred
        from the name of the monitored quantity.

    restore_best_checkpoint: Whether to restore model from
        the checkpoint with the best value of the monitored quantity.
        If False, the model weights obtained at the last step of
        training are used. If True and there is an early stopping,
        the best checkpoint will be restored.
    """

    if mode not in ["auto", "max", "min"]:
        self._loop._log_step(
            f"Early stopping mode='{mode}' is unknown, " "fallback to 'auto' mode"
        )
        mode = "auto"

    class EarlyStopping:
        """Create a call back taht activates early stopping.

        Activate early stopping.
        """

        def __init__(self, loop):
            """Configures an early stopping.
            This is inspired by keras.callbacks.EarlyStopping.

            Args:
                loop:   training ``Loop`` from the current training.

            """

            self._loop = loop
            self.monitor = monitor
            self.min_delta = jnp.abs(min_delta)
            self.patience = jnp.maximum(patience, 1)

            self.restore_best_checkpoint = restore_best_checkpoint

            if mode == "min":
                self.monitor_op = jnp.less
            elif mode == "max":
                self.monitor_op = jnp.greater
            else:
                if self.monitor.endswith("Accuracy"):
                    self.monitor_op = jnp.greater
                else:
                    self.monitor_op = jnp.less

            if self.monitor_op == np.greater:
                self.min_delta *= 1
            else:
                self.min_delta *= -1

            self.wait = 0
            self.stopped_step = 1
            self.best = jnp.inf if self.monitor_op == jnp.less else -jnp.inf
            self.best_step = 1
            self.best_checkpoint_path = None

        def _is_metric_exist(self):
            metric_names = [
                name
                for eval_task in self._loop._eval_tasks
                for name in eval_task.metric_names
            ]
            return self.monitor in metric_names

        def call_at(self, step):
            return self._loop._eval_at(step)

        def on_step_begin(self, step):
            if not self._is_metric_exist():
                # Raise error if the monitor name is not in evaluation task.
                self._loop._log_step(
                    f"Early Stopping metric '{self.monitor}' " "is not in eval_tasks."
                )
                self._loop._log_step(
                    "Select one of " f"them from here {self.metric_names}."
                )

                raise SystemExit("Monitoring metric not found.")

        def on_step_end(self, step):
            pass

        def on_step_begin_with_history(self, step):
            if self.restore_best_checkpoint and self.best_checkpoint_path is None:
                self._loop.save_checkpoint("best_checkpoint")
                self.best_checkpoint_path = os.path.join(
                    self._loop._output_dir, "best_checkpoint.pkl.gz"
                )

            self.wait += 1
            current_step, current = self._get_monitor_value()

            if current is None:
                return

            if self._is_improvement(current, self.best):
                self.best = current
                self.best_step = current_step
                self._loop.save_checkpoint("best_checkpoint")

                # reset wait
                self.wait = 0

            if self.wait >= self.patience and step > 1:
                self.stopped_step = current_step
                self._loop._stop_training = True

                if (
                        self.restore_best_checkpoint
                        and self.best_checkpoint_path is not None
                ):
                    self._loop.load_checkpoint(self.best_checkpoint_path)
                    self._loop._log_step(
                        f"Best checkpoint was restored from Step {self.best_step}."
                    )

                return self.best_step

        def _is_improvement(self, monitor_value, reference_value):
            return self.monitor_op(monitor_value - self.min_delta, reference_value)

        def _get_monitor_value(self):
            step, monitor_value = self._loop.history.get(
                "eval", "metrics/" + self.monitor
            )[-1]
            return step, monitor_value

    return EarlyStopping

## Linear Regression
## Generate data for linear model

In [6]:
def get_data_linear():
    while True:
        x = np.random.randint(low=1, high=10) * 1.0
        y = x * 2.0 - 1
        yield (np.array([x]), np.array([y]))

In [7]:
data_linear = get_data_linear()
print(next(data_linear))

(array([7.]), array([13.]))


In [9]:
from trax.data.preprocessing import inputs as preprocessing

data_pipeline = preprocessing.Serial(preprocessing.Batch(50), preprocessing.AddLossWeights(), )
data_stream = data_pipeline(data_linear)

## Build a simple linear model

In [10]:
model_linear = tl.Serial(tl.Dense(1))

## Train a linear model

In [11]:
from trax import optimizers as optimizers

# Use the same data_stream for both training and evaluation
train_task = training.TrainTask(
    labeled_data=data_stream,
    loss_layer=tl.L2Loss(),
    optimizer=optimizers.SGD(0.01),
    n_steps_per_checkpoint=10,
)

eval_task = training.EvalTask(
    labeled_data=data_stream, metrics=[tl.L2Loss()], n_eval_batches=15,
)

## Add early stopping function

In [12]:
earlystopping = callback_earlystopper(monitor='L2Loss', min_delta=1e-4)

In [13]:
# Delete the training folder
!rm -r linear_model

rm: cannot remove 'linear_model': No such file or directory


In [14]:
model_linear = tl.Serial(tl.Dense(1))
training_loop = MyLoop(
    model=model_linear, tasks=train_task, eval_tasks=[eval_task], output_dir="./linear_model",
    callbacks=[earlystopping]
)
# training_loop.save_checkpoint(f'step_{training_loop.step}')

In [15]:
training_loop.run(1500)


Step      1: Total number of trainable weights: 2
Step      1: Ran 1 train steps in 0.26 secs
Step      1: train L2Loss |  7.69807005
Step      1: eval  L2Loss |  0.63437390

Step     10: Ran 9 train steps in 0.26 secs
Step     10: train L2Loss |  0.29812446
Step     10: eval  L2Loss |  0.23468615

Step     20: Ran 10 train steps in 0.12 secs
Step     20: train L2Loss |  0.21997391
Step     20: eval  L2Loss |  0.20311840

Step     30: Ran 10 train steps in 0.08 secs
Step     30: train L2Loss |  0.19210428
Step     30: eval  L2Loss |  0.18424661

Step     40: Ran 10 train steps in 0.08 secs
Step     40: train L2Loss |  0.18213718
Step     40: eval  L2Loss |  0.17369503

Step     50: Ran 10 train steps in 0.08 secs
Step     50: train L2Loss |  0.15699960
Step     50: eval  L2Loss |  0.15382139

Step     60: Ran 10 train steps in 0.09 secs
Step     60: train L2Loss |  0.15296355
Step     60: eval  L2Loss |  0.12982111

Step     70: Ran 10 train steps in 0.09 secs
Step     70: train L2Los

## Change patience
patience = 10 means it will wait for 10 x 10 = 100 steps (patience * n_steps_per_checkpoint ) to before making a decision to stop.

In [16]:
earlystopping = callback_earlystopper(monitor='L2Loss', patience=10, min_delta=1e-4)

In [17]:
# Delete the training folder
!rm -r linear_model



In [18]:
model_linear = tl.Serial(tl.Dense(1))
training_loop = MyLoop(
    model=model_linear, tasks=train_task, eval_tasks=[eval_task], output_dir="./linear_model",
    callbacks=[earlystopping]
)
# training_loop.save_checkpoint(f'step_{training_loop.step}')

In [19]:
training_loop.run(1500)


Step      1: Total number of trainable weights: 2
Step      1: Ran 1 train steps in 0.22 secs
Step      1: train L2Loss |  18.23833466
Step      1: eval  L2Loss |  2.60297537

Step     10: Ran 9 train steps in 0.08 secs
Step     10: train L2Loss |  0.56292963
Step     10: eval  L2Loss |  0.23197113

Step     20: Ran 10 train steps in 0.08 secs
Step     20: train L2Loss |  0.23490067
Step     20: eval  L2Loss |  0.22663632

Step     30: Ran 10 train steps in 0.08 secs
Step     30: train L2Loss |  0.21568297
Step     30: eval  L2Loss |  0.20940009

Step     40: Ran 10 train steps in 0.08 secs
Step     40: train L2Loss |  0.20161334
Step     40: eval  L2Loss |  0.18221980

Step     50: Ran 10 train steps in 0.08 secs
Step     50: train L2Loss |  0.17556223
Step     50: eval  L2Loss |  0.17093813

Step     60: Ran 10 train steps in 0.09 secs
Step     60: train L2Loss |  0.16914175
Step     60: eval  L2Loss |  0.15923102

Step     70: Ran 10 train steps in 0.09 secs
Step     70: train L2Lo

## Make a prediction

In [20]:
test_data = np.array([[2.0], [3.0], [10.0], [44.0]])
model_linear(test_data)

Array([[ 3.0123417],
       [ 5.009349 ],
       [18.988398 ],
       [86.88663  ]], dtype=float32)