<a href="https://colab.research.google.com/github/johntiger1/hugging-face-generation/blob/master/Copy_of_t5_trivia_working_aug4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


<a href="https://colab.research.google.com/github/google-research/text-to-text-transfer-transformer/blob/master/notebooks/t5-trivia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2020 The T5 Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Copyright 2019 The T5 Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Fine-Tuning the Text-To-Text Transfer Transformer (T5) for Closed-Book Question Answering
## _Or: What does T5 know?_

*The following tutorial guides you through the process of fine-tuning a pre-trained T5 model, evaluating its accuracy, and using it for prediction,
all on a free Google Cloud TPU <a href="https://colab.research.google.com/github/google-research/text-to-text-transfer-transformer/blob/master/notebooks/t5-trivia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>.*

### Background

T5 was introduced in the paper [_Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer_](https://arxiv.org/abs/1910.10683). In that paper, we provided a comprehensive picture of how we pre-trained a standard text-to-text Transformer model on a large text corpus, achieving state-of-the-art results on many NLP tasks after fine-tuning.

We pre-trained T5 on a mixture of supervised and unsupervised tasks with the majoriy of data coming from an unlabeled dataset we developed called [C4](https://www.tensorflow.org/datasets/catalog/c4). C4 is based on a massive scrape of the web produced by [Common Crawl](https://commoncrawl.org). Loosely speaking, pre-training on C4 ideally gives T5 an understanding of natural language in addition to general world knowledge.

### How can we assess what T5 knows?

As the name implies, T5 is a text-to-text model, which enables us to train it on arbitrary tasks involving a textual input and output. As we showed in our paper, a huge variety of NLP tasks can be cast in this format, including translation, summarization, and even classification and regression tasks.

One way to use this text-to-text framework is on reading comprehension problems, where the model is fed some context along with a question and is trained to predict the question's answer. For example, we might feed the model the text from the Wikipedia article about [Hurrican Connie](https://en.wikipedia.org/wiki/Hurricane_Connie) along with the question "On what date did Hurricane Connie occur?" and train the model to predict the answer "August 3rd, 1955".
A related task is open-domain question answering (QA) where the model is not provided with this oracle context. Typically, open-domain QA systems include a mechanism to look up information in an external knowledge source. This setting is similar to an "open-book" exam.

In this notebook, we'll be training T5 on a variant of this task which we call **closed-book question answering**. In closed-book QA, we feed the model a question *without any context or access to external knowledge* and train it to predict the answer. Since the model doesn't receive any context, the primary way it can learn to answer these questions is based on the "knowledge" it obtained during pre-training. We don't expect T5 to contain super specific information, so we will be focusing on two question-answering datasets which largely include trivia questions (i.e. facts about well-known subjects). [Similar](https://arxiv.org/abs/1909.01066) [investigations](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) have recently been done to test the knowledge stored by BERT and GPT-2.

T5 was not pre-trained on closed-book QA, so in this notebook we'll first create two new tasks and then use the [`t5`](https://github.com/google-research/text-to-text-transfer-transformer) library to fine-tune, evaluate, and obtain predictions from T5. In the end, T5's performance on closed-book QA can give us a sense of what kind (and how much) information T5 managed to learn during pre-training.

## State-of-the-art Results
We published a [more in-depth investigation](https://arxiv.org/abs/2002.08910) of closed-book QA with T5 where we achieved SOTA on open-domain variants of WebQuestions and TriviaQA in addition to surpisingly strong results on Natural Questions. The code in this notebook is a simplified version of those experiments but still produces good results.

For code to reproduce our best results, please see the [t5_closed_book_qa](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa) repo.


### Caveats

* While we provide instructions for running on a [Cloud TPU](https://cloud.google.com/tpu/) via Colab for free, a [Google Cloud Storage (GCS)](http://console.cloud.google.com/storage) bucket is required for storing model parameters and data. The [GCS free tier](https://cloud.google.com/free/) provides 5 GB of storage, which should be enough to train the `large` model and smaller but not the `3B` or `11B` parameter models. You can use part of your initial $300 credit to get more space.
* The Cloud TPU provided by Colab (a `v2-8`) does not have enough memory to fine-tune the `11B` parameter model. For this model, you will need to fine-tune inside of a GCP instance (see [README](https://github.com/google-research/text-to-text-transfer-transformer/)).


# Set Up

<h3><a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>  &nbsp;&nbsp;Train on TPU</h3>




   1. Create a Cloud Storage bucket for your data and model checkpoints at http://console.cloud.google.com/storage, and fill in the `BASE_DIR` parameter in the following form. There is a [free tier](https://cloud.google.com/free/) if you do not yet have an account.
 
   1. On the main menu, click Runtime and select **Change runtime type**. Set "TPU" as the hardware accelerator.
   1. Run the following cell and follow instructions to:
    *  Set up a Colab TPU running environment
    *   Verify that you are connected to a TPU device
    *   Upload your credentials to TPU to access your GCS bucket


In [None]:
print("Installing dependencies...")
%tensorflow_version 2.1
!pip install t5==0.6.2

Installing dependencies...
`%tensorflow_version` only switches the major version: 1.x or 2.x.
You set: `2.1`. This will be interpreted as: `2.x`.


TensorFlow 2.x selected.
Collecting t5==0.6.2
[?25l  Downloading https://files.pythonhosted.org/packages/9c/6b/21374c00746a960eceb23bcd679a88832f75b6fc2d6514b50d403e343140/t5-0.6.2-py3-none-any.whl (162kB)
[K     |████████████████████████████████| 163kB 3.5MB/s 
Collecting transformers>=2.7.0
[?25l  Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)
[K     |████████████████████████████████| 778kB 10.2MB/s 
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 17.3MB/s 
Collecting tfds-nightly
[?25l  Downloading https://f

In [None]:
import tensorflow as tf
tf.__version__

'2.3.0'

In [5]:



import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

BASE_DIR = "gs://t5-open-qa-bucket" #@param { type: "string" }
if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")
ON_CLOUD = True


if ON_CLOUD:
  print("Setting up GCS access...")
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  print("we are auth")
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 
Setting up GCS access...
Running on TPU: grpc://10.78.242.58:8470


KeyboardInterrupt: ignored

In [None]:
import tensorflow

In [None]:
tensorflow.__version__

'2.2.0'

# Creating new Tasks and Mixture

Two core components of the T5 library are `Task` and `Mixture` objects.

A `Task` is a dataset along with preprocessing functions and evaluation metrics. A `Mixture` is a collection of `Task` objects along with a mixing rate or a function defining how to compute a mixing rate based on the properties of the constituent `Tasks`.

For this example, we will fine-tune the model to do closed-book question answering.

### Natural Questions

[Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) is a challenging corpus for open-domain QA. Each example includes a question along with an entire Wikipedia article that may or may not contain its answer. The goal is to produce the correct answer given this context. In our case, we will be ignoring the provided context in hopes that the model will learn to find the answers from the world knowledge it has acquired during pre-training.

Since the raw data splits are stored as JSONL files, we will first need to convert them to TSV format to make them parseable in TensorFlow. We will also take the opportunity to drop information we will not be using, remove questions with multiple answers, and to do a bit of cleaning of the text.

In [None]:
import gzip
import json

# Public directory of Natural Questions data on GCS.
NQ_JSONL_DIR = "gs://natural_questions/v1.0-simplified/"
NQ_SPLIT_FNAMES = {
    "train": "simplified-nq-train.jsonl.gz",
    "validation": "nq-dev-all.jsonl.gz"
}
nq_counts_path = os.path.join(DATA_DIR, "nq-counts.json")
nq_tsv_path = {
    "train": os.path.join(DATA_DIR, "nq-train.tsv"),
    "validation": os.path.join(DATA_DIR, "nq-validation.tsv")
}

covid_qa_tsv_path = {
    "train": os.path.join(DATA_DIR, "cleaned-covid-train.tsv"),
    "validation": os.path.join(DATA_DIR, "cleaned-covid-valid.tsv")

}

import pandas as pd
def csv_to_tsv(in_fname, out_fname):
  my_df = pd.load_csv(in_fname)

  

'''
We need to redo it:
formulate the COVID-19 in this format
'''
from tqdm.auto import tqdm
def nq_jsonl_to_tsv(in_fname, out_fname):

  def extract_answer(tokens, span):
    """Reconstruct answer from token span and remove extra spaces."""
    start, end = span["start_token"], span["end_token"]  
    ans = " ".join(tokens[start:end])
    # Remove incorrect spacing around punctuation.
    ans = ans.replace(" ,", ",").replace(" .", ".").replace(" %", "%")
    ans = ans.replace(" - ", "-").replace(" : ", ":").replace(" / ", "/")
    ans = ans.replace("( ", "(").replace(" )", ")")
    ans = ans.replace("`` ", "\"").replace(" ''", "\"")
    ans = ans.replace(" 's", "'s").replace("s ' ", "s' ")
    return ans

  count = 0
  with tf.io.gfile.GFile(in_fname, "rb") as infile,\
       tf.io.gfile.GFile(out_fname, "w") as outfile:
    for line in tqdm(gzip.open(infile)):
      ex = json.loads(line)
      # Remove any examples with more than one answer.
      if len(ex['annotations'][0]['short_answers']) != 1:
        continue
      # Questions in NQ do not include a question mark.
      question = ex["question_text"] + "?"
      answer_span = ex['annotations'][0]['short_answers'][0]
      # Handle the two document formats in NQ (tokens or text).
      if "document_tokens" in ex:
        tokens = [t["token"] for t in ex["document_tokens"]]
      elif "document_text" in ex:
        tokens = ex["document_text"].split(" ")
      answer = extract_answer(tokens, answer_span)
      # Write this line as <question>\t<answer>
      outfile.write("%s\t%s\n" % (question, answer)) # we just need Q TAB A pairs
      count += 1
      tf.logging.log_every_n(
          tf.logging.INFO,
          "Wrote %d examples to %s." % (count, out_fname),
          1000)
      # if count > 100 and count % 10000 == 0:
      #   print("nice, finsihed 10k examples")
      #   break
    return count

if tf.io.gfile.exists(nq_counts_path):
  # Used cached data and counts.
  tf.logging.info("Loading NQ from cache.")
  num_nq_examples = json.load(tf.io.gfile.GFile(nq_counts_path))
else:
  # Create TSVs and get counts.
  tf.logging.info("Generating NQ TSVs.")
  num_nq_examples = {}
  for split, fname in NQ_SPLIT_FNAMES.items():
    num_nq_examples[split] = nq_jsonl_to_tsv(
        os.path.join(NQ_JSONL_DIR, fname), nq_tsv_path[split])
  json.dump(num_nq_examples, tf.io.gfile.GFile(nq_counts_path, "w"))

INFO:tensorflow:Loading NQ from cache.


In [None]:
num_nq_examples

{'train': 96499, 'validation': 2338}

In [None]:
pip install gcsfs

Collecting gcsfs
  Downloading https://files.pythonhosted.org/packages/ce/5c/bc61dbd2e5b61d84486a96a64ca43512c9ac085487464562182f58406290/gcsfs-0.6.2-py2.py3-none-any.whl
Installing collected packages: gcsfs
Successfully installed gcsfs-0.6.2


In [None]:
DATA_DIR

'gs://t5-open-qa-bucket/data'

In [None]:
import pandas as pd


In [None]:
my_df = pd.read_csv(os.path.join(DATA_DIR,"cleaned_QA_COVID_19_General.csv"))




In [None]:
my_df

Unnamed: 0,question,text,answer
0,What is a coronavirus?,Coronaviruses are a large family of viruses wh...,a large family of viruses
1,what is COVID-19?,COVID-19 is the infectious disease caused by t...,infectious disease
2,what are the symptoms of COVID-19?,The most common symptoms of COVID-19 are fever...,"fever, tiredness, and dry cough"
3,How does COVID-19 spread?,People can catch COVID-19 from others who have...,spread from person to person through small dro...
4,How likely am I to catch COVID-19?,The risk depends on where you are - and more ...,in most locations the risk of catching COVID-1...
...,...,...,...
192,\nHow does the virus spread?,Officials are still learning about how COVID-1...,"From touching our mouths, noses or eyes after ..."
193,How does the virus spread?,Officials are still learning about how COVID-1...,From close contact with people who have it.\n
194,How does the virus spread?,Officials are still learning about how COVID-1...,From respiratory droplets that become airborne...
195,\nWill warm weather stop the outbreak of COVID...,Researchers are still learning about how easil...,it’s unclear


In [None]:
my_df.dtypes

question    object
text        object
answer      object
dtype: object

In [None]:
my_df = my_df.astype(pd.StringDtype())

In [None]:
my_df.dtypes

question    string
text        string
answer      string
dtype: object

In [None]:

relevant_df = my_df[["question", "answer"]].dropna()

In [None]:
len(relevant_df)

178

In [None]:
# relevant_df.iloc[432,1]

In [None]:
# incredibly, some of them are floats!!
relevant_df["question"] = relevant_df["question"].map(lambda x: x.replace("\n", " "))
relevant_df["answer"] = relevant_df["answer"].map(lambda x: x.replace("\n", " "))




In [None]:
my_df.dtypes

question    string
text        string
answer      string
dtype: object

In [None]:
relevant_df

Unnamed: 0,question,answer
0,What is a coronavirus?,a large family of viruses
1,what is COVID-19?,infectious disease
2,what are the symptoms of COVID-19?,"fever, tiredness, and dry cough"
3,How does COVID-19 spread?,spread from person to person through small dro...
4,How likely am I to catch COVID-19?,in most locations the risk of catching COVID-1...
...,...,...
192,How does the virus spread?,"From touching our mouths, noses or eyes after ..."
193,How does the virus spread?,From close contact with people who have it.
194,How does the virus spread?,From respiratory droplets that become airborne...
195,Will warm weather stop the outbreak of COVID-19?,it’s unclear


In [None]:
len(relevant_df["question"].unique())


170

In [None]:
relevant_df.groupby("question").count().sort_values(by="answer", ascending=False)

Unnamed: 0_level_0,answer
question,Unnamed: 1_level_1
How can I protect myself from getting COVID-19?,7
How can I ask for help?,2
How does the virus spread?,2
Are there any tests that I can purchase to test myself at home for COVID-19?,1
"What happens if a pregnant worker is exposed to COVID-19 while on reassignment in a healthcare setting, including a dedicated COVID-19 clinic?",1
...,...
How long will the results take?,1
How must the employer verify the state of health of the workers arriving on the work site?,1
How should I clean my environment?,1
"How to put on, use,take off and dispose of a mask?",1


In [None]:

len(relevant_df.groupby("question").count().sort_values(by="answer", ascending=False))

170

In [None]:
type(relevant_df.groupby("question").count().sort_values(by="answer", ascending=False))

pandas.core.frame.DataFrame

In [None]:
isinstance(relevant_df.groupby("question").count().sort_values(by="answer", ascending=False), pd.DataFrame)

True

In [None]:
relevant_df.groupby("question").count().sort_values(by="answer", ascending=False).index # the index is now over questions 

Index(['How can I protect myself from getting COVID-19?',
       'How can I ask for help?', 'How does the virus spread?',
       ' Are there any tests that I can purchase to test myself at home for COVID-19?',
       'What happens if a pregnant worker is exposed to COVID-19 while on reassignment in a healthcare setting, including a dedicated COVID-19 clinic?',
       'What can I do to help those who are vulnerable in Oxford?',
       'What can National Veterinary Services do with regards to companion animals?',
       'What do I do if my workplace first aid card expires?',
       'What does it really mean to self-isolate or self-quarantine? What should or shouldn't I do?',
       'What does the CNESST suggest for my employer, who has arranged a medical assessment?',
       ...
       'How does COVID-19 spread?',
       'How does the CNESST intervene in the case of a priority or non-priority business?',
       'How does the testing work?', 'How likely am I to catch COVID-19?',
       'H

In [None]:
grouped_questions_df = relevant_df.groupby("question").count().sort_values(by="answer", ascending=False)

In [None]:
# (relevant_df.groupby("question").count().sort_values(by="answer", ascending=False)) == pd.DataFrame

In [None]:
# Ensure that repeated questions are dumped in the same split
# to do this, we simply need to split up everything , and assign labels too. Give the indexes to the appropriate people.
# let's handroll it and see what happens!! 
import numpy as np

In [None]:
np.random.binomial(1, 0.05, size=(len(grouped_questions_df))) # want K binomial samples

array([0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [None]:
samples = np.random.binomial(1, 0.05, size=(len(grouped_questions_df))) 

In [None]:
samples_df = pd.DataFrame(samples, columns=["test"])

In [None]:
samples_df

Unnamed: 0,test
0,0
1,0
2,0
3,0
4,0
...,...
165,0
166,0
167,0
168,0


In [None]:
grouped_questions_df

Unnamed: 0_level_0,answer
question,Unnamed: 1_level_1
How can I protect myself from getting COVID-19?,7
How can I ask for help?,2
How does the virus spread?,2
Are there any tests that I can purchase to test myself at home for COVID-19?,1
"What happens if a pregnant worker is exposed to COVID-19 while on reassignment in a healthcare setting, including a dedicated COVID-19 clinic?",1
...,...
How long will the results take?,1
How must the employer verify the state of health of the workers arriving on the work site?,1
How should I clean my environment?,1
"How to put on, use,take off and dispose of a mask?",1


In [None]:
unique_questions_df = grouped_questions_df.reset_index()

In [None]:
unique_questions_df

Unnamed: 0,question,answer
0,How can I protect myself from getting COVID-19?,7
1,How can I ask for help?,2
2,How does the virus spread?,2
3,Are there any tests that I can purchase to te...,1
4,What happens if a pregnant worker is exposed t...,1
...,...,...
165,How long will the results take?,1
166,How must the employer verify the state of heal...,1
167,How should I clean my environment?,1
168,"How to put on, use,take off and dispose of a m...",1


In [None]:
# now, concat things, and then select the df based on them. Will require one last final join too
pd.concat((unique_questions_df,samples_df), axis=1)

Unnamed: 0,question,answer,test
0,How can I protect myself from getting COVID-19?,7,0
1,How can I ask for help?,2,0
2,How does the virus spread?,2,0
3,Are there any tests that I can purchase to te...,1,0
4,What happens if a pregnant worker is exposed t...,1,0
...,...,...,...
165,How long will the results take?,1,0
166,How must the employer verify the state of heal...,1,0
167,How should I clean my environment?,1,0
168,"How to put on, use,take off and dispose of a m...",1,0


In [None]:
assignments_df = pd.concat((unique_questions_df,samples_df), axis=1).drop(["answer"], axis=1)


In [None]:
# assignments_df.drop(["answer"], axis=1)

Unnamed: 0,question,test
0,How can I protect myself from getting COVID-19?,0
1,How can I ask for help?,0
2,How does the virus spread?,0
3,Are there any tests that I can purchase to te...,0
4,What happens if a pregnant worker is exposed t...,0
...,...,...
165,How long will the results take?,0
166,How must the employer verify the state of heal...,0
167,How should I clean my environment?,0
168,"How to put on, use,take off and dispose of a m...",0


In [None]:
expanded_assignments_df = relevant_df.merge(assignments_df, left_on=["question"], right_on=["question"], how="inner") 

In [None]:
expanded_assignments_df["test"].describe()

count    178.000000
mean       0.050562
std        0.219719
min        0.000000
25%        0.000000
50%        0.000000
75%        0.000000
max        1.000000
Name: test, dtype: float64

In [None]:
expanded_assignments_df.groupby(["question", "test"]).count().sort_values(by="answer", ascending=False)

Unnamed: 0_level_0,Unnamed: 1_level_0,answer
question,test,Unnamed: 2_level_1
How can I protect myself from getting COVID-19?,0,7
How can I ask for help?,0,2
How does the virus spread?,0,2
Are there any tests that I can purchase to test myself at home for COVID-19?,0,1
"What happens if a pregnant worker is exposed to COVID-19 while on reassignment in a healthcare setting, including a dedicated COVID-19 clinic?",0,1
...,...,...
How long will the results take?,0,1
How must the employer verify the state of health of the workers arriving on the work site?,0,1
How should I clean my environment?,0,1
"How to put on, use,take off and dispose of a mask?",0,1


In [None]:
expanded_assignments_df.groupby(["question"]).agg("sum").sort_values(by="test",ascending=False)

Unnamed: 0_level_0,test
question,Unnamed: 1_level_1
Can I offer my test for home use and/or self-collection under the Policy for Diagnostic Tests for Coronavirus Disease-2019?,1
Is there a vaccine for coronavirus?,1
Will vaccine be safe?,1
Will warm weather stop the outbreak of COVID-19?,1
Are University premises (such as libraries and museums) still open to the public?,1
...,...
How must the employer verify the state of health of the workers arriving on the work site?,0
How should I clean my environment?,0
"How to put on, use,take off and dispose of a mask?",0
How will COVID-19 respond to the warmer spring/summer weather we are fast-approaching?,0


In [None]:
unique_train_df = expanded_assignments_df[expanded_assignments_df["test"]==0].drop("test",axis=1)

In [None]:
unique_test_df = expanded_assignments_df[expanded_assignments_df["test"]==1].drop("test",axis=1)


In [None]:
unique_test_df

Unnamed: 0,question,answer
5,Should I worry about COVID-19?,quite normal for people to worry about how the...
21,Is there a vaccine for coronavirus?,Developing new vaccines takes time and they mu...
25,Will vaccine be safe?,vaccines usually have a higher bar for safety ...
35,Can I offer my test for home use and/or self-c...,Tests for Coronavirus Disease-2019 do not appl...
37,why test?,testing can confirm an infection
46,Could the arrival of spring and the warmer wea...,"It might, but so far COVID-19 has been found i..."
56,"Are there any foods, supplements, vitamins or ...","There are no foods, supplements, vitamins or n..."
161,Are University premises (such as libraries and...,the University's libraries and museums are now...
176,Will warm weather stop the outbreak of COVID-19?,it’s unclear


In [None]:
unique_train_df.groupby("question").count().sort_values(by="answer")

Unnamed: 0_level_0,answer
question,Unnamed: 1_level_1
Are there any tests that I can purchase to test myself at home for COVID-19?,1
What are the steps for a pregnant worker to apply for reassignment due to the coronavirus?,1
What can I do if I do not receive a record of employment or termination of employment?,1
What can I do to help those who are vulnerable in Oxford?,1
What can National Veterinary Services do with regards to companion animals?,1
...,...
How does the CNESST intervene in the case of a priority or non-priority business?,1
What are the measures to be implemented on construction sites to reduce contamination linked to COVID-19?,1
How does the virus spread?,2
How can I ask for help?,2


In [None]:
len(unique_test_df)

9

In [None]:
unique_train_df.to_csv(os.path.join(DATA_DIR,"cleaned-covid-train.tsv"),sep="\t",index=False, header=False)

In [None]:
unique_test_df.to_csv(os.path.join(DATA_DIR,"cleaned-covid-valid.tsv"),sep="\t",index=False, header=False)

In [None]:
# np.random.uniform(1, 0.05, size=(len(grouped_questions_df))) # 

In [None]:
# from sklearn.model_selection import train_test_split
# # randchoice k indices without replacement
# train_covid_df, valid_covid_df = train_test_split(relevant_df, test_size=0.05, random_state=0)

In [None]:
NUM_COVID_EXAMPLES = {"train":len(unique_train_df), "validation":len(unique_test_df)}

In [None]:
NUM_COVID_EXAMPLES

{'train': 169, 'validation': 9}

In [None]:
# with open(os.path.join(DATA_DIR,"abc.txt"), "w") as file:
#   file.write("hello\n")

In [None]:
with tf.io.gfile.GFile(os.path.join(DATA_DIR,"abc.txt"), "w") as file:
  file.write("hello\n")


In [None]:
# with tf.io.gfile.GFile(covid_qa_tsv_path["train"], "r") as file:
#   for line in file:
#     print(line)

Next, we define a function to load the TSV data as a `tf.data.Dataset` in TensorFlow.

In [None]:
def nq_dataset_fn(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path[split]) #add in the specific examples you want here as well 
  # Split each "<question>\t<answer>" example into (question, answer) tuple.
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # Map each tuple to a {"question": ... "answer": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["question", "answer"], ex)))
  return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(nq_dataset_fn("validation").take(1)):
  # print(ex)
  print(len(ex))
  # print(ex["question"])
  # print(ex["answer"])

A few raw validation examples...
2


In [None]:
covid_qa_tsv_path

{'train': 'gs://t5-open-qa-bucket/data/cleaned-covid-train.tsv',
 'validation': 'gs://t5-open-qa-bucket/data/cleaned-covid-valid.tsv'}

In [None]:
def covid_dataset_fn(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(covid_qa_tsv_path[split]) #add in the specific examples you want here as well 
  # Split each "<question>\t<answer>" example into (question, answer) tuple.
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # Map each tuple to a {"question": ... "answer": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["question", "answer"], ex)))
  
  return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(covid_dataset_fn("train").take(5)):
  # print(ex)
  print(len(ex))
  # print(ex["question"])
  # print(ex["answer"])


A few raw validation examples...
2
2
2
2
2


In [None]:
train_ds = covid_dataset_fn("train")

In [None]:
# # tf.compat.v1.enable_eager_execution()

# with tf.session() as sess:
#   for ex in test_ds:
#     print(ex)

Now, we write a preprocess function to convert the examples in the `tf.data.Dataset` into a text-to-text format, with both `inputs` and `targets` fields. The preprocessor also normalizes the text by lowercasing it and removing quotes since the answers are sometimes formatted in odd ways. Finally, we prepend 'trivia question:' to the inputs so that the model knows what task it's trying to solve.

In [None]:
def trivia_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    text = tf.strings.lower(text)
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"question": ..., "answer": ...}->{"inputs": ..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["trivia question: ", normalize_text(ex["question"])]),
        "targets": normalize_text(ex["answer"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

Finally, we put everything together to create a `Task`.

In [None]:
t5.data.TaskRegistry.add(
    "nq_context_free",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=nq_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[trivia_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy],
    # Not required, but helps for mixing and auto-caching.
    num_input_examples=num_nq_examples
)

Let's look at a few pre-processed examples from the validation set. Note they contain both the tokenized (integer) and plain-text inputs and targets.


In [None]:
nq_task = t5.data.TaskRegistry.get("nq_context_free")
ds = nq_task.get_dataset(split="validation", sequence_length={"inputs": 128, "targets": 32})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

A few preprocessed validation examples...
{'inputs_plaintext': b'trivia question: where do royal families get their money from?', 'inputs': array([22377,   822,    10,   213,   103, 11268,  1791,   129,    70,
         540,    45,    58,     1]), 'targets_plaintext': b'the hereditary revenues of the crown', 'targets': array([    8,   160, 11272,  1208, 14609,    13,     8, 10228,     1])}
{'inputs_plaintext': b'trivia question: where does the term dog and pony show come from?', 'inputs': array([22377,   822,    10,   213,   405,     8,  1657,  1782,    11,
       26190,   504,   369,    45,    58,     1]), 'targets_plaintext': b'in the united states in the late-19th and early-20th centuries', 'targets': array([   16,     8, 18279,  2315,    16,     8,  1480,  4481,   189,
          11,   778,  7988,   189, 11653,     1])}
{'inputs_plaintext': b'trivia question: what us president is the only president to become an eagle scout?', 'inputs': array([22377,   822,    10,   125,   178,  2753,

**Note**: Instead of defining `nq_dataset_fn` and above, we also could have used the `TextLineTask` class with the `parse_tsv` preprocessor for equivalent results as follows:

```py
t5.data.TaskRegistry.add(
    "nq_context_free",
    t5.data.TextLineTask,
    split_to_filepattern=nq_tsv_path,
    text_preprocessor=[
      functools.partial(
          t5.data.preprocessors.parse_tsv,
          field_names=["question", "answer"]),
      trivia_preprocessor
    ],
    postprocess_fn=t5.data.postprocessors.lower_text, 
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples
)
```


In [None]:
t5.data.TaskRegistry.remove("covid_context_free")

In [None]:
t5.data.TaskRegistry.add(
    "covid_context_free",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=covid_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[trivia_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy],
    # Not required, but helps for mixing and auto-caching.
    num_input_examples=NUM_COVID_EXAMPLES
)

In [None]:
covid_qa_task = t5.data.TaskRegistry.get("covid_context_free")
covid_ds = covid_qa_task.get_dataset(split="validation", sequence_length={"inputs": 128, "targets": 32})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(covid_ds.take(5)):
  print(ex)

A few preprocessed validation examples...
{'inputs_plaintext': b'trivia question: are university premises (such as libraries and museums) still open to the public?', 'inputs': array([22377,   822,    10,    33,  3819, 12787,    41,  4415,    38,
       12256,    11, 17385,    61,   341,   539,    12,     8,   452,
          58,     1]), 'targets_plaintext': b"the university's libraries and museums are now closed", 'targets': array([    8,  3819,    31,     7, 12256,    11, 17385,    33,   230,
        3168,     1])}
{'inputs_plaintext': b'trivia question: should i worry about covid-19?', 'inputs': array([22377,   822,    10,   225,     3,    23,  3516,    81,   576,
        6961,  4481,    58,     1]), 'targets_plaintext': b'quite normal for people to worry about how the covid-19 outbreak', 'targets': array([  882,  1389,    21,   151,    12,  3516,    81,   149,     8,
         576,  6961,  4481, 22494,     1])}
{'inputs_plaintext': b'trivia question: are there any foods, supplements,

## TriviaQA

A second dataset we will use is related to [TriviaQA](https://nlp.cs.washington.edu/triviaqa/). It is also intended for reading comprehension, but, once again, we will modify the task here by ignoring the provided context.

Since the dataset has been imported into [TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/catalog/trivia_qa), we can let it handle the data parsing for us. It will take a few minutes to download and preprocess the first time, but we'll be able to access it instantly from our data directory afterward.

In [None]:
ds = tfds.load(
    "trivia_qa/unfiltered.nocontext",
    data_dir=DATA_DIR,
    # Download data locally for preprocessing to avoid using GCS space.
    download_and_prepare_kwargs={"download_dir": "./downloads"})
print("A few raw validation examples...")
for ex in tfds.as_numpy(ds["validation"].take(2)):
  print(ex)

INFO:absl:Load dataset info from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0
INFO:absl:Reusing dataset trivia_qa (gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0)
INFO:absl:Constructing tf.data.Dataset for split None, from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0


A few raw validation examples...
{'answer': {'aliases': array([b'Torquemada (disambiguation)', b'Torquemada'], dtype=object), 'matched_wiki_entity_name': b'', 'normalized_aliases': array([b'torquemada', b'torquemada disambiguation'], dtype=object), 'normalized_matched_wiki_entity_name': b'', 'normalized_value': b'torquemada', 'type': b'WikipediaEntity', 'value': b'Torquemada'}, 'entity_pages': {'doc_source': array([], dtype=object), 'filename': array([], dtype=object), 'title': array([], dtype=object), 'wiki_context': array([], dtype=object)}, 'question': b'In 1483, who was appointed the first grand inquisitor of the Spanish Inquisition?', 'question_id': b'qw_16011', 'question_source': b'http://www.quizwise.com/', 'search_results': {'description': array([], dtype=object), 'filename': array([], dtype=object), 'rank': array([], dtype=int32), 'search_context': array([], dtype=object), 'title': array([], dtype=object), 'url': array([], dtype=object)}}
{'answer': {'aliases': array([b'Auster

As with Natural Questions, we need to preprocess the raw examples into `inputs` and `targets` features. We can reuse the `trivia_preprocessor` above, but first we need to convert the TriviaQA examples into the correct format, ignoring the fields we don't need for our task.

We'll then define our `Task` and print out a few preprocessed examples from the validation set.

Note that we do not need to specify the splits or number of examples since that information is provided by TFDS.

In [None]:
def tiviaqa_extract_qa(ds):
  def exract_qa(ex):
    return {
        "question": ex["question"],
        "answer": ex["answer"]["value"]
    }
  return ds.map(exract_qa, num_parallel_calls=tf.data.experimental.AUTOTUNE)

t5.data.TaskRegistry.add(
    "triviaqa_context_free",
    # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
    t5.data.TfdsTask,
    tfds_name="trivia_qa/unfiltered.nocontext:1.1.0",
    tfds_data_dir=DATA_DIR,
    text_preprocessor=[tiviaqa_extract_qa, trivia_preprocessor],
    postprocess_fn=t5.data.postprocessors.lower_text,
    metric_fns=[t5.evaluation.metrics.accuracy]
)

# Load and print a few examples.
triviaqa_task = t5.data.TaskRegistry.get("triviaqa_context_free")
ds = triviaqa_task.get_dataset(split="validation", sequence_length={"inputs": 128, "targets": 32})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(3)):
  print(ex)

INFO:absl:Load dataset info from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0
INFO:absl:Reusing dataset trivia_qa (gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0)
INFO:absl:Constructing tf.data.Dataset for split validation, from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0
INFO:absl:Load dataset info from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0


A few preprocessed validation examples...
{'inputs_plaintext': b"trivia question: the pose that evans-loude used came from which of marilyn monroe's films?", 'inputs': array([22377,   822,    10,     8,  7663,    24,     3,    15,  2132,
           7,    18,    40,  1063,   221,   261,   764,    45,    84,
          13,  2774,   120,    29,  1911,    52,    32,    15,    31,
           7,  4852,    58,     1]), 'targets_plaintext': b'the seven year itch', 'targets': array([   8, 2391,  215,    3, 7059,    1])}
{'inputs_plaintext': b'trivia question: in which u s state is arches national park located just outside the city of moab?', 'inputs': array([22377,   822,    10,    16,    84,     3,    76,     3,     7,
         538,    19,  1584,  2951,  1157,  2447,  1069,   131,  1067,
           8,   690,    13,  2288,     9,   115,    58,     1]), 'targets_plaintext': b'utah', 'targets': array([  3,  76,  17,   9, 107,   1])}
{'inputs_plaintext': b'trivia question: what is a reality tv show

## Dataset Mixture

We now create a `Mixture` from the above `Tasks`, which we will fine-tune on.

There are different ways to automatically set the rate (for example, based on the number of examples using `rate_num_examples`), but we will just hardcode an equal mixture for simplicity.

In [None]:
t5.data.MixtureRegistry.remove("trivia_all")
t5.data.MixtureRegistry.add(
    "trivia_all",
    ["nq_context_free", "triviaqa_context_free", "covid_context_free"],
     default_rate=1.0
)

# Transferring to new Tasks

We are now ready to fine-tune one of the pre-trained T5 models on our new mixture of closed-book QA tasks.

First, we'll instantiate a `Model` object using the model size of your choice. Note that larger models are slower to train and use but will likely achieve higher accuracy. You also may be able to increase accuracy by training longer with more `FINETUNE_STEPS` below.


## Caveats

* Due to its memory requirements, you will not be able to train the `11B` parameter model on the TPU provided by Colab. Instead, you will need to fine-tune inside of a GCP instance (see [README](https://github.com/google-research/text-to-text-transfer-transformer/)).
* Due to the checkpoint size, you will not be able use the 5GB GCS free tier for the `3B` parameter models. You will need at least 25GB of space, which you can purchase with your $300 of initial credit on GCP.
* While `large` can achieve decent results, it is recommended that you fine-tune at least the `3B` parameter model.


## Define Model

In [None]:
MODEL_SIZE = "large" #@param["small", "base", "large", "3B", "11B"]
# Public GCS path for T5 pre-trained model checkpoints
BASE_PRETRAINED_DIR = "gs://t5-data/pretrained_models"
PRETRAINED_DIR = os.path.join(BASE_PRETRAINED_DIR, MODEL_SIZE)
MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE)

if ON_CLOUD and MODEL_SIZE == "3B":
  tf.logging.warn(
      "The `3B` model is too large to use with the 5GB GCS free tier. "
      "Make sure you have at least 25GB on GCS before continuing."
  )
elif ON_CLOUD and MODEL_SIZE == "11B":
  raise ValueError(
      "The `11B` parameter is too large to fine-tune on the `v2-8` TPU "
      "provided by Colab. Please comment out this Error if you're running "
      "on a larger TPU."
  )

# Set parallelism and batch size to fit on v2-8 TPU (if possible).
# Limit number of checkpoints to fit within 5GB (if possible).
model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)
# The models from our paper are based on the Mesh Tensorflow Transformer.
model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length={"inputs": 128, "targets": 32},
    learning_rate_schedule=0.003,
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

Before we continue, let's load a [TensorBoard](https://www.tensorflow.org/tensorboard) visualizer so that we can keep monitor our progress. The page should automatically update as fine-tuning and evaluation proceed.

In [None]:
# if ON_CLOUD:
#   %reload_ext tensorboard
#   import tensorboard as tb
# tb.notebook.start("--logdir " + MODELS_DIR)

## Fine-tune

We are now ready to fine-tune our model. This will take a while (~2 hours with default settings), so please be patient! The larger the model and more `FINETUNE_STEPS` you use, the longer it will take.

Don't worry, you can always come back later and increase the number of steps, and it will automatically pick up where you left off.

In [None]:
FINETUNE_STEPS =  25000#@param {type: "integer"}

model.finetune(
    mixture_or_task_name="trivia_all",
    pretrained_model_dir=PRETRAINED_DIR,
    finetune_steps=FINETUNE_STEPS
)

INFO:tensorflow:Using config: {'_model_dir': 'gs://t5-open-qa-bucket/models/large', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.111.107.154:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.111.107.154:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.111.107.154:8470', '_evaluation_master': 'grpc://10.111.107.154:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker

INFO:absl:Load dataset info from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0
INFO:absl:Reusing dataset trivia_qa (gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0)
INFO:absl:Constructing tf.data.Dataset for split train, from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0


INFO:tensorflow:enable_2d_tiling: False
INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0]
  [0 0 1]
  [1 0 0]
  [1 0 1]
  [0 1 0]
  [0 1 1]
  [1 1 0]
  [1 1 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0]]

 [[0 0 1]]

 [[0 1 0]]

 [[0 1 1]]

 [[1 0 0]]

 [[1 0 1]]

 [[1 1 0]]

 [[1 1 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[8] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[model=8] LayoutRules{('batch', 'batch'), ('ensemble', 'ensemble'), ('experts', 'batch'), ('vocab', 'model'), ('heads', 'model'), ('d_ff', 'model')}
INFO:tensorflow:Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7f78d4f70ba8

In [None]:
model

<t5.models.mtf_model.MtfModel at 0x7fd38c736a58>

## Expected Results [SPOILER ALERT]

Below are the expected accuracies on the Natural Question (NQ) and TriviQA validation sets for various model sizes. The full 11B model produces the exact text of the answer 34.5% and 25.1% of the time on TriviaQA and NQ, respectively. The 3B parameter model, which is the largest that can be trained with a free Cloud TPU in Colab, achieves 29.7% and 23.7%, respectively.

In reality, the model performs better than this since requiring exact match is too strict of a metric, as you’ll see in the examples below. This helps to explain why the model appears to perform better on TriviaQA than NQ, as the latter tends to include more long-form answers extracted from the context.

Please see our [paper on closed-book QA](https://tiny.cc/t5-qa) where achieved even better results.

<img src="https://storage.googleapis.com/t5-data/assets/t5_trivia_expected.png">

## Evaluate

We now evaluate on the validation sets of the tasks in our mixture. Accuracy results will be logged and added to the TensorBoard above.

In [None]:
# Use a larger batch size for evaluation, which requires less memory.
model.batch_size = train_batch_size * 4
model.eval(
    mixture_or_task_name="trivia_all",
    checkpoint_steps="all"
)

INFO:tensorflow:Using config: {'_model_dir': 'gs://t5-open-qa-bucket/models/large', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.111.107.154:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.111.107.154:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.111.107.154:8470', '_evaluation_master': 'grpc://10.111.107.154:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker

INFO:absl:Load dataset info from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0
INFO:absl:Reusing dataset trivia_qa (gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0)
INFO:absl:Constructing tf.data.Dataset for split validation, from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0


INFO:tensorflow:Checkpoint path gs://t5-open-qa-bucket/models/large/model.ckpt-1010900
INFO:tensorflow:Querying Tensorflow master (grpc://10.111.107.154:8470) for TPU system metadata.
INFO:tensorflow:Initializing TPU system (master: grpc://10.111.107.154:8470) to fetch topology for model parallelism. This might take a while.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 1330902181351047481)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 730113845635288855)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, -3841951880722935286)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 171798

INFO:absl:Load dataset info from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0
INFO:absl:Reusing dataset trivia_qa (gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0)
INFO:absl:Constructing tf.data.Dataset for split validation, from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0


INFO:tensorflow:enable_2d_tiling: False
INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0]
  [0 0 1]
  [1 0 0]
  [1 0 1]
  [0 1 0]
  [0 1 1]
  [1 1 0]
  [1 1 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0]]

 [[0 0 1]]

 [[0 1 0]]

 [[0 1 1]]

 [[1 0 0]]

 [[1 0 1]]

 [[1 1 0]]

 [[1 1 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[8] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[model=8] LayoutRules{('batch', 'batch'), ('ensemble', 'ensemble'), ('experts', 'batch'), ('vocab', 'model'), ('heads', 'model'), ('d_ff', 'model')}
INFO:tensorflow:Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7f78bb02eac8

INFO:absl:Load dataset info from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0
INFO:absl:Reusing dataset trivia_qa (gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0)
INFO:absl:Constructing tf.data.Dataset for split validation, from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0


INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[8] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[model=8] LayoutRules{('batch', 'batch'), ('ensemble', 'ensemble'), ('experts', 'batch'), ('vocab', 'model'), ('heads', 'model'), ('d_ff', 'model')}
INFO:tensorflow:Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7f78bb02eac8>
INFO:tensorflow:Create pnum_tensor
INFO:tensorflow:Casting <dtype: 'int32'> to float32 for allreduce
INFO:tensorflow:Variable decoder/block_000/layer_000/SelfAttention/k                  size 1048576      slice_size 131072       Shape[d_model=1024, heads=1024]                             
INFO:tensorflow:Variable decoder/block_000/layer_000/SelfAttention/o                  size 1048576      slice_size 131072       Shape[heads=1024, d_model=1024

INFO:absl:Load dataset info from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0
INFO:absl:Reusing dataset trivia_qa (gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0)
INFO:absl:Constructing tf.data.Dataset for split validation, from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0


INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[8] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[model=8] LayoutRules{('batch', 'batch'), ('ensemble', 'ensemble'), ('experts', 'batch'), ('vocab', 'model'), ('heads', 'model'), ('d_ff', 'model')}
INFO:tensorflow:Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7f78bb02eac8>
INFO:tensorflow:Create pnum_tensor
INFO:tensorflow:Casting <dtype: 'int32'> to float32 for allreduce
INFO:tensorflow:Variable decoder/block_000/layer_000/SelfAttention/k                  size 1048576      slice_size 131072       Shape[d_model=1024, heads=1024]                             
INFO:tensorflow:Variable decoder/block_000/layer_000/SelfAttention/o                  size 1048576      slice_size 131072       Shape[heads=1024, d_model=1024

INFO:absl:Load dataset info from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0
INFO:absl:Reusing dataset trivia_qa (gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0)
INFO:absl:Constructing tf.data.Dataset for split validation, from gs://t5-open-qa-bucket/data/trivia_qa/unfiltered.nocontext/1.1.0


INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[8] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[model=8] LayoutRules{('batch', 'batch'), ('ensemble', 'ensemble'), ('experts', 'batch'), ('vocab', 'model'), ('heads', 'model'), ('d_ff', 'model')}
INFO:tensorflow:Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7f78bb02eac8>
INFO:tensorflow:Create pnum_tensor
INFO:tensorflow:Casting <dtype: 'int32'> to float32 for allreduce
INFO:tensorflow:Variable decoder/block_000/layer_000/SelfAttention/k                  size 1048576      slice_size 131072       Shape[d_model=1024, heads=1024]                             
INFO:tensorflow:Variable decoder/block_000/layer_000/SelfAttention/o                  size 1048576      slice_size 131072       Shape[heads=1024, d_model=1024

In [None]:
print("OK")

OK


In [None]:
# train_covid_df[train_covid_df["question"].str.contains("start")]

In [None]:
# train_covid_df[train_covid_df["question"].str.contains("pool")] # ensure that we FORCE the QA pairs to disjoin from one another

Let's look at a few random predictions from the validation sets. Note that we measure accuracy based on an *exact match* of the predicted answer and the ground-truth answer. As a result, some of the answers are semantically correct but are counted wrong by the exact match score.

In [None]:
import random

def print_random_predictions(task_name, n=10):
  """Print n predictions from the validation split of a task."""
  # Grab the dataset for this task.
  ds = t5.data.TaskRegistry.get(task_name).get_dataset(
      split="validation",
      sequence_length={"inputs": 128, "targets": 128},
      shuffle=False)

  def _prediction_file_to_ckpt(path):
    """Extract the global step from a prediction filename."""
    return int(path.split("_")[-2])

  # Grab the paths of all logged predictions.
  prediction_files = tf.io.gfile.glob(
      os.path.join(
          MODEL_DIR,
          "validation_eval/%s_*_predictions" % task_name))
  # Get most recent prediction file by sorting by their step.
  latest_prediction_file = sorted(
      prediction_files, key=_prediction_file_to_ckpt)[-1]

  # Collect (inputs, targets, prediction) from the dataset and predictions file
  results = []
  with tf.io.gfile.GFile(latest_prediction_file) as preds:
    for ex, pred in zip(tfds.as_numpy(ds), preds):
      results.append((tf.compat.as_text(ex["inputs_plaintext"]),
                      tf.compat.as_text(ex["targets_plaintext"]),
                      pred.strip()))

  print("<== Random predictions for %s using checkpoint %s ==>\n" %
        (task_name, 
         _prediction_file_to_ckpt(latest_prediction_file)))

  for inp, tgt, pred in random.choices(results, k=10):
    print("Input:", inp)
    print("Target:", tgt)
    print("Prediction:", pred)
    print("Counted as Correct?", tgt == pred)
    print()

# print_random_predictions("triviaqa_context_free")
# print_random_predictions("nq_context_free")
print_random_predictions("covid_context_free")


<== Random predictions for covid_context_free using checkpoint 1025700 ==>

Input: trivia question: could the arrival of spring and the warmer weather affect the spread of covid-19?
Target: it might, but so far covid-19 has been found in many countries, whatever their climate
Prediction: the spread of covid-19 is under control
Counted as Correct? False

Input: trivia question: are university premises (such as libraries and museums) still open to the public?
Target: the university's libraries and museums are now closed
Prediction: all libraries and museums are closed to the public, and there will be no physical access to the collections for researchers or students. however, a wide range
Counted as Correct? False

Input: trivia question: can i offer my test for home use and/or self-collection under the policy for diagnostic tests for coronavirus disease-2019?
Target: tests for coronavirus disease-2019 do not apply to at-home testing,
Prediction: the cnesst will permit home testing for co

# Predict on given questions

In [None]:
PARSA_QA_PATH = os.path.join(DATA_DIR, "cleaned_CORD19_test_QA_df.csv")


In [None]:
additional_COVID_QA_df = pd.read_csv(PARSA_QA_PATH) # cloud read by the pandas

In [None]:
additional_COVID_QA_df

Unnamed: 0,question,text,answer
0,How large is the sample size used in COVID-19 ...,Of all 59 cases diagnosed as COVID-19 in the t...,59 cases
1,What is the incubation period of the virus?,By pooling individual data from seven countrie...,7.44 days
2,How large is the sample size used in COVID-19 ...,"However, based on 329 cases (28.48%) with rele...",329 cases
3,What is the incubation period of the virus?,Twenty-one patients with COVID-19 with GI symp...,4 days (IQR 3–7 days)
4,What is the incubation period of the virus?,The median incubation period of both male and ...,The median incubation period of both male and ...
...,...,...,...
64,What is the RR for severe infection in COVID-1...,"In univariate analyses, factors robustly assoc...",diabetes (9 studies; 3.20 [2.26-4.53]
65,What is the OR for severe infection in COVID-1...,There were significant correlations between CO...,There were significant correlations between CO...
66,What is the OR for severe infection in COVID-1...,There were significant correlations between CO...,"[OR=2.67, 95% CI (1.91, 3.74), P<0.01]"
67,What is the OR for severe infection in COVID-1...,"Arrhythmia (OR: 22.17, 95%CI 4.47-110.04), acu...","OR: 2.69, 95%CI 1."


In [None]:
# METHOD 1
# extract only the questions
additional_COVID_QA_df["question"]
covid_predict_inputs_path = os.path.join(MODEL_DIR, "predict_inputs_COVID_QA.txt")
covid_predict_outputs_path = os.path.join(MODEL_DIR, "predict_outputs_COVID_QA.txt")
with tf.io.gfile.GFile(covid_predict_inputs_path, "w") as f:
  for idx, row in additional_COVID_QA_df.iterrows():
    q = row["question"]
    print(q)
    f.write("trivia question: %s\n" % q.lower()) #TODO: can also invoke the preprocess from earlier

model.batch_size = 32  # Min size for small model on v2-8 with parallelism 1.
model.predict(
    input_file=covid_predict_inputs_path,
    output_file=covid_predict_outputs_path,
    # Select the most probable output token at each step.
    temperature=0,
)


How large is the sample size used in COVID-19 studies?
What is the incubation period of the virus?
How large is the sample size used in COVID-19 studies?
What is the incubation period of the virus?
What is the incubation period of the virus?
What is the incubation period across different age groups?
How large is the sample size used in COVID-19 studies?
What is the incubation period of the virus?
What is the incubation period across different age groups?
What is the incubation period of the virus?
How large is the sample size used in COVID-19 studies?
What is the incubation period of the virus?
How large is the sample size used in COVID-19 studies?
What is the incubation period of the virus?
What is the incubation period of the virus?
What is the asymptomatic transmission during incubation?
What is the asymptomatic transmission during incubation?
How large is the sample size used in COVID-19 studies?
What is the incubation period of the virus?
What is the proportion of patients who wer

In [None]:
# METHOD 2
# use the unique_test_df only
# drawbacks: include the fact that the answer might not be in the span. We would need a new way of evaluating it !

In [None]:
# now, we simply get the predict outputs from the list of paths we provided

In [None]:
prediction_files = sorted(tf.io.gfile.glob(covid_predict_outputs_path + "*"))
print("\nPredictions using checkpoint %s:\n" % prediction_files[-1].split("-")[-1])
with tf.io.gfile.GFile(prediction_files[-1]) as f:
  answers = f.readlines()


Predictions using checkpoint 1025700:



# Predictions (2nd wave)

In [None]:
import pandas as pd

sim_df = pd.read_csv(os.path.join(DATA_DIR,"active_learning", "sim_df_75.csv"))
sme_df = pd.read_csv(os.path.join(DATA_DIR,"active_learning", "sme_rule_based.csv"))




In [None]:
sim_df

Unnamed: 0.1,Unnamed: 0,sim_score,title,sim
0,0,0.800693,A randomized open-label trial on the use of bu...,An antigenic variant of avian infectious bronc...
1,1,0.842684,Using Patient-Specific Induced Pluripotent Ste...,Recent developments in electron microscopy tom...
2,2,0.803470,Using Patient-Specific Induced Pluripotent Ste...,"Three non-structural transmembrane proteins, n..."
3,3,0.794910,Using Patient-Specific Induced Pluripotent Ste...,Pioneering ultrastructural studies using three...
4,4,0.790332,Using Patient-Specific Induced Pluripotent Ste...,The features of 3D culture were identified by ...
...,...,...,...,...
69360,69360,0.763258,Chinese Journal of Natural Medicines Metabolis...,Epidemiological studies using feline NoV-speci...
69361,69361,0.759899,Chinese Journal of Natural Medicines Metabolis...,"Results: In this study, using real-time PCR an..."
69362,69362,0.754613,Chinese Journal of Natural Medicines Metabolis...,A nested PCR assay was used to determine the v...
69363,69363,0.752573,Chinese Journal of Natural Medicines Metabolis...,"Therefore, the purpose of this study was to qu..."


In [None]:
sme_df["questions"].head().tolist()

['How domain growth is implemented determines the long term behaviour of a cell population through its effect on spatial correlations',
 'How to make more from exposure data? An integrated machine Author Contributions Data Accessibility Statement Running Title Machine learning and pathogen exposure risk',
 'When will the battle against novel coronavirus end in Wuhan: a SEIR modeling analysis',
 'How to differentiate COVID-19 pneumonia from heart failure with computed tomography at initial medical contact during epidemic period Running title: CT imaging for COVID-19 and heart failure',
 'How does the outbreak of 2019-nCoV spread in mainland China? A retrospective analysis of the dynamic transmission routes *']

In [None]:
covid_predict_inputs_sim_df_path = os.path.join(MODEL_DIR, "predict_inputs_sim_df_75.txt")
covid_predict_outputs_sim_path = os.path.join(MODEL_DIR, "predict_outputs_sim_df_75.txt")
with tf.io.gfile.GFile(covid_predict_inputs_sim_df_path, "w") as f:
  for idx, row in sim_df.iterrows():
    q = row["title"]
    print(q)
    f.write("trivia question: %s\n" % q.lower()) #TODO: can also invoke the preprocess from earlier

model.batch_size = 256  # Min size for small model on v2-8 with parallelism 1.
model.predict(
    input_file=covid_predict_inputs_sim_df_path,
    output_file=covid_predict_outputs_sim_path,
    # Select the most probable output token at each step.
    temperature=0,
)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 [[1 0 0]]

 [[1 0 1]]

 [[1 1 0]]

 [[1 1 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[8] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[model=8] LayoutRules{('heads', 'model'), ('ensemble', 'ensemble'), ('experts', 'batch'), ('vocab', 'model'), ('d_ff', 'model'), ('batch', 'batch')}
INFO:tensorflow:Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7fd388a02eb8>
INFO:tensorflow:Create pnum_tensor
INFO:tensorflow:Casting <dtype: 'int32'> to float32 for allreduce
INFO:tensorflow:Variable decoder/block_000/layer_000/SelfAttention/k                  size 1048576      slice_size 131072       Shape[d_model=1024, heads=1024]                             
INFO:tensorflow:Variable decoder/block_000/la

In [None]:
covid_predict_inputs_sme_path = os.path.join(MODEL_DIR, "predict_inputs_sme_rule_based.txt")
covid_predict_outputs_sme_path = os.path.join(MODEL_DIR, "predict_outputs_sme_rule_based.txt")
with tf.io.gfile.GFile(covid_predict_inputs_sme_path, "w") as f:
  for idx, row in sme_df.iterrows():
    q = row["questions"]
    print(q)
    f.write("trivia question: %s\n" % q.lower()) #TODO: can also invoke the preprocess from earlier

model.batch_size = 256  # Min size for small model on v2-8 with parallelism 1.
model.predict(
    input_file=covid_predict_inputs_sme_path,
    output_file=covid_predict_outputs_sme_path,
    # Select the most probable output token at each step.
    temperature=0,
)

How domain growth is implemented determines the long term behaviour of a cell population through its effect on spatial correlations
How to make more from exposure data? An integrated machine Author Contributions Data Accessibility Statement Running Title Machine learning and pathogen exposure risk
When will the battle against novel coronavirus end in Wuhan: a SEIR modeling analysis
How to differentiate COVID-19 pneumonia from heart failure with computed tomography at initial medical contact during epidemic period Running title: CT imaging for COVID-19 and heart failure
How does the outbreak of 2019-nCoV spread in mainland China? A retrospective analysis of the dynamic transmission routes *
Whole genome methylation array analysis reveals new aspects in Balkan endemic nephropathy etiology
When Viruses Don't Go Viral: The Importance of Host Phylogeographic Structure in the Spatial Spread of Arenaviruses
How Change of Public Transportation Usage Reveals Fear of the SARS Virus in a City
Who

In [None]:
print("OK")

OK


In [None]:
# tying together predictions from both

In [None]:
covid_predict_outputs_sme_path + "-1025700"

'gs://t5-open-qa-bucket/models/large/predict_outputs_sme_rule_based.txt-1025700'

In [None]:
# we need a google cloud read!

In [None]:
with tf.io.gfile.GFile((covid_predict_outputs_sme_path + "-1025700")) as answers, tf.io.gfile.GFile((covid_predict_inputs_sme_path )) as questions:
  list_answers = answers.readlines()
  list_questions = questions.readlines()
  sme_df = pd.DataFrame(data={"questions":list_questions , "answers": list_answers} )
  # print(len(), len(questions.readlines()))

    
    #

In [None]:
sme_df

Unnamed: 0,questions,answers
0,trivia question: how domain growth is implemen...,how genes are transferred between different po...
1,trivia question: how to make more from exposur...,machine learning and pathogen exposure risk mi...
2,trivia question: when will the battle against ...,"january 3, 2018\n"
3,trivia question: how to differentiate covid-19...,the ct imaging method is a technique that can ...
4,trivia question: how does the outbreak of 2019...,from respiratory droplets that become airborne...
...,...,...
605,trivia question: the ethics of improving afric...,african traditional research methods\n
606,trivia question: are formyl peptide receptors ...,formyl peptide receptors (fprs) are a class of...
607,trivia question: minireview new agents modulat...,there is a growing body of evidence that sugge...
608,trivia question: pentraxins and collectins: fr...,a key player in the pathogen-host relationship\n


In [None]:
sme_out_path = os.path.join(DATA_DIR, "sme.csv")
sim_out_path = os.path.join(DATA_DIR, "sim.csv")

In [None]:
sme_df.to_csv(sme_out_path)

In [None]:
with tf.io.gfile.GFile((covid_predict_outputs_sim_path + "-1025700")) as answers, tf.io.gfile.GFile((covid_predict_inputs_sim_df_path )) as questions:
  list_answers = answers.readlines()
  list_questions = questions.readlines()
  sim_df = pd.DataFrame(data={"questions":list_questions , "answers": list_answers} )

In [None]:
sim_df

Unnamed: 0,questions,answers
0,trivia question: a randomized open-label trial...,"a randomized, placebo-controlled trial\n"
1,trivia question: using patient-specific induce...,we have used iPS cells to generate a knockout ...
2,trivia question: using patient-specific induce...,we have used iPS cells to generate a knockout ...
3,trivia question: using patient-specific induce...,we have used iPS cells to generate a knockout ...
4,trivia question: using patient-specific induce...,we have used iPS cells to generate a knockout ...
...,...,...
69360,trivia question: chinese journal of natural me...,chinese journal of natural medicines\n
69361,trivia question: chinese journal of natural me...,chinese journal of natural medicines\n
69362,trivia question: chinese journal of natural me...,chinese journal of natural medicines\n
69363,trivia question: chinese journal of natural me...,chinese journal of natural medicines\n


In [None]:
sim_df.to_csv(sim_out_path)

In [None]:
additional_COVID_QA_df["t5_answers"] = answers

In [None]:
additional_COVID_QA_df

Unnamed: 0,question,text,answer,t5_answers
0,How large is the sample size used in COVID-19 ...,Of all 59 cases diagnosed as COVID-19 in the t...,59 cases,the study sample must be at least 10 times lar...
1,What is the incubation period of the virus?,By pooling individual data from seven countrie...,7.44 days,the incubation period for most viruses range f...
2,How large is the sample size used in COVID-19 ...,"However, based on 329 cases (28.48%) with rele...",329 cases,the study must have at least one adult partici...
3,What is the incubation period of the virus?,Twenty-one patients with COVID-19 with GI symp...,4 days (IQR 3–7 days),the incubation period for most viruses range f...
4,What is the incubation period of the virus?,The median incubation period of both male and ...,The median incubation period of both male and ...,the incubation period for most viruses range f...
...,...,...,...,...
64,What is the RR for severe infection in COVID-1...,"In univariate analyses, factors robustly assoc...",diabetes (9 studies; 3.20 [2.26-4.53],rr > 0\n
65,What is the OR for severe infection in COVID-1...,There were significant correlations between CO...,There were significant correlations between CO...,etr\n
66,What is the OR for severe infection in COVID-1...,There were significant correlations between CO...,"[OR=2.67, 95% CI (1.91, 3.74), P<0.01]",a type 2 diabetes mellitus\n
67,What is the OR for severe infection in COVID-1...,"Arrhythmia (OR: 22.17, 95%CI 4.47-110.04), acu...","OR: 2.69, 95%CI 1.",etr\n


In [None]:
T5_QA_PATH = os.path.join(DATA_DIR, "T5_CORD19_cleaned_test_QA_results_df.csv")


In [None]:
additional_COVID_QA_df.to_csv(T5_QA_PATH)

In [None]:
# set the addtional_COVID_QA_df to be the test set only
# additional_COVID_QA_df = unique_test_df

In [None]:
from __future__ import print_function
from collections import Counter
import string
import re
import argparse
import json
import sys


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def get_exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def get_f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1



exact_match_score = 0
f1_score = 0
total = len(additional_COVID_QA_df.index)

for index, row in additional_COVID_QA_df.iterrows():
    answer = normalize_answer(str(row['answer']))
    predicted_answer = normalize_answer(str(row['t5_answers']))
    
    exact_match_score += get_exact_match_score(predicted_answer, answer)
    f1_score += get_f1_score(predicted_answer, answer)

f1_score = 100.0 * f1_score / total
exact_match_score = 100.0 * exact_match_score / total
print('F1 Score: ' + str(f1_score))
print('Exact Match Score: ' + str(exact_match_score))

F1 Score: 5.953619873107497
Exact Match Score: 0.0


## Predict

Now that we have fine-tuned the model, we can feed T5 arbitrary questions and have it predict the answers!

There is a significant amount of overhead in initializing the model so this may take a few minutes to run each time even though the prediction itself is quite fast.


To avoid this overhead, you might consider exporting a `SavedModel` and running it on [Cloud ML Engine](https://cloud.google.com/ml-engine/).



In [None]:
question_1 = "What is known about COVID-19 therapeutics?" #@param {type:"string"}
question_2 = "What is the most populous country in the world?" #@param {type:"string"}
question_3 = "Who are the 4 members of The Beatles?" #@param {type:"string"}
question_4 = "How many teeth do humans have?" #@param {type:"string"}

questions = [question_1, question_2, question_3, question_4]

now = time.time()
# Write out the supplied questions to text files.
predict_inputs_path = os.path.join(MODEL_DIR, "predict_inputs_%d.txt" % now)
predict_outputs_path = os.path.join(MODEL_DIR, "predict_outputs_%d.txt" % now)
# Manually apply preprocessing by prepending "triviaqa question:".
with tf.io.gfile.GFile(predict_inputs_path, "w") as f:
  for q in questions:
    f.write("trivia question: %s\n" % q.lower())

# Ignore any logging so that we only see the model's answers to the questions.
# with tf_verbosity_level('ERROR'):
model.batch_size = 8  # Min size for small model on v2-8 with parallelism 1.
model.predict(
    input_file=predict_inputs_path,
    output_file=predict_outputs_path,
    # Select the most probable output token at each step.
    temperature=0,
)

# The output filename will have the checkpoint appended so we glob to get 
# the latest.
prediction_files = sorted(tf.io.gfile.glob(predict_outputs_path + "*"))
print("\nPredictions using checkpoint %s:\n" % prediction_files[-1].split("-")[-1])
with tf.io.gfile.GFile(prediction_files[-1]) as f:
  for q, a in zip(questions, f):
    if q:
      print("Q: " + q)
      print("A: " + a)
      print()

INFO:tensorflow:Using config: {'_model_dir': 'gs://t5-open-qa-bucket/models/large', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.111.107.154:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.111.107.154:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.111.107.154:8470', '_evaluation_master': 'grpc://10.111.107.154:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker

# Export Model for Serving

As mentioned in the previous section, exporting a [`SavedModel`](https://www.tensorflow.org/guide/saved_model) can be useful for improving performance during inference or allowing your model to be deployed on a variety of platforms (e.g., TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub).

**Note:** we currently only support exporting a SavedModel that runs on both CPU and GPU, not TPU.

## Export SavedModel

We first export the SavedModel. We set a batch size of 1 for simplicity, but it may be more efficient to use a larger batch size if you want to handle multiple requests per call.

For 3B and 11B models the export will take approximately 30-45 minutes.

In [None]:
export_dir = os.path.join(MODEL_DIR, "export")

model.batch_size = 1 # make one prediction per call
saved_model_path = model.export(
    export_dir,
    checkpoint_step=-1,  # use most recent
    beam_size=1,  # no beam search
    temperature=1.0,  # sample according to predicted distribution
)
print("Model saved to:", saved_model_path)

INFO:tensorflow:Using config: {'_model_dir': 'gs://t5-open-qa-bucket/models/large', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.111.107.154:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.111.107.154:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.111.107.154:8470', '_evaluation_master': 'grpc://10.111.107.154:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker

## Load SavedModel

One way to test our model is to load it either in eager mode or a TF 1.x session so that we can repeatedly predict from the model without the overhead of loading the graph and weights each time.

We pay the overhead once here, but it shouldn't take more than a few minutes.


### Optional: Switch to GPU Runtime

Changing the runtime type to GPU in the `Runtime` menu above before loading the SavedModel will speed up inference by using the GPU instead of CPU.



In [None]:
#@title Optional: Run this cell to re-initialize if you switched to GPU runtime.
%tensorflow_version 2.x
!pip install tensorflow-text
from google.colab import auth
auth.authenticate_user()



In [None]:
import tensorflow as tf
import tensorflow_text  # Required to run exported model.

def load_predict_fn(model_path):
  if tf.executing_eagerly():
    print("Loading SavedModel in eager mode.")
    imported = tf.saved_model.load(model_path, ["serve"])
    return lambda x: imported.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
  else:
    print("Loading SavedModel in tf 1.x graph mode.")
    tf.compat.v1.reset_default_graph()
    sess = tf.compat.v1.Session()
    meta_graph_def = tf.compat.v1.saved_model.load(sess, ["serve"], model_path)
    signature_def = meta_graph_def.signature_def["serving_default"]
    return lambda x: sess.run(
        fetches=signature_def.outputs["outputs"].name, 
        feed_dict={signature_def.inputs["input"].name: x}
    )

predict_fn = load_predict_fn(saved_model_path)

Loading SavedModel in tf 1.x graph mode.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
INFO:tensorflow:Restoring parameters from gs://t5-open-qa-bucket/models/large/export/1592978440/variables/variables


KeyboardInterrupt: ignored

## Predict

We can now call the predict method with different inputs each time and relatively quickly get results.

In [None]:
def answer(question):
  return predict_fn([question])[0].decode('utf-8')

for question in ["trivia question: where is the google headquarters?",
                 "trivia question: what is the most populous country in the world?",
                 "trivia question: who are the 4 members of the beatles?",
                 "trivia question: how many teeth do humans have?"]:
    print(answer(question))

## Deploy SavedModel

You can now deploy your SavedModel for serving (e.g., with [TensorFlow Serving](https://www.tensorflow.org/tfx/tutorials/serving/rest_simple)).