# Interleaving Datasets

Original: sanchit-gandhi \
This version: Michael Kamfonas (Farsipal)


In [11]:
# Ensure datasets is installed from main. Uncomment the following line if you face issues running this script:
# !pip install git+https://github.com/huggingface/datasets

In [12]:
from datasets import Audio, interleave_datasets, IterableDataset, load_dataset, SplitDict
from typing import List, Optional
from torch.utils.data import IterableDataset, DataLoader

### Define the dataset attributes

The cell below is the original example including parameters for VoxPopuli and Mulitlingual LibriSpeech. These parameters may become handy if someone wants to inlcude these  datasets, so it
is here commented out for reference.

In [13]:
#  
#dataset_names = ["mozilla-foundation/common_voice_11_0", "facebook/voxpopuli", "facebook/multilingual_librispeech", "google/fleurs"]
#dataset_config_names = ["es", "es", "spanish", "es_419"]
#text_column_names = ["sentence", "normalized_text", "text", "transcription"]

The example I use in this version shows how to combine the Common Voice 11 and FLEURS datasets for Greek (el). The modified version below produces interleaved datasets for both training and testing.

-   The resulting training corpus will be equal to the sum of the individual datasets and we will use both test and validation splits for both. 
-   The test dataset will be equal to the **Common Voice 11 Test split only**. 

The parameters defined in the next cell are explained below. 

-   All parameters are in lists of the same length with one element per dataset.
-   `dataset_names` contains the Hugging Face Hub namea of each of the datasets used
-   `dataset_config_names` contains the respective language codes.  
-   `text_column_names` contains the name used for the text feature (column) for each respective dataset.
-   `train_splits` and `test_splits` contain split names used for the train and test interleaved datasets we will produce. If multiple splits need to be interleaved for any of the datasets, the respective split-names are concatenated into one string separated by the + sign. E.g to merge both train and validation the code should be `"train+validation"`. A special split name `"-"` (dash) can be used to suppress a dataset-split. This for example is the case for the test dataset which is based only on the Common Voice 11 test, and the fleurs test excluded. 

In [14]:
dataset_names = ["mozilla-foundation/common_voice_11_0", "google/fleurs"]
dataset_config_names = ["el", "el_gr"]
text_column_names = ["sentence",  "transcription"]
train_splits =["train+validation","train+validation"]
test_splits  = ["test","-"] # we want the test to come from one dataset only

In [15]:
dataset_name = "mozilla-foundation/common_voice_11_0,google/fleurs"
dataset_config_name = "el,el_gr"
text_column_name = "sentence,transcription"
train_split ="train+validation,train+validation"
test_split  = "test,-" # we want the test to come from one dataset only

dataset_names=dataset_name.split(",")
dataset_config_names=dataset_config_name.split(",")
text_column_names=text_column_name.split(",")
train_spllits = train_split.split(",")
test_spllits = test_split.split(",")



### Define the merging function

We define a function, `load_multiple_streaming_datasets`, that takes as argument a list of datasets, configs, splits (optional) and text column names (optional). It sets them to a specified sampling rate and interleaves them together, giving one merged dataset. This is all 
done in _streaming mode_: as we iterate over the merged dataset we load samples one-by-one on the fly. No data is
saved to disk.

We can also specify our strategy for interleaving datasets. The default strategy, `all_exhausted` is an oversampling 
strategy. In this case, the dataset construction is stopped as soon as every samples in every dataset 
has been added at least once. In practice, it means that if a dataset is exhausted, it will return to the 
beginning of this dataset until the stop criterion has been reached. You can specify `stopping_strategy=first_exhausted` 
for a subsampling strategy, i.e the dataset construction is stopped as soon one of the dataset runs out of samples. 

In [20]:
def load_multiple_streaming_datasets(
    dataset_names: List,
    dataset_config_names: List,
    train_splits: Optional[List] = None,
    test_splits: Optional[List] = None,
    text_column_names: Optional[List] = None,
    sampling_rate: Optional[int] = 16000,
    stopping_strategy: Optional[str] = "all_exhausted",
    **kwargs
) -> IterableDataset:

    if len(dataset_names) != len(dataset_config_names):
        raise ValueError(
            f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
            f" {len(dataset_config_names)} configs."
        )

    if train_splits is not None and len(train_splits) != len(dataset_names):
        raise ValueError(
            f"Ensure one train_split is passed for each dataset, got {len(dataset_names)} datasets and {len(train_splits)} splits."
        )

    if test_splits is not None and len(test_splits) != len(dataset_names):
        raise ValueError(
            f"Ensure one test_split is passed for each dataset, got {len(dataset_names)} datasets and {len(test_splits)} splits."
        )

    if text_column_names is not None and len(text_column_names) != len(dataset_names):
        raise ValueError(
            f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
            f" {len(text_column_names)} text column names."
        )

    train_splits = train_splits if train_splits is not None \
        else ["train" for i in range(len(dataset_names))]

    test_splits = test_splits if test_splits is not None \
        else ["test" for i in range(len(dataset_names))]

    text_column_names = (
        text_column_names if text_column_names is not None \
            else ["text" for i in range(len(dataset_names))]
    )


    all_train_splits = []
    all_test_splits  = []
    # iterate over the datasets we want to interleave
    for dset, cfgNm, trnSplit, tstSplit, colNm in zip(dataset_names,dataset_config_names,\
                                                train_splits,test_splits,text_column_names):

        train_dset_splits = [load_dataset(dset, cfgNm, split=c, streaming=True, **kwargs) \
            for c in trnSplit.split('+') if c != '-']
        test_dset_splits  = [load_dataset(dset, cfgNm, split=c, streaming=True, **kwargs) \
            for c in tstSplit.split('+') if c != '-']

        train_dset_splits = [ds.cast_column("audio", Audio(sampling_rate)) \
            for ds in train_dset_splits]
        test_dset_splits  = [ds.cast_column("audio", Audio(sampling_rate)) \
            for ds in test_dset_splits]

        train_dset_splits = [ds.rename_column(colNm, "text") for ds in train_dset_splits]
        test_dset_splits  = [ds.rename_column(colNm, "text") for ds in test_dset_splits]

        cols2keep = set(["audio", "text"])

        train_dset_splits = [ds.remove_columns(set(ds.features.keys()) - cols2keep) for ds in train_dset_splits]
        test_dset_splits  = [ds.remove_columns(set(ds.features.keys()) - cols2keep) for ds in test_dset_splits]

        all_train_splits +=   train_dset_splits
        all_test_splits  +=   test_dset_splits
        
    interleaved_train_dataset = interleave_datasets(all_train_splits, stopping_strategy=stopping_strategy)
    interleaved_test_dataset = interleave_datasets(all_test_splits, stopping_strategy=stopping_strategy)

    return interleaved_train_dataset, interleaved_test_dataset

Let's apply this function to load and merge our train and test the dataset dataset splits:

In [21]:
train_ds, test_ds = load_multiple_streaming_datasets(dataset_names, 
        dataset_config_names=dataset_config_names, 
        train_splits = train_splits,
        test_splits = test_splits,
        text_column_names=text_column_names, 
        use_auth_token=True)

In [None]:
print('train',train_ds.dataset_size)
print('test',train_ds.dataset_size)


train None
test None


### Iterate over the dataset

We iterate over the dataset, loading and merging samples on the fly. Let's print the transcriptions for the first 10 samples of our merged dataset:

In [23]:
for i, sample in enumerate(train_ds):
    print(i, sample["text"])
    if i == 20:
        break

Reading metadata...: 1914it [00:00, 5805.35it/s]


0 πρόσταξε το Βασιλόπουλο.


Reading metadata...: 1701it [00:00, 6077.88it/s]


1 Άλογο; Ο Τζοτζές έχει άλογο;
2 στη βάση του βουνού αναφέρθηκε η παρουσία σκοτεινών συννέφων που δεν σχετίζονταν με ηφαιστιακή δραστηριότητα
3 αυτά αποτελούν πλέον ξεχωριστές αρχές  οι οποίες εστιάζουν στην παροχή λύσεων σε πραγματικά προβλήματα της καθημερινότητας
4 θυμήθηκα, σαν ήλθε η ώρα, τα λόγια μου.
5 ρώτησε το πεύκο σκανδαλισμένο.
6 συχνά είναι πιο αυτόνομοι από τα συμβατικά μέλη της ομάδας καθώς οι ομάδες τους μπορεί να συναντώνται σύμφωνα με διαφορετικές χρονικές ζώνες κάτι που μπορεί να μην είναι αποδεκτό από την τοπική τους διοίκηση
7 το nhk δήλωσε επίσης ότι ο πυρηνικός σταθμός ηλεκτρικής ενέργειας κασιβαζάκι-καρίβα στον νομό νιιγκάτα λειτουργούσε κανονικά
8 Ύστερα είπε
9 Ώστε ο ήλιος είχε βασιλέψει
10 υπάρχουν άπειρες πιθανές παραλλαγές αλλά αυτό συνεχίζουν να εννοούν οι περισσότεροι άνθρωποι όταν λένε «πάω στο ντίσνεϊ γουορλντ»
11 η έρευνα στην τεχνητή νοημοσύνη περιλαμβάνει την κατασκευή μηχανών με σκοπό την αυτοματοποίηση εργασιών που απαιτούν νοήμονα συμπεριφορά
12 Τ

The following code cell is lifted from the Whisper training notebook: https://github.com/huggingface/community-events/blob/main/whisper-fine-tuning-event/fine-tune-whisper-streaming.ipynb

In [None]:
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

do_lower_case = False
do_remove_punctuation = False

normalizer = BasicTextNormalizer()

Now we define a function to normalise our transcriptions:

In [None]:
def normalize_transcriptions(batch):
    # optional pre-processing steps
    transcription = batch["text"]
    if do_lower_case:
        transcription = transcription.lower()
    if do_remove_punctuation:
        transcription = normalizer(transcription).strip()
    batch["text"] = transcription
    return batch

Let's apply the data pre-processing steps to our dataset and view the first 10 samples again:

In [24]:
ds = train_ds.map(normalize_transcriptions)

for i, sample in enumerate(ds):
    print(i, sample["text"])
    if i == 9:
        break

Reading metadata...: 1914it [00:00, 8234.86it/s]


0 πρόσταξε το Βασιλόπουλο.


Reading metadata...: 1701it [00:00, 5668.92it/s]


1 Άλογο; Ο Τζοτζές έχει άλογο;
2 στη βάση του βουνού αναφέρθηκε η παρουσία σκοτεινών συννέφων που δεν σχετίζονταν με ηφαιστιακή δραστηριότητα
3 αυτά αποτελούν πλέον ξεχωριστές αρχές  οι οποίες εστιάζουν στην παροχή λύσεων σε πραγματικά προβλήματα της καθημερινότητας
4 θυμήθηκα, σαν ήλθε η ώρα, τα λόγια μου.
5 ρώτησε το πεύκο σκανδαλισμένο.
6 συχνά είναι πιο αυτόνομοι από τα συμβατικά μέλη της ομάδας καθώς οι ομάδες τους μπορεί να συναντώνται σύμφωνα με διαφορετικές χρονικές ζώνες κάτι που μπορεί να μην είναι αποδεκτό από την τοπική τους διοίκηση
7 το nhk δήλωσε επίσης ότι ο πυρηνικός σταθμός ηλεκτρικής ενέργειας κασιβαζάκι-καρίβα στον νομό νιιγκάτα λειτουργούσε κανονικά
8 Ύστερα είπε
9 Ώστε ο ήλιος είχε βασιλέψει


This time the transcriptions are in a consistent format. We can use this data to fine-tune our Whisper model. Note that since we've removed punctuation and casing, the Whisper model won't learn to predict these features.