In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from t5.data import preprocessors as prep
import functools
import t5
import gin
import sentencepiece as spm
from glob import glob
import os

gin.parse_config_file('pretrained_models_base_operative_config.gin')
vocab = 'sp10m.cased.t5.model'
sp = spm.SentencePieceProcessor()
sp.Load(vocab)

True

In [6]:
files = [
    '../../pure-text/dumping-parliament.txt',
    '../../pure-text/filtered-dumping-wiki.txt',
    '../../pure-text/filtered-dumping-cleaned-common-crawl.txt',
    '../../pure-text/filtered-dumping-academia.txt',
    '../../news/cleaned-news.txt'
]

In [7]:
os.path.split(files[0])[1]

'dumping-parliament.txt'

In [8]:
import re

def cleaning(string):
    string = string.replace('\n', ' ').replace('\t', ' ')
    string = re.sub(r'[ ]+', ' ', string).strip()
    return string

In [9]:
for file in files:
    with open(file) as fopen:
        data = list(filter(None, fopen.read().split('\n')))
    print(file, len(data))
    s = os.path.split(file)[1]
    filename = f'{s}.tsv'
    with tf.io.gfile.GFile(filename, 'w') as outfile:
        for i in range(len(data)):
            outfile.write('%s\t%s\n' % ('', cleaning(data[i])))

../../pure-text/dumping-parliament.txt 277157
../../pure-text/filtered-dumping-wiki.txt 2037602
../../pure-text/filtered-dumping-cleaned-common-crawl.txt 41666319
../../pure-text/filtered-dumping-academia.txt 4086649
../../news/cleaned-news.txt 3483907


In [10]:
def dumping_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset(
        ['dumping-parliament.txt.tsv', 'filtered-dumping-wiki.txt.tsv']
    )

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['title', 'text'], ex)))
    return ds

In [11]:
t5.data.TaskRegistry.remove('dumping_dataset')
t5.data.TaskRegistry.add(
    'dumping_dataset',
    dataset_fn = dumping_dataset,
    splits = ['train'],
    text_preprocessor = functools.partial(
        t5.data.preprocessors.rekey,
        key_map = {'inputs': None, 'targets': 'text'},
    ),
    token_preprocessor = t5.data.preprocessors.unsupervised,
    sentencepiece_model_path = vocab,
    metric_fns = [],
)

  "get_sentencepiece_model_path is deprecated. Please pass the mixture or "


In [12]:
nq_task = t5.data.TaskRegistry.get("dumping_dataset")
ds = nq_task.get_dataset(split='qa.tsv', sequence_length={"inputs": 1024, "targets": 1024})
r = tfds.as_numpy(ds)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
INFO:tensorflow:tokens_length=1137 inputs_length=1024 targets_length=229 noise_density=0.15 mean_noise_span_length=3.0 


In [13]:
next(r)

{'inputs': array([   38,   453,    15,     3,  1500,    80,  1439,    15,    12,
         2868, 11617, 19992,  1040,  3462,  8672,   402, 11769,  1403,
        11466,   223,   124,    18,     7,   738,    27,  1576,    72,
           33,  1638,   950,    43,   717,  1914,    37,    48,   303,
           30,    40,    92,   867,  1382, 15929,    16,  8721,    13,
         7360, 32099,    15,  9190,  6121,    15, 32098,  1937,   359,
          304,    15,     3, 32097,    25,    13,    16,  8721,    13,
         7360,    15,     4,  1483,    15,  9190,  6121, 32096,    35,
           43,  2296,   635,    21,  1907,  7442,    75,  4178,    46,
          116,   624,  1522,  2185,    22,  1112,  6569,    40,    16,
        32095,   229,   651,   231,  1973,    19,   274,   845,   537,
           15,     4, 32094,     5,   524,    73, 11429,    15,     4,
         8721,    13,  7360,   524,  1483,    15,  9190,  6121,    15,
            5,   651,  1075,  2483,  2696,    15,     3,    15, 156