# AirIO Quickstart Example

This notebook demonstrates the creation of a basic `Task` with two
preprocessing steps. It performs the following actions:

1. Load the [IMDB reviews][imdb_reviews] dataset.
2. Map the raw data to a format suitable for training.
3. Tokenize the text using SeqIO's
[`SentencePieceVocabulary`][seqio_vocabularies].

The task's `get_dataset()` method is called, to demonstrate the contents
of each record after all transformation steps.


[imdb_reviews]: https://www.tensorflow.org/datasets/catalog/imdb_reviews
[seqio_vocabularies]: https://github.com/google/seqio/blob/main/seqio/vocabularies.py

In [None]:
import functools
from typing import Dict

from absl import app
from airio import data_sources
from airio import dataset_providers
from airio import tokenizer
import grain.python as grain
from seqio import vocabularies

First we load a SeqIO vocabulary for tokenization.

In [None]:
DEFAULT_SPM_PATH = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model"
DEFAULT_VOCAB = vocabularies.SentencePieceVocabulary(DEFAULT_SPM_PATH)

The raw data contains two fields per record: `{"text", "label"}`. For training, we need alternate field names: `{"inputs", "targets"}`. The values in each record need to be a suitable format, so we first define a function for converting raw data.

This function performs the following:
* Remaps the field names to `{"inputs", "targets"}`
* Adds the prefix `"imdb "` to the raw text
* Maps the raw label from an integer to `{"negative", "positive", "invalid"}`

In [None]:
def _imdb_preprocessor(raw_example: Dict[str, bytes]) -> Dict[str, str]:
    final_example = {"inputs": "imdb " + raw_example["text"].decode("utf-8")}
    raw_label = str(raw_example["label"])
    if raw_label == "0":
      final_example["targets"] = "negative"
    elif raw_label == "1":
      final_example["targets"] = "positive"
    else:
      final_example["targets"] = "invalid"
    return final_example

Next we define a task that uses this function as a preprocessor, followed by tokenization.

In [None]:
task = dataset_providers.Task(
      name="dummy_airio_task",
      source=data_sources.TfdsDataSource(
          tfds_name="imdb_reviews/plain_text:1.0.0", splits=["train"]
      ),
      preprocessors=[
          grain.MapOperation(map_function=_imdb_preprocessor),
          grain.MapOperation(
              functools.partial(
                  tokenizer.tokenize,
                  tokenizer_configs={
                      "inputs": tokenizer.TokenizerConfig(vocab=DEFAULT_VOCAB),
                      "targets": tokenizer.TokenizerConfig(vocab=DEFAULT_VOCAB),
                  },
              )
          ),
      ],
  )

Now we can retrieve an iterator to view the effect of the series of transformations on the full dataset.

In [None]:
ds = task.get_dataset()

In [None]:
count = 0
for element in ds:
  for k, v in element.items():
    print(f"    {k}: {v}")
  print(f"  -------------------------")
  count += 1
  if count >= dataset_providers.DEFAULT_NUM_RECORDS_TO_INSPECT:
    break