<a href="https://colab.research.google.com/github/ivantor0/colab-usefulness/blob/master/Train%20heptabot%20scaled%20to%203b.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


<a href="https://colab.research.google.com/github/google-research/text-to-text-transfer-transformer/blob/master/notebooks/t5-trivia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2020 The T5 Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Copyright 2019 The T5 Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Fine-Tuning the Text-To-Text Transfer Transformer (T5) for Closed-Book Question Answering
## _Or: What does T5 know?_

*The following tutorial guides you through the process of fine-tuning a pre-trained T5 model, evaluating its accuracy, and using it for prediction,
all on a free Google Cloud TPU <a href="https://colab.research.google.com/github/google-research/text-to-text-transfer-transformer/blob/master/notebooks/t5-trivia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>.*

### Background

T5 was introduced in the paper [_Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer_](https://arxiv.org/abs/1910.10683). In that paper, we provided a comprehensive picture of how we pre-trained a standard text-to-text Transformer model on a large text corpus, achieving state-of-the-art results on many NLP tasks after fine-tuning.

We pre-trained T5 on a mixture of supervised and unsupervised tasks with the majoriy of data coming from an unlabeled dataset we developed called [C4](https://www.tensorflow.org/datasets/catalog/c4). C4 is based on a massive scrape of the web produced by [Common Crawl](https://commoncrawl.org). Loosely speaking, pre-training on C4 ideally gives T5 an understanding of natural language in addition to general world knowledge.

### How can we assess what T5 knows?

As the name implies, T5 is a text-to-text model, which enables us to train it on arbitrary tasks involving a textual input and output. As we showed in our paper, a huge variety of NLP tasks can be cast in this format, including translation, summarization, and even classification and regression tasks.

One way to use this text-to-text framework is on reading comprehension problems, where the model is fed some context along with a orig_text and is trained to predict the orig_text's corr_text. For example, we might feed the model the text from the Wikipedia article about [Hurrican Connie](https://en.wikipedia.org/wiki/Hurricane_Connie) along with the orig_text "On what date did Hurricane Connie occur?" and train the model to predict the corr_text "August 3rd, 1955".
A related task is open-domain orig_text answering (QA) where the model is not provided with this oracle context. Typically, open-domain QA systems include a mechanism to look up information in an external knowledge source. This setting is similar to an "open-book" exam.

In this notebook, we'll be training T5 on a variant of this task which we call **closed-book orig_text answering**. In closed-book QA, we feed the model a orig_text *without any context or access to external knowledge* and train it to predict the corr_text. Since the model doesn't receive any context, the primary way it can learn to corr_text these orig_texts is based on the "knowledge" it obtained during pre-training. We don't expect T5 to contain super specific information, so we will be focusing on two orig_text-answering datasets which largely include trivia orig_texts (i.e. facts about well-known subjects). [Similar](https://arxiv.org/abs/1909.01066) [investigations](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) have recently been done to test the knowledge stored by BERT and GPT-2.

T5 was not pre-trained on closed-book QA, so in this notebook we'll first create two new tasks and then use the [`t5`](https://github.com/google-research/text-to-text-transfer-transformer) library to fine-tune, evaluate, and obtain predictions from T5. In the end, T5's performance on closed-book QA can give us a sense of what kind (and how much) information T5 managed to learn during pre-training.

## State-of-the-art Results
We published a [more in-depth investigation](https://arxiv.org/abs/2002.08910) of closed-book QA with T5 where we achieved SOTA on open-domain variants of WebQuestions and TriviaQA in addition to surpisingly strong results on Natural Questions. The code in this notebook is a simplified version of those experiments but still produces good results.

For code to reproduce our best results, please see the [t5_closed_book_qa](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa) repo.


### Caveats

* While we provide instructions for running on a [Cloud TPU](https://cloud.google.com/tpu/) via Colab for free, a [Google Cloud Storage (GCS)](http://console.cloud.google.com/storage) bucket is required for storing model parameters and data. The [GCS free tier](https://cloud.google.com/free/) provides 5 GB of storage, which should be enough to train the `large` model and smaller but not the `3B` or `11B` parameter models. You can use part of your initial $300 credit to get more space.
* The Cloud TPU provided by Colab (a `v2-8`) does not have enough memory to fine-tune the `11B` parameter model. For this model, you will need to fine-tune inside of a GCP instance (see [README](https://github.com/google-research/text-to-text-transfer-transformer/)).


# Set Up

<h3><a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>  &nbsp;&nbsp;Train on TPU</h3>




   1. Create a Cloud Storage bucket for your data and model checkpoints at http://console.cloud.google.com/storage, and fill in the `BASE_DIR` parameter in the following form. There is a [free tier](https://cloud.google.com/free/) if you do not yet have an account.
 
   1. On the main menu, click Runtime and select **Change runtime type**. Set "TPU" as the hardware accelerator.
   1. Run the following cell and follow instructions to:
    *  Set up a Colab TPU running environment
    *   Verify that you are connected to a TPU device
    *   Upload your credentials to TPU to access your GCS bucket


In [None]:
!pip install spacy==1.9.0
!python -m spacy download -d en_core_web_sm-1.2.0
!python -m spacy link en_core_web_sm en

Collecting spacy==1.9.0
[?25l  Downloading https://files.pythonhosted.org/packages/63/ce/afee53c365617e5f3e58825d71421bce14949a15f7150742d2a7b8859c53/spacy-1.9.0.tar.gz (3.4MB)
[K     |████████████████████████████████| 3.4MB 3.5MB/s 
Collecting murmurhash<0.27,>=0.26
  Downloading https://files.pythonhosted.org/packages/ff/53/1f428861e59c2382e22b8839d03cc315e1a7633a827497b3d389b8d8772d/murmurhash-0.26.4.tar.gz
Collecting cymem<1.32,>=1.30
  Downloading https://files.pythonhosted.org/packages/a5/0f/d29aa68c55db37844c77e7e96143bd96651fd0f4453c9f6ee043ac846b77/cymem-1.31.2-cp36-cp36m-manylinux1_x86_64.whl
Collecting preshed<2.0.0,>=1.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/12/88/57a818051f3d71e800bfb7ba4df56d3ea5793482ef11f1d2109b726f3bac/preshed-1.0.1-cp36-cp36m-manylinux1_x86_64.whl (80kB)
[K     |████████████████████████████████| 81kB 7.6MB/s 
[?25hCollecting thinc<6.6.0,>=6.5.0
[?25l  Downloading https://files.pythonhosted.org/packages/f7/9b/78fab962e0c8b5

In [None]:
!pip install mosestokenizer

Collecting mosestokenizer
  Downloading https://files.pythonhosted.org/packages/4b/b3/c0af235b16c4f44a2828ef017f7947d1262b2646e440f85c6a2ff26a8c6f/mosestokenizer-1.1.0.tar.gz
  "Distutils was imported before Setuptools. This usage is discouraged "
Collecting openfile (from mosestokenizer)
  Downloading https://files.pythonhosted.org/packages/93/e6/805db6867faacb488b44ba8e0829ef4de151dd0499f3c5da5f4ad11698a7/openfile-0.0.7-py3-none-any.whl
Collecting uctools (from mosestokenizer)
  Downloading https://files.pythonhosted.org/packages/04/cb/70ed842d9a43460eedaa11f7503b4ab6537b43b63f0d854d59d8e150fac1/uctools-1.3.0.tar.gz
Collecting toolwrapper (from mosestokenizer)
  Downloading https://files.pythonhosted.org/packages/41/7b/34bf8fb69426d8a18bcc61081e9d126f4fcd41c3c832072bef39af1602cd/toolwrapper-2.1.0.tar.gz
Building wheels for collected packages: mosestokenizer, uctools, toolwrapper
  Running setup.py bdist_wheel for mosestokenizer ... [?25ldone
[?25h  Stored in directory: /root/.cache

In [None]:
!wget https://www.comp.nus.edu.sg/~nlp/sw/m2scorer.tar.gz
!tar -xzf m2scorer.tar.gz

--2020-08-09 12:12:05--  https://www.comp.nus.edu.sg/~nlp/sw/m2scorer.tar.gz
Resolving www.comp.nus.edu.sg (www.comp.nus.edu.sg)... 45.60.31.225
Connecting to www.comp.nus.edu.sg (www.comp.nus.edu.sg)|45.60.31.225|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 22836 (22K) [application/x-gzip]
Saving to: ‘m2scorer.tar.gz’


2020-08-09 12:12:07 (83.0 KB/s) - ‘m2scorer.tar.gz’ saved [22836/22836]



In [None]:
!git clone https://github.com/keisks/jfleg

Cloning into 'jfleg'...
remote: Enumerating objects: 170, done.[K
remote: Total 170 (delta 0), reused 0 (delta 0), pack-reused 170[K
Receiving objects: 100% (170/170), 777.12 KiB | 6.42 MiB/s, done.
Resolving deltas: 100% (73/73), done.


In [None]:
!wget https://www.comp.nus.edu.sg/~nlp/conll14st/conll14st-test-data.tar.gz
!tar -xzf conll14st-test-data.tar.gz

--2020-08-09 12:12:13--  https://www.comp.nus.edu.sg/~nlp/conll14st/conll14st-test-data.tar.gz
Resolving www.comp.nus.edu.sg (www.comp.nus.edu.sg)... 45.60.31.225
Connecting to www.comp.nus.edu.sg (www.comp.nus.edu.sg)|45.60.31.225|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 643482 (628K) [application/x-gzip]
Saving to: ‘conll14st-test-data.tar.gz’


2020-08-09 12:12:16 (321 KB/s) - ‘conll14st-test-data.tar.gz’ saved [643482/643482]



In [None]:
!pip install rouge

Collecting rouge
  Downloading https://files.pythonhosted.org/packages/43/cc/e18e33be20971ff73a056ebdb023476b5a545e744e3fc22acd8c758f1e0d/rouge-1.0.0-py3-none-any.whl
Installing collected packages: rouge
Successfully installed rouge-1.0.0


In [None]:
print("Installing dependencies...")
%tensorflow_version 2.x
!pip install -q t5==0.6.0

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

BASE_DIR = "gs://ml-bucket-isikus/t5-base-model" #@param { type: "string" }
if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")
ON_CLOUD = True


if ON_CLOUD:
  print("Setting up GCS access...")
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

Installing dependencies...
[K    100% |████████████████████████████████| 153kB 4.1MB/s 
[K    100% |████████████████████████████████| 307kB 3.0MB/s 
[K    100% |████████████████████████████████| 2.6MB 464kB/s 
[K    100% |████████████████████████████████| 778kB 1.6MB/s 
[K    100% |████████████████████████████████| 3.4MB 371kB/s 
[K    100% |████████████████████████████████| 1.2MB 1.1MB/s 
[K    100% |████████████████████████████████| 51kB 12.0MB/s 
[K    100% |████████████████████████████████| 3.0MB 396kB/s 
[K    100% |████████████████████████████████| 890kB 2.0MB/s 
  "Distutils was imported before Setuptools. This usage is discouraged "
[?25h  Running setup.py bdist_wheel for sacremoses ... [?25ldone
[?25hSetting up GCS access...
Running on TPU: grpc://10.90.169.50:8470
Instructions for updating:
non-resource variables are not supported in the long term


Instructions for updating:
non-resource variables are not supported in the long term


In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [None]:
t5.data.utils.set_global_cache_dirs([BASE_DIR, os.getcwd()])

In [None]:
from google.colab import auth
auth.authenticate_user()
project_id = 'better-record'
!gcloud config set project {project_id}
!gsutil ls

Updated property [core/project].
gs://ml-bucket-isikus/


In [None]:
import nltk
nltk.download("punkt")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

### Reproducibility

In [None]:
import random
import numpy as np

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  tf.compat.v1.set_random_seed(seed)

set_seed(42)

# Creating new Tasks and Mixture

Two core components of the T5 library are `Task` and `Mixture` objects.

A `Task` is a dataset along with preprocessing functions and evaluation metrics. A `Mixture` is a collection of `Task` objects along with a mixing rate or a function defining how to compute a mixing rate based on the properties of the constituent `Tasks`.

For this example, we will fine-tune the model to do closed-book orig_text answering.

### correct

[Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) is a challenging corpus for open-domain QA. Each example includes a orig_text along with an entire Wikipedia article that may or may not contain its corr_text. The goal is to produce the correct corr_text given this context. In our case, we will be ignoring the provided context in hopes that the model will learn to find the corr_texts from the world knowledge it has acquired during pre-training.

Since the raw data splits are stored as JSONL files, we will first need to convert them to TSV format to make them parseable in TensorFlow. We will also take the opportunity to drop information we will not be using, remove orig_texts with multiple corr_texts, and to do a bit of cleaning of the text.

In [None]:
import gzip
import json

corr_tsv_path = {
    "train": os.path.join(DATA_DIR, "correct-train.tsv"),
    "validation": os.path.join(DATA_DIR, "correct-target.tsv")
}

Next, we define a function to load the TSV data as a `tf.data.Dataset` in TensorFlow.

In [None]:
def corr_dataset_fn(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(corr_tsv_path[split])
  # Split each "<orig_text>\t<corr_text>" example into (orig_text, corr_text) tuple.
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # Map each tuple to a {"orig_text": ... "corr_text": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["orig_text", "corr_text"], ex)))
  return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(corr_dataset_fn("validation").take(5)):
  print(ex)

A few raw validation examples...
{'orig_text': b'Armageddon is my favourite science fiction movie. The plot is about how to survive in the bad situation. It presented the cooperation from everybody such as new technology, joining between american and russian astronauts and the private people who had great experience in digging. I felt of the bravehearts who make a sacrifice. The soundtrack was pretty good. I felt sad when the character played by Bruce Willis called his daughter on the earth, prior to exploding a main meteor. Armageddon was directed by Michael Bay.', 'corr_text': b'Armageddon is my favourite science fiction movie. The plot is about how to survive in a bad situation. It presented the cooperation from everybody such as new technology, cooperation between American and Russian astronauts and the private people who had great experience in digging. I felt the bravehearts who make a sacrifice. The soundtrack was pretty good. I felt sad when the character played by Bruce Willis

In [None]:
bucket_name = 'ml-bucket-isikus'
!gsutil -m cp -r gs://{bucket_name}/t5-base-model/data/correct-train.tsv correct_train.tsv

with open("correct_train.tsv", "r", encoding="utf-8") as inf:
  correct_len = len([l for l in inf.read().split("\n") if l])

Copying gs://ml-bucket-isikus/t5-base-model/data/correct-train.tsv...
\ [1/1 files][ 53.0 MiB/ 53.0 MiB] 100% Done                                    
Operation completed over 1 objects/53.0 MiB.                                     


Now, we write a preprocess function to convert the examples in the `tf.data.Dataset` into a text-to-text format, with both `inputs` and `targets` fields. The preprocessor also normalizes the text by lowercasing it and removing quotes since the corr_texts are sometimes formatted in odd ways. Finally, we prepend 'trivia orig_text:' to the inputs so that the model knows what task it's trying to solve.

In [None]:
def correction_preprocessor(ds):
  def normalize_text(text):
    """Remove quotes from a TensorFlow string."""
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"orig_text": ..., "corr_text": ...}->{"inputs": ..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["correction: ", normalize_text(ex["orig_text"])]),
        "targets": normalize_text(ex["corr_text"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

Finally, we put everything together to create a `Task`.

In [None]:
t5.data.TaskRegistry.add(
    "correct",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=corr_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[correction_preprocessor],
    # Use the same vocabulary that we used for pre-training.
    sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
    # We'll use accuracy as our evaluation metric.
    # metric_fns=[t5.evaluation.metrics.accuracy]
    metric_fns=[t5.evaluation.metrics.accuracy]
)



Let's look at a few pre-processed examples from the validation set. Note they contain both the tokenized (integer) and plain-text inputs and targets.


In [None]:
corr_task = t5.data.TaskRegistry.get("correct")
ds = corr_task.get_dataset(split="validation", sequence_length={"inputs": 512, "targets": 512})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

  return dataset.map(my_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed validation examples...
{'inputs_plaintext': b'correction: The chart illustrates the amount of illiterate males and females in the world in 2020. Overall, there is a gap that divides countries with high level of literacy and illiterate countries. <br> To begin with, it is clearly seen that there are more illiterate females than males in every area. Though, the difference may be small, for example in Latin America or may be relatively large like in East Asia and the second type is more common. All regions can be divided in two groups. The first one is with high level of illiteracy and it includes South Asia, Arab States and Sub-Saharan Afrrica. Another consist of Developed countries, Latin America/Caribbean and East Asia/Oceania and shows low level of illiterate people (less than 20). <br> To sum up, the few facts should be emphasized. In 2020 there will be two different groups of areas, depending on level of illiteracy. Besides, the amount of illiterate females will b

### conll

[Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) is a challenging corpus for open-domain QA. Each example includes a orig_text along with an entire Wikipedia article that may or may not contain its corr_text. The goal is to produce the correct corr_text given this context. In our case, we will be ignoring the provided context in hopes that the model will learn to find the corr_texts from the world knowledge it has acquired during pre-training.

Since the raw data splits are stored as JSONL files, we will first need to convert them to TSV format to make them parseable in TensorFlow. We will also take the opportunity to drop information we will not be using, remove orig_texts with multiple corr_texts, and to do a bit of cleaning of the text.

In [None]:
import gzip
import json

conll_tsv_path = {
    "train": os.path.join(DATA_DIR, "conll-train.tsv"),
    "validation": os.path.join(DATA_DIR, "conll-eval.tsv")
}

Next, we define a function to load the TSV data as a `tf.data.Dataset` in TensorFlow.

In [None]:
def conll_dataset_fn(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(conll_tsv_path[split])
  # Split each "<orig_text>\t<corr_text>" example into (orig_text, corr_text) tuple.
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # Map each tuple to a {"orig_text": ... "corr_text": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["orig_text", "corr_text"], ex)))
  return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(conll_dataset_fn("validation").take(5)):
  print(ex)

A few raw validation examples...
{'orig_text': b'sentence: Keeping the Secret of Genetic Testing parsing: ROOT_VERB_VerbForm_Tense_Aspect det_DET dobj_NOUN_Number prep_ADP compound_PROPN_NounType_Number pobj_PROPN_NounType_Number', 'corr_text': b'sentence: Keeping the Secret of Genetic Testing parsing: ROOT_VERB_VerbForm_Tense_Aspect det_DET dobj_NOUN_Number prep_ADP compound_PROPN_NounType_Number pobj_PROPN_NounType_Number'}
{'orig_text': b'sentence: What is genetic risk ? parsing: attr_NOUN_PronType ROOT_VERB_VerbForm_Tense_Number_Person amod_ADJ_Degree nsubj_NOUN_Number ?', 'corr_text': b'sentence: What is genetic risk ? parsing: attr_NOUN_PronType ROOT_VERB_VerbForm_Tense_Number_Person amod_ADJ_Degree nsubj_NOUN_Number ?'}
{'orig_text': b'sentence: Genetic risk refers more to your chance of inheriting a disorder or disease . parsing: compound_ADJ_Degree nsubj_NOUN_Number ROOT_VERB_VerbForm_Tense_Number_Person dobj_ADV_Degree prep_ADP poss_ADJ_PronType_Poss pobj_NOUN_Number prep_ADP

In [None]:
bucket_name = 'ml-bucket-isikus'
!gsutil -m cp -r gs://{bucket_name}/t5-base-model/data/conll-train.tsv conll-train.tsv

with open("conll-train.tsv", "r", encoding="utf-8") as inf:
  conll_len = len([l for l in inf.read().split("\n") if l])

Copying gs://ml-bucket-isikus/t5-base-model/data/conll-train.tsv...
/ [1/1 files][ 12.0 MiB/ 12.0 MiB] 100% Done                                    
Operation completed over 1 objects/12.0 MiB.                                     


Now, we write a preprocess function to convert the examples in the `tf.data.Dataset` into a text-to-text format, with both `inputs` and `targets` fields. The preprocessor also normalizes the text by lowercasing it and removing quotes since the corr_texts are sometimes formatted in odd ways. Finally, we prepend 'trivia orig_text:' to the inputs so that the model knows what task it's trying to solve.

In [None]:
def conll_preprocessor(ds):
  def normalize_text(text):
    """Remove quotes from a TensorFlow string."""
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"orig_text": ..., "corr_text": ...}->{"inputs": ..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["conll: ", normalize_text(ex["orig_text"])]),
        "targets": normalize_text(ex["corr_text"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

Finally, we put everything together to create a `Task`.

In [None]:
t5.data.TaskRegistry.add(
    "conll",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=conll_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[conll_preprocessor],
    # Use the same vocabulary that we used for pre-training.
    sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy]
)



Let's look at a few pre-processed examples from the validation set. Note they contain both the tokenized (integer) and plain-text inputs and targets.


In [None]:
corr_task = t5.data.TaskRegistry.get("conll")
ds = corr_task.get_dataset(split="validation", sequence_length={"inputs": 512, "targets": 512})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

  return dataset.map(my_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed validation examples...
{'inputs_plaintext': b'conll: sentence: And so , he would have chosen not to undergo generic disorder testing and let the truth be mined forever . parsing: cc_CCONJ_ConjType advmod_ADV_Degree , nsubj_PRON_PronType aux_VERB_VerbType aux_VERB_VerbForm ROOT_VERB_VerbForm_Tense_Aspect neg_ADV_Degree aux_PART_PartType_VerbForm xcomp_VERB_VerbForm amod_ADJ_Degree compound_NOUN_Number dobj_NOUN_Number cc_CCONJ_ConjType conj_VERB_VerbForm det_DET nsubjpass_NOUN_Number auxpass_VERB_VerbForm ccomp_VERB_VerbForm_Tense_Aspect advmod_ADV_Degree .', 'inputs': array([  975,   195,    10,  7142,    10,   275,    78,     3,     6,
           3,    88,   133,    43,  3934,    59,    12, 17601,  8165,
        9311,  2505,    11,   752,     8,  2827,    36,  2000,    26,
        6276,     3,     5,   260,     7,    53,    10,     3,    75,
          75,   834,   254, 17752,   683,   834,  4302,   354, 25160,
           3,     9,    26,   208,  7360,   834,   188, 

### jfleg

[Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) is a challenging corpus for open-domain QA. Each example includes a orig_text along with an entire Wikipedia article that may or may not contain its corr_text. The goal is to produce the correct corr_text given this context. In our case, we will be ignoring the provided context in hopes that the model will learn to find the corr_texts from the world knowledge it has acquired during pre-training.

Since the raw data splits are stored as JSONL files, we will first need to convert them to TSV format to make them parseable in TensorFlow. We will also take the opportunity to drop information we will not be using, remove orig_texts with multiple corr_texts, and to do a bit of cleaning of the text.

In [None]:
import gzip
import json

jfleg_tsv_path = {
    "train": os.path.join(DATA_DIR, "jfleg-train.tsv"),
    "validation": os.path.join(DATA_DIR, "jfleg-eval.tsv")
}

Next, we define a function to load the TSV data as a `tf.data.Dataset` in TensorFlow.

In [None]:
def jfleg_dataset_fn(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(jfleg_tsv_path[split])
  # Split each "<orig_text>\t<corr_text>" example into (orig_text, corr_text) tuple.
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # Map each tuple to a {"orig_text": ... "corr_text": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["orig_text", "corr_text"], ex)))
  return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(jfleg_dataset_fn("validation").take(5)):
  print(ex)

A few raw validation examples...
{'orig_text': b'New and new technology has been introduced to the society .', 'corr_text': b'New and new technology has been introduced to the society .'}
{'orig_text': b'One possible outcome is that an environmentally-induced reduction in motorization levels in the richer countries will outweigh any rise in motorization levels in the poorer countries .', 'corr_text': b'One possible outcome is that an environmentally-induced reduction in motorization levels in the richer countries will outweigh any rise in motorization levels in the poorer countries .'}
{'orig_text': b'Every person needs to know a bit about math , sciences , arts , literature and history in order to stand out in society .', 'corr_text': b'Every person needs to know a bit about math , sciences , arts , literature and history in order to stand out in society .'}
{'orig_text': b'While the travel company will most likely show them some interesting sites in order for their customers to adver

In [None]:
bucket_name = 'ml-bucket-isikus'
!gsutil -m cp -r gs://{bucket_name}/t5-base-model/data/jfleg-train.tsv jfleg-train.tsv

with open("jfleg-train.tsv", "r", encoding="utf-8") as inf:
  jfleg_len = len([l for l in inf.read().split("\n") if l])

Copying gs://ml-bucket-isikus/t5-base-model/data/jfleg-train.tsv...
/ [1/1 files][570.1 KiB/570.1 KiB] 100% Done                                    
Operation completed over 1 objects/570.1 KiB.                                    


Now, we write a preprocess function to convert the examples in the `tf.data.Dataset` into a text-to-text format, with both `inputs` and `targets` fields. The preprocessor also normalizes the text by lowercasing it and removing quotes since the corr_texts are sometimes formatted in odd ways. Finally, we prepend 'trivia orig_text:' to the inputs so that the model knows what task it's trying to solve.

In [None]:
def jfleg_preprocessor(ds):
  def normalize_text(text):
    """Remove quotes from a TensorFlow string."""
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"orig_text": ..., "corr_text": ...}->{"inputs": ..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["jfleg: ", normalize_text(ex["orig_text"])]),
        "targets": normalize_text(ex["corr_text"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

Finally, we put everything together to create a `Task`.

In [None]:
t5.data.TaskRegistry.add(
    "jfleg",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=jfleg_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[jfleg_preprocessor],
    # Use the same vocabulary that we used for pre-training.
    sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy]
)



Let's look at a few pre-processed examples from the validation set. Note they contain both the tokenized (integer) and plain-text inputs and targets.


In [None]:
corr_task = t5.data.TaskRegistry.get("jfleg")
ds = corr_task.get_dataset(split="validation", sequence_length={"inputs": 512, "targets": 512})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

  return dataset.map(my_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed validation examples...
{'inputs_plaintext': b'jfleg: In her salary we cant buy some car because we are planing to finish our hause in Binangonan , Rizal and we will planing to finish my study in Boston .', 'inputs': array([   3,  354,   89, 5772,   10,   86,  160, 9090,   62,   54,   17,
        805,  128,  443,  250,   62,   33,  515,   53,   12, 1992,   69,
          3, 2989,   15,   16, 7617, 1468,  106,  152,    3,    6, 2403,
        172,  138,   11,   62,   56,  515,   53,   12, 1992,   82,  810,
         16, 5053,    3,    5,    1]), 'targets_plaintext': b'In her salary we cant buy some car because we are planing to finish our hause in Binangonan , Rizal and we will planing to finish my study in Boston .', 'targets': array([  86,  160, 9090,   62,   54,   17,  805,  128,  443,  250,   62,
         33,  515,   53,   12, 1992,   69,    3, 2989,   15,   16, 7617,
       1468,  106,  152,    3,    6, 2403,  172,  138,   11,   62,   56,
        515,   53,   12, 199

### bea

[Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) is a challenging corpus for open-domain QA. Each example includes a orig_text along with an entire Wikipedia article that may or may not contain its corr_text. The goal is to produce the correct corr_text given this context. In our case, we will be ignoring the provided context in hopes that the model will learn to find the corr_texts from the world knowledge it has acquired during pre-training.

Since the raw data splits are stored as JSONL files, we will first need to convert them to TSV format to make them parseable in TensorFlow. We will also take the opportunity to drop information we will not be using, remove orig_texts with multiple corr_texts, and to do a bit of cleaning of the text.

In [None]:
bea_tsv_path = {
    "train": os.path.join(DATA_DIR, "bea-train-strict.tsv"),
    "validation": os.path.join(DATA_DIR, "bea-eval.tsv")
}

Next, we define a function to load the TSV data as a `tf.data.Dataset` in TensorFlow.

In [None]:
def bea_dataset_fn(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(bea_tsv_path[split])
  # Split each "<orig_text>\t<corr_text>" example into (orig_text, corr_text) tuple.
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # Map each tuple to a {"orig_text": ... "corr_text": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["orig_text", "corr_text"], ex)))
  return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(bea_dataset_fn("validation").take(5)):
  print(ex)

A few raw validation examples...
{'orig_text': b'sentence: Dear Sir , parsing: amod_ADJ_Degree ROOT_PROPN_NounType_Number ,', 'corr_text': b'sentence: Dear Sir , parsing: amod_ADJ_Degree ROOT_PROPN_NounType_Number ,'}
{'orig_text': b'sentence: I have seen your advertisement for a job on the internet and I am writing to apply for a summer job as an instructor and keeper of children in your camp . parsing: nsubj_PRON_PronType aux_VERB_VerbForm_Tense ROOT_VERB_VerbForm_Tense_Aspect poss_ADJ_PronType_Poss dobj_NOUN_Number prep_ADP det_DET pobj_NOUN_Number prep_ADP det_DET pobj_NOUN_Number cc_CCONJ_ConjType nsubj_PRON_PronType aux_VERB_VerbForm_Tense conj_VERB_VerbForm_Tense_Aspect aux_PART_PartType_VerbForm advcl_VERB_VerbForm prep_ADP det_DET compound_NOUN_Number pobj_NOUN_Number prep_ADP det_DET pobj_NOUN_Number cc_CCONJ_ConjType conj_NOUN_Number prep_ADP pobj_NOUN_Number prep_ADP poss_ADJ_PronType_Poss pobj_NOUN_Number .', 'corr_text': b'sentence: I have seen your advertisement for a jo

In [None]:
bucket_name = 'ml-bucket-isikus'
!gsutil -m cp -r gs://{bucket_name}/t5-base-model/data/bea-train-strict.tsv bea-train-strict.tsv

with open("bea-train-strict.tsv", "r", encoding="utf-8") as inf:
  bea_len = len([l for l in inf.read().split("\n") if l])

Copying gs://ml-bucket-isikus/t5-base-model/data/bea-train-strict.tsv...
\ [1/1 files][ 16.6 MiB/ 16.6 MiB] 100% Done                                    
Operation completed over 1 objects/16.6 MiB.                                     


Now, we write a preprocess function to convert the examples in the `tf.data.Dataset` into a text-to-text format, with both `inputs` and `targets` fields. The preprocessor also normalizes the text by lowercasing it and removing quotes since the corr_texts are sometimes formatted in odd ways. Finally, we prepend 'trivia orig_text:' to the inputs so that the model knows what task it's trying to solve.

In [None]:
def bea_preprocessor(ds):
  def normalize_text(text):
    """Remove quotes from a TensorFlow string."""
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"orig_text": ..., "corr_text": ...}->{"inputs": ..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["bea: ", normalize_text(ex["orig_text"])]),
        "targets": normalize_text(ex["corr_text"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

Finally, we put everything together to create a `Task`.

In [None]:
t5.data.TaskRegistry.add(
    "bea",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=bea_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[bea_preprocessor],
    # Use the same vocabulary that we used for pre-training.
    sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy]
)



Let's look at a few pre-processed examples from the validation set. Note they contain both the tokenized (integer) and plain-text inputs and targets.


In [None]:
corr_task = t5.data.TaskRegistry.get("bea")
ds = corr_task.get_dataset(split="validation", sequence_length={"inputs": 512, "targets": 512})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

  return dataset.map(my_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed validation examples...
{'inputs_plaintext': b'bea: sentence: The answer is pork , lamb and other meats . parsing: det_DET nsubj_NOUN_Number ROOT_VERB_VerbForm_Tense_Number_Person attr_NOUN_Number , conj_NOUN_Number cc_CCONJ_ConjType amod_ADJ_Degree conj_NOUN_Number .', 'inputs': array([   36,     9,    10,  7142,    10,    37,  1525,    19, 13654,
           3,     6, 17871,    11,   119,  3604,     7,     3,     5,
         260,     7,    53,    10,    20,    17,   834,  5596,   382,
           3,    29,  7304,   354,   834,  7400,  7443,   834,   567,
        5937,    49, 10264,  6951,   834, 16174,   279,   834,  5000,
         115,  3809,    51,   834,   382,  5167,   834,   567,  5937,
          49,   834,   345, 13515,    44,    17,    52,   834,  7400,
        7443,   834,   567,  5937,    49,     3,     6,   975,   354,
         834,  7400,  7443,   834,   567,  5937,    49,     3,    75,
          75,   834,   254, 17752,   683,   834,  4302,   354, 25160,
 

## Dataset Mixture

We now create a `Mixture` from the above `Tasks`, which we will fine-tune on.

There are different ways to automatically set the rate (for example, based on the number of examples using `rate_num_examples`), but we will just hardcode an equal mixture for simplicity.

In [None]:
print(correct_len)
print(conll_len)
print(jfleg_len)
print(bea_len)

32668
28346
3016
34304


In [None]:
tasks_and_weights = [
  ('correct', float(correct_len)),
  ('conll', float(conll_len)),
  ('jfleg', float(jfleg_len)),
  ('bea', float(bea_len))
]

In [None]:
t5.data.MixtureRegistry.remove("correctit_all")
t5.data.MixtureRegistry.add("correctit_all", tasks_and_weights)

# Transferring to new Tasks

We are now ready to fine-tune one of the pre-trained T5 models on our new mixture of closed-book QA tasks.

First, we'll instantiate a `Model` object using the model size of your choice. Note that larger models are slower to train and use but will likely achieve higher accuracy. You also may be able to increase accuracy by training longer with more `FINETUNE_STEPS` below.


## Caveats

* Due to its memory requirements, you will not be able to train the `11B` parameter model on the TPU provided by Colab. Instead, you will need to fine-tune inside of a GCP instance (see [README](https://github.com/google-research/text-to-text-transfer-transformer/)).
* Due to the checkpoint size, you will not be able use the 5GB GCS free tier for the `3B` parameter models. You will need at least 25GB of space, which you can purchase with your $300 of initial credit on GCP.
* While `large` can achieve decent results, it is recommended that you fine-tune at least the `3B` parameter model.


## Define Model

In [None]:
run = "total"  # @param {"type": "string"}

In [None]:
MODEL_SIZE = "3B" #@param["small", "base", "large", "3B", "11B"]
# Public GCS path for T5 pre-trained model checkpoints
BASE_PRETRAINED_DIR = "gs://t5-data/pretrained_models"
PRETRAINED_DIR = os.path.join(BASE_PRETRAINED_DIR, MODEL_SIZE)
if run not in [None, ""]:
    MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE+"-"+run)
else:
    MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE)

if ON_CLOUD and MODEL_SIZE == "3B":
  tf.logging.warn(
      "The `3B` model is too large to use with the 5GB GCS free tier. "
      "Make sure you have at least 25GB on GCS before continuing."
  )
elif ON_CLOUD and MODEL_SIZE == "11B":
  raise ValueError(
      "The `11B` parameter is too large to fine-tune on the `v2-8` TPU "
      "provided by Colab. Please comment out this Error if you're running "
      "on a larger TPU."
  )

# Set parallelism and batch size to fit on v2-8 TPU (if possible).
# Limit number of checkpoints to fit within 5GB (if possible).
model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 128, 16),
    "base": (2, 64, 8),
    "large": (8, 32, 4),
    "3B": (8, 8, 1),
    "11B": (8, 8, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)



## Train and evaluate

We now evaluate on the validation sets of the tasks in our mixture. Accuracy results will be logged and added to the TensorBoard above.

In [None]:
from mosestokenizer import *

In [None]:
import spacy

nlp = spacy.load("en", disable=["tagger", "parser", 'ner', 'textcat', 'lemmatizer'])
tokenify = lambda snt: " ".join(str(x) for x in nlp(snt))

In [None]:
from contextlib import contextmanager
import logging

@contextmanager
def all_logging_disabled(highest_level=logging.CRITICAL):
    """
    A context manager that will prevent any logging messages
    triggered during the body from being processed.
    :param highest_level: the maximum logging level in use.
      This would only need to be changed if a custom level greater than CRITICAL
      is defined.
    """
    # two kind-of hacks here:
    #    * can't get the highest logging level in effect => delegate to the user
    #    * can't get the current module-level override => use an undocumented
    #       (but non-private!) interface

    previous_level = logging.root.manager.disable

    logging.disable(highest_level)

    try:
        yield
    finally:
        logging.disable(previous_level)

In [None]:
from nltk import sent_tokenize

In [None]:
def heal(insent, raw=True):
  insent = insent.replace(chr(8263) + " ", "<")
  if not raw:
    tlist = tokenify(insent).split(" ")
    with MosesDetokenizer('en') as detokenize:
      ss = detokenize(tlist)
    # outsent = " ".join(s.capitalize() for s in sent_tokenize(ss, "english"))
    outsent = outsent.replace(" - ", "-").replace(" 've", "")
    # outsent = re.sub(r'(.*? < br > )(.)(.*?)', lambda m: r'{}'.format(m.group(1)+m.group(2).upper()+m.group(3)), outsent)
  outsent = outsent.replace(" <br> ", "\n").replace(" < br > ", "\n")
  return outsent

In [None]:
import warnings
from math import ceil

from nltk.translate.gleu_score import corpus_gleu
from rouge import Rouge

rouge = Rouge()

In [None]:
!gsutil -m cp -r gs://{bucket_name}/t5-base-model/data/correct-target.tsv correct-target.tsv

Copying gs://ml-bucket-isikus/t5-base-model/data/correct-target.tsv...
- [1/1 files][ 13.4 MiB/ 13.4 MiB] 100% Done                                    
Operation completed over 1 objects/13.4 MiB.                                     


In [None]:
import signal
from contextlib import contextmanager

class TimeoutException(Exception): pass

@contextmanager
def time_limit(seconds):
    def signal_handler(signum, frame):
        raise TimeoutException("Timed out!")
    signal.signal(signal.SIGALRM, signal_handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)

In [None]:
ft_steps = 6400
n_ckpts = 4
steparr = [int(n) for n in np.linspace(0, ft_steps, n_ckpts)[1:]]

for i in range(n_ckpts):
  STEP = steparr[i]

  print("**** Training checkpoint %s ****" % str(STEP))
  # The models from our paper are based on the Mesh Tensorflow Transformer.
  model = t5.models.MtfModel(
      model_dir=MODEL_DIR,
      tpu=TPU_ADDRESS,
      tpu_topology=TPU_TOPOLOGY,
      model_parallelism=model_parallelism,
      batch_size=train_batch_size,
      sequence_length={"inputs": 512, "targets": 512},
      learning_rate_schedule=0.0025,
      save_checkpoints_steps=STEP + 1000,
      keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
      iterations_per_loop=100,
  )

  FINETUNE_STEPS = STEP
  print("Finetuning for", FINETUNE_STEPS, "steps")

  model.finetune(
      mixture_or_task_name="correctit_all",
      pretrained_model_dir=PRETRAINED_DIR,
      finetune_steps=FINETUNE_STEPS
  )

  print("")
  print("")

  epoch = 1000000 + STEP
  print("**** Evaluating checkpoint %s ****" % str(epoch))

  with all_logging_disabled():
    model.batch_size = train_batch_size * 4
    model.eval(
        mixture_or_task_name="correctit_all",
        checkpoint_steps=[epoch]
    )

    for task in tasks:
      !gsutil -m cp -r $MODEL_DIR/validation_eval/{task}_{str(epoch)}_predictions {task}_{str(epoch)}.bin
      with open(task + "_" + str(epoch) + ".bin", "rb") as inf:
        exec(task + ' = inf.read().decode().split("\\n")[:-1]')

  print("")

  print("** Testing on raw outputs **")

  refs = list(test["targets"])
  tknzd = [tokenify(sent) for sent in refs]

  cor_df = pd.DataFrame({
      "predictions": correct,
      "references": refs,
      "tokenized": tknzd
  })

  picklename = "corr_test_tf_" + MODEL_SIZE.lower() + "_e" + str(epoch) + ".pickle"
  cor_df.to_pickle(picklename)

  !cp ./{picklename} /content/gdrive/My\ Drive/heptabot/output

  t_preds = [tokenify(sent).split(" ") for sent in correct]
  t_refs = [ref.split(" ") for ref in tknzd]
  gleu = corpus_gleu([[ref] for ref in t_refs], t_preds)
  print("GLEU:", gleu)

  try:
    with time_limit(300):
      rouge_scores = rouge.get_scores(correct, refs, avg=True, ignore_empty=False)
      print("ROUGE-L:", rouge_scores["rouge-l"])
  except TimeoutException as e:
    print("ROUGE function timed out")
  print("")

  print("** Testing on BEA-2019 **")
  outstr = ""
  for line in bea:
    line = line.replace("\n", " ")
    outstr += tokenify(line) + "\n"

  with open("ABCN.bea19.test.corr", "w", encoding="utf-8") as outtest:
    outtest.write(outstr)

  !zip bea-test-{"tf_" + MODEL_SIZE.lower() + "_e" + str(epoch)}.zip ABCN.bea19.test.corr
  !cp ./bea-test-{"tf_" + MODEL_SIZE.lower() + "_e" + str(epoch)}.zip /content/gdrive/My\ Drive/heptabot/output
  print("")

  print("** Testing on JFLEG **")
  %cd jfleg
  outstr = ""

  for line in jfleg:
    line = line.replace("\n", " ")
    outstr += tokenify(line) + "\n"

  with open("test.nospc.res", "w", encoding="utf-8") as outtest:
    outtest.write(outstr)
  
  !cp ./test.nospc.res /content/gdrive/My\ Drive/heptabot/output/test.nospc.{"tf_" + MODEL_SIZE.lower() + "_e" + str(epoch)}.res

  !python ./eval/gleu.py -r ./test/test.ref[0-3] -s ./test/test.src --hyp test.nospc.res

  %cd ../
  print("")

  print("** Testing on CoNLL-2014 **")
  outstr = ""

  for line in conll:
    line = line.replace("\n", " ")
    outstr += tokenify(line) + "\n"

  with open("conll14_nospc.txt", "w", encoding="utf-8") as outtest:
    outtest.write(outstr)
  !cp ./conll14_nospc.txt /content/gdrive/My\ Drive/heptabot/output/conll14_nospc_{"tf_" + MODEL_SIZE.lower() + "_e" + str(epoch)}.txt

  try:
    with time_limit(300):
      !python2 ./m2scorer/scripts/m2scorer.py ./conll14_nospc.txt ./conll14st-test-data/noalt/official-2014.combined.m2
  except TimeoutException as e:
    print("M2 scorer timed out")
  print("")

In [None]:
import re
import pandas as pd

os.environ["MODEL_DIR"] = MODEL_DIR

ckpts = !gsutil ls $MODEL_DIR
ckptset = set(int(re.search(r"ckpt-([0-9]+?)\.data", s).group(1)) for s in ckpts
              if re.search(r"ckpt-([0-9]+?)\.data", s))
ckptlist = sorted(list(ckptset))

tasks = ["correct", "jfleg", "conll", "bea"]
test = pd.read_csv("correct-target.tsv", sep="\t", header=None)
test.columns = ["sources", "targets"]


for epoch in ckptlist:
  print("**** Evaluating checkpoint %s ****" % str(epoch))

  with all_logging_disabled():
    model.batch_size = train_batch_size * 4
    model.eval(
        mixture_or_task_name="correctit_all",
        checkpoint_steps=[epoch]
    )

    for task in tasks:
      !gsutil -m cp -r $MODEL_DIR/validation_eval/{task}_{str(epoch)}_predictions {task}_{str(epoch)}.bin
      with open(task + "_" + str(epoch) + ".bin", "rb") as inf:
        exec(task + ' = inf.read().decode().split("\\n")[:-1]')

  print("")

  print("** Testing on raw outputs **")

  refs = list(test["targets"])
  tknzd = [tokenify(sent) for sent in refs]

  cor_df = pd.DataFrame({
      "predictions": correct,
      "references": refs,
      "tokenized": tknzd
  })

  picklename = "corr_test_tf_" + MODEL_SIZE.lower() + "_e" + str(epoch) + ".pickle"
  cor_df.to_pickle(picklename)

  !cp ./{picklename} /content/gdrive/My\ Drive/heptabot/output

  t_preds = [tokenify(sent).split(" ") for sent in correct]
  t_refs = [ref.split(" ") for ref in tknzd]
  gleu = corpus_gleu([[ref] for ref in t_refs], t_preds)
  print("GLEU:", gleu)

  try:
    with time_limit(300):
      rouge_scores = rouge.get_scores(correct, refs, avg=True, ignore_empty=False)
      print("ROUGE-L:", rouge_scores["rouge-l"])
  except TimeoutException as e:
    print("ROUGE function timed out")
  print("")

  print("** Testing on BEA-2019 **")
  outstr = ""
  for line in bea:
    line = line.replace("\n", " ")
    outstr += tokenify(line) + "\n"

  with open("ABCN.bea19.test.corr", "w", encoding="utf-8") as outtest:
    outtest.write(outstr)

  !zip bea-test-{"tf_" + MODEL_SIZE.lower() + "_e" + str(epoch)}.zip ABCN.bea19.test.corr
  !cp ./bea-test-{"tf_" + MODEL_SIZE.lower() + "_e" + str(epoch)}.zip /content/gdrive/My\ Drive/heptabot/output
  print("")

  print("** Testing on JFLEG **")
  %cd jfleg
  outstr = ""

  for line in jfleg:
    line = line.replace("\n", " ")
    outstr += tokenify(line) + "\n"

  with open("test.nospc.res", "w", encoding="utf-8") as outtest:
    outtest.write(outstr)
  
  !cp ./test.nospc.res /content/gdrive/My\ Drive/heptabot/output/test.nospc.{"tf_" + MODEL_SIZE.lower() + "_e" + str(epoch)}.res

  !python ./eval/gleu.py -r ./test/test.ref[0-3] -s ./test/test.src --hyp test.nospc.res

  %cd ../
  print("")

  print("** Testing on CoNLL-2014 **")
  outstr = ""

  for line in conll:
    line = line.replace("\n", " ")
    outstr += tokenify(line) + "\n"

  with open("conll14_nospc.txt", "w", encoding="utf-8") as outtest:
    outtest.write(outstr)
  !cp ./conll14_nospc.txt /content/gdrive/My\ Drive/heptabot/output/conll14_nospc_{"tf_" + MODEL_SIZE.lower() + "_e" + str(epoch)}.txt

  try:
    with time_limit(300):
      !python2 ./m2scorer/scripts/m2scorer.py ./conll14_nospc.txt ./conll14st-test-data/noalt/official-2014.combined.m2
  except TimeoutException as e:
    print("M2 scorer timed out")
  print("")

**** Evaluating checkpoint 1006400 ****


  return dataset.map(my_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)


Copying gs://ml-bucket-isikus/t5-base-model/models/3B-total/validation_eval/correct_1006400_predictions...
- [1/1 files][  6.6 MiB/  6.6 MiB] 100% Done                                    
Operation completed over 1 objects/6.6 MiB.                                      
Copying gs://ml-bucket-isikus/t5-base-model/models/3B-total/validation_eval/jfleg_1006400_predictions...
/ [1/1 files][ 71.6 KiB/ 71.6 KiB] 100% Done                                    
Operation completed over 1 objects/71.6 KiB.                                     
Copying gs://ml-bucket-isikus/t5-base-model/models/3B-total/validation_eval/conll_1006400_predictions...
/ [1/1 files][158.5 KiB/158.5 KiB] 100% Done                                    
Operation completed over 1 objects/158.5 KiB.                                    
Copying gs://ml-bucket-isikus/t5-base-model/models/3B-total/validation_eval/bea_1006400_predictions...
/ [1/1 files][430.4 KiB/430.4 KiB] 100% Done                                    
Operation 