Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unable to train mt5 from t5x using mixtures ValueError: Dataset is missing an expected feature during input_validation validation: 'inputs' #310

Closed
StephennFernandes opened this issue Aug 24, 2022 · 3 comments

Comments

@StephennFernandes
Copy link

Hey there,

I am currently pretraining mt5 model on 23 different languages. but when i create a mixture and set the mixture name in t5x .gin config file for training on the mixture i get the following error.

ValueError: Dataset is missing an expected feature during input_validation validation: 'inputs'

However when i individually ran the independent tasks by setting them in the gin file everything works fine.

the following is how my task.py file looks like.

TaskRegistry = seqio.TaskRegistry
vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)


DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(
        vocabulary=vocabulary, add_eos=True,
        required=False),
    "targets": seqio.Feature(
        vocabulary=vocabulary, add_eos=True)
}



def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_path=None):
    dataset = load_dataset(dataset_path, streaming=True, use_auth_token=True)
    if shuffle:
        if seed:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()
    while True:
        for item in dataset[str(split)]:
            yield item[column]


def dataset_fn(split, shuffle_files, seed=None, dataset_path=None):
    return tf.data.Dataset.from_generator(
        functools.partial(gen_dataset, split, shuffle_files, seed, dataset_path=dataset_path),
        output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_path)
    )


@utils.map_over_dataset
def target_to_key(x, key_map, target_key):
    """Assign the value from the dataset to target_key in key_map"""
    return {**key_map, target_key: x}


TaskRegistry.add(
    "urdu_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='StephennFernandes/ciil_mega_corpus_urdu'),
        splits=("train", "validation"),
        caching_permitted=False),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption, 
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
    metric_fns=[]
)
 ### similar multiple languages are loaded here ### 


#seqio mixture 3.5 
seqio.MixtureRegistry.add(
  "ciil_mix_3.5",
  ["assamese_span_curruption", "bengali_span_curruption", 
  "bhisnupuriya_span_curruption", "bodo_span_curruption", 
  "divehi_span_curruption", "dogri_span_curruption", 
  "english_span_curruption", "gujarati_span_curruption",
  "hindi_span_curruption", "kannada_span_curruption", 
  "kashmiri_span_curruption", "konkani_span_curruption", 
  "maithili_span_curruption", "malayalam_span_curruption",
  "manipuri_span_curruption", "marathi_span_curruption",
  "nepali_span_curruption", "odia_span_curruption",
  "panjabi_span_curruption", "sanskrit_span_curruption",
  "tamil_span_curruption", "telugu_span_curruption",
   "urdu_span_curruption" ],
  default_rate=3.5
)

upon running the mt5 model with the mixture name in the .gin file i get the following error:

File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 744, in _main
    train_using_gin()
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 249, in train
    train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/utils.py", line 1366, in get_dataset
    return get_dataset_inner(cfg, shard_info, feature_converter_cls, seed,
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/utils.py", line 1387, in get_dataset_inner
    ds = seqio.get_dataset(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1681, in get_dataset
    ds = feature_converter(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/feature_converters.py", line 404, in __call__
    ds = self._validate_dataset(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/feature_converters.py", line 294, in _validate_dataset
    raise ValueError("Dataset is missing an expected feature during "
ValueError: Dataset is missing an expected feature during input_validation validation: 'inputs'

@gauravmishra
Copy link
Collaborator

Hi @StephennFernandes , your target_to_key preprocessor doesn't look correct:
functools.partial(
target_to_key, key_map={
"inputs": None,
"targets": None,
}, target_key="targets"),
it basically maps both "inputs" and "targets" fields to nothing, so you end up with no data after this preprocessor runs. See an example of correct usage here in the T5 codebase - https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/tasks.py#L50

@tarudesu
Copy link

tarudesu commented Mar 7, 2024

@StephennFernandes Hi mate, have you done with data sampling for pre-training mT5? I am still stuck in pre-processing data for that (using Flax).

@StephennFernandes
Copy link
Author

StephennFernandes commented Mar 9, 2024

@tarudesu the following is how i create seqio TaskRegistry per language and them further created MixtureRegistry with different mixture rates to test which one works best on my version of mt5 model. you can latter use the name from the mixture registry to call the dataset mixture to pre-train the model.
btw i do use huggingface dataset and not tfds to load the dataset into seqio, the gen_dataset func is what loads this.
the code is 1.5 year old so a word of caution before working it with your version of dependencies, but should work fine mostly.

import functools

import seqio
import tensorflow as tf
import t5.data
from datasets import load_from_disk
from t5.data import postprocessors
from t5.data import preprocessors
from t5.evaluation import metrics
from seqio import FunctionDataSource, utils

TaskRegistry = seqio.TaskRegistry
vocabulary = seqio.SentencePieceVocabulary("/home/stephen/Desktop/t5_test_run/t5x/HI_ALL_VOCAB_32000_UNIGRAM.model", extra_ids=100)


DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(
        vocabulary=t5.data.get_default_vocabulary(), add_eos=True,
        required=False),
    "targets": seqio.Feature(
        vocabulary=t5.data.get_default_vocabulary(), add_eos=True)
}



def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_path=None):
    dataset = load_from_disk(dataset_path)
    if shuffle:
        if seed:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()
    while True:
        for item in dataset[str(split)]:
            yield item[column]


def dataset_fn(split, shuffle_files, seed=None, dataset_path=None):
    return tf.data.Dataset.from_generator(
        functools.partial(gen_dataset, split, shuffle_files, seed, dataset_path=dataset_path),
        output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_path)
    )


@utils.map_over_dataset
def target_to_key(x, key_map, target_key):
    """Assign the value from the dataset to target_key in key_map"""
    return {**key_map, target_key: x}

# link to the mt5 sentencepiece tokenizer vocabulary
vocabulary = seqio.SentencePieceVocabulary("/home/stephen/Desktop/t5_test_run/t5x/HI_ALL_VOCAB_32000_UNIGRAM.model", extra_ids=100)

#assamese_span_curruption
TaskRegistry.add(
    "assamese_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/ASSAMESE'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#bengali_span_curruption
TaskRegistry.add(
    "bengali_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/BENGALI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#bhisnupuriya_span_curruption
TaskRegistry.add(
    "bhisnupuriya_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/BHISNUPURIYA'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#bodo_span_curruption
TaskRegistry.add(
    "bodo_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/BODO'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#divehi_span_curruption
TaskRegistry.add(
    "divehi_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/DIVEHI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#dogri_span_curruption
TaskRegistry.add(
    "dogri_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/DOGRI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#english_span_curruption
TaskRegistry.add(
    "english_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/ENGLISH'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#gujarati_span_curruption
TaskRegistry.add(
    "gujarati_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/GUJARATI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#hindi_span_curruption
TaskRegistry.add(
    "hindi_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/HINDI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#kannada_span_curruption
TaskRegistry.add(
    "kannada_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/KANNADA'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#kashmiri_span_curruption
TaskRegistry.add(
    "kashmiri_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/KASHMIRI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#konkani_span_curruption
TaskRegistry.add(
    "konkani_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/KONKANI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#maithili_span_curruption
TaskRegistry.add(
    "maithili_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/MAITHILI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#malayalam_span_curruption
TaskRegistry.add(
    "malayalam_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/MALAYALAM'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#manipuri_span_curruption
TaskRegistry.add(
    "manipuri_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/MANIPURI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#marathi_span_curruption
TaskRegistry.add(
    "marathi_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/MARATHI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#nepali_span_curruption
TaskRegistry.add(
    "nepali_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/NEPALI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#odia_span_curruption
TaskRegistry.add(
    "odia_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/ODIA'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#panjabi_span_curruption
TaskRegistry.add(
    "panjabi_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/PANJABI'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#sanskrit_span_curruption
TaskRegistry.add(
    "sanskrit_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/SANSKRIT'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#tamil_span_curruption
TaskRegistry.add(
    "tamil_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/TAMIL'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#telugu_span_curruption
TaskRegistry.add(
    "telugu_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/TELUGU'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)
#urdu_span_curruption
TaskRegistry.add(
    "urdu_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='/home/stephen/Desktop/MEGA_CORPUS/COMBINED_CORPUS/URDU'),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=None,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
    metric_fns=[]
)

# defining a mixture of languages. 

#seqio mixture 3.5 
seqio.MixtureRegistry.add(
  "mix_3.5",
  ["assamese_span_curruption", "bengali_span_curruption", 
  "bhisnupuriya_span_curruption", "bodo_span_curruption", 
  "divehi_span_curruption", "dogri_span_curruption", 
  "english_span_curruption", "gujarati_span_curruption",
  "hindi_span_curruption", "kannada_span_curruption", 
  "kashmiri_span_curruption", "konkani_span_curruption", 
  "maithili_span_curruption", "malayalam_span_curruption",
  "manipuri_span_curruption", "marathi_span_curruption",
  "nepali_span_curruption", "odia_span_curruption",
  "panjabi_span_curruption", "sanskrit_span_curruption",
  "tamil_span_curruption", "telugu_span_curruption",
   "urdu_span_curruption" ],
  default_rate=3.5
)

seqio.MixtureRegistry.add(
  "mix_3",
  ["assamese_span_curruption", "bengali_span_curruption", 
  "bhisnupuriya_span_curruption", "bodo_span_curruption", 
  "divehi_span_curruption", "dogri_span_curruption", 
  "english_span_curruption", "gujarati_span_curruption",
  "hindi_span_curruption", "kannada_span_curruption", 
  "kashmiri_span_curruption", "konkani_span_curruption", 
  "maithili_span_curruption", "malayalam_span_curruption",
  "manipuri_span_curruption", "marathi_span_curruption",
  "nepali_span_curruption", "odia_span_curruption",
  "panjabi_span_curruption", "sanskrit_span_curruption",
  "tamil_span_curruption", "telugu_span_curruption",
   "urdu_span_curruption" ],
  default_rate=3
)

seqio.MixtureRegistry.add(
  "mix_2",
  ["assamese_span_curruption", "bengali_span_curruption", 
  "bhisnupuriya_span_curruption", "bodo_span_curruption", 
  "divehi_span_curruption", "dogri_span_curruption", 
  "english_span_curruption", "gujarati_span_curruption",
  "hindi_span_curruption", "kannada_span_curruption", 
  "kashmiri_span_curruption", "konkani_span_curruption", 
  "maithili_span_curruption", "malayalam_span_curruption",
  "manipuri_span_curruption", "marathi_span_curruption",
  "nepali_span_curruption", "odia_span_curruption",
  "panjabi_span_curruption", "sanskrit_span_curruption",
  "tamil_span_curruption", "telugu_span_curruption",
   "urdu_span_curruption" ],
  default_rate=2
)

seqio.MixtureRegistry.add(
  "mix_4",
  ["assamese_span_curruption", "bengali_span_curruption", 
  "bhisnupuriya_span_curruption", "bodo_span_curruption", 
  "divehi_span_curruption", "dogri_span_curruption", 
  "english_span_curruption", "gujarati_span_curruption",
  "hindi_span_curruption", "kannada_span_curruption", 
  "kashmiri_span_curruption", "konkani_span_curruption", 
  "maithili_span_curruption", "malayalam_span_curruption",
  "manipuri_span_curruption", "marathi_span_curruption",
  "nepali_span_curruption", "odia_span_curruption",
  "panjabi_span_curruption", "sanskrit_span_curruption",
  "tamil_span_curruption", "telugu_span_curruption",
   "urdu_span_curruption" ],
  default_rate=4
)```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants