In [3]:
# !wget https://f000.backblazeb2.com/file/malay-dataset/qa/natural/translated-train.json
# !wget https://f000.backblazeb2.com/file/malay-dataset/qa/natural/translated-validation.json

In [2]:
files = ['translated-train.json', 'translated-validation.json']

In [4]:
import re

def cleaning(string):
    string = string.replace('\n', ' ').replace('\t', ' ')
    string = re.sub(r'[ ]+', ' ', string).strip()
    return string

In [5]:
import json
import tensorflow as tf

questions, answers = [], []
for file in files:
    print(file)
    with open(file) as fopen:
        data = json.load(fopen)

    for row in data:
        try:
            if '<>' not in row:
                q, a = row.split('? ')
                q = f'{q}?'
            else:
                q, a = row.split('<>')
            questions.append(q.strip())
            answers.append(a.strip())
        except:
            pass
    
with tf.io.gfile.GFile('qa.tsv', "w") as outfile:
    for i in range(len(questions)):
        outfile.write("%s\t%s\n" % (cleaning(questions[i]), cleaning(answers[i])))

translated-train.json
translated-validation.json


In [7]:
import tensorflow as tf
import tensorflow_datasets as tfds
from t5.data import preprocessors as prep
import functools
import t5
import gin

gin.parse_config_file('pretrained_models_base_operative_config.gin')
vocab = 'sp10m.cased.t5.model'

In [8]:
def question_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.data.TextLineDataset([split])

    ds = ds.map(
        functools.partial(
            tf.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))
    return ds


def question_preprocessor(ds):
    def to_inputs_and_targets(ex):
        return {
            'inputs': tf.strings.join(['soalan: ', ex['question']]),
            'targets': ex['answer'],
        }

    return ds.map(
        to_inputs_and_targets,
        num_parallel_calls = tf.data.experimental.AUTOTUNE,
    )

In [9]:
t5.data.TaskRegistry.remove('question_dataset')
t5.data.TaskRegistry.add(
    'question_dataset',
    dataset_fn = question_dataset,
    splits = ['train'],
    text_preprocessor = [question_preprocessor],
    sentencepiece_model_path = vocab,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

  "get_sentencepiece_model_path is deprecated. Please pass the mixture or "


In [10]:
nq_task = t5.data.TaskRegistry.get("question_dataset")
ds = nq_task.get_dataset(split='qa.tsv', sequence_length={"inputs": 1024, "targets": 1024})

In [11]:
r = tfds.as_numpy(ds)

In [None]:
next(r)