In [1]:
# !wget https://malaya-dataset.s3-ap-southeast-1.amazonaws.com/qa/natural/translated-train.json
# !wget https://malaya-dataset.s3-ap-southeast-1.amazonaws.com/qa/natural/translated-validation.json

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [3]:
import json
with open('translated-validation.json') as fopen:
    data = json.load(fopen)
    
data[:5]

['apa maksud 3 titik dalam matematik? <> Oleh itu tanda (∴) biasanya digunakan sebelum akibat logik, seperti kesimpulan silogisme',
 'siapa yang bermain pertunjukan separuh masa di super bowl 2016? <> Coldplay dengan persembahan tetamu khas Beyoncé dan Bruno Mars',
 'di mana forum ekonomi dunia diadakan tahun ini? <> Davos, sebuah resort gunung di Graubünden, di wilayah timur Alpen Switzerland',
 'siapa yang memenangi personaliti sukan tahun 2017? <> Mo Farah',
 'siapakah ketua menteri pertama bengal barat? <> Prafulla Chandra Ghosh dari Kongres Nasional India']

In [4]:
questions, answers = [], []
for row in data:
    if '<>' not in row:
        q, a = row.split('? ')
        q = f'{q}?'
    else:
        q, a = row.split('<>')
    questions.append(q.strip())
    answers.append(a.strip())

In [5]:
len(questions), len(answers)

(2338, 2338)

In [6]:
with tf.io.gfile.GFile('out.tsv', "w") as outfile:
    for i in range(len(questions)):
        outfile.write("%s\t%s\n" % (questions[i], answers[i]))

In [7]:
import t5
print(t5.data.MixtureRegistry.names())

dict_keys(['glue_v002_proportional', 'super_glue_v102_proportional', 'glue_mnli_and_dev_v002', 'en_mix', 'all_equal', 'all_proportional', 'all_mix', 'leave_one_out_glue_v002_proportional', 'leave_one_out_super_glue_v102_proportional', 'leave_one_out_cnn_dailymail_v002', 'leave_one_out_squad_v010_allanswers', 'leave_one_out_wmt_t2t_ende_v003', 'leave_one_out_wmt15_enfr_v003', 'leave_one_out_wmt16_enro_v003', 'large_supervised_equal', 'large_supervised_proportional', 'large_translation_equal', 'squad_trivia_qa_equal', 'wsc_dpr_simple_proportional'])


In [None]:
import functools

def nq_dataset_fn(split, shuffle_files=False):
    del shuffle_files
    ds = tf.data.TextLineDataset(['out.tsv'])

    ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Map each tuple to a {"question": ... "answer": ...} dict.
    ds = ds.map(lambda *ex: dict(zip(["question", "answer"], ex)))
    return ds
    
def trivia_preprocessor(ds):
    def to_inputs_and_targets(ex):
        """Map {"question": ..., "answer": ...}->{"inputs": ..., "targets": ...}."""
        return {
            "inputs":
                 tf.strings.join(
                     ["soalan: ", ex["question"]]),
            "targets": ex["answer"]
        }
    return ds.map(to_inputs_and_targets, 
                    num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
# ds = trivia_preprocessor(nq_dataset_fn('a'))
# for ex in tfds.as_numpy(ds.take(5)):
#     print(ex)

In [None]:
t5.data.TaskRegistry.remove('nq_context_free')
t5.data.TaskRegistry.add(
    "nq_context_free",
    dataset_fn=nq_dataset_fn,
    splits=["train"],
    text_preprocessor=[trivia_preprocessor],
    sentencepiece_model_path='sp10m.cased.t5.model',
    postprocess_fn=t5.data.postprocessors.lower_text, 
    metric_fns=[t5.evaluation.metrics.accuracy],
)

In [None]:
# !pip3 install tf-sentencepiece sentencepiece tensorflow-text==1.15 tfds-nightly --no-deps

In [None]:
nq_task = t5.data.TaskRegistry.get("nq_context_free")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 32})

In [None]:
for ex in tfds.as_numpy(ds.take(5)):
    print(ex)

In [None]:
t5.data.MixtureRegistry.remove("trivia_all")
t5.data.MixtureRegistry.add(
    "trivia_all",
    ["nq_context_free"],
     default_rate=1.0
)