In [1]:
import sys
sys.path.append("/home/jovyan/Projects/tf-transformers/src/")

In [2]:
import datasets
import json
import os
import glob
import time
import tensorflow as tf

# tf.config.set_visible_devices([], 'GPU')

from tf_transformers.models import RobertaModel, EncoderDecoder
from transformers import RobertaTokenizer
from tf_transformers.data import TFWriter, TFReader, TFProcessor
from tf_transformers.losses import cross_entropy_loss
from tf_transformers.optimization import create_optimizer
from tf_transformers.core import Trainer
from absl import logging
logging.set_verbosity("INFO")

In [3]:
import os
 
os.environ['NO_PROXY'] = '169.254.169.254'
 
os.environ['HTTP_PROXY'] = '10.239.228.20:8000'
 
os.environ['HTTPS_PROXY'] = '10.239.228.20:8000'
 
!cat /etc/resolv.conf
 
!cat ~/.wgetrcb
 
!echo "use_proxy=yes\nhttp_proxy=http.proxy.fmr.com:8000\nhttps_proxy=http.proxy.fmr.com:8000" > ~/.wgetrc
 
 
#cat ~/.wgetrc
 
!echo $HTTP_PROXY
 
!echo $HTTPS_PROXY

nameserver 172.16.0.10
search fmr-a642163.svc.gpu-cluster.local svc.gpu-cluster.local gpu-cluster.local fmr.com
options ndots:5
cat: /home/jovyan/.wgetrcb: No such file or directory
10.239.228.20:8000
10.239.228.20:8000


In [4]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
examples = datasets.load_from_disk("/mnt/home/PRE_MODELS/HuggingFace_models/datasets/xsum/")
train_examples = examples["train"]


### Train Dataset

In [9]:
encoder_max_length=512
decoder_max_length=64

def parse_train():
    result = {}
    for f in train_examples:
        input_ids = [tokenizer.cls_token] + tokenizer.tokenize(f['document'])[: encoder_max_length-2] + [tokenizer.sep_token] # -2 to add CLS and SEP
        input_ids = tokenizer.convert_tokens_to_ids(input_ids)
        input_mask = [1] * len(input_ids)
        input_type_ids = [0] * len(input_ids)

        decoder_input_ids = [tokenizer.cls_token] + tokenizer.tokenize(f['summary'])[: decoder_max_length-2] + [tokenizer.sep_token]
        decoder_input_ids = tokenizer.convert_tokens_to_ids(decoder_input_ids)
        decoder_input_type_ids = [0] * len(decoder_input_ids)

        result = {}
        result['encoder_input_ids'] = input_ids
        result['encoder_input_mask'] = input_mask
        result['encoder_input_type_ids'] = input_type_ids
        result['decoder_input_ids'] = decoder_input_ids[:-1] # except last word
        result['decoder_input_type_ids'] = decoder_input_type_ids[:-1] # except last word
        
        result['labels'] = decoder_input_ids[1:] # not including first word
        result['labels_mask'] = [1] * len(decoder_input_ids[1:])
        
        # Decoder doesnt need input_mask because by default decoder has causal mask mode

        yield result
        
# Lets write using TF Writer
# Use TFProcessor for smalled data

schema = {
    "encoder_input_ids": ("var_len", "int"),
    "encoder_input_mask": ("var_len", "int"),
    "encoder_input_type_ids": ("var_len", "int"),
    "decoder_input_ids": ("var_len", "int"),
    "decoder_input_type_ids": ("var_len", "int"),
    "labels": ("var_len", "int"),
    "labels_mask": ("var_len", "int"),
}

tfrecord_train_dir = '/tmp/bbc_xsum/roberta/train'
tfwriter = TFWriter(schema=schema, 
                    model_dir=tfrecord_train_dir,
                    tag='train',
                    overwrite=True,
                    verbose_counter=10000
                    )
tfwriter.process(parse_fn=parse_train())

Tag train


INFO:absl:Wrote 10000 tfrecods
INFO:absl:Wrote 20000 tfrecods
INFO:absl:Wrote 30000 tfrecods
INFO:absl:Wrote 40000 tfrecods
INFO:absl:Wrote 50000 tfrecods
INFO:absl:Wrote 60000 tfrecods
INFO:absl:Wrote 70000 tfrecods
INFO:absl:Wrote 80000 tfrecods
INFO:absl:Wrote 90000 tfrecods
INFO:absl:Wrote 100000 tfrecods
INFO:absl:Wrote 110000 tfrecods
INFO:absl:Wrote 120000 tfrecods
INFO:absl:Wrote 130000 tfrecods
INFO:absl:Wrote 140000 tfrecods
INFO:absl:Wrote 150000 tfrecods
INFO:absl:Wrote 160000 tfrecods
INFO:absl:Wrote 170000 tfrecods
INFO:absl:Wrote 180000 tfrecods
INFO:absl:Wrote 190000 tfrecods
INFO:absl:Wrote 200000 tfrecods
INFO:absl:Total individual observations/examples written is 204045 in 665.2274258136749 seconds
INFO:absl:All writer objects closed


### Read Train Dataset

In [5]:
# Read Data

schema = json.load(open("{}/schema.json".format(tfrecord_train_dir)))
stats = json.load(open("{}/stats.json".format(tfrecord_train_dir)))
total_train_examples = stats['total_records']

all_files = glob.glob("{}/*.tfrecord".format(tfrecord_train_dir))
tf_reader = TFReader(schema=schema, 
                    tfrecord_files=all_files)

x_keys = ['encoder_input_ids', 'encoder_input_type_ids', 'encoder_input_mask', 'decoder_input_ids', 'decoder_input_type_ids']
y_keys = ['labels', 'labels_mask']

batch_size = 32
train_dataset = tf_reader.read_record(auto_batch=True, 
                                   keys=x_keys,
                                   batch_size=batch_size, 
                                   x_keys = x_keys, 
                                   y_keys = y_keys,
                                   shuffle=True, 
                                   drop_remainder=True
                                  )

### Validation Dataset

In [36]:
dev_examples = examples['validation']

encoder_max_length=512
decoder_max_length=64

def parse_dev():
    result = {}
    for f in dev_examples:
        input_ids = [tokenizer.cls_token] + tokenizer.tokenize(f['document'])[: encoder_max_length-2] + [tokenizer.sep_token] # -2 to add CLS and SEP
        input_ids = tokenizer.convert_tokens_to_ids(input_ids)
        input_mask = [1] * len(input_ids)
        input_type_ids = [0] * len(input_ids)

        decoder_input_ids = [tokenizer.cls_token] + tokenizer.tokenize(f['summary'])[: decoder_max_length-2] + [tokenizer.sep_token]
        decoder_input_ids = tokenizer.convert_tokens_to_ids(decoder_input_ids)
        decoder_input_type_ids = [0] * len(decoder_input_ids)

        result = {}
        result['encoder_input_ids'] = input_ids
        result['encoder_input_mask'] = input_mask
        result['encoder_input_type_ids'] = input_type_ids
        result['text'] = f['summary']
        
        yield result
        

schema = {
    "encoder_input_ids": ("var_len", "int"),
    "encoder_input_mask": ("var_len", "int"),
    "encoder_input_type_ids": ("var_len", "int"),
    "text": ("var_len", "bytes")
}

tfrecord_dev_dir = '/tmp/bbc_xsum/roberta/dev'
tfwriter = TFWriter(schema=schema, 
                    model_dir=tfrecord_dev_dir,
                    tag='eval',
                    overwrite=True,
                    verbose_counter=10000
                    )
tfwriter.process(parse_fn=parse_dev())

Tag eval


INFO:absl:Wrote 10000 tfrecods
INFO:absl:Total individual observations/examples written is 11332 in 40.449267625808716 seconds
INFO:absl:All writer objects closed


### Read Validation Dataset

In [6]:
# Read Data


encoder_max_length=512
decoder_max_length=64
tfrecord_dev_dir = '/tmp/bbc_xsum/roberta/dev'

schema = json.load(open("{}/schema.json".format(tfrecord_dev_dir)))
stats = json.load(open("{}/stats.json".format(tfrecord_dev_dir)))
total_train_examples = stats['total_records']

all_files = glob.glob("{}/*.tfrecord".format(tfrecord_dev_dir))
tf_reader = TFReader(schema=schema, 
                    tfrecord_files=all_files)

x_keys = ['encoder_input_ids', 'encoder_input_type_ids', 'encoder_input_mask']
y_keys = ['text']

eval_batch_size = 4
eval_dataset = tf_reader.read_record(auto_batch=True, 
                                   keys=x_keys,
                                   batch_size=eval_batch_size, 
                                   x_keys = x_keys, 
                                   y_keys = y_keys,
                                   shuffle=False, 
                                   drop_remainder=False
                                  )

In [7]:
def model_fn():
    encoder_layer = RobertaModel.from_pretrained("roberta-base", return_layer=True)
    decoder_layer = RobertaModel.from_pretrained("roberta-base",
                                                 mask_mode="causal",
                                                 use_decoder=True,
                                                 return_layer=True)

    decoder_layer._embedding_layer = encoder_layer._embedding_layer
    decoder_layer._type_embeddings_layer = encoder_layer._type_embeddings_layer
    model_layer = EncoderDecoder(encoder=encoder_layer, decoder=decoder_layer, decoder_start_token_id=tokenizer.cls_token_id)
    model = model_layer.get_model()
    return model


def inference_model_fn():
    encoder_layer = RobertaModel.from_pretrained("roberta-base", return_layer=True)
    decoder_layer = RobertaModel.from_pretrained("roberta-base",
                                                 mask_mode="causal",
                                                 use_decoder=True,
                                                 use_auto_regressive=True,
                                                 return_layer=True)

    decoder_layer._embedding_layer = encoder_layer._embedding_layer
    decoder_layer._type_embeddings_layer = encoder_layer._type_embeddings_layer
    model_layer = EncoderDecoder(encoder=encoder_layer, decoder=decoder_layer, decoder_start_token_id=tokenizer.cls_token_id)
    model = model_layer.get_model()
    return model

In [12]:
def lm_loss(y_true_dict, y_pred_dict):
    
    loss = cross_entropy_loss(labels=y_true_dict['labels'], 
                                   logits=y_pred_dict['token_logits'], 
                                      label_weights=y_true_dict['labels_mask'])
    return {'loss': loss}

def get_optimizer(learning_rate, examples, batch_size, epochs):
    """Get optimizer"""
    steps_per_epoch = int(examples / batch_size)
    num_train_steps = steps_per_epoch * epochs
    warmup_steps = int(0.1 * num_train_steps)

    def optimizer_fn():
        optimizer, learning_rate_fn = create_optimizer(learning_rate, num_train_steps, warmup_steps)
        return optimizer

    return optimizer_fn

learning_rate = 5e-05
epochs = 5

# Load optimizer
optimizer_fn = get_optimizer(
    learning_rate, total_train_examples, batch_size, epochs
)

# Trainer
distribution_strategy = "mirrored"
num_gpus = 2
trainer = Trainer(distribution_strategy, num_gpus=num_gpus)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


In [10]:
model_auto_regressive = inference_model_fn()
rouge_callback = TextGenerationMetricCallback(model=model_auto_regressive,
                                              tokenizer=tokenizer,
                                        decoder_kwargs = {"mode": "greedy" , "max_iterations": 64, "eos_id": -100},
                                        decoder_start_token_id=tokenizer.cls_token_id,
                                        input_type_ids=0
                                             )

You are using a model of type roberta to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /tmp/tf_transformers_cache/roberta-base/ckpt-1
You are using a model of type roberta to instantiate a model of type . This is not supported for all configurations of models and can yield errors.



Two checkpoint references resolved to different objects (<tf_transformers.models.roberta.roberta.RobertaEncoder object at 0x7fdd6816c7c0> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7fdf2c7a9f70>).



Two checkpoint references resolved to different objects (<tf_transformers.models.roberta.roberta.RobertaEncoder object at 0x7fdd6816c7c0> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7fdf2c7a9f70>).
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /tmp/tf_transformers_cache/roberta-base/ckpt-1
INFO:absl:Setting decoder_start_token_id = 0


In [None]:
# steps_per_epoch = total_train_examples//batch_size
# model_checkpoint_dir = 'roberta_seq2seq_temp'
# history = trainer.run(
#     model_fn=model_fn,
#     optimizer_fn=optimizer_fn,
#     train_dataset=train_dataset.take(3),
#     train_loss_fn=lm_loss,
#     epochs=epochs,
#     steps_per_epoch=steps_per_epoch,
#     model_checkpoint_dir=model_checkpoint_dir,
#     batch_size=batch_size,
#     validation_dataset=eval_dataset.take(5),
#     callbacks=[rouge_callback],
#     repeat_dataset=True,
# )



steps_per_epoch = 10
model_checkpoint_dir = 'roberta_seq2seq_temp'
history = trainer.run(
    model_fn=model_fn,
    optimizer_fn=optimizer_fn,
    train_dataset=train_dataset.take(3),
    train_loss_fn=lm_loss,
    epochs=3,
    steps_per_epoch=steps_per_epoch,
    model_checkpoint_dir=model_checkpoint_dir,
    batch_size=batch_size,
    validation_dataset=eval_dataset.take(5),
    callbacks=[rouge_callback],
    steps_per_call=1,
    repeat_dataset=True,
)

INFO:absl:Make sure `steps_per_epoch` should be less than or equal to number of batches in dataset.
INFO:absl:Policy: ----> float32
INFO:absl:Strategy: ---> <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fdff07dc250>
INFO:absl:Num GPU Devices: ---> 2
You are using a model of type roberta to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /tmp/tf_transformers_cache/roberta-base/ckpt-1
You are using a model of type roberta to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /tmp/tf_transformers_cache/roberta-base/ckpt-1
INFO:absl:Setting decoder_start_token_id = 0
INFO:absl:Using Adamw optimizer
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from roberta_seq2seq_temp/ckpt-1
INFO:absl:Succesfully 









INFO:tensorflow:batch_all_reduce: 510 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 510 all-reduces with algorithm = nccl, num_packs = 1






INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).










INFO:tensorflow:batch_all_reduce: 510 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 510 all-reduces with algorithm = nccl, num_packs = 1






INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/3 --- Step 10/10 --- total examples 288: 100%|██████████| 10/10 [01:44<00:00, 10.40s/batch , learning_rate=2.82e-6, loss=10.3]
INFO:absl:Model saved at epoch 1 at roberta_seq2seq_temp/ckpt-1
INFO:absl:Callbacks in progress at epoch end 1 . . . .






INFO:absl:Callback for ('rouge',) is in progress . . . . . . . . . .



Two checkpoint references resolved to different objects (<tf_transformers.models.encoder_decoder.encoder_decoder.EncoderDecoder object at 0x7fdf94153d00> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7fdff06a4820>).



Two checkpoint references resolved to different objects (<tf_transformers.models.encoder_decoder.encoder_decoder.EncoderDecoder object at 0x7fdf94153d00> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7fdff06a4820>).
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from roberta_seq2seq_temp/ckpt-1






























































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































INFO:tensorflow:Assets written to: /tmp/tmp_g33vq3n/assets


INFO:tensorflow:Assets written to: /tmp/tmp_g33vq3n/assets


In [7]:
# Inference Model
model_checkpoint_dir = 'roberta_seq2seq'
model = inference_model_fn()
# model.load_checkpoint(model_checkpoint_dir)

You are using a model of type roberta to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /tmp/tf_transformers_cache/roberta-base/ckpt-1
You are using a model of type roberta to instantiate a model of type . This is not supported for all configurations of models and can yield errors.



Two checkpoint references resolved to different objects (<tf_transformers.models.roberta.roberta.RobertaEncoder object at 0x7f287a818d90> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7f27f201eeb0>).



Two checkpoint references resolved to different objects (<tf_transformers.models.roberta.roberta.RobertaEncoder object at 0x7f287a818d90> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7f27f201eeb0>).
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /tmp/tf_transformers_cache/roberta-base/ckpt-1
INFO:absl:Setting decoder_start_token_id = 0


In [9]:
# coding=utf-8
# Copyright 2021 TF-Transformers Authors.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple metric callback for some Text Generation metrics in Tensorflow 2.0"""

import tempfile

import tensorflow as tf
import tqdm
from absl import logging
from rouge_score import rouge_scorer, scoring

from tf_transformers.text import TextDecoder

_ALL_METRIC_NAMES = {'rouge': True}


class TextGenerationMetricCallback:
    def __init__(
        self,
        model,
        tokenizer,
        decoder_kwargs={"mode": "greedy", "max_iterations": 64, "eos_id": -100},
        decoder_start_token_id=None,
        input_mask_ids=-1,
        input_type_ids=-1,
        metric_name_list=('rouge',),
        validation_dataset: tf.data.Dataset = None,
    ) -> None:
        """

        Args:
            validation_dataset (tf.data.Dataset, optional): Validation dataset
        """
        for metric_name in metric_name_list:
            if metric_name not in _ALL_METRIC_NAMES:
                raise ValueError(
                    "metric {} not found in supported metric list {}".format(metric_name, _ALL_METRIC_NAMES)
                )
        self.model = model
        self.tokenizer = tokenizer
        self.decoder_kwargs = decoder_kwargs
        self.metric_name_list = metric_name_list
        self.decoder_start_token_id = decoder_start_token_id
        self.input_mask_ids = input_mask_ids
        self.input_type_ids = input_type_ids
        self.validation_dataset = validation_dataset

    def __call__(self, trainer_kwargs):
        """This is getting called inside the trainer class"""
        logging.info("Callback for {} is in progress . . . . . . . . . .".format(self.metric_name_list))
        scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeLsum"], use_stemmer=True)
        aggregator = scoring.BootstrapAggregator()
        # This is non distribute
        validation_dataset = trainer_kwargs['validation_dataset']
        model_checkpoint_dir = trainer_kwargs['model_checkpoint_dir']
        # No validation dataset has been provided
        if validation_dataset is None:
            if self.validation_dataset is None:
                raise ValueError(
                    "No validation dataset has been provided either in the trainer class, \
                                 or when callback is initialized. Please provide a validation dataset"
                )
            else:
                validation_dataset = self.validation_dataset

        # Model from trainer
        self.model.load_checkpoint(model_checkpoint_dir)
        dirpath = tempfile.mkdtemp()

        #save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
        self.model.save_transformers_serialized(dirpath, overwrite=True)

        #load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
        loaded = tf.saved_model.load(dirpath)

        decoder = TextDecoder(
            model=loaded,
            decoder_start_token_id=self.decoder_start_token_id,
            input_type_ids=self.input_type_ids,
            input_mask_ids=self.input_mask_ids,
        )

        original_summaries = []
        predicted_summaries = []
        # Save model as saved_model and load it
        for dist_inputs in tqdm.tqdm(validation_dataset):
            batch_inputs, batch_labels = dist_inputs
            decoder_outputs = decoder.decode(batch_inputs, **self.decoder_kwargs)

            predicted_ids = decoder_outputs['predicted_ids']
            predicted_ids_sliced = []
            predicted_ids = predicted_ids[:, 0, :]
            # beam or top_k_top_p
            if decoder_outputs['matched_eos_pos'].ndim == 2:
                matched_eos_pos = decoder_outputs['matched_eos_pos'][0]
            else:
                matched_eos_pos = decoder_outputs['matched_eos_pos']
            for index, single_tensor in enumerate(predicted_ids):
                eos_index = matched_eos_pos[index]
                predicted_ids_sliced.append(single_tensor[:eos_index].numpy().tolist())

            predicted_summaries_text = self.tokenizer.batch_decode(predicted_ids_sliced, skip_special_tokens=True)
            predicted_summaries.extend(predicted_summaries_text)

            if batch_labels['text'].ndim == 2:
                original_labels = [text.numpy().decode() for text in tf.squeeze(batch_labels['text'], axis=1)]
                original_summaries.extend(original_labels)
            else:
                original_labels = [text.numpy().decode() for text in batch_labels['text']]
                original_summaries.extend(original_labels)

        assert len(original_summaries) == len(predicted_summaries)
        for i in range(len(original_summaries)):
            score = scorer.score(original_summaries[i], predicted_summaries[i])
            aggregator.add_scores(score)

        result = {}
        result['rouge2_f1score_mid'] = aggregator.aggregate()['rouge2'].mid.fmeasure
        result['rouge1_f1score_mid'] = aggregator.aggregate()['rouge1'].mid.fmeasure
        result['rougel_f1score_mid'] = aggregator.aggregate()['rougeLsum'].mid.fmeasure

        return result


In [25]:
eval_dataset = eval_dataset.take(10)

In [26]:
trainer_kwargs = {"validation_dataset": eval_dataset, "model_checkpoint_dir": model_checkpoint_dir}

In [27]:
result = rouge_callback(trainer_kwargs)

INFO:absl:Callback for ('rouge',) is in progress . . . . . . . . . .



Two checkpoint references resolved to different objects (<tf_transformers.models.encoder_decoder.encoder_decoder.EncoderDecoder object at 0x7f2608436a00> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7f2600756f10>).



Two checkpoint references resolved to different objects (<tf_transformers.models.encoder_decoder.encoder_decoder.EncoderDecoder object at 0x7f2608436a00> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7f2600756f10>).
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from roberta_seq2seq/ckpt-5






























































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































INFO:tensorflow:Assets written to: /tmp/tmp5kh5ejlw/assets


INFO:tensorflow:Assets written to: /tmp/tmp5kh5ejlw/assets
10it [00:16,  1.67s/it]


In [28]:
result

{'rouge2_f1score_mid': 0.1418022286934874,
 'rouge1_f1score_mid': 0.35100887191172436,
 'rougel_f1score_mid': 0.2742538036079153}