In [1]:
# !pip3 install promise tensorflow-datasets dill future tensorflow-metadata
# !pip3 install tensorflow-datasets tf-sentencepiece sentencepiece tensorflow-text==1.15 tfds-nightly --no-deps
# !pip3 install t5[gcp]

In [2]:
# !pip3 install tensorflow==1.15 tensorflow-text==1.15

In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds
from t5.data import preprocessors as prep
import functools
import t5
t5.__version__

'0.5.0'

In [4]:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver('node-1', 'us-central1-a', 'mesolitica-cloud')
TPU_ADDRESS = tpu.get_master()
TPU_TOPOLOGY = '2x2'
TPU_ADDRESS

'grpc://10.76.157.10:8470'

In [5]:
vocab = 'gs://mesolitica-general/t5-vocab/sp10m.cased.t5.model'

In [6]:
def dumping_dataset(split, shuffle_files=False):
    del shuffle_files
    ds = tf.data.TextLineDataset(['gs://mesolitica-general/t5-data/dumping-iium.tsv'])

    ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.map(lambda *ex: dict(zip(["title", "text"], ex)))
    return ds

t5.data.TaskRegistry.remove('dumping_txt')
t5.data.TaskRegistry.add(
    "dumping_txt",
    dataset_fn=dumping_dataset,
    splits=["train"],
    text_preprocessor=functools.partial(
        t5.data.preprocessors.rekey, key_map={"inputs": None, "targets": "text"}),
    token_preprocessor=t5.data.preprocessors.unsupervised,
    sentencepiece_model_path=vocab,
    metric_fns=[])

In [7]:
# nq_task = t5.data.TaskRegistry.get("dumping_txt")
# ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 1024, "targets": 32})
# for ex in tfds.as_numpy(ds.take(5)):
#     print(ex)

In [8]:
t5.data.MixtureRegistry.remove("trivia_all")
t5.data.MixtureRegistry.add(
    "trivia_all",
    ['dumping_txt'],
     default_rate=1.0
)

In [9]:
import gin

gin.parse_config_file('pretrained_models_base_operative_config.gin')

In [10]:
MODEL_SIZE = 'base'
model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}['base']

In [11]:
model = t5.models.MtfModel(
    model_dir='gs://mesolitica-general/t5-base/',
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,    
    batch_size=train_batch_size,
    sequence_length={"inputs": 1024, "targets": 1024},
    learning_rate_schedule=0.003,
    save_checkpoints_steps=5000,
    iterations_per_loop=100,
)

In [12]:
model.train(mixture_or_task_name='trivia_all', steps=100)

INFO:tensorflow:Using config: {'_model_dir': 'gs://mesolitica-general/t5-base/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.76.157.10:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fb67f2ac080>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.76.157.10:8470', '_evaluation_master': 'grpc://10.76.157.10:8470', '_is_chief': True, '_num_ps_replica

Exception in thread Thread-5:
Traceback (most recent call last):
  File "/home/ubuntu/.local/lib/python3.6/site-packages/tensorflow_core/python/distribute/cluster_resolver/tpu_cluster_resolver.py", line 476, in _fetch_cloud_tpu_metadata
    return request.execute()
  File "/home/ubuntu/.local/lib/python3.6/site-packages/googleapiclient/_helpers.py", line 134, in positional_wrapper
    return wrapped(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.6/site-packages/googleapiclient/http.py", line 898, in execute
    raise HttpError(resp, content, uri=self.uri)
googleapiclient.errors.HttpError: <HttpError 403 when requesting https://tpu.googleapis.com/v1/projects/None/locations/None/nodes/10.76.157.10:8470?alt=json returned "Permission denied on resource project None.". Details: "[{'@type': 'type.googleapis.com/google.rpc.Help', 'links': [{'description': 'Google developer console API key', 'url': 'https://console.developers.google.com/project/None/apiui/credential'}]}]">

During ha

INFO:tensorflow:Done with copy master to slices.
INFO:tensorflow:Saving checkpoints for 0 into gs://mesolitica-general/t5-base/model.ckpt.
INFO:tensorflow:Before Save.
INFO:tensorflow:About to write a checkpoint
INFO:tensorflow:gs://mesolitica-general/t5-base/model.ckpt-0 is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Done writing checkpoint.
INFO:tensorflow:Enqueue next (100) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (100) batch(es) of data from outfeed.
INFO:tensorflow:Outfeed finished for iteration (0, 0)
INFO:tensorflow:Outfeed finished for iteration (0, 46)
INFO:tensorflow:Outfeed finished for iteration (0, 92)
INFO:tensorflow:loss = 0.6796875, step = 100
INFO:tensorflow:Stop infeed thread controller
INFO:tensorflow:Shutting down InfeedController thread.
INFO:tensorflow:InfeedController received shutdown signal, stopping.
INFO:tensorflow:Infeed thread finished, shutting down.
INFO:tensorflow:infeed marked as finished
INFO:tensorflow:Stop 