In [2]:
# Install tf-transformers from github

# Roberta2Roberta + Summarization + Xsum

This tutorial contains code to fine-tune an Roberta2Roberta Encoder Decoder Model for Summarization

In this notebook:
- Load the data + create ```tf.data.Dataset``` using TFWriter
- Load and warmstart Roberta base and use it to create a Summarization Model
- Train using ```tf.keras.Model.fit``` and ```Custom Trainer``` 
- Minimze LM loss
- Evaluate ROUGE score
- In production using faster ```tf.SavedModel``` + no architecture code

In [3]:
import datasets
import json
import os
import glob
import time

from tf_transformers.models import EncoderDecoderModel
from transformers import RobertaTokenizer
from tf_transformers.data import TFWriter, TFReader, TFProcessor
from tf_transformers.losses import cross_entropy_loss
from tf_transformers.core import optimization, SimpleTrainer
from absl import logging
logging.set_verbosity("INFO")

### Load Tokenizer

In [5]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')


### Load XSum sumarization data using Huggingface datasets

In [4]:
examples = datasets.load_from_disk("/mnt/home/PRE_MODELS/HuggingFace_models/datasets/xsum/")
train_examples = examples["train"]


### Parse train data

In [38]:
encoder_max_length=512
decoder_max_length=64

def parse_train():
    result = {}
    for f in train_examples:
        input_ids = [tokenizer.cls_token] + tokenizer.tokenize(f['document'])[: encoder_max_length-2] + [tokenizer.sep_token] # -2 to add CLS and SEP
        input_ids = tokenizer.convert_tokens_to_ids(input_ids)
        input_mask = [1] * len(input_ids)
        input_type_ids = [0] * len(input_ids)

        decoder_input_ids = [tokenizer.cls_token] + tokenizer.tokenize(f['summary'])[: decoder_max_length-2] + [tokenizer.sep_token]
        decoder_input_ids = tokenizer.convert_tokens_to_ids(decoder_input_ids)
        decoder_input_type_ids = [0] * len(decoder_input_ids)

        result = {}
        result['encoder_input_ids'] = input_ids
        result['encoder_input_mask'] = input_mask
        result['encoder_input_type_ids'] = input_type_ids
        result['decoder_input_ids'] = decoder_input_ids[:-1] # except last word
        result['decoder_input_type_ids'] = decoder_input_type_ids[:-1] # except last word
        
        result['labels'] = decoder_input_ids[1:] # not including first word
        result['labels_mask'] = [1] * len(decoder_input_ids[1:])
        
        # Decoder doesnt need input_mask because by default decoder has causal mask mode

        yield result
        
# Lets write using TF Writer
# Use TFProcessor for smalled data

schema = {
    "encoder_input_ids": ("var_len", "int"),
    "encoder_input_mask": ("var_len", "int"),
    "encoder_input_type_ids": ("var_len", "int"),
    "decoder_input_ids": ("var_len", "int"),
    "decoder_input_type_ids": ("var_len", "int"),
    "labels": ("var_len", "int"),
    "labels_mask": ("var_len", "int"),
}

tfrecord_train_dir = '../OFFICIAL_TFRECORDS/bbc_xsum/roberta/train'
tfrecord_filename = 'bbc_xsum'
tfwriter = TFWriter(schema=schema, 
                    file_name=tfrecord_filename, 
                    model_dir=tfrecord_train_dir,
                    tag='train',
                    overwrite=True
                    )
tfwriter.process(parse_fn=parse_train())

INFO:absl:Wrote 1000 tfrecods
INFO:absl:Wrote 2000 tfrecods
INFO:absl:Wrote 3000 tfrecods
INFO:absl:Wrote 4000 tfrecods
INFO:absl:Wrote 5000 tfrecods
INFO:absl:Wrote 6000 tfrecods
INFO:absl:Wrote 7000 tfrecods
INFO:absl:Wrote 8000 tfrecods
INFO:absl:Wrote 9000 tfrecods
INFO:absl:Wrote 10000 tfrecods
INFO:absl:Wrote 11000 tfrecods
INFO:absl:Wrote 12000 tfrecods
INFO:absl:Wrote 13000 tfrecods
INFO:absl:Wrote 14000 tfrecods
INFO:absl:Wrote 15000 tfrecods
INFO:absl:Wrote 16000 tfrecods
INFO:absl:Wrote 17000 tfrecods
INFO:absl:Wrote 18000 tfrecods
INFO:absl:Wrote 19000 tfrecods
INFO:absl:Wrote 20000 tfrecods
INFO:absl:Wrote 21000 tfrecods
INFO:absl:Wrote 22000 tfrecods
INFO:absl:Wrote 23000 tfrecods
INFO:absl:Wrote 24000 tfrecods
INFO:absl:Wrote 25000 tfrecods
INFO:absl:Wrote 26000 tfrecods
INFO:absl:Wrote 27000 tfrecods
INFO:absl:Wrote 28000 tfrecods
INFO:absl:Wrote 29000 tfrecods
INFO:absl:Wrote 30000 tfrecods
INFO:absl:Wrote 31000 tfrecods
INFO:absl:Wrote 32000 tfrecods
INFO:absl:Wrote 3

### Read TFRecords using TFReader

In [7]:
# Read Data

schema = json.load(open("{}/schema.json".format(tfrecord_train_dir)))
all_files = glob.glob("{}/*.tfrecord".format(tfrecord_train_dir))
tf_reader = TFReader(schema=schema, 
                    tfrecord_files=all_files)

x_keys = ['encoder_input_ids', 'encoder_input_type_ids', 'encoder_input_mask', 'decoder_input_ids', 'decoder_input_type_ids']
y_keys = ['labels', 'labels_mask']
batch_size = 8
train_dataset = tf_reader.read_record(auto_batch=True, 
                                   keys=x_keys,
                                   batch_size=batch_size, 
                                   x_keys = x_keys, 
                                   y_keys = y_keys,
                                   shuffle=True, 
                                   drop_remainder=True
                                  )

In [8]:
# Look at inputs, labels
for (batch_inputs, batch_labels) in train_dataset.take(1):
    print(batch_inputs, batch_labels)

{'decoder_input_ids': <tf.Tensor: shape=(8, 31), dtype=int32, numpy=
array([[    0,   133,  3940,    23, 19931,   230,  4183,  4869,   108,
         5494, 13700,    13,  5345,    34,   554,    63, 21727,     4,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0],
       [    0,   250,    92,   471,  3254,    16,   145,  2952,    13,
           10,   881, 25209,   334,    15,    10,  5411,  2946,  1602,
           25,   145,    15,     5,  3543,     9,     5,   232,     4,
            0,     0,     0,     0],
       [    0,  4030,   291, 17055,  8905,    33,    57,  2942,    11,
         1667,     9, 12426,    11,    10,  2311,     7,   847,  2078,
         4971,    11,  4860,   911,     4,     0,     0,     0,     0,
            0,     0,     0,     0],
       [    0, 16764,  3121, 13398,    34,  1147, 23750,    10,   780,
         1612,  1967,  1887,     7, 21662,    41,  1704, 16010,    18,
         2040,     9,  2429, 25711,    

### Load Roberta2Roberta (Encoder Decoder Model)

In [9]:
import tensorflow as tf

model_layer, model, config = EncoderDecoderModel(model_name='roberta-base', 
                                                 is_training=True, 
                                                 encoder_checkpoint_dir='/mnt/home/PRE_MODELS/LegacyAI_models/checkpoints/roberta-base/')

INFO:absl:Overwride mask_mode with user_defined
INFO:absl:Initialized Variables
INFO:absl:Overwride mask_mode with causal
INFO:absl:Initialized Variables
INFO:absl:Succesful: Model checkpoints matched
INFO:absl:Encoder loaded succesfully from /mnt/home/PRE_MODELS/LegacyAI_models/checkpoints/roberta-base/
INFO:absl:Warm started decoder 197/202 variables
INFO:absl:Inputs -->
INFO:absl:encoder_input_ids ---> Tensor("input_ids:0", shape=(None, None), dtype=int32)
INFO:absl:encoder_input_mask ---> Tensor("input_mask:0", shape=(None, None), dtype=int32)
INFO:absl:encoder_input_type_ids ---> Tensor("input_type_ids:0", shape=(None, None), dtype=int32)
INFO:absl:decoder_input_ids ---> Tensor("decoder_input_ids:0", shape=(None, None), dtype=int32)
INFO:absl:decoder_input_type_ids ---> Tensor("decoder_input_type_ids:0", shape=(None, None), dtype=int32)
INFO:absl:Initialized Variables
INFO:absl:Inputs -->
INFO:absl:encoder_input_ids ---> Tensor("input_ids:0", shape=(None, None), dtype=int32)
INFO:

### Define Loss

Loss function is simple.
* labels: 2D (batch_size x sequence_length)
* logits: 3D (batch_size x sequence_length x vocab_size)
* label_weights: 2D (batch_size x sequence_length) # we don't want all words in the sequence to have loss so, we mask them and don't calculate for loss

In [10]:
def lm_loss(y_true_dict, y_pred_dict):
    
    return cross_entropy_loss(labels=y_true_dict['labels'], 
                                   logits=y_pred_dict['token_logits'], 
                                      label_weights=y_true_dict['labels_mask'])


### Define Optimizer

**PRO TIP**: These models are very sensitive to optimizer, especially learning rates. So, make sure you play around to find a good combination

In [11]:
train_data_size = 204045
learning_rate   = 1e-05
steps_per_epoch = int(train_data_size / batch_size)
EPOCHS = 3
num_train_steps = steps_per_epoch * EPOCHS
warmup_steps = int(0.1 * num_train_steps)
# creates an optimizer with learning rate schedule
optimizer_type = 'adamw'
adam_beta2=0.997
adam_epsilon=1e-09
optimizer, learning_rate_fn = optimization.create_optimizer(learning_rate,
                                                steps_per_epoch * EPOCHS,
                                                warmup_steps,
                                                optimizer_type = optimizer_type, 
                                                decay_function = 'linear',
                                                adam_beta_2 = adam_beta2, 
                                                adam_epsilon = adam_epsilon)

INFO:absl:using Adamw optimizer


### Train Using Keras :-)

- ```compile2``` allows you to have directly use model outputs as well batch dataset outputs into the loss function, without any further complexity.

Note: For ```compile2```, loss_fn must be None, and custom_loss_fn must be active. Metrics are not supprted for time being.

In [12]:
# Keras Fit
# Change epochs and steps_per_epoch for full training
# If steps_per_epoch is not familiar, dont use it, provide only epochs

keras_loss_fn = {'token_logits': lm_loss
                }
model.compile2(optimizer=optimizer, 
                            loss=None, 
                            custom_loss=keras_loss_fn
              )
history = model.fit(train_dataset, epochs=2, steps_per_epoch=5)

Epoch 1/2
















Epoch 2/2


### Train using SimpleTrainer (part of tf-transformers)

In [None]:
# Custom training
# You can provide gradient_accumulation_steps if required
# I find it hurting the performance, don't know why
history = SimpleTrainer(model = model,
             optimizer = optimizer,
             loss_fn = lm_loss,
             dataset = train_dataset.repeat(EPOCHS+1), # This is important
             epochs = EPOCHS, 
             num_train_examples = train_data_size, 
             batch_size = batch_size, 
             steps_per_call=100, 
             gradient_accumulation_steps=None)
model.save_checkpoint("../OFFICIAL_MODELS/bbc_xsum/roberta2roberta")

### Save Models 

You can save models as checkpoints using ```.save_checkpoint``` attribute, which is a part of all ```LegacyModels```

In [13]:
model_save_dir = "../OFFICIAL_MODELS/bbc_xsum/roberta2roberta"
model.save_checkpoint(model_save_dir)

### Load the model for Text Genration (Auto-Regressive)

1. For any model to use for auto-regressive tasks we have to provide **"pipeline_mode='auto-regressive'"**

tf-transformers will handle everything for you internally


In [14]:
# Load the model by disabling dropout and add pipeline_mode = 'auto-regressive'

import tensorflow as tf

model_layer, model, config = EncoderDecoderModel(model_name='roberta-base', 
                                                 is_training=False, 
                                                 pipeline_mode='auto-regressive'
                                                )

model.load_checkpoint(model_save_dir)

INFO:absl:Overwride mask_mode with user_defined
INFO:absl:We are overwriding `is_training` is False to                         `is_training` to True with `use_dropout` is False, no effects on your inference pipeline
INFO:absl:Initialized Variables
INFO:absl:Overwride mask_mode with causal
INFO:absl:Initialized Variables
INFO:absl:Inputs -->
INFO:absl:encoder_input_ids ---> Tensor("input_ids:0", shape=(None, None), dtype=int32)
INFO:absl:encoder_input_mask ---> Tensor("input_mask:0", shape=(None, None), dtype=int32)
INFO:absl:encoder_input_type_ids ---> Tensor("input_type_ids:0", shape=(None, None), dtype=int32)
INFO:absl:decoder_input_ids ---> Tensor("decoder_input_ids:0", shape=(None, None), dtype=int32)
INFO:absl:decoder_input_type_ids ---> Tensor("decoder_input_type_ids:0", shape=(None, None), dtype=int32)
INFO:absl:decoder_all_cache_key ---> Tensor("all_cache_key:0", shape=(None, None, 12, None, 64), dtype=float32)
INFO:absl:decoder_all_cache_value ---> Tensor("all_cache_value:0", 


Two checkpoint references resolved to different objects (<tf_transformers.models.encoder_decoder.EncoderDecoder object at 0x7f88ac6b7220> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7f886029f4f0>).



Two checkpoint references resolved to different objects (<tf_transformers.models.encoder_decoder.EncoderDecoder object at 0x7f88ac6b7220> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7f886029f4f0>).
INFO:absl:Succesful: Model checkpoints matched


### Save the model as serialized version

This is very important, because serialized model is significantly faster.
tf-transfomers provide **save_as_serialize_module**

In [15]:
# tf-transformers provide "save_as_serialize_module" for this
model.save_as_serialize_module("{}/saved_model".format(model_save_dir))

loaded = tf.saved_model.load("{}/saved_model".format(model_save_dir))

### Parse validation data

We use ```TFProcessor``` to create validation data, because dev data is small

In [16]:
examples = datasets.load_from_disk("/mnt/home/PRE_MODELS/HuggingFace_models/datasets/xsum/")
dev_examples = examples['validation']

encoder_max_length=512
decoder_max_length=64

def parse_dev():
    result = {}
    for f in dev_examples:
        input_ids = [tokenizer.cls_token] + tokenizer.tokenize(f['document'])[: encoder_max_length-2] + [tokenizer.sep_token] # -2 to add CLS and SEP
        input_ids = tokenizer.convert_tokens_to_ids(input_ids)
        input_mask = [1] * len(input_ids)
        input_type_ids = [0] * len(input_ids)
        
        result = {}
        result['encoder_input_ids'] = input_ids
        result['encoder_input_mask'] = input_mask
        result['encoder_input_type_ids'] = input_type_ids
        
        yield result
        
tf_processor = TFProcessor()
dev_dataset = tf_processor.process(parse_fn=parse_dev())
dev_dataset = tf_processor.auto_batch(dev_dataset, batch_size=32)

INFO:absl:Processed  10000 examples so far
INFO:absl:Total individual observations/examples written is 11332


In [17]:
for (batch_inputs) in dev_datasetset.take(1):
    print(batch_inputs)

{'encoder_input_ids': <tf.Tensor: shape=(32, 512), dtype=int32, numpy=
array([[   0,  133,  247, ...,    0,    0,    0],
       [   0, 3084,  758, ..., 1341,  239,    2],
       [   0, 1121,   10, ...,    0,    0,    0],
       ...,
       [   0,  133,  517, ..., 8874,    9,    2],
       [   0,  500,  636, ...,    0,    0,    0],
       [   0,  487, 1116, ...,    0,    0,    0]], dtype=int32)>, 'encoder_input_mask': <tf.Tensor: shape=(32, 512), dtype=int32, numpy=
array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0]], dtype=int32)>, 'encoder_input_type_ids': <tf.Tensor: shape=(32, 512), dtype=int32, numpy=
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)>}


### Text-Generation for dev dataset

1. For **EncoderDecoder** models like Roberta2Roberts, Bert2GPT, t5, BART use **TextDecoderSeq2Seq**
2. For **Encoder** only models like GPT2, BERT, Roberta use **TextDecoder**

In [None]:
from tf_transformers.text import TextDecoderSeq2Seq

# You can pass model = (saved_model or keras model)
# Saved model take 1200 seconds 
# Keras Model take 2300 seconds

# Thats why always choose saved_model for faster inference in production
decoder = TextDecoderSeq2Seq(model=loaded, 
                            decoder_start_token_id=tokenizer.cls_token_id, # Decoder always expect a start_token_id
                            decoder_input_type_ids=0 # If you have input_type_ids
                            )

# Greedy Decoding
start_time = time.time()
predicted_summaries = []
for batch_inputs in dev_dataset:
    model_outputs = decoder.decode(batch_inputs, 
                   mode='greedy', 
                   max_iterations=64, 
                   eos_id=tokenizer.sep_token_id)

    output_summaries = tokenizer.batch_decode(tf.squeeze(model_outputs['predicted_ids'], 1), skip_special_tokens=True)
    predicted_summaries.extend(output_summaries)
end_time = time.time()
print("Time taken is {}".format(end_time-start_time))

### Evaluate ROUGE score using Huggingface datasets metric

In [1]:
original_summaries = [item['summary'] for item in dev_examples]
rouge = datasets.load_metric("rouge")
rouge_output2 = rouge.compute(predictions=predicted_summaries, references=original_summaries, rouge_types=["rouge2"])["rouge2"].mid
rouge_output1 = rouge.compute(predictions=predicted_summaries, references=original_summaries, rouge_types=["rouge1"])["rouge1"].mid
rouge_outputL = rouge.compute(predictions=predicted_summaries, references=original_summaries, rouge_types=["rougeL"])["rougeL"].mid
print("Rouge1", rouge_output1)
print("Rouge2", rouge_output2)
print("RougeL", rouge_outputL)

Rouge1 Score(precision=0.4030931388183621, recall=0.36466254213804195, fmeasure=0.37530511219642493)
Rouge2 Score(precision=0.16782261295821255, recall=0.15203057122838165, fmeasure=0.1561568579032115)
RougeL Score(precision=0.327351036176015, recall=0.2969254630660535, fmeasure=0.30522124104859427)

SyntaxError: invalid syntax (<ipython-input-1-355446ca4e30>, line 10)

### Evaluate ROUGE score using Google rouge_score library

In [None]:
from rouge_score import rouge_scorer
from rouge_score import scoring

scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeLsum"], use_stemmer=True)
aggregator = scoring.BootstrapAggregator()

for i in range(len(original_summaries)):
    score = scorer.score(original_summaries[i], predicted_summaries[i])
    aggregator.add_scores(score)
    
print("Aggregated scores", aggregator.aggregate())

{'rouge1': AggregateScore(low=Score(precision=0.4140161871825283, recall=0.37430234257946526, fmeasure=0.38546441367482415), mid=Score(precision=0.4170448579572875, recall=0.377157245705547, fmeasure=0.3882117449428194), high=Score(precision=0.4201437437963473, recall=0.37995542626327267, fmeasure=0.39091534567006364)),
 'rouge2': AggregateScore(low=Score(precision=0.1691630552674417, recall=0.1529611842223162, fmeasure=0.1573395483096894),mid=Score(precision=0.17186887531131723, recall=0.1556336324896489, fmeasure=0.15993720424744032), high=Score(precision=0.17473545118625014, recall=0.1582221241761583, fmeasure=0.16259483307298692)),
 'rougeLsum': AggregateScore(low=Score(precision=0.33240856683967057, recall=0.301583552325229, fmeasure=0.3101081853974168), mid=Score(precision=0.33521687141849427, recall=0.3040523188928863, fmeasure=0.3125425409744611), high=Score(precision=0.33844188019812493, recall=0.3066532822640133, fmeasure=0.31522376216006714))}

### In Production
1. Lets see how we can deploy this model in production

In [35]:
from tf_transformers.text import TextDecoderSeq2Seq
from tf_transformers.data import pad_dataset

# 1. Load Saved Model
loaded = tf.saved_model.load("{}/saved_model".format(model_save_dir))

# 2. Initiate a decode object
decoder = TextDecoderSeq2Seq(model=loaded, 
                            decoder_start_token_id=tokenizer.cls_token_id, # Decoder always expect a start_token_id
                            decoder_input_type_ids=0 # If you have input_type_ids
                            )

# 3. Convert text to inputs

# Tokenizer fn convert text -> model inputs
# Make sure you return dict with key-> list of list
# pad_dataset is a decorator, hich will automatically taken care of padding

# If you want to write your own function, please. model expect inputs in a specifed format thats all.
@pad_dataset
def tokenizer_fn(texts):
    input_ids      = []
    input_mask     = []
    input_type_ids = []
    for text in texts:
        input_ids_ex = [tokenizer.cls_token] + tokenizer.tokenize(text)[: encoder_max_length-2] + [tokenizer.sep_token] # -2 to add CLS and SEP
        input_ids_ex = tokenizer.convert_tokens_to_ids(input_ids_ex)
        input_mask_ex = [1] * len(input_ids_ex)
        input_type_ids_ex = [0] * len(input_ids_ex)
        input_ids.append(input_ids_ex)
        input_mask.append(input_mask_ex)
        input_type_ids.append(input_type_ids_ex)
        
    result = {}
    result['encoder_input_ids'] = input_ids
    result['encoder_input_mask'] = input_mask
    result['encoder_input_type_ids'] = input_type_ids
    
    return result
        
    
# 4. Examples
text1 = '''Tulips (Tulipa) form a genus of spring-blooming perennial herbaceous bulbiferous geophytes (having bulbs as storage organs). The flowers are usually large, showy and brightly colored, generally red, pink, yellow, or white (usually in warm colors). They often have a different colored blotch at the base of the tepals (petals and sepals, collectively), internally. Because of a degree of variability within the populations, and a long history of cultivation, classification has been complex and controversial. The tulip is a member of the lily family, Liliaceae, along with 14 other genera, where it is most closely related to Amana, Erythronium and Gagea in the tribe Lilieae. There are about 75 species, and these are divided among four subgenera. The name "tulip" is thought to be derived from a Persian word for turban, which it may have been thought to resemble. Tulips originally were found in a band stretching from Southern Europe to Central Asia, but since the seventeenth century have become widely naturalised and cultivated (see map). In their natural state they are adapted to steppes and mountainous areas with temperate climates. Flowering in the spring, they become dormant in the summer once the flowers and leaves die back, emerging above ground as a shoot from the underground bulb in early spring.

Originally growing wild in the valleys of the Tian Shan Mountains, tulips were cultivated in Constantinople as early as 1055. By the 15th century, tulips were among the most prized flowers; becoming the symbol of the Ottomans.[2] While tulips had probably been cultivated in Persia from the tenth century, they did not come to the attention of the West until the sixteenth century, when Western diplomats to the Ottoman court observed and reported on them. They were rapidly introduced into Europe and became a frenzied commodity during Tulip mania. Tulips were frequently depicted in Dutch Golden Age paintings, and have become associated with the Netherlands, the major producer for world markets, ever since. In the seventeenth century Netherlands, during the time of the Tulip mania, an infection of tulip bulbs by the tulip breaking virus created variegated patterns in the tulip flowers that were much admired and valued. While truly broken tulips do not exist anymore, the closest available specimens today are part of the group known as the Rembrandts – so named because Rembrandt painted some of the most admired breaks of his time.[3]'''


text2 = '''By any yardstick, the 2013 blockbuster Drishyam is a hard act to follow. Writer-director Jeethu Joseph’s crime thriller starring Mohanlal, Meena, Asha Sharath and Siddique was so well-rounded in the writing and execution of its murder-and-subsequent-cover-up mystery and such a box-office superhit that it was remade in Tamil, Telugu, Hindi and Kannada, headlined by some of the biggest male stars of those industries, in addition to foreign revisitations in Sinhalese and Mandarin.

At the time, Jeethu was questioned about his script drawing on Japanese novelist Keigo Higashino’s The Devotion of Suspect X, but he denied the charge and said he was inspired instead by a real-life incident. Be that as it may, Drishyam 2: The Resumption is all the redemption he needs. In a country that does not have a great track record with whodunnits, pulling off a brilliant howdunnit and howhegotawaywithit like Drishyam was an achievement. Returning with a howhe’sstillgettingawaywithit and actually pulling it off is nothing short of incredible.

Drishyam 2 is a surprisingly satisfying sequel to a spectacular first film.

Jeethu Joseph’s new crime drama is set in the same Kerala town where the events of its precursor took place. Georgekutty (Mohanlal) is now the owner of a cinema theatre. His prosperity is reflected in the larger, posher house he currently occupies with his wife Rani (Meena) and their daughters Anju (Ansiba) and Anu (Esther Anil) on the same land where they earlier lived. He is still movie crazy. Rani and he are still a committed couple yet constantly sniping at each other as before. And they are still a rock-solid team in the upbringing of their girls.

The difference between then and now is twofold. First, the townsfolk had backed the family when IG Geetha Prabhakar (Asha Sharath) got after them on the suspicion that they killed her son. They are not so supportive any more, driven as they are by jealousy at Georgekutty’s rise in life.

Second, the experiences of Drishyam have had a deep psychological impact on both Rani and Anju. Rani is tormented by Georgekutty’s refusal to ever discuss what happened back then. The first half of Drishyam 2 constructs their continuing trauma and gradually establishes the fact that the police never gave up on the case. The second half is about the resumed investigation.'''

# 5. Choose the type of decoding
batch_inputs = tokenizer_fn([text1, text2])
model_outputs = decoder.decode(batch_inputs, 
               mode='greedy', 
               max_iterations=64, 
               eos_id=tokenizer.sep_token_id)

output_summaries = tokenizer.batch_decode(tf.squeeze(model_outputs['predicted_ids'], 1), skip_special_tokens=True)

In [36]:
output_summaries

["The world's most famous flower, the tulip, is one of the most important and important species of flower.",
 'The sequel to the cult film Drishyam is a timely sequel to one of the most popular films in the country.']

### Advanced

**TextDecoderSerializable** internally uses for loop.

Can we do better. If we could use ```tf.while_loop```, we can save the whole model as serialized.
which no only improves speed, but also make life much easier in production

In [44]:
# Save the end-to-end decoder as seriazed model

from tf_transformers.text import TextDecoderSerializableSeq2Seq
from tf_transformers.core import LegacyModule

decoder_layer = TextDecoderSerializableSeq2Seq(model=model, 
                            decoder_start_token_id=tokenizer.cls_token_id, # Decoder always expect a start_token_id
                            decoder_input_type_ids=0, # If you have input_type_ids
                            mode="greedy",
                            max_iterations=64,
                            eos_id=tokenizer.sep_token_id
                            )
decoder_model = decoder_layer.get_model()
decoder_module = LegacyModule(decoder_model)
decoder_module.save("{}/saved_decoder_model".format(model_save_dir))

### In Production (Advanced) - Just 2 lines of code.

In [47]:
# 1.  Load serialized model
decoder_serialized = tf.saved_model.load("{}/saved_decoder_model".format(model_save_dir))

# 2. text to inputs
model_outputs2 = decoder_serialized(**batch_inputs)
output_summaries2 = tokenizer.batch_decode(tf.squeeze(model_outputs2['predicted_ids'], 1), skip_special_tokens=True)

# Output summaries matches with TextDecoderSerializableSeq2Seq
assert(output_summaries == output_summaries2)

# Succesful :-)