In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from t5.data import preprocessors as prep
import functools
import t5
import gin
import sentencepiece as spm
from glob import glob
import os

gin.parse_config_file('pretrained_models_base_operative_config.gin')
vocab = 'sp10m.cased.t5.model'
sp = spm.SentencePieceProcessor()
sp.Load(vocab)

True

In [2]:
files = [
    '../../pure-text/dumping-parliament.txt',
    '../../pure-text/filtered-dumping-wiki.txt',
    '../../pure-text/filtered-dumping-cleaned-common-crawl.txt',
    '../../pure-text/filtered-dumping-academia.txt',
    '../../news/cleaned-news.txt'
]

In [3]:
import re

def cleaning(string):
    string = string.replace('\n', ' ').replace('\t', ' ')
    string = re.sub(r'[ ]+', ' ', string).strip()
    return string

In [9]:
with open(files[-2]) as fopen:
    data = fopen.read().split('\n')
results, result = [], []
for i in data:
    if not len(i) and len(result):
        results.append(' '.join(result))
        result = []
    else:
        if len(i):
            result.append(i)
if len(result):
    results.append(' '.join(result))

In [13]:
for file in files:
    with open(file) as fopen:
        data = fopen.read().split('\n')
    results, result = [], []
    for i in data:
        if not len(i) and len(result):
            results.append(' '.join(result))
            result = []
        else:
            if len(i):
                result.append(i)
    if len(result):
        results.append(' '.join(result))
        
    print(file, len(results))
    s = os.path.split(file)[1]
    filename = f'{s}-pair.tsv'
    
    with tf.io.gfile.GFile(filename, 'w') as outfile:
        for i in range(len(results)):
            outfile.write('%s\t\n' % (cleaning(results[i])))

../../pure-text/dumping-parliament.txt 69702
../../pure-text/filtered-dumping-wiki.txt 363578
../../pure-text/filtered-dumping-cleaned-common-crawl.txt 8915464
../../pure-text/filtered-dumping-academia.txt 963765
../../news/cleaned-news.txt 173012


In [14]:
def pair_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset(
        glob('*pair.tsv')
    )

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['text'], ex)))
    return ds


t5.data.TaskRegistry.remove('pair_dataset')
t5.data.TaskRegistry.add(
    'pair_dataset',
    dataset_fn = pair_dataset,
    splits = ['train'],
    text_preprocessor = [prep.next_sentence_prediction],
    sentencepiece_model_path = vocab,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

  "get_sentencepiece_model_path is deprecated. Please pass the mixture or "


In [15]:
nq_task = t5.data.TaskRegistry.get("pair_dataset")
ds = nq_task.get_dataset(split='qa.tsv', sequence_length={"inputs": 1024, "targets": 1024})

      lambda x: x[text_key], num_parallel_calls=tf.data.experimental.AUTOTUNE)

If this is a lambda function, the error may be avoided by creating the lambda in a standalone statement.
      lambda x: x[text_key], num_parallel_calls=tf.data.experimental.AUTOTUNE)

If this is a lambda function, the error may be avoided by creating the lambda in a standalone statement.


In [16]:
r = tfds.as_numpy(ds)

In [19]:
next(r)

{'inputs_plaintext': b'nsp: Dr Mahathir, yang juga Pengerusi Pakatan Harapan berkata, kajian semula terhadap M63 itu pasti akan menimbulkan semula keperluan penilaian semula terhadap beberapa amalan yang telah terlaksana sebelum ini. Ini merupakan kali pertama Perdana Menteri ke negeri ini selepas beliau dipilih semula mengetuai kerajaan persekutuan pada Pilihan Raya Umum yang lepas.',
 'inputs': array([ 2532,  7175,    50,   122,   386,    14,    17,    42,   758,
          919,   831,    60,    14,   577,   196,   109,   190,  2963,
           24,   808,    38,  1834,   196,   646,  2907,   196,   109,
          103,  1436,    17,    33, 21075,   137,    20,     3,   297,
           34,   252,   131,   203,    82,    30,    86,    20,    98,
           71,  1345,   196,  4479,    72,  2835,    23,   866,   489,
         1339,    17,  1155,     3,     1]),
 'targets_plaintext': b'next',
 'targets': array([7426,    1])}