In [None]:
print("Installing dependencies...")
%tensorflow_version 2.x
!pip install -q t5

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5
import t5.models
import seqio

BASE_DIR = "gs://pythia_t5" #@param { type: "string" }
if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")
ON_CLOUD = True


if ON_CLOUD:
  print("Setting up GCS access...")
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "v2-8"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.enable_eager_execution()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

In [None]:
#MODELS_DIR +="_task_3_5_cr"
#MODELS_DIR +="_task_3_10_cr"
MODELS_DIR +="_task_3_5_rr"
#MODELS_DIR +="_task_3_10_rr"
#MODELS_DIR +="_task_3_5_rr_gen_labels_large"
#MODELS_DIR +="_task_3_5_rr_gen_labels_small"
#MODELS_DIR +="_task_3_5_cr_gen_labels_large"
#MODELS_DIR +="_task_3_5_cr_gen_labels_small"
#MODELS_DIR +="_task_3_10_rr_labels_small"
print("DATA_DIR:", DATA_DIR)
print("MODELS_DIR:", MODELS_DIR)

In [None]:
import json
print(DATA_DIR)
#counts_path = os.path.join(DATA_DIR, "counts.json")
tsv_path = {
    #"train": os.path.join(DATA_DIR, "train-task3-5-cr-v1.2.tsv"),
    "train": os.path.join(DATA_DIR, "train-task3-5-rr-v1.2.tsv"),
    #"train": os.path.join(DATA_DIR, "train-task3-5-rr-v1.2_Gen_Labels_Large.tsv"),
    #"train": os.path.join(DATA_DIR, "train-task3-5-rr-v1.2_Gen_Labels_Small.tsv"),
    #"train": os.path.join(DATA_DIR, "train-task3-5-cr-v1.2_Gen_Labels_Small.tsv"),
    #"train": os.path.join(DATA_DIR, "train-task3-5-cr-v1.2_Gen_Labels_Large.tsv"),
    #"train": os.path.join(DATA_DIR, "train-task3-10-rr-v1.2_Gen_Labels_Small.tsv"),
    #"train": os.path.join(DATA_DIR, "train-task3-10-cr-v1.2.tsv"),
    #"train": os.path.join(DATA_DIR, "train-task3-10-rr-v1.2.tsv"),
    "validation": os.path.join(DATA_DIR, "test-task3-manual-5-rr.tsv"),
    #"validation": os.path.join(DATA_DIR, "test-task3-manual-10-cr.tsv"),
    #"validation": os.path.join(DATA_DIR, "test-task3-manual-5-cr.tsv"),
    #"validation": os.path.join(DATA_DIR, "test-task3-manual-10-rr.tsv"),
    #"validation": os.path.join(DATA_DIR, "test-task3-5-rr-strict.tsv"), #"test-task1-manual.tsv"),
}


In [None]:
def to_dataset_ts(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(tsv_path[split])
  # Split each "<question>\t<answer>" example into (question, answer) tuple.
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # Map each tuple to a {"totto": ... "explain_table": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["ambiguous", "label"], ex)))
  return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(to_dataset_ts("validation").take(5)):
  print(ex)

In [None]:
def text_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    text = tf.strings.lower(text)
    #text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"question": ..., "answer": ...}->{"inputs": ..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["pythia_schema_task1: ", normalize_text(ex["ambiguous"])]),
        "targets": normalize_text(ex["label"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
#!pip install t5[gcp]

In [None]:
def extractValues(x):
  s = x.replace('[', '').replace(']','')
  if len(s) == 0:
      return None
  else:
      #print(s)
      splits = s.split(sep=',')
      values = []
      for split in splits:
          tmp = split.strip().replace("'", "")
          values.append(tmp)
      return values

def countStats(targets, predictions):
  tn = 0
  fp = 0
  tp = 0
  fn = 0
  wrongLabels = 0
  for tgt, pred in zip(targets, predictions):
    #print(tgt, pred)
    valuesTgt = extractValues(tgt)
    if valuesTgt is None:
      valuesTgt = ["none"]
    if "none" in valuesTgt:
      if pred == "none":
        #print("TN")
        tn += 1
        continue
      else:
        fp += 1
        #print("FP")
        continue
    if "none" not in valuesTgt:
      if pred == "none":
        fn += 1
        #print("FN")
        continue
      else:
        ff = False
        for v in valuesTgt:
          if v == pred:
            ff = True
            tp += 1
            #print("TP")
            break
        if ff == False:
          wrongLabels += 1
          #print("FN")
  return tn, fp, tp, fn, wrongLabels

def countStatsLabelsStats(targets, predictions):
    count_pairs = 0
    hits = 0
    #print("**** TARGETS *****")
    #print(type(targets))
    #print(targets)
    #print("**** PREDICTIONS *****")
    #print(type(predictions))
    #print(predictions)
    #return 0/0
    for tgt, pred in zip(targets, predictions):
        valuesTgt = extractValues(tgt)
        if valuesTgt is not None:
          count_pairs += 1
        if (pred != "none") and (valuesTgt is not None):
          for v in valuesTgt:
            if v == pred:
              ff = True
              hits += 1
              break
    return {"count_pairs": count_pairs, "hits": hits}
    #return count_pairs, hits

def my_accuracy_fn(targets, predictions):
  hits = 0
  for tgt, pred in zip(targets, predictions):
    valuesTgt = extractValues(tgt)
    if valuesTgt is None:
      valuesTgt = ["none"]
    for v in valuesTgt:
      if v == pred:
        hits += 1
        continue
  total = len(targets)
  accuracy = hits/total
  return {"my_accuracy": accuracy}

def my_precision_fn(targets, predictions):
  tn, fp, tp, fn, wrongLabels = countStats(targets, predictions)
  if tp + fp == 0:
    return {"my_precision": 0}
  precision = tp/(tp + fp)
  return {"my_precision": precision}

def my_recall_fn(targets, predictions):
  tn, fp, tp, fn, wrongLabels = countStats(targets, predictions)
  if tp + fn == 0:
    return {"my_recall": 0}
  recall = tp/(tp + fn)
  return {"my_recall": recall}

def my_recall_fn_2(targets, predictions):
  tn, fp, tp, fn, wrongLabels = countStats(targets, predictions)
  if tp + fn == 0:
    return {"my_recall_2": 0}
  recall = tp/(tp + fn + wrongLabels)
  return {"my_recall_2": recall}

def my_precision_binary(targets, predictions):
  tn, fp, tp, fn, wrongLabels = countStats(targets, predictions)
  if tp + wrongLabels + fp == 0:
    return {"my_precision_binary": 0}
  precisionBinary = (tp + wrongLabels)/(tp + wrongLabels + fp)
  return {"my_precision_binary": precisionBinary}

def my_recall_binary(targets, predictions):
  tn, fp, tp, fn, wrongLabels = countStats(targets, predictions)
  if tp + fn + wrongLabels == 0:
    return {"my_recall_binary": 0}
  recallBinary = (tp + wrongLabels) / (tp + fn + wrongLabels)
  return {"my_recall_binary": recallBinary}

In [None]:
#import t5

t5.data.TaskRegistry.add(
    "ambiguity_context",
    # Specify the task type.
    t5.data.Task,
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=to_dataset_ts,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[text_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    #metric_fns=[t5.evaluation.metrics.accuracy],
    metric_fns=[my_accuracy_fn, my_precision_fn, my_recall_fn, my_recall_fn_2, my_precision_binary, my_recall_binary, countStatsLabelsStats],
    # Not required, but helps for mixing and auto-caching.
    #num_input_examples=num_nq_examples
)

In [None]:
nq_task = t5.data.TaskRegistry.get("ambiguity_context")
#ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 128, "targets": 32})
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 512, "targets": 32})
#ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 1024, "targets": 32})
print("A few preprocessed train examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

In [None]:
t5.data.MixtureRegistry.remove("all_mix")
t5.data.MixtureRegistry.add(
    "all_mix",
    ["ambiguity_context"],
     default_rate=1.0
)

In [None]:
MODEL_SIZE = "3B" #@param["small", "base", "large", "3B", "11B"]
# Public GCS path for T5 pre-trained model checkpoints
BASE_PRETRAINED_DIR = "gs://t5-data/pretrained_models"
PRETRAINED_DIR = os.path.join(BASE_PRETRAINED_DIR, MODEL_SIZE)
MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE)

if ON_CLOUD and MODEL_SIZE == "3B":
  tf.logging.warning(
      "The `3B` model is too large to use with the 5GB GCS free tier. "
      "Make sure you have at least 25GB on GCS before continuing."
  )
elif ON_CLOUD and MODEL_SIZE == "11B":
  raise ValueError(
      "The `11B` parameter is too large to fine-tune on the `v2-8` TPU "
      "provided by Colab. Please comment out this Error if you're running "
      "on a larger TPU."
  )

# Set parallelism and batch size to fit on v2-8 TPU (if possible).
# Limit number of checkpoints to fit within 5GB (if possible).
model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)
# The models from our paper are based on the Mesh Tensorflow Transformer.

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    #sequence_length={"inputs": 128, "targets": 32},
    sequence_length={"inputs": 512, "targets": 32}, ## good for 5 encoding traing
    #sequence_length={"inputs": 1024, "targets": 32},
    learning_rate_schedule=0.003,
    save_checkpoints_steps=100,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

In [None]:
if ON_CLOUD:
  %reload_ext tensorboard
%tensorboard --logdir="$MODEL_DIR" --port=0

In [None]:
print(model._keep_checkpoint_max)

In [None]:
FINETUNE_STEPS =  3000#@param {type: "integer"}

model._save_checkpoints_steps=200
model._keep_checkpoint_max = 50

model.finetune(
    mixture_or_task_name="all_mix",
    pretrained_model_dir=PRETRAINED_DIR,
    finetune_steps=FINETUNE_STEPS
)

In [None]:
print(model.batch_size)

In [None]:
#model.batch_size = train_batch_size * 4
model.eval(
    mixture_or_task_name="ambiguity_context",
    #checkpoint_steps="all",
    checkpoint_steps=[1000300, 1000600, 1000900, 1001200, 1001500, 1001800, 1002100, 1002400, 1002700, 1003000],
    #checkpoint_steps=[1000500,1001000,1001500, 1002000, 1002500, 1003000, 1003500, 1004000, 1004500, 1005000, 1006000, 1007000, 1008000, 1009000, 1010000],
    #checkpoint_steps = [1001500],
    summary_dir = MODEL_DIR
)