A minimal framework for handling tasks input and output processing. This is
heavily inspired by Google's SeqIO but not
written with tf.data
. For the time being, this
uses HuggingFace's Dataset framework as
the backbone.
To install run:
git clone https://github.com/gabeorlanski/taskio.git
cd taskio
pip install -r requirements.txt
pip install -e .
Each Task
has 4 key elements that make it up:
- A
SPLIT_MAPPING
that maps a split name (e.g.train
,validation
) to some key value. - A
tokenizer
for automatically encoding and decoding the inputs - Two list of callable functions
preprocessors
andpostprocessors
that are for preprocessing and postprocessing respectively. Each callable in these must take in a single dictionary argument. (More advanced things can be done withfunctools.partial
) - A set of
metric_fns
that are a list of callables. Each function must have the signaturepredictions: List[str], targets: List[str]
To create your own task, you must first subclass the Task
class:
from tio import Task
@Task.register('example')
class ExampleTask(Task):
SPLIT_MAPPING = {
"train" : "path to the train file",
"validation": "Path to the validation file"
}
@staticmethod
def map_to_standard_entries(sample: Dict) -> Dict:
sample['input_sequence'] = sample['input']
sample['target'] = sample['output']
return sample
def dataset_load_fn(self, split: str) -> Dataset:
# This is only an example and will not work
return Dataset.load_dataset(self.SPLIT_MAPPING[split])
The first step is to register your task in the Task
registry (Inspired by
AllenNLP's registrable). Then you must set the SPLIT_MAPPING
and override the
two functions:
map_to_standard_entries
: When preprocessing and postprocessing, theTask
class expects there to be two columnsinput_sequence
andtarget
. This function maps the input to those columns.dataset_load_fn
: Function to load the dataset.
To actually use the task and get the dataset use:
from tio import Task
task = Task.get_task(
name='example',
tokenizer=tokenizer,
preprocessors=preprocessors,
postprocessors=postprocessors,
metric_fns=metric_fns
)
tokenized_dataset = task.get_split("train")
...
metrics = task.evaluate(
**task.postprocess_raw_tokens(predictions, tokenized_dataset['labels'])
)
TODO: Make this less clunky