<a href="https://colab.research.google.com/github/isikus/qualification-project/blob/master/notebooks/3.%20Model%20training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2020 The T5 Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [0]:
# Copyright 2019 The T5 Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Model training
In this notebook we retrain the most successful model from our research to the same checkpoint. This notebook is based on [this](github/google-research/text-to-text-transfer-transformer/blob/master/notebooks/t5-trivia.ipynb) example notebook from T5 authors.

**Please note the following:**
1. A connection to a Google Cloud Storage bucket is required to train the model.

## Set Up

<h3><a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>  &nbsp;&nbsp;Train on TPU</h3>




   1. Create a Cloud Storage bucket for your data and model checkpoints at http://console.cloud.google.com/storage, and fill in the `BASE_DIR` parameter in the following form. There is a [free tier](https://cloud.google.com/free/) if you do not yet have an account.
 
   1. On the main menu, click Runtime and select **Change runtime type**. Set "TPU" as the hardware accelerator.
   1. Run the following cell and follow instructions to:
    *  Set up a Colab TPU running environment
    *   Verify that you are connected to a TPU device
    *   Upload your credentials to TPU to access your GCS bucket


In [0]:
print("Installing dependencies...")
%tensorflow_version 2.x
!pip install -q t5

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

BASE_DIR = "gs://" #@param { type: "string" }
if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")
ON_CLOUD = True


if ON_CLOUD:
  print("Setting up GCS access...")
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

In [0]:
from google.colab import auth
auth.authenticate_user()
project_id = 'project-id'  # @param {"type": "string"}

!gcloud config set project {project_id}
!gsutil ls

### Try to add reproducibility

In [0]:
import random
import numpy as np

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  tf.compat.v1.set_random_seed(seed)

set_seed(42)

## Create new Tasks and Mixture

Two core components of the T5 library are `Task` and `Mixture` objects.

A `Task` is a dataset along with preprocessing functions and evaluation metrics. A `Mixture` is a collection of `Task` objects along with a mixing rate or a function defining how to compute a mixing rate based on the properties of the constituent `Tasks`.

We will describe one `Task` which is to correct our examples and then sample the `Mixture` from this only task.

Later on, we will repeat the process to continue model training.

In [0]:
import gzip
import json

tsv_path = {
    "train": os.path.join(DATA_DIR, "3b-fuse1.tsv"),
    "validation": os.path.join(DATA_DIR, "correct-val.tsv")
}

Here a function is defined to load the TSV data as a `tf.data.Dataset` in TensorFlow.

In [0]:
def corr_dataset_fn(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(tsv_path[split])
  # Split each "<orig_text>\t<corr_text>" example into (orig_text, corr_text) tuple.
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # Map each tuple to a {"orig_text": ... "corr_text": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["orig_text", "corr_text"], ex)))
  return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(corr_dataset_fn("validation").take(5)):
  print(ex)

Now, we write a preprocess function to convert the examples in the `tf.data.Dataset` into a text-to-text format, with both `inputs` and `targets` fields. The preprocessor also normalizes the text by lowercasing it and removing quotes since the corr_texts are sometimes formatted in odd ways. Finally, we prepend 'correct: ' to the inputs so that the model knows what task it is trying to solve.

In [0]:
def correction_preprocessor(ds):
  def normalize_text(text):
    """Remove quotes from a TensorFlow string."""
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"orig_text": ..., "corr_text": ...}->{"inputs": ..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["correction: ", normalize_text(ex["orig_text"])]),
        "targets": normalize_text(ex["corr_text"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

Finally, we put everything together to create a `Task`.

In [0]:
t5.data.TaskRegistry.add(
    "correct_3b",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=corr_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[correction_preprocessor],
    # Use the same vocabulary that we used for pre-training.
    sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy,
                t5.evaluation.metrics.bleu,
                t5.evaluation.metrics.rouge]
)

Let's look at a few pre-processed examples from the validation set.

In [0]:
corr_task = t5.data.TaskRegistry.get("correct_3b")
ds = corr_task.get_dataset(split="validation", sequence_length={"inputs": 512, "targets": 512})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

## Dataset Mixture

We now create a `Mixture` from the above `Task`, which we will fine-tune on.

In [0]:
t5.data.MixtureRegistry.remove("correct_3b_all")
t5.data.MixtureRegistry.add(
    "correct_3b_all",
    ["correct_3b"],
     default_rate=1.0
)

## Transferring to new Tasks

We are now ready to fine-tune one of the pre-trained T5 models on our new mixture.

First, we'll instantiate a `Model`.


## Caveats

* Due to its memory requirements, you will not be able to train the `11B` parameter model on the TPU provided by Colab. Instead, you will need to fine-tune inside of a GCP instance (see [README](https://github.com/google-research/text-to-text-transfer-transformer/)).
* Due to the checkpoint size, you will not be able use the 5GB GCS free tier for the `3B` parameter models. You will need at least 25GB of space, which you can purchase with your $300 of initial credit on GCP.
* While `large` can achieve decent results, it is recommended that you fine-tune at least the `3B` parameter model.


## Define Model

In [0]:
run = "3b-retrain"  # @param {"type": "string"}

In [0]:
MODEL_SIZE = "3B" #@param["small", "base", "large", "3B", "11B"]
# Public GCS path for T5 pre-trained model checkpoints
BASE_PRETRAINED_DIR = "gs://t5-data/pretrained_models"
PRETRAINED_DIR = os.path.join(BASE_PRETRAINED_DIR, MODEL_SIZE)
if run not in [None, ""]:
    MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE+"-"+run)
else:
    MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE)

if ON_CLOUD and MODEL_SIZE == "3B":
  tf.logging.warn(
      "The `3B` model is too large to use with the 5GB GCS free tier. "
      "Make sure you have at least 25GB on GCS before continuing."
  )
elif ON_CLOUD and MODEL_SIZE == "11B":
  raise ValueError(
      "The `11B` parameter is too large to fine-tune on the `v2-8` TPU "
      "provided by Colab. Please comment out this Error if you're running "
      "on a larger TPU."
  )

# Set parallelism and batch size to fit on v2-8 TPU (if possible).
# Limit number of checkpoints to fit within 5GB (if possible).
model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 128, 16),
    "base": (2, 64, 8),
    "large": (8, 32, 4),
    "3B": (8, 8, 1),
    "11B": (8, 8, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)
# The models from our paper are based on the Mesh Tensorflow Transformer.
model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length={"inputs": 512, "targets": 512},
    learning_rate_schedule=0.0025,
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

## Fine-tune

We are now ready to fine-tune our model. This will take a while, so please be patient!

Don't worry, you can always come back later and increase the number of steps, and it will automatically pick up where you left off.

In [0]:
FINETUNE_STEPS = 15000
print("Finetuning for", FINETUNE_STEPS, "steps")

model.finetune(
    mixture_or_task_name="correct_3b_all",
    pretrained_model_dir=PRETRAINED_DIR,
    finetune_steps=FINETUNE_STEPS
)

## Repeat the procedure

In order to fuse the same, but now spell-checked, data into the model, we redo the previous procedures with the new data source:

In [0]:
tsv_path = {
    "train": os.path.join(DATA_DIR, "3b-fuse2.tsv"),
    "validation": os.path.join(DATA_DIR, "correct-val.tsv")
}

In [0]:
t5.data.TaskRegistry.add(
    "correct_3b",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=corr_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[correction_preprocessor],
    # Use the same vocabulary that we used for pre-training.
    sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy,
                t5.evaluation.metrics.bleu,
                t5.evaluation.metrics.rouge]
)

In [0]:
corr_task = t5.data.TaskRegistry.get("correct_3b")
ds = corr_task.get_dataset(split="validation", sequence_length={"inputs": 512, "targets": 512})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

In [0]:
t5.data.MixtureRegistry.remove("correct_3b_all")
t5.data.MixtureRegistry.add(
    "correct_3b_all",
    ["correct_3b"],
     default_rate=1.0
)

## Fine-tune

And now we continue fine-tuning up to step 25200.

In [0]:
FINETUNE_STEPS = 25200
print("Finetuning for", FINETUNE_STEPS, "steps")

model.finetune(
    mixture_or_task_name="correct_3b_all",
    pretrained_model_dir=PRETRAINED_DIR,
    finetune_steps=FINETUNE_STEPS
)

## Evaluate

We now evaluate on the validation sets of the tasks in our mixture.

In [0]:
%%time

# Use a larger batch size for evaluation, which requires less memory.
model.batch_size = train_batch_size * 4
model.eval(
    mixture_or_task_name="correct_3b_all",
    checkpoint_steps=-1  # use latest checkpoint
)

Let's look at a few random predictions from the validation sets. Note that we measure accuracy based on an *exact match* of the predicted correction and the ground-truth correction. As a result, all the correction will be likely counted as false.

In [0]:
import random

def print_random_predictions(task_name, n=10):
  """Print n predictions from the validation split of a task."""
  # Grab the dataset for this task.
  ds = t5.data.TaskRegistry.get(task_name).get_dataset(
      split="validation",
      sequence_length={"inputs": 512, "targets": 512},
      shuffle=False)

  def _prediction_file_to_ckpt(path):
    """Extract the global step from a prediction filename."""
    return int(path.split("_")[-2])

  # Grab the paths of all logged predictions.
  prediction_files = tf.io.gfile.glob(
      os.path.join(
          MODEL_DIR,
          "validation_eval/%s_*_predictions" % task_name))
  # Get most recent prediction file by sorting by their step.
  latest_prediction_file = sorted(
      prediction_files, key=_prediction_file_to_ckpt)[-1]

  # Collect (inputs, targets, prediction) from the dataset and predictions file
  results = []
  with tf.io.gfile.GFile(latest_prediction_file) as preds:
    for ex, pred in zip(tfds.as_numpy(ds), preds):
      results.append((tf.compat.as_text(ex["inputs_plaintext"]),
                      tf.compat.as_text(ex["targets_plaintext"]),
                      pred.strip()))

  print("<== Random predictions for %s using checkpoint %s ==>\n" %
        (task_name, 
         _prediction_file_to_ckpt(latest_prediction_file)))

  for inp, tgt, pred in random.choices(results, k=10):
    print("Input:", inp)
    print("Target:", tgt)
    print("Prediction:", pred)
    print("Counted as Correct?", tgt == pred)
    print()

print_random_predictions("correct_3b")

## Predict

Now that we have fine-tuned the model, we can feed T5 arbitrary texts and have it correct them for us!

There is a significant amount of overhead in initializing the model so this may take a few minutes to run each time even though the prediction itself is quite fast.


In [0]:
orig_text_1 = "In 1999 theunemployment figures has had a decrease on 15 percents." #@param {type:"string"}
orig_text_2 = "I agree with the statement because part-time job gave me lots of experiences. For example, I understood following things through part-time jobs. It is hard to make money because I have to go to part-time job when I promised day and time, even if I don't want it. I have to keep time because if I delayed, I interrupted colleague's jobs and make bad relationships. I can't stay I'm center of the world because guests choice restaurant which they want to go, so I need to make them comfortable for they choose my shop. I learned from part-time jobs how important money and to make money, to make relationship. But I learned it is makes me tired to have job at the same time. I need time to have a rest for recovery. So it is interrupt my study to have too much jobs. I agree with having a part-time job but it is better to work during long vacation like summer or spring. Because student first priority is study and part-time job should not interrupt their study." #@param {type:"string"}
orig_text_3 = "Internet is a new Fenomenon in our world. Our generation is aboped to internet easy, but How previous generation react introduse in our live. \u003Cbr> On diagram is showed persent of online adult in USA who use networks. What can we see? Firstly, the people's activity over 65 year old is lower than activity of another groups. Next, we can see that then higher age than lower share of people who use comunicatin networks (such as Facebook or instagram). But it is not true for work networks (such as Linkedin). Share People with age between 40 and 64 use work networks higher then share of othor groups. About Facebook. It is most popular network in all groups. It may be explained by different funccioning of this network." #@param {type:"string"}
orig_text_4 = "I always had consider myself been very looky to be born in the  Century because  all the amaizing human inventions. \u003Cbr> As an example let's have a look and we will see of this nation  how it has developed. We just need to think about the first world wold and all the inventions created thanks to the industrial revolution to these day. \u003Cbr> However I feel that at the moment there should emerge more grups to control this development because  the industrial polution that it has been created and it seems that its very little what has been done at the moment in relation to our embiroment. \u003Cbr> I hope in the future there will be more and more people interested in this important issue." #@param {type:"string"}

orig_texts = [orig_text_1, orig_text_2, orig_text_3, orig_text_4]
#orig_texts = [correct(text) for text in orig_texts]

now = time.time()
# Write out the supplied orig_texts to text files.
predict_inputs_path = os.path.join(MODEL_DIR, "predict_inputs_%d.txt" % now)
predict_outputs_path = os.path.join(MODEL_DIR, "predict_outputs_%d.txt" % now)
# Manually apply preprocessing by prepending "triviaqa orig_text:".
with tf.io.gfile.GFile(predict_inputs_path, "w") as f:
  for q in orig_texts:
    f.write("correction: %s\n" % q.lower())

# Ignore any logging so that we only see the model's corr_texts to the orig_texts.
with tf_verbosity_level('ERROR'):
  model.batch_size = 8  # Min size for small model on v2-8 with parallelism 1.
  model.predict(
      input_file=predict_inputs_path,
      output_file=predict_outputs_path,
      # Select the most probable output token at each step.
      temperature=0,
  )

# The output filename will have the checkpoint appended so we glob to get 
# the latest.
prediction_files = sorted(tf.io.gfile.glob(predict_outputs_path + "*"))
print("\nPredictions using checkpoint %s:\n" % prediction_files[-1].split("-")[-1])
with tf.io.gfile.GFile(prediction_files[-1]) as f:
  for q, a in zip(orig_texts, f):
    if q:
      print("Q: " + q)
      print("A: " + a)
      print()

## Export SavedModel

Finally, here one can save and load a model.

In [0]:
%%time

export_dir = os.path.join(MODEL_DIR, "export")

model.batch_size = 1 # make one prediction per call
saved_model_path = model.export(
    export_dir,
    checkpoint_step=-1,  # use most recent
    beam_size=1,  # no beam search
    temperature=1.0,  # sample according to predicted distribution
)
print("Model saved to:", saved_model_path)

## Load SavedModel

In [0]:
import tensorflow as tf
import tensorflow_text  # Required to run exported model.

def load_predict_fn(model_path):
  if tf.executing_eagerly() and False:  # eager execution somehow not working
    print("Loading SavedModel in eager mode.")
    imported = tf.saved_model.load(model_path, ["serve"])
    return lambda x: imported.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
  else:
    print("Loading SavedModel in tf 1.x graph mode.")
    with tf.device('/cpu:0'):
      tf.compat.v1.reset_default_graph()
      sess = tf.compat.v1.Session()
      meta_graph_def = tf.compat.v1.saved_model.load(sess, ["serve"], model_path)
      signature_def = meta_graph_def.signature_def["serving_default"]
      return lambda x: sess.run(
          fetches=signature_def.outputs["outputs"].name, 
          feed_dict={signature_def.inputs["input"].name: x}
      )

predict_fn = load_predict_fn(saved_model_path)

## Predict

We can now call the predict method with different inputs each time and relatively quickly get results.

In [0]:
def infer(text):
  return predict_fn([text])[0].decode('utf-8')

In [0]:
%%time

for orig_text in ["correction: " + q for q in orig_texts]:
    print(infer(orig_text))