In [1]:
import os

os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'mesolitica-tpu.json'

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
from t5.data import preprocessors as prep
import functools
import t5

In [3]:
import gin

gin.parse_config_file('gs://mesolitica-tpu-general/t5-data/pretrained_models_base_operative_config.gin')

In [4]:
vocab = 'gs://mesolitica-tpu-general/t5-data-v2/sp10m.cased.ms-en.model'

In [5]:
def dumping_dataset(split, shuffle_files = False):
    del shuffle_files
    files = [
        'gs://mesolitica-tpu-general/t5-data-v2/dumping-news.txt.tsv',
        'gs://mesolitica-tpu-general/t5-data-v2/dumping-parliament.txt.tsv',
        'gs://mesolitica-tpu-general/t5-data-v2/filtered-dumping-academia.txt.tsv',
        'gs://mesolitica-tpu-general/t5-data-v2/filtered-dumping-wiki.txt.tsv'
    ]
    files.extend(tf.io.gfile.glob('gs://mesolitica-tpu-general/t5-data-v2/00.jsonl-*.translated.txt.tsv'))
    ds = tf.data.TextLineDataset(files)

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['title', 'text'], ex)))
    return ds


t5.data.TaskRegistry.remove('dumping_dataset')
t5.data.TaskRegistry.add(
    'dumping_dataset',
    dataset_fn = dumping_dataset,
    splits = ['train'],
    text_preprocessor = functools.partial(
        t5.data.preprocessors.rekey,
        key_map = {'inputs': None, 'targets': 'text'},
    ),
    token_preprocessor = t5.data.preprocessors.unsupervised,
    sentencepiece_model_path = vocab,
    metric_fns = [],
)

In [6]:
nq_task = t5.data.TaskRegistry.get('dumping_dataset')
ds = nq_task.get_dataset(
    split = 'train', sequence_length = {'inputs': 512, 'targets': 512}
)
for ex in tfds.as_numpy(ds.take(5)):
    print(ex)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
INFO:tensorflow:tokens_length=1137 inputs_length=1024 targets_length=229 noise_density=0.15 mean_noise_span_length=3.0 
{'inputs': array([   14,   494,   270,     7,   883,    17,   964,    24,   272,
       10496,     3,  4500,    83,  1105,    14,   158,  2255,    41,
         307,    14,  5400, 32099, 10735,  5763,    34,   157, 10237,
        8159,    17,  9804,     3,   418,  1232,  7478,   208, 24093,
          24, 11009,   983,  9447,  3914,  2607, 22725,  7313, 12451,
        1327,  2821,   125, 22725,  6166,  1056, 32098,     3,  7884,
        3540, 32097,  7329,    42,  3099,  3105,    62,  7925,   179,
           3,   276,  1773,    73,    25,  4282, 11762, 32096,  1650,
        1434,  1451,     3,  1865,   522,    37,    14,   179,     3,
        1727, 32095,  4282, 11762,  6989,    14,    87, 11479,  4282,
       32094,  6989,    14,    87,   195,  1974,   200,   793,  2427,
     

In [7]:
def question_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset(
        [
            'gs://mesolitica-tpu-general/t5-data-v2/qa.tsv',
        ]
    )

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))
    return ds


def question_preprocessor(ds):
    def to_inputs_and_targets(ex):
        return {
            'inputs': tf.strings.join(['soalan: ', ex['question']]),
            'targets': ex['answer'],
        }

    return ds.map(
        to_inputs_and_targets,
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )


t5.data.TaskRegistry.remove('question_dataset')
t5.data.TaskRegistry.add(
    'question_dataset',
    dataset_fn = question_dataset,
    splits = ['train'],
    text_preprocessor = [question_preprocessor],
    sentencepiece_model_path = vocab,
    postprocess_fn = t5.data.postprocessors.lower_text,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

In [8]:
nq_task = t5.data.TaskRegistry.get('question_dataset')
ds = nq_task.get_dataset(
    split = 'train', sequence_length = {'inputs': 256, 'targets': 32}
)
for ex in tfds.as_numpy(ds.take(5)):
    print(ex)

{'inputs_plaintext': b'soalan: yang memenangi liga juara sebanyak 2020 pada tahun 2013?', 'inputs': array([2762,   31,   17, 1477, 4983, 5805,  739, 3229,   33,   53,  646,
         77,    1]), 'targets_plaintext': b'Orang India Mumbai', 'targets': array([ 1033,   688, 17801,     1])}
{'inputs_plaintext': b'soalan: di mana jaren jackson senior bermain bola keranjang kolej?', 'inputs': array([ 2762,    31,    24,   185,    13,  3410,   258,    13, 15469,
         729,  2000,   964,  1078,  7951,  4509,    77,     1]), 'targets_plaintext': b'Universiti Georgetown', 'targets': array([ 1908, 17322,     1])}
{'inputs_plaintext': b'soalan: siapa yang menyanyikan anak lelaki yang baik?', 'inputs': array([ 2762,    31,  1652,    17, 10835,   270,   257,    17,   187,
          77,     1]), 'targets_plaintext': b'Cockerel Chorus', 'targets': array([ 858, 5389,  607, 4782, 5901,    1])}
{'inputs_plaintext': b'soalan: siapa penyanyi asal saya menembak sheriff?', 'inputs': array([ 2762,    31,  16

In [9]:
def pair_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset(tf.io.gfile.glob('gs://mesolitica-tpu-general/t5-data-v2/*pair.tsv'))

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['text'], ex)))
    return ds


t5.data.TaskRegistry.remove('pair_dataset')
t5.data.TaskRegistry.add(
    'pair_dataset',
    dataset_fn = pair_dataset,
    splits = ['train'],
    text_preprocessor = [prep.next_sentence_prediction],
    sentencepiece_model_path = vocab,
    postprocess_fn = t5.data.postprocessors.lower_text,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

In [10]:
nq_task = t5.data.TaskRegistry.get('pair_dataset')
ds = nq_task.get_dataset(
    split = 'train', sequence_length = {'inputs': 256, 'targets': 32}
)
for ex in tfds.as_numpy(ds.take(5)):
    print(ex)

      lambda x: x[text_key], num_parallel_calls=tf.data.experimental.AUTOTUNE)

If this is a lambda function, the error may be avoided by creating the lambda in a standalone statement.
      lambda x: x[text_key], num_parallel_calls=tf.data.experimental.AUTOTUNE)

If this is a lambda function, the error may be avoided by creating the lambda in a standalone statement.
{'inputs_plaintext': b'nsp: Bahasa yang ditekankan, dipetik dengan baik di Serbin. menyiratkan bahawa pihak yang tidak berada dalam persaingan langsung.', 'inputs': array([   13,   152,  4615,    31,  4730,    17, 21621,   103,    14,
       23968,    28,   187,    24, 25847,   153,     3, 27406,    56,
         432,    17,    30,   288,    36,  5563,  1320,     3,     1]), 'targets_plaintext': b'next', 'targets': array([554,   1])}
{'inputs_plaintext': b'nsp: Walaupun ragu-ragu mengenai kecenderungan *C. auris* untuk menyebabkan jangkitan saluran kencing atau empyema seperti yang dinyatakan oleh CDC, kami mendapati bahawa

In [11]:
def news_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset(
        [
            'gs://mesolitica-tpu-general/t5-data-v2/newstitle.tsv'
        ]
    )

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))
    return ds


def news_preprocessor(ds):
    def to_inputs_and_targets(ex):
        return {
            'inputs': tf.strings.join(['tajuk: ', ex['question']]),
            'targets': ex['answer'],
        }

    return ds.map(
        to_inputs_and_targets,
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )


t5.data.TaskRegistry.remove('news_dataset')
t5.data.TaskRegistry.add(
    'news_dataset',
    dataset_fn = news_dataset,
    splits = ['train'],
    text_preprocessor = [news_preprocessor],
    sentencepiece_model_path = vocab,
    postprocess_fn = t5.data.postprocessors.lower_text,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

In [12]:
nq_task = t5.data.TaskRegistry.get('news_dataset')
ds = nq_task.get_dataset(
    split = 'train', sequence_length = {'inputs': 1024, 'targets': 1024}
)
for ex in tfds.as_numpy(ds.take(1)):
    print(ex)

{'inputs_plaintext': b'tajuk: Mahkamah Tinggi di sini hari ini menetapkan 1 Mac bagi pengurusan kes petisyen pilihan raya bagi kerusi Parlimen Kimanis dan Sipitang yang dimenangi calon Barisan Nasional (BN) pada Pilihan Raya Umum Ke-14 (PRU-14), 9 Mei lalu. Timbalan Pendaftar Mahkamah Tinggi Kota Kinabalu, Cindy Mc Juce Balitus menetapkan tarikh itu bagi Parlimen Kimanis yang akan didengar di hadapan Hakim Datuk Lee Heng Cheong di Mahkamah Tinggi di sini. Dalam pendengaran pengurusan kes dalam kamar itu, Suruhanjaya Pilihan Raya (SPR) diwakili peguam Faizal Sarbi dan Abdul Fikri Jaafar. Calon Parti Warisan Sabah (WARISAN) bagi Parlimen Kimanis, Datuk Karim Bujang diwakili Syaiful Sufeyyan Sidin dan Ahli Parlimen Kimanis, Datuk Seri Anifah Aman diwakili Wilson Chang Khai Sim serta Rizwandean M Borhan. Kelmarin, Mahkamah Persekutuan di Putrajaya memerintahkan perbicaraan petisyen pilihan raya diadakan bagi kerusi Parlimen Kimanis dan Sipitang yang dimenangi calon BN pada PRU-14. Panel li

In [13]:
def summarization_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset(
        [
            'gs://mesolitica-tpu-general/t5-data-v2/summarization.tsv'
        ]
    )

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))
    return ds


def summarization_preprocessor(ds):
    def to_inputs_and_targets(ex):
        return {
            'inputs': tf.strings.join(['ringkasan: ', ex['question']]),
            'targets': ex['answer'],
        }

    return ds.map(
        to_inputs_and_targets,
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )


t5.data.TaskRegistry.remove('summarization_dataset')
t5.data.TaskRegistry.add(
    'summarization_dataset',
    dataset_fn = summarization_dataset,
    splits = ['train'],
    text_preprocessor = [summarization_preprocessor],
    sentencepiece_model_path = vocab,
    postprocess_fn = t5.data.postprocessors.lower_text,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

In [14]:
nq_task = t5.data.TaskRegistry.get('summarization_dataset')
ds = nq_task.get_dataset(
    split = 'train', sequence_length = {'inputs': 1024, 'targets': 1024}
)
for ex in tfds.as_numpy(ds.take(1)):
    print(ex)

{'inputs_plaintext': b"ringkasan: Seorang wanita Texas dijatuhi hukuman penjara seumur hidup tanpa pembebasan bersyarat Isnin kerana mencekik anak perempuan tunangnya hingga mati setelah dia menemui teks-teks perkauman di telefonnya dan mengakhiri hubungan itu. Juri mendapati Melinda Muniz, kini berusia 26 tahun, membunuh Grace Ford pada 9 Januari 2014 - beberapa jam selepas bapa Grace, seorang veteran Perang Iraq, menyuruhnya berpindah keluar dari pangsapuri Plano, Texas, mereka. 'Anda adalah jahat tulen,' nenek Grace berkata kepada Muniz dalam satu kenyataan impak di mahkamah, WFAA melaporkan. 'Anda tidak akan dapat menyentuh anak lagi. 'Perbicaraan pembunuhan: Melinda Muniz, 26, (kiri) didakwa membunuh anak perempuan tunangnya, Grace (kanan) Grace, dua tahun, yang juga mengalami kecederaan parah pada alat kelaminnya, meninggal dunia ketika dia dikeluarkan sokongan hidup tiga hari kemudian di hospital. Muniz tidak memberi keterangan pada perbicaraan, di mana juri mendengar dia telah 

In [15]:
def similarity_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset(
        [
            'gs://mesolitica-tpu-general/t5-data-v2/snli.tsv',
            'gs://mesolitica-tpu-general/t5-data-v2/mnli.tsv'
        ]
    )

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))
    return ds


def similarity_preprocessor(ds):
    def to_inputs_and_targets(ex):
        return {
            'inputs': ex['question'],
            'targets': ex['answer'],
        }

    return ds.map(
        to_inputs_and_targets,
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )


t5.data.TaskRegistry.remove('similarity_dataset')
t5.data.TaskRegistry.add(
    'similarity_dataset',
    dataset_fn = similarity_dataset,
    splits = ['train'],
    text_preprocessor = [similarity_preprocessor],
    sentencepiece_model_path = vocab,
    postprocess_fn = t5.data.postprocessors.lower_text,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

In [16]:
nq_task = t5.data.TaskRegistry.get('similarity_dataset')
ds = nq_task.get_dataset(
    split = 'train', sequence_length = {'inputs': 256, 'targets': 32}
)
for ex in tfds.as_numpy(ds.take(1)):
    print(ex)

{'inputs_plaintext': b'ayat1: Seorang lelaki sedang duduk di kerusi hijau, bercakap di telefon, dan bekerja di komputer riba, dan terdapat tapak pembinaan bersebelahan dengan bangunannya. ayat2: Lelaki itu boleh mendengar pembinaan yang sedang berjalan.', 'inputs': array([13694,   201,    31,  1064,   257,   579,  1625,    24,  3160,
        4359,    14,   962,    24,  1034,    14,    22,   616,    24,
        2411, 15021,    14,    22,   697,  6932,  2825, 17609,    28,
        1527,    38,     3, 13694,   215,    31,  5561,    37,   150,
        1605,  2825,    17,   579,  1047,     3,     1]), 'targets_plaintext': b'neutral', 'targets': array([12712,     1])}


In [17]:
def en_ms_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset(
        [
            'gs://mesolitica-tpu-general/t5-data-v2/en-ms.tsv'
        ]
    )

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))
    return ds


def en_ms_preprocessor(ds):
    def to_inputs_and_targets(ex):
        return {
            'inputs': tf.strings.join(['terjemah Inggeris ke Melayu: ', ex['question']]),
            'targets': ex['answer'],
        }

    return ds.map(
        to_inputs_and_targets,
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )


t5.data.TaskRegistry.remove('en_ms_dataset')
t5.data.TaskRegistry.add(
    'en_ms_dataset',
    dataset_fn = en_ms_dataset,
    splits = ['train'],
    text_preprocessor = [en_ms_preprocessor],
    sentencepiece_model_path = vocab,
    postprocess_fn = t5.data.postprocessors.lower_text,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

In [18]:
nq_task = t5.data.TaskRegistry.get('en_ms_dataset')
ds = nq_task.get_dataset(
    split = 'train', sequence_length = {'inputs': 1024, 'targets': 1024}
)
for ex in tfds.as_numpy(ds.take(1)):
    print(ex)

{'inputs_plaintext': b"terjemah Inggeris ke Melayu: Right, it's not even an email.", 'inputs': array([   13, 26087,  2040,    55,  1550,    31,  8471,    14,    43,
          12,    16,    69,   318,    80,  4083,     3,     1]), 'targets_plaintext': b'Betul, itu bukan nya e-mail.', 'targets': array([17232,    14,    37,   232,    13,    38,    13,    81,     7,
        4114,     3,     1])}


In [19]:
def ms_en_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset(
        [
            'gs://mesolitica-tpu-general/t5-data-v2/ms-en.tsv'
        ]
    )

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))
    return ds


def ms_en_preprocessor(ds):
    def to_inputs_and_targets(ex):
        return {
            'inputs': tf.strings.join(['terjemah Melayu ke Inggeris: ', ex['question']]),
            'targets': ex['answer'],
        }

    return ds.map(
        to_inputs_and_targets,
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )


t5.data.TaskRegistry.remove('ms_en_dataset')
t5.data.TaskRegistry.add(
    'ms_en_dataset',
    dataset_fn = ms_en_dataset,
    splits = ['train'],
    text_preprocessor = [ms_en_preprocessor],
    sentencepiece_model_path = vocab,
    postprocess_fn = t5.data.postprocessors.lower_text,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

In [20]:
nq_task = t5.data.TaskRegistry.get('ms_en_dataset')
ds = nq_task.get_dataset(
    split = 'train', sequence_length = {'inputs': 1024, 'targets': 1024}
)
for ex in tfds.as_numpy(ds.take(1)):
    print(ex)

{'inputs_plaintext': b'terjemah Melayu ke Inggeris: Meliputi pelbagai genre, filem-filemnya termasuk "Near Dark" (1987), "Point Break" (1991), "Strange Days" (1995), "K-19: The Widowmaker" (2002), "The Hurt Locker" (2008), "Zero Dark Thirty" (2012), dan "Detroit" (2017).', 'inputs': array([   13, 26087,  1550,    55,  2040,    31,   777,  5610, 13612,
         879,  7664,    14,   492,     7,  9733,    38,   293,    13,
           6, 13934,   382, 11786,     6,    13,     4,   756,  4406,
           5,    14,    13,     6, 21341,    13, 17035,     6,    13,
           4,   756,  5569,     5,    14,    13,     6,  8138, 10484,
        1714,    16,     6,    13,     4,   756,  2920,     5,    14,
          13,     6,   471,     7,   756,    31,    35,  3075, 13524,
        9965,     6,    13,     4, 23882,     5,    14,    13,     6,
         198,    13, 16025,    41,  1743,  5389,     6,    13,     4,
       14089,     5,    14,    13,     6,  1757,  5418, 11786, 26881,
           6,   

In [23]:
def knowledge_graph_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset(
        [
            'gs://mesolitica-tpu-general/t5-data-v2/knowledge-graph.tsv'
        ]
    )

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))
    return ds

def knowledge_graph_preprocessor(ds):
    def to_inputs_and_targets(ex):
        return {
            'inputs': tf.strings.join(['grafik pengetahuan: ', ex['question']]),
            'targets': ex['answer'],
        }

    return ds.map(
        to_inputs_and_targets,
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )

t5.data.TaskRegistry.remove('knowledge_graph_dataset')
t5.data.TaskRegistry.add(
    'knowledge_graph_dataset',
    dataset_fn = knowledge_graph_dataset,
    splits = ['train'],
    text_preprocessor = [knowledge_graph_preprocessor],
    sentencepiece_model_path = vocab,
    postprocess_fn = t5.data.postprocessors.lower_text,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

In [24]:
nq_task = t5.data.TaskRegistry.get('knowledge_graph_dataset')
ds = nq_task.get_dataset(
    split = 'train', sequence_length = {'inputs': 1024, 'targets': 1024}
)
for ex in tfds.as_numpy(ds.take(1)):
    print(ex)

{'inputs_plaintext': b'grafik pengetahuan: Connie Dierking bermain untuk Sacramento Kings dan Philadelphia Tapers dalam Liga Bola Keranjang Kebangsaan. Dia mewakili Amerika Syarikat.', 'inputs': array([12333,  5836,    31,  1597,  5217,   208,   125,  2921,   964,
          25, 19188, 13423,    22,  4093, 17082,   287,    36,  4529,
        7910, 31548,  1697,     3,   160,  2860,   209,   487,     3,
           1]), 'targets_plaintext': b"Connie Dierking country for sport United States, member of sports team Philadelphia Tapers, competition class men's basketball, member of sports team Sacramento Kings.", 'targets': array([ 1597,  5217,   208,   125,  2921,   286,    29,  6177,   485,
         891,    14,  1988,    18,  3364,   650,  4093, 17082,   287,
          14,  3081,  2816,   440,    12,    16,  9272,    14,  1988,
          18,  3364,   650, 19188, 13423,     3,     1])}


In [25]:
def paraphrase_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset(
        [
            'gs://mesolitica-tpu-general/t5-data-v2/paraphrase.tsv'
        ]
    )

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))
    return ds

def paraphrase_preprocessor(ds):
    def to_inputs_and_targets(ex):
        return {
            'inputs': tf.strings.join(['parafrasa: ', ex['question']]),
            'targets': ex['answer'],
        }

    return ds.map(
        to_inputs_and_targets,
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )

t5.data.TaskRegistry.remove('paraphrase_dataset')
t5.data.TaskRegistry.add(
    'paraphrase_dataset',
    dataset_fn = paraphrase_dataset,
    splits = ['train'],
    text_preprocessor = [paraphrase_preprocessor],
    sentencepiece_model_path = vocab,
    postprocess_fn = t5.data.postprocessors.lower_text,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

In [26]:
nq_task = t5.data.TaskRegistry.get('paraphrase_dataset')
ds = nq_task.get_dataset(
    split = 'train', sequence_length = {'inputs': 1024, 'targets': 1024}
)
for ex in tfds.as_numpy(ds.take(1)):
    print(ex)

{'inputs_plaintext': b'parafrasa: Pada bulan November, Royals memperoleh CF Coco Crisp dari Boston Red Sox sebagai pertukaran untuk RP Ramon Ramirez.', 'inputs': array([  445,  4435,   722,    31,   206,   204,   664,    14,  2586,
          16,  1243,    13, 13547, 19576, 16469,  4615,    42,  1010,
        1975, 11434,    85,  4137,    25,    13, 15240, 24817, 23171,
           3,     1]), 'targets_plaintext': b'Pada bulan November, Royals CF Coco Crisp diperoleh dari Boston Red Sox sebagai pertukaran untuk RP Ramon Ramirez.', 'targets': array([  206,   204,   664,    14,  2586,    16,    13, 13547, 19576,
       16469,  4615,  5578,    42,  1010,  1975, 11434,    85,  4137,
          25,    13, 15240, 24817, 23171,     3,     1])}
