In [2]:
import sys
sys.path.append("/home/jovyan/TF_NEW/tf-transformers/src/")



In [3]:
import os
import tempfile
import json
import glob
import datasets
import shutil
import tensorflow as tf

from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf

from tf_transformers.data import TFReader, TFWriter
from tf_transformers.models import Classification_Model
from tf_transformers.losses import cross_entropy_loss_for_classification
from model import get_model, get_tokenizer, get_optimizer, get_trainer

In [4]:
with initialize(config_path="conf/"):
    cfg = compose(config_name="config", overrides=["data.take_sample=true", "+glue=mrpc"])
    print(cfg)

{'data': {'train_batch_size': 32, 'eval_batch_size': 64, 'take_sample': True, 'max_seq_length': 128}, 'trainer': {'type': 'gpu', 'dtype': 'fp32', 'num_gpus': 2, 'tpu_address': None, 'epochs': 3, 'strategy': 'mirrored'}, 'optimizer': {'learning_rate': 2e-05, 'loss_type': None}, 'model': {'is_training': True, 'use_dropout': True}, 'glue': {'task': {'name': 'mrpc'}, 'data': {'name': 'mrpc', 'num_classes': 2}}}


In [5]:
# Steps

# 1. Download the data
# 2. Prepare TFRecords
# 3. Read TFrecords to tf.data
# 4. Train the model

In [6]:
def get_classification_model(num_classes, return_all_layer_outputs, is_training, use_dropout):
        
    def model_fn():
        model = get_model(return_all_layer_outputs, is_training, use_dropout)
        classification_model = Classification_Model(model,
                                                    num_classes, 
                                                    use_all_layers=return_all_layer_outputs, 
                                                    is_training=is_training, 
                                                    use_dropout=use_dropout)
        classification_model = classification_model.get_model()
        return classification_model
    return model_fn

In [7]:
# Convert data to features using specific length
# into a temp dir (and log it as well for monitoring)

def get_dataset(data, batch_size, tokenizer, max_seq_length, mode, tfrecord_dir, take_sample=False):
    
    if mode not in ["train", "eval"]:
        raise ValueError("Inavlid mode `{}` specified. Available mode is ['train', 'eval']".format(mode))
    
    def get_tfrecord_example(data):
        result = {}
        for f in data:
            input_ids_s1 = [tokenizer.cls_token] + tokenizer.tokenize(f['sentence1'])[: max_seq_length-2] + [tokenizer.sep_token] # -2 to add CLS and SEP
            input_ids_s1 = tokenizer.convert_tokens_to_ids(input_ids_s1)
            input_type_ids_s1 = [0] * len(input_ids_s1) # 0 for s1

            input_ids_s2 = tokenizer.tokenize(f['sentence2'])[: max_seq_length-1] + [tokenizer.sep_token] # -1 to add SEP
            input_ids_s2 = tokenizer.convert_tokens_to_ids(input_ids_s2)
            input_type_ids_s2 = [1] * len(input_ids_s2)
            
            # concatanate two sentences
            input_ids =  input_ids_s1 + input_ids_s2
            input_type_ids = input_type_ids_s1 + input_type_ids_s2
            input_mask = [1] * len(input_ids) # 1 for s2
            
            result = {}
            result['input_ids'] = input_ids
            result['input_mask'] = input_mask
            result['input_type_ids'] = input_type_ids

            result['labels'] = f['label']
            yield result
            
    schema = {
        "input_ids": ("var_len", "int"),
        "input_mask": ("var_len", "int"),
        "input_type_ids": ("var_len", "int"),
        "labels": ("var_len", "int"),
    }
    
    # Create a temp dir
    if mode == "train":
        # Write tf records
        train_data_dir = os.path.join(tfrecord_dir,"train")        
        tfrecord_filename = 'mrpc'
        tfwriter = TFWriter(schema=schema, 
                            file_name=tfrecord_filename, 
                            model_dir=train_data_dir,
                            tag='train',
                            overwrite=False
                     )
        data_train = data['train']
        # Take sample
        if take_sample:
            data_train = data_train.select(range(500))
            
        tfwriter.process(parse_fn=get_tfrecord_example(data_train))
        
        # Read tfrecord to dataset
        schema = json.load(open("{}/schema.json".format(train_data_dir)))
        stats  = json.load(open('{}/stats.json'.format(train_data_dir)))
        all_files = glob.glob("{}/*.tfrecord".format(train_data_dir))
        tf_reader = TFReader(schema=schema, 
                            tfrecord_files=all_files)

        x_keys = ['input_ids', 'input_type_ids', 'input_mask']
        y_keys = ['labels']
        train_dataset = tf_reader.read_record(auto_batch=True, 
                                           keys=x_keys,
                                           batch_size=batch_size, 
                                           x_keys = x_keys, 
                                           y_keys = y_keys,
                                           shuffle=True, 
                                           drop_remainder=True
                                          )
        return train_dataset, stats['total_records']
    if mode == "eval":
        # Write tfrecords
        eval_data_dir = os.path.join(tfrecord_dir,"eval")
        tfrecord_filename = 'mrpc'
        tfwriter = TFWriter(schema=schema, 
                            file_name=tfrecord_filename, 
                            model_dir=eval_data_dir,
                            tag='dev',
                            overwrite=False
                            )
        data_eval = data['validation']
        # Take sample
        if take_sample:
            data_eval = data_eval.select(range(500))
        tfwriter.process(parse_fn=get_tfrecord_example(data_eval))
        
        
        # Read tfrecord to dataset
        schema = json.load(open("{}/schema.json".format(eval_data_dir)))
        stats  = json.load(open('{}/stats.json'.format(eval_data_dir)))
        all_files = glob.glob("{}/*.tfrecord".format(eval_data_dir))
        tf_reader = TFReader(schema=schema, 
                            tfrecord_files=all_files)

        x_keys = ['input_ids', 'input_type_ids', 'input_mask']
        y_keys = ['labels']
        eval_dataset = tf_reader.read_record(auto_batch=True, 
                                           keys=x_keys,
                                           batch_size=batch_size, 
                                           x_keys = x_keys, 
                                           y_keys = y_keys,
                                           shuffle=False, 
                                           drop_remainder=False
                                          )
        return eval_dataset, stats['total_records']

In [8]:
def get_loss(loss_type):

    if loss_type and loss_type == 'joint':

        def loss_fn(y_true_dict, y_pred_dict):
            """Joint loss over all layers"""
            loss_dict = {}
            loss_holder = []
            for layer_count, per_layer_output in enumerate(y_pred_dict['class_logits']):

                loss = cross_entropy_loss_for_classification(
                    labels=tf.squeeze(y_true_dict['labels'], axis=1),
                    logits=per_layer_output
                )
                loss_dict['loss_{}'.format(layer_count + 1)] = loss
                loss_holder.append(loss)
            # Mean over batch
            loss_dict['loss'] = tf.reduce_mean(loss_holder, axis=0)
            return loss_dict

    else:

        def loss_fn(y_true_dict, y_pred_dict):
            """last layer loss"""
            loss_dict = {}
            loss = cross_entropy_loss_for_classification(
                labels=tf.squeeze(y_true_dict['labels'], axis=1),
                logits=y_pred_dict['class_logits']
            )
            loss_dict['loss'] = loss
            return loss_dict

    return loss_fn

In [9]:
cfg

{'data': {'train_batch_size': 32, 'eval_batch_size': 64, 'take_sample': True, 'max_seq_length': 128}, 'trainer': {'type': 'gpu', 'dtype': 'fp32', 'num_gpus': 2, 'tpu_address': None, 'epochs': 3, 'strategy': 'mirrored'}, 'optimizer': {'learning_rate': 2e-05, 'loss_type': None}, 'model': {'is_training': True, 'use_dropout': True}, 'glue': {'task': {'name': 'mrpc'}, 'data': {'name': 'mrpc', 'num_classes': 2}}}

In [10]:

# Data specific configuration
max_seq_len = cfg.data.max_seq_length
take_sample = cfg.data.take_sample
max_seq_length = cfg.data.max_seq_length
train_batch_size = cfg.data.train_batch_size
eval_batch_size  = cfg.data.eval_batch_size

# Trainer specifics
device = cfg.trainer.type
num_gpus = cfg.trainer.num_gpus
tpu_address = cfg.trainer.tpu_address
dtype = cfg.trainer.dtype
epochs = cfg.trainer.epochs
strategy = cfg.trainer.strategy

# Optimizer
learning_rate = cfg.optimizer.learning_rate
loss_type = cfg.optimizer.loss_type
return_all_layer_outputs = False
if loss_type and loss_type == 'joint':
    return_all_layer_outputs = True

# Core data specifics
data_name = cfg_task.glue.data.name
num_classes = cfg_task.glue.data.num_classes

# Model specific
is_training = cfg.model.is_training
use_dropout = cfg.model.use_dropout

In [11]:
# Load tokenizer
tokenizer = get_tokenizer()

# Load data
data = datasets.load_dataset("glue", data_name)
tfrecord_dir = tempfile.mkdtemp()

train_dataset, total_train_examples = get_dataset(data, train_batch_size,tokenizer, max_seq_length, "train", tfrecord_dir, take_sample)
eval_dataset, total_eval_examples  = get_dataset(data, eval_batch_size,tokenizer, max_seq_len, "eval", tfrecord_dir, take_sample)

Reusing dataset glue (/home/jovyan/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
INFO:absl:Total individual observations/examples written is 500 in 0.33753085136413574 seconds
INFO:absl:All writer objects closed
INFO:absl:Total individual observations/examples written is 500 in 0.3376038074493408 seconds
INFO:absl:All writer objects closed


In [12]:
# Load optimizer
optimizer_fn = get_optimizer(learning_rate, total_train_examples, train_batch_size, epochs)

# Load trainer
# trainer = get_trainer(device, dtype, strategy, num_gpus, tpu_address)


In [13]:
# Load model function
model_fn = get_classification_model(num_classes, 
                                 return_all_layer_outputs, is_training, use_dropout)
# Load loss function 
train_loss_fn = get_loss(loss_type)

In [14]:
for (batch_inputs, batch_labels) in train_dataset.take(1):
    print(batch_inputs['input_ids'].shape, batch_labels['labels'].shape)

(32, 95) (32, 1)


In [14]:
class Callback():
    
    def __init__(self):
        pass
    
    def call(trainer_kwargs):
        
        for k, v in trainer_kwargs.items():
            print(k, '-->', v)
callback = Callback()

In [15]:
# coding=utf-8
# Copyright 2021 TF-Transformers Authors.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
import tqdm
from absl import logging

from tf_transformers.core import keras_utils
from tf_transformers.core.distribute_utils import get_distribution_strategy
from tf_transformers.core.performance_utils import (
    configure_optimizer,
    get_tf_dtype,
    is_float16,
    set_mixed_precision_policy,
)


def flat_metric_dict(metric_dict):
    """Flatten the dict"""
    dict_flatten = {}
    dict_flatten['steps'] = list(metric_dict.keys())
    for _key, value in metric_dict.items():
        for sub_key, sub_value in value.items():
            if sub_key not in dict_flatten:
                dict_flatten[sub_key] = [sub_value]
            else:
                dict_flatten[sub_key].append(sub_value)
    return dict_flatten


def save_model_checkpoints(model, overwrite_checkpoint_dir, model_checkpoint_dir, max_number_of_models):
    # Model checkpoint
    if not overwrite_checkpoint_dir:
        import os

        if os.path.exists(model_checkpoint_dir):
            raise FileExistsError("Model directory exists")

    checkpoint = tf.train.Checkpoint(model=model)
    manager = tf.train.CheckpointManager(checkpoint, directory=model_checkpoint_dir, max_to_keep=max_number_of_models)
    return manager


def get_loss_metric_dict(training_loss_names, validation_loss_names):

    training_loss_dict_metric = {name: tf.keras.metrics.Mean(name, dtype=tf.float32) for name in training_loss_names}
    training_loss_dict_metric["learning_rate"] = tf.keras.metrics.Mean(
        "learning_rate", dtype=tf.float32
    )  # We store learning rate here and reset after every global steps

    validation_loss_dict_metric = {}
    if validation_loss_names:
        validation_loss_dict_metric = {
            name: tf.keras.metrics.Mean(name, dtype=tf.float32) for name in validation_loss_names
        }

    return training_loss_dict_metric, validation_loss_dict_metric


def get_and_reset_metric_from_dict(metric_dict):
    if not metric_dict:
        return {}
    metric_result = {name: metric.result().numpy() for name, metric in metric_dict.items()}
    for _name, metric in metric_dict.items():
        metric.reset_states()
    return metric_result


def get_tensorboard_writers(model_checkpoint_dir):
    train_log_dir = model_checkpoint_dir + "/logs/train"
    test_log_dir = model_checkpoint_dir + "/logs/dev"
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    test_summary_writer = tf.summary.create_file_writer(test_log_dir)
    return train_summary_writer, test_summary_writer


def train_and_eval(
    model,
    optimizer,
    strategy,
    epochs,
    steps_per_epoch,
    steps_per_call,
    train_dataset_iter,
    train_loss_fn,
    GLOBAL_BATCH_SIZE,
    training_loss_dict_metric,
    validation_dataset_distributed,
    validation_loss_fn,
    validation_loss_dict_metric,
    validation_interval_steps,
    mixed_precision,
    callbacks,
    callbacks_interval_steps,
    trainer_kwargs,
    checkpoint_manager,
    model_checkpoint_dir,
    model_save_interval_steps,
):
    def save_model(epoch_end=False):
        if not epoch_end:
            if model_save_interval_steps:
                if global_step % model_save_interval_steps == 0:
                    checkpoint_manager.save()
                    logging.info("Model saved at step {}".format(global_step))
        else:
            checkpoint_manager.save()
            logging.info("Model saved at epoch {}".format(epoch))

    # @tf.function(experimental_relax_shapes=True)
    def write_metrics(metric_dict, writer, step):
        # @tf.function
        def _write(step):
            # other model code would go here
            with writer.as_default():
                for name, result in metric_dict.items():
                    tf.summary.scalar(name, result, step=step)

        _write(step)
        writer.flush()

    def compute_loss(batch_labels, model_outputs):
        """Loss computation which takes care of loss reduction based on GLOBAL_BATCH_SIZE"""
        per_example_loss = train_loss_fn(batch_labels, model_outputs)
        per_example_loss_averaged = {}
        # Inplace update
        # Avergae loss per global batch size , recommended
        for name, loss in per_example_loss.items():
            per_example_loss_averaged[name] = tf.nn.compute_average_loss(loss, global_batch_size=GLOBAL_BATCH_SIZE)
        return per_example_loss_averaged

    def compute_loss_valid(batch_labels, model_outputs):
        """Validation Loss computation which takes care of loss reduction based on GLOBAL_BATCH_SIZE"""
        per_example_loss = validation_loss_fn(batch_labels, model_outputs)
        per_example_loss_averaged = {}
        # Inplace update
        # Avergae loss per global batch size , recommended
        for name, loss in per_example_loss.items():
            per_example_loss_averaged[name] = tf.nn.compute_average_loss(loss, global_batch_size=GLOBAL_BATCH_SIZE)
        return per_example_loss_averaged

    # Train Functions
    @tf.function
    def do_train(iterator):
        """The step function for one training step"""

        def train_step(dist_inputs):
            """The computation to run on each device."""
            batch_inputs, batch_labels = dist_inputs
            with tf.GradientTape() as tape:
                model_outputs = model(batch_inputs)
                loss = compute_loss(batch_labels, model_outputs)
                tf.debugging.check_numerics(loss['loss'], message='Loss value is either NaN or inf')
                if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
                    loss_scaled = {name: optimizer.get_scaled_loss(loss_value) for name, loss_value in loss.items()}
                # TODO
                # Scales down the loss for gradients to be invariant from replicas.
                # loss = loss / strategy.num_replicas_in_sync
            if mixed_precision:
                scaled_gradients = tape.gradient(loss_scaled["loss"], model.trainable_variables)
                grads = optimizer.get_unscaled_gradients(scaled_gradients)
            else:
                grads = tape.gradient(loss["loss"], model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            # training_loss.update_state(loss * strategy.num_replicas_in_sync)
            return loss

        for _ in tf.range(tf.convert_to_tensor(steps_per_call)):
            dist_inputs = next(iterator)
            loss = strategy.run(train_step, args=(dist_inputs,))
            # strategy reduce
            loss = {
                name: strategy.reduce(tf.distribute.ReduceOp.MEAN, loss_value, axis=None)
                for name, loss_value in loss.items()
            }
            for name, loss_value in loss.items():
                training_loss = training_loss_dict_metric[name]
                training_loss.update_state(loss_value)
            # Get current learning rate
            if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
                current_lr = optimizer._optimizer._decayed_lr(tf.float32)
            else:
                current_lr = optimizer._decayed_lr(tf.float32)
            training_loss_dict_metric["learning_rate"].update_state(current_lr)
            # training_result = get_and_reset_metric_from_dict(training_loss_dict_metric)

    # do validation
    def do_validation(validation_dataset_distributed):
        """Validation step"""

        @tf.function
        def _validate_step(dist_inputs):

            batch_inputs, batch_labels = dist_inputs
            model_outputs = model(batch_inputs)
            loss = compute_loss_valid(batch_labels, model_outputs)
            return loss

        if not epoch_end:
            if (
                validation_dataset_distributed
                and validation_loss_fn
                and validation_interval_steps
                and (global_step % validation_interval_steps == 0)
            ):
                logging.info("Validation in progress at step {} . . . .".format(global_step))
                with tqdm.tqdm(validation_dataset_distributed, unit=" Val batch ") as val_batches:
                    for dist_inputs in val_batches:
                        loss = strategy.run(_validate_step, args=(dist_inputs,))
                        for name, loss_value in loss.items():
                            loss_value = strategy.reduce(tf.distribute.ReduceOp.SUM, loss_value, axis=None)
                            validation_loss = validation_loss_dict_metric[name]
                            validation_loss.update_state(loss_value)

                validation_result = get_and_reset_metric_from_dict(validation_loss_dict_metric)
                validation_history[global_step] = validation_result
                write_metrics(validation_result, val_summary_writer, global_step)
                logging.info("Validation result at step {}".format(validation_result))
                print("\n")
        else:
            if validation_dataset_distributed and validation_loss_fn:
                logging.info("Validation in progress at epoch end {} . . . .".format(epoch))
                with tqdm.tqdm(validation_dataset_distributed, unit=" Val batch ") as val_batches:
                    for dist_inputs in val_batches:
                        loss = strategy.run(_validate_step, args=(dist_inputs,))
                        for name, loss_value in loss.items():
                            loss_value = strategy.reduce(tf.distribute.ReduceOp.SUM, loss_value, axis=None)
                            validation_loss = validation_loss_dict_metric[name]
                            validation_loss.update_state(loss_value)

                validation_result = get_and_reset_metric_from_dict(validation_loss_dict_metric)
                write_metrics(validation_result, val_summary_writer, global_step)
                # validation_history[global_step] = validation_result
                logging.info("Validation result at epoch {} is {}".format(epoch, validation_result))
                print("\n")

    def do_callbacks(callbacks):
        """Call callbacks"""
        if not epoch_end:
            callback_scores = None
            if callbacks and callbacks_interval_steps:
                logging.info("Callbacks in progress at step {} . . . .".format(global_step))
                callback_scores = []
                for callback, callback_steps in zip(callbacks, callbacks_interval_steps):
                    if callback_steps and (global_step % callback_steps == 0):
                        score = callback(trainer_kwargs)
                        callback_scores.append(score)
                    else:
                        callback_scores.append(None)
            return callback_scores
        else:
            callback_scores = None
            if callbacks:
                logging.info("Callbacks in progress at epoch end {} . . . .".format(epoch))
                callback_scores = []
                for callback in callbacks:
                    score = callback(trainer_kwargs)
                    callback_scores.append(score)

                    # Try to write a callback scores (only on epoch end)
                    # If we are returning a dict like {'exact_match': 81} or
                    # {'rougue-1': 30} etc . . . .
                    if score and isinstance(score, dict):
                        write_metrics(score, val_summary_writer, epoch)
            return callback_scores

    # Loop starts here
    # Get Tensorboard writers
    train_summary_writer, val_summary_writer = get_tensorboard_writers(model_checkpoint_dir)
    validation_history = {}
    training_history = {}
    global_step = 0
    epoch_end = False
    STEPS = steps_per_epoch // steps_per_call
    print("STEPS", STEPS)
    for epoch in range(1, epochs + 1):
        # start_epoch_time = time.time()
        with tqdm.trange(STEPS, unit="batch ") as tepoch:
            for step in tepoch:
                steps_covered = (step + 1) * steps_per_call
                global_step += steps_per_call
                print("Started epoch {} and step {}".format(epoch, global_step))
                tepoch.set_description(
                    "Epoch {}/{} --- Step {}/{} --- ".format(epoch, epochs, steps_covered, steps_per_epoch)
                )
                # Call Train
                do_train(train_dataset_iter)
                print("Train done")
                # Call Validation
                do_validation(validation_dataset_distributed)
                print("Val done")
                # Call Callbacks
                callback_scores = do_callbacks(callbacks)

                # Train Metrics
                training_result = get_and_reset_metric_from_dict(training_loss_dict_metric)
                training_history[global_step] = training_result
                write_metrics(training_result, train_summary_writer, global_step)
                # training_result["learning_rate"] = learning_rate_holder.result().numpy()
                # learning_rate_holder.reset_states()
                tepoch.set_postfix(**training_result)

                # Save model
                save_model()

        # Do after every epoch
        epoch_end = True
        save_model(epoch_end)
        #do_validation(validation_dataset_distributed)
        #callback_scores = do_callbacks(callbacks)
        epoch_end = False

    # Flatten the results
    training_history = flat_metric_dict(training_history)
    validation_history = flat_metric_dict(validation_history)
    return training_history, validation_history, callback_scores


class GPUTrainer:
    def __init__(
        self,
        distribution_strategy,
        num_gpus=0,
        all_reduce_alg=None,
        num_packs=1,
        tpu_address=None,
        dtype='fp32',
        loss_scale='dynamic',
    ):

        self.distribution_strategy = get_distribution_strategy(
            distribution_strategy=distribution_strategy,
            num_gpus=num_gpus,
            all_reduce_alg=all_reduce_alg,
            num_packs=num_packs,
            tpu_address=tpu_address,
        )

        self.num_replicas = self.distribution_strategy.num_replicas_in_sync
        self._dtype = get_tf_dtype(dtype)

        # Setting dtype policy
        set_mixed_precision_policy(self._dtype)
        self.use_float16 = is_float16(self._dtype)
        self.loss_scale = loss_scale

        # # TODO
        # if self.use_tpu:
        # params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
        # else:
        # logging.info("Running transformer with num_gpus = %d", num_gpus)

        # Add keras utils threads

    def run(
        self,
        model_fn,
        optimizer_fn,
        train_dataset,
        train_loss_fn,
        epochs,
        steps_per_epoch,
        model_checkpoint_dir,
        batch_size,
        training_loss_names=None,
        validation_loss_names=None,
        validation_dataset=None,
        validation_loss_fn=None,
        validation_interval_steps=None,
        steps_per_call=100,
        enable_xla=True,
        callbacks=None,
        callbacks_interval_steps=None,
        overwrite_checkpoint_dir=False,
        max_number_of_models=10,
        model_save_interval_steps=None,
        repeat_dataset=True,
        latest_checkpoint=None,
    ):

        if steps_per_epoch:
            logging.info("Make sure `steps_per_epoch` should be less than or equal to number of batches in dataset.")
        if callbacks:
            if callbacks_interval_steps is None:
                callbacks_interval_steps = [None for callback in callbacks]
            assert len(callbacks) == len(callbacks_interval_steps)

        # Enable XLA
        keras_utils.set_session_config(enable_xla=enable_xla)
        logging.info("Policy: ----> {}".format(keras_utils.get_policy_name()))
        logging.info("Strategy: ---> {}".format(self.distribution_strategy))
        logging.info("Num GPU Devices: ---> {}".format(self.distribution_strategy.num_replicas_in_sync))

        tf.keras.backend.clear_session()

        # Under Strategy Scope
        with self.distribution_strategy.scope():
            # Model
            model = model_fn()

            # Optimizer
            optimizer = optimizer_fn()

            optimizer = configure_optimizer(optimizer, use_float16=self.use_float16, loss_scale=self.loss_scale)

        # We use this to avoid inferring names from loss functions
        _training_loss_names = ['loss']
        _validation_loss_names = ['loss']
        if training_loss_names:
            _training_loss_names += training_loss_names
        if validation_loss_names:
            _validation_loss_names += validation_loss_names
        # Make unique names
        training_loss_names = list(set(_training_loss_names))
        validation_loss_names = list(set(_validation_loss_names))
        # Checkpoint manager
        checkpoint_manager = save_model_checkpoints(
            model, overwrite_checkpoint_dir, model_checkpoint_dir, max_number_of_models
        )

        # Try to load latest checkpoint
        model.load_checkpoint(checkpoint_dir=model_checkpoint_dir, checkpoint_path=latest_checkpoint, opt=optimizer)

        # Get metric dicts before distributing the dataset
        # ddistributed datasets has no attribute .take
        training_loss_dict_metric, validation_loss_dict_metric = get_loss_metric_dict(
            training_loss_names, validation_loss_names
        )
        # Distribute dataset
        if not repeat_dataset:
            train_dataset_distributed = self.distribution_strategy.experimental_distribute_dataset(
                train_dataset.repeat(epochs + 1)
            )
        else:
            train_dataset_distributed = self.distribution_strategy.experimental_distribute_dataset(
                train_dataset.repeat()
            )
        validation_dataset_distributed = None
        if validation_dataset:
            validation_dataset_distributed = self.distribution_strategy.experimental_distribute_dataset(
                validation_dataset
            )

        # Make train dataset iterator
        train_dataset_distributed = iter(train_dataset_distributed)

        history = {}
        training_history, validation_history, callback_scores = train_and_eval(
            model,
            optimizer,
            self.distribution_strategy,
            epochs,
            steps_per_epoch,
            steps_per_call,
            train_dataset_distributed,
            train_loss_fn,
            batch_size,
            training_loss_dict_metric,
            validation_dataset_distributed,
            validation_loss_fn,
            validation_loss_dict_metric,
            validation_interval_steps,
            self.use_float16,
            callbacks,
            callbacks_interval_steps,
            locals(),
            checkpoint_manager,
            model_checkpoint_dir,
            model_save_interval_steps,
        )
        history['training_history'] = training_history
        history['validation_hsitory'] = validation_history
        history['callbacks'] = callback_scores

        # Save json
        return history


In [16]:
distribution_strategy = "mirrored"
num_gpus = 2

In [17]:
trainer = GPUTrainer(
        distribution_strategy,
        num_gpus=num_gpus,
        all_reduce_alg=None,
        num_packs=1,
        tpu_address=None,
        dtype='fp32',
        loss_scale='dynamic',
    )

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


In [19]:
model_checkpoint_dir = "/tmp/model_ckpt/"
history = trainer.run(
    model_fn = model_fn,
    optimizer_fn = optimizer_fn,
    train_dataset = train_dataset,
    train_loss_fn = train_loss_fn,
    epochs = 2,
    steps_per_epoch = 100,
    model_checkpoint_dir=model_checkpoint_dir,
    batch_size=train_batch_size,
    training_loss_names=None,
    validation_loss_names=None,
    validation_dataset=eval_dataset,
    validation_loss_fn=train_loss_fn,
    validation_interval_steps=None,
    steps_per_call=1,
    enable_xla=False,
    callbacks=[callback],
    callbacks_interval_steps=None,
    overwrite_checkpoint_dir=True,
    max_number_of_models=10,
    model_save_interval_steps=None,
    repeat_dataset=True,
    latest_checkpoint=None,
)

INFO:absl:Make sure `steps_per_epoch` should be less than or equal to number of batches in dataset.
INFO:absl:Policy: ----> float32
INFO:absl:Strategy: ---> <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7ff612223790>
INFO:absl:Num GPU Devices: ---> 2
You are using a model of type albert to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO:absl:Successful: Model checkpoints matched and loaded from /tmp/tf_transformers_cache/albert-base-v2/ckpt-1
INFO:absl:Using Adamw optimizer
INFO:absl:No checkpoint found in /tmp/model_ckpt/
Epoch 1/2 --- Step 1/100 --- :   0%|          | 0/100 [00:00<?, ?batch /s]

STEPS 100
Started epoch 1 and step 1








INFO:tensorflow:batch_all_reduce: 26 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 26 all-reduces with algorithm = nccl, num_packs = 1






INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).










INFO:tensorflow:batch_all_reduce: 26 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 26 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:absl:Callbacks in progress at step 1 . . . .
Epoch 1/2 --- Step 2/100 --- :   1%|          | 1/100 [00:22<37:38, 22.81s/batch , learning_rate=4.44e-7, loss=0.426]INFO:absl:Callbacks in progress at step 2 . . . .
Epoch 1/2 --- Step 3/100 --- :   2%|▏         | 2/100 [00:22<26:08, 16.01s/batch , learning_rate=8.89e-7, loss=0.359]

Train done
Val done
Started epoch 1 and step 2
Train done
Val done
Started epoch 1 and step 3


INFO:absl:Callbacks in progress at step 3 . . . .
Epoch 1/2 --- Step 4/100 --- :   3%|▎         | 3/100 [00:23<18:10, 11.24s/batch , learning_rate=1.33e-6, loss=0.413]INFO:absl:Callbacks in progress at step 4 . . . .
Epoch 1/2 --- Step 5/100 --- :   4%|▍         | 4/100 [00:23<12:38,  7.90s/batch , learning_rate=1.78e-6, loss=0.366]

Train done
Val done
Started epoch 1 and step 4
Train done
Val done
Started epoch 1 and step 5


INFO:absl:Callbacks in progress at step 5 . . . .
Epoch 1/2 --- Step 6/100 --- :   5%|▌         | 5/100 [00:23<08:49,  5.57s/batch , learning_rate=2.22e-6, loss=0.364]INFO:absl:Callbacks in progress at step 6 . . . .
Epoch 1/2 --- Step 7/100 --- :   6%|▌         | 6/100 [00:23<06:09,  3.93s/batch , learning_rate=2.67e-6, loss=0.42] 

Train done
Val done
Started epoch 1 and step 6
Train done
Val done
Started epoch 1 and step 7


INFO:absl:Callbacks in progress at step 7 . . . .
Epoch 1/2 --- Step 8/100 --- :   7%|▋         | 7/100 [00:23<04:19,  2.79s/batch , learning_rate=3.11e-6, loss=0.379]INFO:absl:Callbacks in progress at step 8 . . . .
Epoch 1/2 --- Step 9/100 --- :   8%|▊         | 8/100 [00:23<03:03,  1.99s/batch , learning_rate=3.56e-6, loss=0.385]

Train done
Val done
Started epoch 1 and step 8
Train done
Val done
Started epoch 1 and step 9


INFO:absl:Callbacks in progress at step 9 . . . .
Epoch 1/2 --- Step 10/100 --- :   9%|▉         | 9/100 [00:23<02:10,  1.43s/batch , learning_rate=4e-6, loss=0.345]  INFO:absl:Callbacks in progress at step 10 . . . .
Epoch 1/2 --- Step 11/100 --- :  10%|█         | 10/100 [00:23<01:33,  1.04s/batch , learning_rate=4.44e-6, loss=0.357]

Train done
Val done
Started epoch 1 and step 10
Train done
Val done
Started epoch 1 and step 11


INFO:absl:Callbacks in progress at step 11 . . . .
Epoch 1/2 --- Step 12/100 --- :  11%|█         | 11/100 [00:24<01:08,  1.31batch /s, learning_rate=4.89e-6, loss=0.341]INFO:absl:Callbacks in progress at step 12 . . . .
Epoch 1/2 --- Step 13/100 --- :  12%|█▏        | 12/100 [00:24<00:50,  1.75batch /s, learning_rate=5.33e-6, loss=0.388]

Train done
Val done
Started epoch 1 and step 12
Train done
Val done
Started epoch 1 and step 13


INFO:absl:Callbacks in progress at step 13 . . . .
Epoch 1/2 --- Step 14/100 --- :  13%|█▎        | 13/100 [00:24<00:37,  2.29batch /s, learning_rate=5.78e-6, loss=0.375]INFO:absl:Callbacks in progress at step 14 . . . .
Epoch 1/2 --- Step 15/100 --- :  14%|█▍        | 14/100 [00:24<00:29,  2.93batch /s, learning_rate=6.22e-6, loss=0.365]

Train done
Val done
Started epoch 1 and step 14
Train done
Val done
Started epoch 1 and step 15


INFO:absl:Callbacks in progress at step 15 . . . .
Epoch 1/2 --- Step 16/100 --- :  15%|█▌        | 15/100 [00:24<00:23,  3.63batch /s, learning_rate=6.67e-6, loss=0.318]INFO:absl:Callbacks in progress at step 16 . . . .
Epoch 1/2 --- Step 17/100 --- :  16%|█▌        | 16/100 [00:24<00:19,  4.37batch /s, learning_rate=7.11e-6, loss=0.285]

Train done
Val done
Started epoch 1 and step 16
Train done
Val done
Started epoch 1 and step 17


INFO:absl:Callbacks in progress at step 17 . . . .
Epoch 1/2 --- Step 18/100 --- :  17%|█▋        | 17/100 [00:24<00:16,  4.99batch /s, learning_rate=7.56e-6, loss=0.322]INFO:absl:Callbacks in progress at step 18 . . . .
Epoch 1/2 --- Step 19/100 --- :  18%|█▊        | 18/100 [00:24<00:14,  5.66batch /s, learning_rate=8e-6, loss=0.264]   

Train done
Val done
Started epoch 1 and step 18
Train done
Val done
Started epoch 1 and step 19


INFO:absl:Callbacks in progress at step 19 . . . .
Epoch 1/2 --- Step 20/100 --- :  19%|█▉        | 19/100 [00:25<00:12,  6.24batch /s, learning_rate=8.44e-6, loss=0.351]INFO:absl:Callbacks in progress at step 20 . . . .
Epoch 1/2 --- Step 21/100 --- :  20%|██        | 20/100 [00:25<00:11,  6.74batch /s, learning_rate=8.89e-6, loss=0.273]

Train done
Val done
Started epoch 1 and step 20
Train done
Val done
Started epoch 1 and step 21


INFO:absl:Callbacks in progress at step 21 . . . .
Epoch 1/2 --- Step 22/100 --- :  21%|██        | 21/100 [00:25<00:10,  7.28batch /s, learning_rate=9.33e-6, loss=0.323]INFO:absl:Callbacks in progress at step 22 . . . .
Epoch 1/2 --- Step 23/100 --- :  22%|██▏       | 22/100 [00:25<00:10,  7.63batch /s, learning_rate=9.78e-6, loss=0.347]

Train done
Val done
Started epoch 1 and step 22
Train done
Val done
Started epoch 1 and step 23


INFO:absl:Callbacks in progress at step 23 . . . .
Epoch 1/2 --- Step 24/100 --- :  23%|██▎       | 23/100 [00:25<00:09,  7.77batch /s, learning_rate=1.02e-5, loss=0.31] INFO:absl:Callbacks in progress at step 24 . . . .
Epoch 1/2 --- Step 25/100 --- :  24%|██▍       | 24/100 [00:25<00:09,  8.01batch /s, learning_rate=1.07e-5, loss=0.311]

Train done
Val done
Started epoch 1 and step 24
Train done
Val done
Started epoch 1 and step 25


INFO:absl:Callbacks in progress at step 25 . . . .
Epoch 1/2 --- Step 26/100 --- :  25%|██▌       | 25/100 [00:25<00:09,  8.10batch /s, learning_rate=1.11e-5, loss=0.302]INFO:absl:Callbacks in progress at step 26 . . . .
Epoch 1/2 --- Step 27/100 --- :  26%|██▌       | 26/100 [00:25<00:09,  8.11batch /s, learning_rate=1.16e-5, loss=0.263]

Train done
Val done
Started epoch 1 and step 26
Train done
Val done
Started epoch 1 and step 27


INFO:absl:Callbacks in progress at step 27 . . . .
Epoch 1/2 --- Step 28/100 --- :  27%|██▋       | 27/100 [00:25<00:08,  8.13batch /s, learning_rate=1.2e-5, loss=0.304] INFO:absl:Callbacks in progress at step 28 . . . .
Epoch 1/2 --- Step 29/100 --- :  28%|██▊       | 28/100 [00:26<00:08,  8.14batch /s, learning_rate=1.24e-5, loss=0.322]

Train done
Val done
Started epoch 1 and step 28
Train done
Val done
Started epoch 1 and step 29


INFO:absl:Callbacks in progress at step 29 . . . .
Epoch 1/2 --- Step 30/100 --- :  29%|██▉       | 29/100 [00:26<00:08,  8.27batch /s, learning_rate=1.29e-5, loss=0.329]INFO:absl:Callbacks in progress at step 30 . . . .
Epoch 1/2 --- Step 31/100 --- :  30%|███       | 30/100 [00:26<00:08,  8.16batch /s, learning_rate=1.33e-5, loss=0.316]

Train done
Val done
Started epoch 1 and step 30
Train done
Val done
Started epoch 1 and step 31


INFO:absl:Callbacks in progress at step 31 . . . .
Epoch 1/2 --- Step 32/100 --- :  31%|███       | 31/100 [00:26<00:08,  8.17batch /s, learning_rate=1.38e-5, loss=0.279]INFO:absl:Callbacks in progress at step 32 . . . .
Epoch 1/2 --- Step 33/100 --- :  32%|███▏      | 32/100 [00:26<00:08,  8.17batch /s, learning_rate=1.42e-5, loss=0.347]

Train done
Val done
Started epoch 1 and step 32
Train done
Val done
Started epoch 1 and step 33


INFO:absl:Callbacks in progress at step 33 . . . .
Epoch 1/2 --- Step 34/100 --- :  33%|███▎      | 33/100 [00:26<00:08,  8.37batch /s, learning_rate=1.47e-5, loss=0.328]INFO:absl:Callbacks in progress at step 34 . . . .
Epoch 1/2 --- Step 35/100 --- :  34%|███▍      | 34/100 [00:26<00:07,  8.29batch /s, learning_rate=1.51e-5, loss=0.245]

Train done
Val done
Started epoch 1 and step 34
Train done
Val done
Started epoch 1 and step 35


INFO:absl:Callbacks in progress at step 35 . . . .
Epoch 1/2 --- Step 36/100 --- :  35%|███▌      | 35/100 [00:26<00:07,  8.25batch /s, learning_rate=1.56e-5, loss=0.319]INFO:absl:Callbacks in progress at step 36 . . . .
Epoch 1/2 --- Step 37/100 --- :  36%|███▌      | 36/100 [00:27<00:07,  8.35batch /s, learning_rate=1.6e-5, loss=0.28]  

Train done
Val done
Started epoch 1 and step 36
Train done
Val done
Started epoch 1 and step 37


INFO:absl:Callbacks in progress at step 37 . . . .
Epoch 1/2 --- Step 38/100 --- :  37%|███▋      | 37/100 [00:27<00:07,  8.41batch /s, learning_rate=1.64e-5, loss=0.303]INFO:absl:Callbacks in progress at step 38 . . . .
Epoch 1/2 --- Step 39/100 --- :  38%|███▊      | 38/100 [00:27<00:07,  8.48batch /s, learning_rate=1.69e-5, loss=0.288]

Train done
Val done
Started epoch 1 and step 38
Train done
Val done
Started epoch 1 and step 39


INFO:absl:Callbacks in progress at step 39 . . . .
Epoch 1/2 --- Step 40/100 --- :  39%|███▉      | 39/100 [00:27<00:07,  8.42batch /s, learning_rate=1.73e-5, loss=0.277]INFO:absl:Callbacks in progress at step 40 . . . .
Epoch 1/2 --- Step 41/100 --- :  40%|████      | 40/100 [00:27<00:07,  8.38batch /s, learning_rate=1.78e-5, loss=0.269]

Train done
Val done
Started epoch 1 and step 40
Train done
Val done
Started epoch 1 and step 41


INFO:absl:Callbacks in progress at step 41 . . . .
Epoch 1/2 --- Step 42/100 --- :  41%|████      | 41/100 [00:27<00:07,  8.30batch /s, learning_rate=1.82e-5, loss=0.322]INFO:absl:Callbacks in progress at step 42 . . . .
Epoch 1/2 --- Step 43/100 --- :  42%|████▏     | 42/100 [00:27<00:07,  8.04batch /s, learning_rate=1.87e-5, loss=0.306]

Train done
Val done
Started epoch 1 and step 42
Train done
Val done
Started epoch 1 and step 43


INFO:absl:Callbacks in progress at step 43 . . . .
Epoch 1/2 --- Step 44/100 --- :  43%|████▎     | 43/100 [00:27<00:07,  8.06batch /s, learning_rate=1.91e-5, loss=0.321]INFO:absl:Callbacks in progress at step 44 . . . .
Epoch 1/2 --- Step 45/100 --- :  44%|████▍     | 44/100 [00:28<00:06,  8.11batch /s, learning_rate=1.96e-5, loss=0.254]

Train done
Val done
Started epoch 1 and step 44
Train done
Val done
Started epoch 1 and step 45


INFO:absl:Callbacks in progress at step 45 . . . .
Epoch 1/2 --- Step 46/100 --- :  45%|████▌     | 45/100 [00:28<00:06,  8.11batch /s, learning_rate=0, loss=0.24]       INFO:absl:Callbacks in progress at step 46 . . . .
Epoch 1/2 --- Step 47/100 --- :  46%|████▌     | 46/100 [00:28<00:06,  8.15batch /s, learning_rate=0, loss=0.282]

Train done
Val done
Started epoch 1 and step 46
Train done
Val done
Started epoch 1 and step 47


INFO:absl:Callbacks in progress at step 47 . . . .
Epoch 1/2 --- Step 48/100 --- :  47%|████▋     | 47/100 [00:28<00:06,  7.91batch /s, learning_rate=0, loss=0.256]INFO:absl:Callbacks in progress at step 48 . . . .
Epoch 1/2 --- Step 49/100 --- :  48%|████▊     | 48/100 [00:28<00:06,  8.19batch /s, learning_rate=0, loss=0.314]

Train done
Val done
Started epoch 1 and step 48
Train done
Val done
Started epoch 1 and step 49


INFO:absl:Callbacks in progress at step 49 . . . .
Epoch 1/2 --- Step 50/100 --- :  49%|████▉     | 49/100 [00:28<00:06,  8.26batch /s, learning_rate=0, loss=0.3]  INFO:absl:Callbacks in progress at step 50 . . . .
Epoch 1/2 --- Step 51/100 --- :  50%|█████     | 50/100 [00:28<00:06,  8.19batch /s, learning_rate=0, loss=0.275]

Train done
Val done
Started epoch 1 and step 50
Train done
Val done
Started epoch 1 and step 51


INFO:absl:Callbacks in progress at step 51 . . . .
Epoch 1/2 --- Step 52/100 --- :  51%|█████     | 51/100 [00:28<00:05,  8.28batch /s, learning_rate=0, loss=0.297]INFO:absl:Callbacks in progress at step 52 . . . .
Epoch 1/2 --- Step 53/100 --- :  52%|█████▏    | 52/100 [00:29<00:05,  8.23batch /s, learning_rate=0, loss=0.321]

Train done
Val done
Started epoch 1 and step 52
Train done
Val done
Started epoch 1 and step 53


INFO:absl:Callbacks in progress at step 53 . . . .
Epoch 1/2 --- Step 54/100 --- :  53%|█████▎    | 53/100 [00:29<00:05,  8.21batch /s, learning_rate=0, loss=0.264]INFO:absl:Callbacks in progress at step 54 . . . .
Epoch 1/2 --- Step 55/100 --- :  54%|█████▍    | 54/100 [00:29<00:05,  8.29batch /s, learning_rate=0, loss=0.313]

Train done
Val done
Started epoch 1 and step 54
Train done
Val done
Started epoch 1 and step 55


INFO:absl:Callbacks in progress at step 55 . . . .
Epoch 1/2 --- Step 56/100 --- :  55%|█████▌    | 55/100 [00:29<00:05,  8.29batch /s, learning_rate=0, loss=0.254]INFO:absl:Callbacks in progress at step 56 . . . .
Epoch 1/2 --- Step 57/100 --- :  56%|█████▌    | 56/100 [00:29<00:05,  8.29batch /s, learning_rate=0, loss=0.283]

Train done
Val done
Started epoch 1 and step 56
Train done
Val done
Started epoch 1 and step 57


INFO:absl:Callbacks in progress at step 57 . . . .
Epoch 1/2 --- Step 58/100 --- :  57%|█████▋    | 57/100 [00:29<00:05,  8.26batch /s, learning_rate=0, loss=0.322]INFO:absl:Callbacks in progress at step 58 . . . .
Epoch 1/2 --- Step 59/100 --- :  58%|█████▊    | 58/100 [00:29<00:05,  8.28batch /s, learning_rate=0, loss=0.257]

Train done
Val done
Started epoch 1 and step 58
Train done
Val done
Started epoch 1 and step 59


INFO:absl:Callbacks in progress at step 59 . . . .
Epoch 1/2 --- Step 60/100 --- :  59%|█████▉    | 59/100 [00:29<00:04,  8.24batch /s, learning_rate=0, loss=0.289]INFO:absl:Callbacks in progress at step 60 . . . .
Epoch 1/2 --- Step 61/100 --- :  60%|██████    | 60/100 [00:29<00:04,  8.22batch /s, learning_rate=0, loss=0.312]

Train done
Val done
Started epoch 1 and step 60
Train done
Val done
Started epoch 1 and step 61


INFO:absl:Callbacks in progress at step 61 . . . .
Epoch 1/2 --- Step 62/100 --- :  61%|██████    | 61/100 [00:30<00:04,  8.21batch /s, learning_rate=0, loss=0.293]INFO:absl:Callbacks in progress at step 62 . . . .
Epoch 1/2 --- Step 63/100 --- :  62%|██████▏   | 62/100 [00:30<00:04,  8.19batch /s, learning_rate=0, loss=0.254]

Train done
Val done
Started epoch 1 and step 62
Train done
Val done
Started epoch 1 and step 63


INFO:absl:Callbacks in progress at step 63 . . . .
Epoch 1/2 --- Step 64/100 --- :  63%|██████▎   | 63/100 [00:30<00:04,  7.98batch /s, learning_rate=0, loss=0.261]INFO:absl:Callbacks in progress at step 64 . . . .
Epoch 1/2 --- Step 65/100 --- :  64%|██████▍   | 64/100 [00:30<00:04,  8.07batch /s, learning_rate=0, loss=0.261]

Train done
Val done
Started epoch 1 and step 64
Train done
Val done
Started epoch 1 and step 65


INFO:absl:Callbacks in progress at step 65 . . . .
Epoch 1/2 --- Step 66/100 --- :  65%|██████▌   | 65/100 [00:30<00:04,  8.12batch /s, learning_rate=0, loss=0.34] INFO:absl:Callbacks in progress at step 66 . . . .
Epoch 1/2 --- Step 67/100 --- :  66%|██████▌   | 66/100 [00:30<00:04,  8.13batch /s, learning_rate=0, loss=0.302]

Train done
Val done
Started epoch 1 and step 66
Train done
Val done
Started epoch 1 and step 67


INFO:absl:Callbacks in progress at step 67 . . . .
Epoch 1/2 --- Step 68/100 --- :  67%|██████▋   | 67/100 [00:30<00:03,  8.27batch /s, learning_rate=0, loss=0.271]INFO:absl:Callbacks in progress at step 68 . . . .
Epoch 1/2 --- Step 69/100 --- :  68%|██████▊   | 68/100 [00:30<00:03,  8.25batch /s, learning_rate=0, loss=0.284]

Train done
Val done
Started epoch 1 and step 68
Train done
Val done
Started epoch 1 and step 69


INFO:absl:Callbacks in progress at step 69 . . . .
Epoch 1/2 --- Step 70/100 --- :  69%|██████▉   | 69/100 [00:31<00:03,  8.28batch /s, learning_rate=0, loss=0.296]INFO:absl:Callbacks in progress at step 70 . . . .
Epoch 1/2 --- Step 71/100 --- :  70%|███████   | 70/100 [00:31<00:03,  8.36batch /s, learning_rate=0, loss=0.314]

Train done
Val done
Started epoch 1 and step 70
Train done
Val done
Started epoch 1 and step 71


INFO:absl:Callbacks in progress at step 71 . . . .
Epoch 1/2 --- Step 72/100 --- :  71%|███████   | 71/100 [00:31<00:03,  8.54batch /s, learning_rate=0, loss=0.318]INFO:absl:Callbacks in progress at step 72 . . . .
Epoch 1/2 --- Step 73/100 --- :  72%|███████▏  | 72/100 [00:31<00:03,  8.45batch /s, learning_rate=0, loss=0.328]

Train done
Val done
Started epoch 1 and step 72
Train done
Val done
Started epoch 1 and step 73


INFO:absl:Callbacks in progress at step 73 . . . .
Epoch 1/2 --- Step 74/100 --- :  73%|███████▎  | 73/100 [00:31<00:03,  8.35batch /s, learning_rate=0, loss=0.288]INFO:absl:Callbacks in progress at step 74 . . . .
Epoch 1/2 --- Step 75/100 --- :  74%|███████▍  | 74/100 [00:31<00:03,  8.42batch /s, learning_rate=0, loss=0.352]

Train done
Val done
Started epoch 1 and step 74
Train done
Val done
Started epoch 1 and step 75


INFO:absl:Callbacks in progress at step 75 . . . .
Epoch 1/2 --- Step 76/100 --- :  75%|███████▌  | 75/100 [00:31<00:03,  8.33batch /s, learning_rate=0, loss=0.237]INFO:absl:Callbacks in progress at step 76 . . . .
Epoch 1/2 --- Step 77/100 --- :  76%|███████▌  | 76/100 [00:31<00:02,  8.40batch /s, learning_rate=0, loss=0.323]

Train done
Val done
Started epoch 1 and step 76
Train done
Val done
Started epoch 1 and step 77


INFO:absl:Callbacks in progress at step 77 . . . .
Epoch 1/2 --- Step 78/100 --- :  77%|███████▋  | 77/100 [00:32<00:02,  8.47batch /s, learning_rate=0, loss=0.323]INFO:absl:Callbacks in progress at step 78 . . . .
Epoch 1/2 --- Step 79/100 --- :  78%|███████▊  | 78/100 [00:32<00:02,  8.32batch /s, learning_rate=0, loss=0.324]

Train done
Val done
Started epoch 1 and step 78
Train done
Val done
Started epoch 1 and step 79


INFO:absl:Callbacks in progress at step 79 . . . .
Epoch 1/2 --- Step 80/100 --- :  79%|███████▉  | 79/100 [00:32<00:02,  8.26batch /s, learning_rate=0, loss=0.262]INFO:absl:Callbacks in progress at step 80 . . . .
Epoch 1/2 --- Step 81/100 --- :  80%|████████  | 80/100 [00:32<00:02,  8.01batch /s, learning_rate=0, loss=0.235]

Train done
Val done
Started epoch 1 and step 80
Train done
Val done
Started epoch 1 and step 81


INFO:absl:Callbacks in progress at step 81 . . . .
Epoch 1/2 --- Step 82/100 --- :  81%|████████  | 81/100 [00:32<00:02,  8.05batch /s, learning_rate=0, loss=0.262]INFO:absl:Callbacks in progress at step 82 . . . .
Epoch 1/2 --- Step 83/100 --- :  82%|████████▏ | 82/100 [00:32<00:02,  8.12batch /s, learning_rate=0, loss=0.244]

Train done
Val done
Started epoch 1 and step 82
Train done
Val done
Started epoch 1 and step 83


INFO:absl:Callbacks in progress at step 83 . . . .
Epoch 1/2 --- Step 84/100 --- :  83%|████████▎ | 83/100 [00:32<00:02,  8.34batch /s, learning_rate=0, loss=0.316]INFO:absl:Callbacks in progress at step 84 . . . .
Epoch 1/2 --- Step 85/100 --- :  84%|████████▍ | 84/100 [00:32<00:01,  8.41batch /s, learning_rate=0, loss=0.29] 

Train done
Val done
Started epoch 1 and step 84
Train done
Val done
Started epoch 1 and step 85


INFO:absl:Callbacks in progress at step 85 . . . .
Epoch 1/2 --- Step 86/100 --- :  85%|████████▌ | 85/100 [00:33<00:01,  8.33batch /s, learning_rate=0, loss=0.255]INFO:absl:Callbacks in progress at step 86 . . . .
Epoch 1/2 --- Step 87/100 --- :  86%|████████▌ | 86/100 [00:33<00:01,  8.32batch /s, learning_rate=0, loss=0.255]

Train done
Val done
Started epoch 1 and step 86
Train done
Val done
Started epoch 1 and step 87


INFO:absl:Callbacks in progress at step 87 . . . .
Epoch 1/2 --- Step 88/100 --- :  87%|████████▋ | 87/100 [00:33<00:01,  8.27batch /s, learning_rate=0, loss=0.307]INFO:absl:Callbacks in progress at step 88 . . . .
Epoch 1/2 --- Step 89/100 --- :  88%|████████▊ | 88/100 [00:33<00:01,  8.24batch /s, learning_rate=0, loss=0.248]

Train done
Val done
Started epoch 1 and step 88
Train done
Val done
Started epoch 1 and step 89


INFO:absl:Callbacks in progress at step 89 . . . .
Epoch 1/2 --- Step 90/100 --- :  89%|████████▉ | 89/100 [00:33<00:01,  8.25batch /s, learning_rate=0, loss=0.25] INFO:absl:Callbacks in progress at step 90 . . . .
Epoch 1/2 --- Step 91/100 --- :  90%|█████████ | 90/100 [00:33<00:01,  8.24batch /s, learning_rate=0, loss=0.337]

Train done
Val done
Started epoch 1 and step 90
Train done
Val done
Started epoch 1 and step 91


INFO:absl:Callbacks in progress at step 91 . . . .
Epoch 1/2 --- Step 92/100 --- :  91%|█████████ | 91/100 [00:33<00:01,  8.22batch /s, learning_rate=0, loss=0.299]INFO:absl:Callbacks in progress at step 92 . . . .
Epoch 1/2 --- Step 93/100 --- :  92%|█████████▏| 92/100 [00:33<00:00,  8.31batch /s, learning_rate=0, loss=0.304]

Train done
Val done
Started epoch 1 and step 92
Train done
Val done
Started epoch 1 and step 93


INFO:absl:Callbacks in progress at step 93 . . . .
Epoch 1/2 --- Step 94/100 --- :  93%|█████████▎| 93/100 [00:33<00:00,  8.31batch /s, learning_rate=0, loss=0.27] INFO:absl:Callbacks in progress at step 94 . . . .
Epoch 1/2 --- Step 95/100 --- :  94%|█████████▍| 94/100 [00:34<00:00,  8.27batch /s, learning_rate=0, loss=0.237]

Train done
Val done
Started epoch 1 and step 94
Train done
Val done
Started epoch 1 and step 95


INFO:absl:Callbacks in progress at step 95 . . . .
Epoch 1/2 --- Step 96/100 --- :  95%|█████████▌| 95/100 [00:34<00:00,  8.37batch /s, learning_rate=0, loss=0.308]INFO:absl:Callbacks in progress at step 96 . . . .
Epoch 1/2 --- Step 97/100 --- :  96%|█████████▌| 96/100 [00:34<00:00,  8.30batch /s, learning_rate=0, loss=0.339]

Train done
Val done
Started epoch 1 and step 96
Train done
Val done
Started epoch 1 and step 97


INFO:absl:Callbacks in progress at step 97 . . . .
Epoch 1/2 --- Step 98/100 --- :  97%|█████████▋| 97/100 [00:34<00:00,  8.32batch /s, learning_rate=0, loss=0.308]INFO:absl:Callbacks in progress at step 98 . . . .
Epoch 1/2 --- Step 99/100 --- :  98%|█████████▊| 98/100 [00:34<00:00,  8.26batch /s, learning_rate=0, loss=0.252]

Train done
Val done
Started epoch 1 and step 98
Train done
Val done
Started epoch 1 and step 99


INFO:absl:Callbacks in progress at step 99 . . . .
Epoch 1/2 --- Step 100/100 --- :  99%|█████████▉| 99/100 [00:34<00:00,  8.22batch /s, learning_rate=0, loss=0.252]INFO:absl:Callbacks in progress at step 100 . . . .
Epoch 1/2 --- Step 100/100 --- : 100%|██████████| 100/100 [00:34<00:00,  2.87batch /s, learning_rate=0, loss=0.331]

Train done
Val done
Started epoch 1 and step 100
Train done
Val done



INFO:absl:Model saved at epoch 1
Epoch 2/2 --- Step 1/100 --- :   0%|          | 0/100 [00:00<?, ?batch /s]INFO:absl:Callbacks in progress at step 101 . . . .
Epoch 2/2 --- Step 2/100 --- :   1%|          | 1/100 [00:00<00:12,  8.15batch /s, learning_rate=0, loss=0.267]

Started epoch 2 and step 101
Train done
Val done
Started epoch 2 and step 102


INFO:absl:Callbacks in progress at step 102 . . . .
Epoch 2/2 --- Step 3/100 --- :   2%|▏         | 2/100 [00:00<00:12,  7.92batch /s, learning_rate=0, loss=0.277]INFO:absl:Callbacks in progress at step 103 . . . .
Epoch 2/2 --- Step 4/100 --- :   3%|▎         | 3/100 [00:00<00:12,  8.00batch /s, learning_rate=0, loss=0.275]

Train done
Val done
Started epoch 2 and step 103
Train done
Val done
Started epoch 2 and step 104


INFO:absl:Callbacks in progress at step 104 . . . .
Epoch 2/2 --- Step 5/100 --- :   4%|▍         | 4/100 [00:00<00:11,  8.26batch /s, learning_rate=0, loss=0.292]INFO:absl:Callbacks in progress at step 105 . . . .
Epoch 2/2 --- Step 6/100 --- :   5%|▌         | 5/100 [00:00<00:11,  8.26batch /s, learning_rate=0, loss=0.331]

Train done
Val done
Started epoch 2 and step 105
Train done
Val done
Started epoch 2 and step 106


INFO:absl:Callbacks in progress at step 106 . . . .
Epoch 2/2 --- Step 7/100 --- :   6%|▌         | 6/100 [00:00<00:11,  8.20batch /s, learning_rate=0, loss=0.266]INFO:absl:Callbacks in progress at step 107 . . . .
Epoch 2/2 --- Step 8/100 --- :   7%|▋         | 7/100 [00:00<00:11,  8.21batch /s, learning_rate=0, loss=0.303]

Train done
Val done
Started epoch 2 and step 107
Train done
Val done
Started epoch 2 and step 108


INFO:absl:Callbacks in progress at step 108 . . . .
Epoch 2/2 --- Step 9/100 --- :   8%|▊         | 8/100 [00:00<00:11,  8.20batch /s, learning_rate=0, loss=0.306]INFO:absl:Callbacks in progress at step 109 . . . .
Epoch 2/2 --- Step 10/100 --- :   9%|▉         | 9/100 [00:01<00:11,  8.22batch /s, learning_rate=0, loss=0.304]

Train done
Val done
Started epoch 2 and step 109
Train done
Val done
Started epoch 2 and step 110


INFO:absl:Callbacks in progress at step 110 . . . .
Epoch 2/2 --- Step 11/100 --- :  10%|█         | 10/100 [00:01<00:10,  8.20batch /s, learning_rate=0, loss=0.293]INFO:absl:Callbacks in progress at step 111 . . . .
Epoch 2/2 --- Step 12/100 --- :  11%|█         | 11/100 [00:01<00:10,  8.15batch /s, learning_rate=0, loss=0.232]

Train done
Val done
Started epoch 2 and step 111
Train done
Val done
Started epoch 2 and step 112


INFO:absl:Callbacks in progress at step 112 . . . .
Epoch 2/2 --- Step 13/100 --- :  12%|█▏        | 12/100 [00:01<00:10,  8.24batch /s, learning_rate=0, loss=0.32] INFO:absl:Callbacks in progress at step 113 . . . .
Epoch 2/2 --- Step 14/100 --- :  13%|█▎        | 13/100 [00:01<00:10,  8.34batch /s, learning_rate=0, loss=0.285]

Train done
Val done
Started epoch 2 and step 113
Train done
Val done
Started epoch 2 and step 114


INFO:absl:Callbacks in progress at step 114 . . . .
Epoch 2/2 --- Step 15/100 --- :  14%|█▍        | 14/100 [00:01<00:10,  8.52batch /s, learning_rate=0, loss=0.374]INFO:absl:Callbacks in progress at step 115 . . . .
Epoch 2/2 --- Step 16/100 --- :  15%|█▌        | 15/100 [00:01<00:10,  8.46batch /s, learning_rate=0, loss=0.253]

Train done
Val done
Started epoch 2 and step 115
Train done
Val done
Started epoch 2 and step 116


INFO:absl:Callbacks in progress at step 116 . . . .
Epoch 2/2 --- Step 17/100 --- :  16%|█▌        | 16/100 [00:01<00:09,  8.49batch /s, learning_rate=0, loss=0.268]INFO:absl:Callbacks in progress at step 117 . . . .
Epoch 2/2 --- Step 18/100 --- :  17%|█▋        | 17/100 [00:02<00:09,  8.37batch /s, learning_rate=0, loss=0.31] 

Train done
Val done
Started epoch 2 and step 117
Train done
Val done
Started epoch 2 and step 118


INFO:absl:Callbacks in progress at step 118 . . . .
Epoch 2/2 --- Step 19/100 --- :  18%|█▊        | 18/100 [00:02<00:09,  8.35batch /s, learning_rate=0, loss=0.244]INFO:absl:Callbacks in progress at step 119 . . . .
Epoch 2/2 --- Step 20/100 --- :  19%|█▉        | 19/100 [00:02<00:10,  8.09batch /s, learning_rate=0, loss=0.234]

Train done
Val done
Started epoch 2 and step 119
Train done
Val done
Started epoch 2 and step 120


INFO:absl:Callbacks in progress at step 120 . . . .
Epoch 2/2 --- Step 21/100 --- :  20%|██        | 20/100 [00:02<00:09,  8.14batch /s, learning_rate=0, loss=0.276]INFO:absl:Callbacks in progress at step 121 . . . .
Epoch 2/2 --- Step 22/100 --- :  21%|██        | 21/100 [00:02<00:09,  7.93batch /s, learning_rate=0, loss=0.267]

Train done
Val done
Started epoch 2 and step 121
Train done
Val done
Started epoch 2 and step 122


INFO:absl:Callbacks in progress at step 122 . . . .
Epoch 2/2 --- Step 23/100 --- :  22%|██▏       | 22/100 [00:02<00:09,  7.93batch /s, learning_rate=0, loss=0.26] INFO:absl:Callbacks in progress at step 123 . . . .
Epoch 2/2 --- Step 24/100 --- :  23%|██▎       | 23/100 [00:02<00:09,  7.99batch /s, learning_rate=0, loss=0.305]

Train done
Val done
Started epoch 2 and step 123
Train done
Val done
Started epoch 2 and step 124


INFO:absl:Callbacks in progress at step 124 . . . .
Epoch 2/2 --- Step 25/100 --- :  24%|██▍       | 24/100 [00:02<00:09,  7.80batch /s, learning_rate=0, loss=0.286]INFO:absl:Callbacks in progress at step 125 . . . .
Epoch 2/2 --- Step 26/100 --- :  25%|██▌       | 25/100 [00:03<00:09,  8.00batch /s, learning_rate=0, loss=0.259]

Train done
Val done
Started epoch 2 and step 125
Train done
Val done
Started epoch 2 and step 126


INFO:absl:Callbacks in progress at step 126 . . . .
Epoch 2/2 --- Step 27/100 --- :  26%|██▌       | 26/100 [00:03<00:08,  8.26batch /s, learning_rate=0, loss=0.36] INFO:absl:Callbacks in progress at step 127 . . . .
Epoch 2/2 --- Step 28/100 --- :  27%|██▋       | 27/100 [00:03<00:08,  8.24batch /s, learning_rate=0, loss=0.39]

Train done
Val done
Started epoch 2 and step 127
Train done
Val done
Started epoch 2 and step 128


INFO:absl:Callbacks in progress at step 128 . . . .
Epoch 2/2 --- Step 29/100 --- :  28%|██▊       | 28/100 [00:03<00:08,  8.19batch /s, learning_rate=0, loss=0.253]INFO:absl:Callbacks in progress at step 129 . . . .
Epoch 2/2 --- Step 30/100 --- :  29%|██▉       | 29/100 [00:03<00:08,  8.14batch /s, learning_rate=0, loss=0.317]

Train done
Val done
Started epoch 2 and step 129
Train done
Val done
Started epoch 2 and step 130


INFO:absl:Callbacks in progress at step 130 . . . .
Epoch 2/2 --- Step 31/100 --- :  30%|███       | 30/100 [00:03<00:08,  8.16batch /s, learning_rate=0, loss=0.302]INFO:absl:Callbacks in progress at step 131 . . . .
Epoch 2/2 --- Step 32/100 --- :  31%|███       | 31/100 [00:03<00:08,  8.18batch /s, learning_rate=0, loss=0.232]

Train done
Val done
Started epoch 2 and step 131
Train done
Val done
Started epoch 2 and step 132


INFO:absl:Callbacks in progress at step 132 . . . .
Epoch 2/2 --- Step 33/100 --- :  32%|███▏      | 32/100 [00:03<00:08,  8.21batch /s, learning_rate=0, loss=0.232]INFO:absl:Callbacks in progress at step 133 . . . .
Epoch 2/2 --- Step 34/100 --- :  33%|███▎      | 33/100 [00:04<00:08,  8.27batch /s, learning_rate=0, loss=0.279]

Train done
Val done
Started epoch 2 and step 133
Train done
Val done
Started epoch 2 and step 134


INFO:absl:Callbacks in progress at step 134 . . . .
Epoch 2/2 --- Step 35/100 --- :  34%|███▍      | 34/100 [00:04<00:07,  8.36batch /s, learning_rate=0, loss=0.312]INFO:absl:Callbacks in progress at step 135 . . . .
Epoch 2/2 --- Step 36/100 --- :  35%|███▌      | 35/100 [00:04<00:07,  8.34batch /s, learning_rate=0, loss=0.258]

Train done
Val done
Started epoch 2 and step 135
Train done
Val done
Started epoch 2 and step 136


INFO:absl:Callbacks in progress at step 136 . . . .
Epoch 2/2 --- Step 37/100 --- :  36%|███▌      | 36/100 [00:04<00:07,  8.28batch /s, learning_rate=0, loss=0.347]INFO:absl:Callbacks in progress at step 137 . . . .
Epoch 2/2 --- Step 38/100 --- :  37%|███▋      | 37/100 [00:04<00:07,  8.35batch /s, learning_rate=0, loss=0.267]

Train done
Val done
Started epoch 2 and step 137
Train done
Val done
Started epoch 2 and step 138


INFO:absl:Callbacks in progress at step 138 . . . .
Epoch 2/2 --- Step 39/100 --- :  38%|███▊      | 38/100 [00:04<00:07,  8.34batch /s, learning_rate=0, loss=0.269]INFO:absl:Callbacks in progress at step 139 . . . .
Epoch 2/2 --- Step 40/100 --- :  39%|███▉      | 39/100 [00:04<00:07,  8.27batch /s, learning_rate=0, loss=0.304]

Train done
Val done
Started epoch 2 and step 139
Train done
Val done
Started epoch 2 and step 140


INFO:absl:Callbacks in progress at step 140 . . . .
Epoch 2/2 --- Step 41/100 --- :  40%|████      | 40/100 [00:04<00:07,  8.22batch /s, learning_rate=0, loss=0.253]INFO:absl:Callbacks in progress at step 141 . . . .
Epoch 2/2 --- Step 42/100 --- :  41%|████      | 41/100 [00:04<00:07,  8.42batch /s, learning_rate=0, loss=0.354]

Train done
Val done
Started epoch 2 and step 141
Train done
Val done
Started epoch 2 and step 142


INFO:absl:Callbacks in progress at step 142 . . . .
Epoch 2/2 --- Step 43/100 --- :  42%|████▏     | 42/100 [00:05<00:06,  8.37batch /s, learning_rate=0, loss=0.317]INFO:absl:Callbacks in progress at step 143 . . . .
Epoch 2/2 --- Step 44/100 --- :  43%|████▎     | 43/100 [00:05<00:07,  8.08batch /s, learning_rate=0, loss=0.271]

Train done
Val done
Started epoch 2 and step 143
Train done
Val done
Started epoch 2 and step 144


INFO:absl:Callbacks in progress at step 144 . . . .
Epoch 2/2 --- Step 45/100 --- :  44%|████▍     | 44/100 [00:05<00:06,  8.23batch /s, learning_rate=0, loss=0.307]INFO:absl:Callbacks in progress at step 145 . . . .
Epoch 2/2 --- Step 46/100 --- :  45%|████▌     | 45/100 [00:05<00:06,  8.24batch /s, learning_rate=0, loss=0.277]

Train done
Val done
Started epoch 2 and step 145
Train done
Val done
Started epoch 2 and step 146


INFO:absl:Callbacks in progress at step 146 . . . .
Epoch 2/2 --- Step 47/100 --- :  46%|████▌     | 46/100 [00:05<00:06,  8.23batch /s, learning_rate=0, loss=0.339]INFO:absl:Callbacks in progress at step 147 . . . .
Epoch 2/2 --- Step 48/100 --- :  47%|████▋     | 47/100 [00:05<00:06,  8.26batch /s, learning_rate=0, loss=0.262]

Train done
Val done
Started epoch 2 and step 147
Train done
Val done
Started epoch 2 and step 148


INFO:absl:Callbacks in progress at step 148 . . . .
Epoch 2/2 --- Step 49/100 --- :  48%|████▊     | 48/100 [00:05<00:06,  8.22batch /s, learning_rate=0, loss=0.285]INFO:absl:Callbacks in progress at step 149 . . . .
Epoch 2/2 --- Step 50/100 --- :  49%|████▉     | 49/100 [00:05<00:06,  8.25batch /s, learning_rate=0, loss=0.232]

Train done
Val done
Started epoch 2 and step 149
Train done
Val done
Started epoch 2 and step 150


INFO:absl:Callbacks in progress at step 150 . . . .
Epoch 2/2 --- Step 51/100 --- :  50%|█████     | 50/100 [00:06<00:06,  8.24batch /s, learning_rate=0, loss=0.297]INFO:absl:Callbacks in progress at step 151 . . . .
Epoch 2/2 --- Step 52/100 --- :  51%|█████     | 51/100 [00:06<00:06,  8.02batch /s, learning_rate=0, loss=0.283]

Train done
Val done
Started epoch 2 and step 151
Train done
Val done
Started epoch 2 and step 152


INFO:absl:Callbacks in progress at step 152 . . . .
Epoch 2/2 --- Step 53/100 --- :  52%|█████▏    | 52/100 [00:06<00:05,  8.04batch /s, learning_rate=0, loss=0.284]INFO:absl:Callbacks in progress at step 153 . . . .
Epoch 2/2 --- Step 54/100 --- :  53%|█████▎    | 53/100 [00:06<00:05,  8.03batch /s, learning_rate=0, loss=0.279]

Train done
Val done
Started epoch 2 and step 153
Train done
Val done
Started epoch 2 and step 154


INFO:absl:Callbacks in progress at step 154 . . . .
Epoch 2/2 --- Step 55/100 --- :  54%|█████▍    | 54/100 [00:06<00:05,  8.10batch /s, learning_rate=0, loss=0.293]INFO:absl:Callbacks in progress at step 155 . . . .
Epoch 2/2 --- Step 56/100 --- :  55%|█████▌    | 55/100 [00:06<00:05,  8.14batch /s, learning_rate=0, loss=0.331]

Train done
Val done
Started epoch 2 and step 155
Train done
Val done
Started epoch 2 and step 156


INFO:absl:Callbacks in progress at step 156 . . . .
Epoch 2/2 --- Step 57/100 --- :  56%|█████▌    | 56/100 [00:06<00:05,  8.38batch /s, learning_rate=0, loss=0.369]INFO:absl:Callbacks in progress at step 157 . . . .
Epoch 2/2 --- Step 58/100 --- :  57%|█████▋    | 57/100 [00:06<00:05,  8.32batch /s, learning_rate=0, loss=0.29] 

Train done
Val done
Started epoch 2 and step 157
Train done
Val done
Started epoch 2 and step 158


INFO:absl:Callbacks in progress at step 158 . . . .
Epoch 2/2 --- Step 59/100 --- :  58%|█████▊    | 58/100 [00:07<00:05,  8.20batch /s, learning_rate=0, loss=0.322]INFO:absl:Callbacks in progress at step 159 . . . .
Epoch 2/2 --- Step 60/100 --- :  59%|█████▉    | 59/100 [00:07<00:04,  8.29batch /s, learning_rate=0, loss=0.291]

Train done
Val done
Started epoch 2 and step 159
Train done
Val done
Started epoch 2 and step 160


INFO:absl:Callbacks in progress at step 160 . . . .
Epoch 2/2 --- Step 61/100 --- :  60%|██████    | 60/100 [00:07<00:04,  8.36batch /s, learning_rate=0, loss=0.297]INFO:absl:Callbacks in progress at step 161 . . . .
Epoch 2/2 --- Step 62/100 --- :  61%|██████    | 61/100 [00:07<00:04,  8.34batch /s, learning_rate=0, loss=0.276]

Train done
Val done
Started epoch 2 and step 161
Train done
Val done
Started epoch 2 and step 162


INFO:absl:Callbacks in progress at step 162 . . . .
Epoch 2/2 --- Step 63/100 --- :  62%|██████▏   | 62/100 [00:07<00:04,  8.34batch /s, learning_rate=0, loss=0.258]INFO:absl:Callbacks in progress at step 163 . . . .
Epoch 2/2 --- Step 64/100 --- :  63%|██████▎   | 63/100 [00:07<00:04,  8.40batch /s, learning_rate=0, loss=0.338]

Train done
Val done
Started epoch 2 and step 163
Train done
Val done
Started epoch 2 and step 164


INFO:absl:Callbacks in progress at step 164 . . . .
Epoch 2/2 --- Step 65/100 --- :  64%|██████▍   | 64/100 [00:07<00:04,  8.37batch /s, learning_rate=0, loss=0.243]INFO:absl:Callbacks in progress at step 165 . . . .
Epoch 2/2 --- Step 66/100 --- :  65%|██████▌   | 65/100 [00:07<00:04,  8.30batch /s, learning_rate=0, loss=0.315]

Train done
Val done
Started epoch 2 and step 165
Train done
Val done
Started epoch 2 and step 166


INFO:absl:Callbacks in progress at step 166 . . . .
Epoch 2/2 --- Step 67/100 --- :  66%|██████▌   | 66/100 [00:08<00:04,  8.26batch /s, learning_rate=0, loss=0.344]INFO:absl:Callbacks in progress at step 167 . . . .
Epoch 2/2 --- Step 68/100 --- :  67%|██████▋   | 67/100 [00:08<00:03,  8.33batch /s, learning_rate=0, loss=0.321]

Train done
Val done
Started epoch 2 and step 167
Train done
Val done
Started epoch 2 and step 168


INFO:absl:Callbacks in progress at step 168 . . . .
Epoch 2/2 --- Step 69/100 --- :  68%|██████▊   | 68/100 [00:08<00:03,  8.03batch /s, learning_rate=0, loss=0.249]INFO:absl:Callbacks in progress at step 169 . . . .
Epoch 2/2 --- Step 70/100 --- :  69%|██████▉   | 69/100 [00:08<00:03,  8.10batch /s, learning_rate=0, loss=0.265]

Train done
Val done
Started epoch 2 and step 169
Train done
Val done
Started epoch 2 and step 170


INFO:absl:Callbacks in progress at step 170 . . . .
Epoch 2/2 --- Step 71/100 --- :  70%|███████   | 70/100 [00:08<00:03,  8.33batch /s, learning_rate=0, loss=0.33] INFO:absl:Callbacks in progress at step 171 . . . .
Epoch 2/2 --- Step 72/100 --- :  71%|███████   | 71/100 [00:08<00:03,  8.27batch /s, learning_rate=0, loss=0.328]

Train done
Val done
Started epoch 2 and step 171
Train done
Val done
Started epoch 2 and step 172


INFO:absl:Callbacks in progress at step 172 . . . .
Epoch 2/2 --- Step 73/100 --- :  72%|███████▏  | 72/100 [00:08<00:03,  8.35batch /s, learning_rate=0, loss=0.294]INFO:absl:Callbacks in progress at step 173 . . . .
Epoch 2/2 --- Step 74/100 --- :  73%|███████▎  | 73/100 [00:08<00:03,  8.32batch /s, learning_rate=0, loss=0.294]

Train done
Val done
Started epoch 2 and step 173
Train done
Val done
Started epoch 2 and step 174


INFO:absl:Callbacks in progress at step 174 . . . .
Epoch 2/2 --- Step 75/100 --- :  74%|███████▍  | 74/100 [00:08<00:03,  8.25batch /s, learning_rate=0, loss=0.257]INFO:absl:Callbacks in progress at step 175 . . . .
Epoch 2/2 --- Step 76/100 --- :  75%|███████▌  | 75/100 [00:09<00:03,  8.21batch /s, learning_rate=0, loss=0.284]

Train done
Val done
Started epoch 2 and step 175
Train done
Val done
Started epoch 2 and step 176


INFO:absl:Callbacks in progress at step 176 . . . .
Epoch 2/2 --- Step 77/100 --- :  76%|███████▌  | 76/100 [00:09<00:02,  8.19batch /s, learning_rate=0, loss=0.294]INFO:absl:Callbacks in progress at step 177 . . . .
Epoch 2/2 --- Step 78/100 --- :  77%|███████▋  | 77/100 [00:09<00:02,  8.23batch /s, learning_rate=0, loss=0.298]

Train done
Val done
Started epoch 2 and step 177
Train done
Val done
Started epoch 2 and step 178


INFO:absl:Callbacks in progress at step 178 . . . .
Epoch 2/2 --- Step 79/100 --- :  78%|███████▊  | 78/100 [00:09<00:02,  8.21batch /s, learning_rate=0, loss=0.264]INFO:absl:Callbacks in progress at step 179 . . . .
Epoch 2/2 --- Step 80/100 --- :  79%|███████▉  | 79/100 [00:09<00:02,  8.29batch /s, learning_rate=0, loss=0.305]

Train done
Val done
Started epoch 2 and step 179
Train done
Val done
Started epoch 2 and step 180


INFO:absl:Callbacks in progress at step 180 . . . .
Epoch 2/2 --- Step 81/100 --- :  80%|████████  | 80/100 [00:09<00:02,  8.31batch /s, learning_rate=0, loss=0.246]INFO:absl:Callbacks in progress at step 181 . . . .
Epoch 2/2 --- Step 82/100 --- :  81%|████████  | 81/100 [00:09<00:02,  8.32batch /s, learning_rate=0, loss=0.235]

Train done
Val done
Started epoch 2 and step 181
Train done
Val done
Started epoch 2 and step 182


INFO:absl:Callbacks in progress at step 182 . . . .
Epoch 2/2 --- Step 83/100 --- :  82%|████████▏ | 82/100 [00:09<00:02,  8.32batch /s, learning_rate=0, loss=0.275]INFO:absl:Callbacks in progress at step 183 . . . .
Epoch 2/2 --- Step 84/100 --- :  83%|████████▎ | 83/100 [00:10<00:02,  8.29batch /s, learning_rate=0, loss=0.3]  

Train done
Val done
Started epoch 2 and step 183
Train done
Val done
Started epoch 2 and step 184


INFO:absl:Callbacks in progress at step 184 . . . .
Epoch 2/2 --- Step 85/100 --- :  84%|████████▍ | 84/100 [00:10<00:01,  8.28batch /s, learning_rate=0, loss=0.318]INFO:absl:Callbacks in progress at step 185 . . . .
Epoch 2/2 --- Step 86/100 --- :  85%|████████▌ | 85/100 [00:10<00:01,  8.23batch /s, learning_rate=0, loss=0.289]

Train done
Val done
Started epoch 2 and step 185
Train done
Val done
Started epoch 2 and step 186


INFO:absl:Callbacks in progress at step 186 . . . .
Epoch 2/2 --- Step 87/100 --- :  86%|████████▌ | 86/100 [00:10<00:01,  7.99batch /s, learning_rate=0, loss=0.282]INFO:absl:Callbacks in progress at step 187 . . . .
Epoch 2/2 --- Step 88/100 --- :  87%|████████▋ | 87/100 [00:10<00:01,  8.04batch /s, learning_rate=0, loss=0.316]

Train done
Val done
Started epoch 2 and step 187
Train done
Val done
Started epoch 2 and step 188


INFO:absl:Callbacks in progress at step 188 . . . .
Epoch 2/2 --- Step 89/100 --- :  88%|████████▊ | 88/100 [00:10<00:01,  8.08batch /s, learning_rate=0, loss=0.35] INFO:absl:Callbacks in progress at step 189 . . . .
Epoch 2/2 --- Step 90/100 --- :  89%|████████▉ | 89/100 [00:10<00:01,  8.23batch /s, learning_rate=0, loss=0.304]

Train done
Val done
Started epoch 2 and step 189
Train done
Val done
Started epoch 2 and step 190


INFO:absl:Callbacks in progress at step 190 . . . .
Epoch 2/2 --- Step 91/100 --- :  90%|█████████ | 90/100 [00:10<00:01,  8.25batch /s, learning_rate=0, loss=0.286]INFO:absl:Callbacks in progress at step 191 . . . .
Epoch 2/2 --- Step 92/100 --- :  91%|█████████ | 91/100 [00:11<00:01,  8.26batch /s, learning_rate=0, loss=0.289]

Train done
Val done
Started epoch 2 and step 191
Train done
Val done
Started epoch 2 and step 192


INFO:absl:Callbacks in progress at step 192 . . . .
Epoch 2/2 --- Step 93/100 --- :  92%|█████████▏| 92/100 [00:11<00:00,  8.36batch /s, learning_rate=0, loss=0.283]INFO:absl:Callbacks in progress at step 193 . . . .
Epoch 2/2 --- Step 94/100 --- :  93%|█████████▎| 93/100 [00:11<00:00,  8.53batch /s, learning_rate=0, loss=0.349]

Train done
Val done
Started epoch 2 and step 193
Train done
Val done
Started epoch 2 and step 194


INFO:absl:Callbacks in progress at step 194 . . . .
Epoch 2/2 --- Step 95/100 --- :  94%|█████████▍| 94/100 [00:11<00:00,  8.56batch /s, learning_rate=0, loss=0.314]INFO:absl:Callbacks in progress at step 195 . . . .
Epoch 2/2 --- Step 96/100 --- :  95%|█████████▌| 95/100 [00:11<00:00,  8.42batch /s, learning_rate=0, loss=0.295]

Train done
Val done
Started epoch 2 and step 195
Train done
Val done
Started epoch 2 and step 196


INFO:absl:Callbacks in progress at step 196 . . . .
Epoch 2/2 --- Step 97/100 --- :  96%|█████████▌| 96/100 [00:11<00:00,  8.34batch /s, learning_rate=0, loss=0.303]INFO:absl:Callbacks in progress at step 197 . . . .
Epoch 2/2 --- Step 98/100 --- :  97%|█████████▋| 97/100 [00:11<00:00,  8.29batch /s, learning_rate=0, loss=0.255]

Train done
Val done
Started epoch 2 and step 197
Train done
Val done
Started epoch 2 and step 198


INFO:absl:Callbacks in progress at step 198 . . . .
Epoch 2/2 --- Step 99/100 --- :  98%|█████████▊| 98/100 [00:11<00:00,  8.01batch /s, learning_rate=0, loss=0.286]INFO:absl:Callbacks in progress at step 199 . . . .
Epoch 2/2 --- Step 100/100 --- :  99%|█████████▉| 99/100 [00:12<00:00,  8.01batch /s, learning_rate=0, loss=0.304]

Train done
Val done
Started epoch 2 and step 199
Train done
Val done
Started epoch 2 and step 200


INFO:absl:Callbacks in progress at step 200 . . . .
Epoch 2/2 --- Step 100/100 --- : 100%|██████████| 100/100 [00:12<00:00,  8.22batch /s, learning_rate=0, loss=0.278]
INFO:absl:Model saved at epoch 2


Train done
Val done


In [19]:
        # Under Strategy Scope
        with trainer.distribution_strategy.scope():
            # Model
            model = model_fn()

            # Optimizer
            optimizer = optimizer_fn()

            optimizer = configure_optimizer(optimizer, use_float16=False, loss_scale="dynamic")


You are using a model of type albert to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO:absl:Successful: Model checkpoints matched and loaded from /tmp/tf_transformers_cache/albert-base-v2/ckpt-1
INFO:absl:Using Adamw optimizer


In [20]:
    def compute_loss(batch_labels, model_outputs):
        """Loss computation which takes care of loss reduction based on GLOBAL_BATCH_SIZE"""
        per_example_loss = train_loss_fn(batch_labels, model_outputs)
        per_example_loss_averaged = {}
        # Inplace update
        # Avergae loss per global batch size , recommended
        for name, loss in per_example_loss.items():
            per_example_loss_averaged[name] = tf.nn.compute_average_loss(loss, global_batch_size=GLOBAL_BATCH_SIZE)
        return per_example_loss_averaged

In [36]:
    # Train Functions
    @tf.function
    def do_train(iterator):
        """The step function for one training step"""

        def train_step(dist_inputs):
            """The computation to run on each device."""
            batch_inputs, batch_labels = dist_inputs
            with tf.GradientTape() as tape:
                model_outputs = model(batch_inputs)
                loss = compute_loss(batch_labels, model_outputs)
                tf.debugging.check_numerics(loss['loss'], message='Loss value is either NaN or inf')
                if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
                    loss_scaled = {name: optimizer.get_scaled_loss(loss_value) for name, loss_value in loss.items()}
                # TODO
                # Scales down the loss for gradients to be invariant from replicas.
                # loss = loss / strategy.num_replicas_in_sync
            if mixed_precision:
                scaled_gradients = tape.gradient(loss_scaled["loss"], model.trainable_variables)
                grads = optimizer.get_unscaled_gradients(scaled_gradients)
            else:
                grads = tape.gradient(loss["loss"], model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            # training_loss.update_state(loss * strategy.num_replicas_in_sync)
            return loss

        for _ in tf.range(tf.convert_to_tensor(steps_per_call)):
            dist_inputs = next(iterator)
            loss = strategy.run(train_step, args=(dist_inputs,))
            # strategy reduce
            loss = {
                name: strategy.reduce(tf.distribute.ReduceOp.MEAN, loss_value, axis=None)
                for name, loss_value in loss.items()
            }
            
            t_loss.update_state(loss['loss'])
            # training_result = get_and_reset_metric_from_dict(training_loss_dict_metric)

In [22]:
train_dataset_distributed = trainer.distribution_strategy.experimental_distribute_dataset(train_dataset)
train_dataset_distributed = iter(train_dataset_distributed)

In [34]:
steps_per_call = 1
GLOBAL_BATCH_SIZE = 32
mixed_precision = False
strategy = trainer.distribution_strategy
t_loss = tf.keras.metrics.Mean("loss", dtype=tf.float32)

In [37]:
l = do_train(train_dataset_distributed)









INFO:tensorflow:batch_all_reduce: 26 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 26 all-reduces with algorithm = nccl, num_packs = 1






INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


In [38]:
t_loss.result()

<tf.Tensor: shape=(), dtype=float32, numpy=0.36904955>

INFO:absl:Make sure `steps_per_epoch` should be less than or equal to number of batches in dataset.
INFO:absl:XLA enabled
INFO:absl:Policy: ----> float32
INFO:absl:Strategy: ---> <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fbe2c4b1280>
INFO:absl:Num GPU Devices: ---> 2
You are using a model of type albert to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO:absl:Successful: Model checkpoints matched and loaded from /tmp/tf_transformers_cache/albert-base-v2/ckpt-1
INFO:absl:Using Adamw optimizer
INFO:absl:No checkpoint found in /tmp/model_ckpt/
Epoch 1/2 --- Step 50/100 --- :   0%|          | 0/2 [00:00<?, ?batch /s]

STEPS 2
Started epoch 1 and step 50








INFO:tensorflow:batch_all_reduce: 26 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 26 all-reduces with algorithm = nccl, num_packs = 1






INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).










INFO:tensorflow:batch_all_reduce: 26 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 26 all-reduces with algorithm = nccl, num_packs = 1






INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


In [None]:
shutil.rmtree(tfrecord_dir)
shutil.rmtree(model_checkpoint_dir)