# MLS 2 WebDataset

This notebook and repository will walk you through the process of converting the mls dataset from being on disk, into a streamable dataset easily to split across machines for training.

First off, download the dataset from https://openslr.org/94/.

Once you have it downloaded and extracted, this notebook can take over from there!

In [None]:
%pip install webdataset pandas torch torchaudio

In [None]:
import os
import webdataset as wds
import pandas as pd
from torch.utils.data import DataLoader

Now that we have what we will be using to import our data. Let's do that!

In [None]:
# Change me to where your data was extracted stored.
root_path = os.path.expanduser('/media/jstackhouse/spinner/mls_english')
output_dir = os.path.expanduser('~/data/mls_english')

# Create the directories if they don't exist for our output.
os.makedirs(output_dir, exist_ok=True)

def load_split(root: str, split: str) -> pd.DataFrame:
    path = os.path.join(root, f"{split}/transcripts.txt")
    return pd.read_csv(
        path, 
        sep='\t', 
        header=None, 
        names=['path', 'transcript']
    )

In [None]:
# We will create a tar file per split.
splits = ["dev", "train", "test"]
for split in splits:
    df = load_split(root_path, split)
    # The first column is the path to the audio file.
    with wds.ShardWriter(f"{output_dir}/{split}-%04d.tar") as sink:
        for index, row in df.iterrows():
            audio_parts = row['path'].split('_')
            with open(os.path.join(root_path, split, 'audio', audio_parts[0], audio_parts[1], f"{row['path']}.flac"), "rb") as stream:
                audio_bytes = stream.read()
            # The second column is the transcript.
            sink.write({
                "__key__": row['path'],
                "audio.flac": audio_bytes,
                "transcript.txt": row['transcript']
            })

In [None]:
url = output_dir + "/train-{0000..0906}.tar"
dataset = (
    wds.WebDataset(url)
        .shuffle(100)
        .decode(wds.torch_audio)
        .to_tuple("transcript.txt", "audio.flac")
)
dataloader = DataLoader(dataset.batched(4), batch_size=None)
for transcripts, recordings in dataloader:
    print("===========")
    print(transcripts)
    print(recordings)
    print("===========")