In [1]:
import mlx.data as dx
import numpy as np
from pfl.data.dataset import Dataset
from mlx.data.datasets import librispeech
from mlx.data.datasets.librispeech import load_librispeech_tarfile
from mlx.data.core import CharTrie
from operator import itemgetter
from itertools import groupby
from tqdm import tqdm
import operator
from collections import Counter
from pfl.data.federated_dataset import FederatedDataset
from pfl.data.pytorch import PyTorchTensorDataset
from pfl.data.sampling import MinimizeReuseUserSampler
from pfl.model.pytorch import PyTorchModel
import torch

In [2]:
class AttributeDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

args = AttributeDict({})
args.batch_size = 2
args.batch_policy = 'dynamic'
args.max_target_len = 100 #400
args.threads = 10
args.input_key = 'file'
args.target_key = 'transcription'
args.target_nopad = False
args.target_nosil = False
args.tar = True
args.federated_cohort_size = 10
args.local_num_epochs = 1
random_seed = 123
num_local_devices = 1

In [3]:
def construct_eng_char_trie_for_ctc(additional_chars):
    trie = CharTrie()
    trie.insert("@")  # blank
    trie.insert(" ")
    trie.insert("'")
    for c in range(ord("a"), ord("z") + 1):
        trie.insert(chr(c))
    if additional_chars:
        for c in additional_chars:
            trie.insert(c)
    return trie


def load_librispeech_slices(root=None, split="dev-clean", quiet=False, validate_download=True, trie=None, args=None):
    """Load the librispeech dataset directly from the TAR archive.

    Args:
        root (Path or str, optional): The The directory to load/save the data. If
            none is given the ``~/.cache/mlx.data/librispeech`` is used.
        split (str): The split to use. It should be one of dev-clean,
            dev-other, test-clean, test-other, train-clean-100,
            train-clean-360, train-other-500 .
        quiet (bool): If true do not show download (and possibly decompression)
            progress.
        args: Additional arguments passed via flags.
    """
    def _to_audio_and_transcript2(sample):
        # Split the line
        file_part, transcript = bytes(sample["sample"]).split(b" ", 1)
    
        # Extract the audio path
        parts = file_part.split(b"-")
        parts[-1] = file_part + b".flac"
        audio_path = b"/".join(parts)
    
        # Prepare the transcript
        transcript = transcript.lower()
    
        # User id
        user_id = int(parts[-3])
    
        return {"audio_file": audio_path, "transcript": transcript, "user_id": user_id}

    
    target = load_librispeech_tarfile(
        root=root, split=split, quiet=quiet, validate_download=validate_download
    )
    target = str(target)
    prefix = f"LibriSpeech/{split}"

    start = time.time()
    
    dset = (
        dx.files_from_tar(target)
        .to_stream()
        .sample_transform(lambda s: s if bytes(s["file"]).endswith(b".txt") else dict())
        .read_from_tar(target, "file", "samples", )
        .line_reader_from_key("samples", "sample", from_memory=True)
        .sample_transform(_to_audio_and_transcript2)
        .prefetch(args.threads, args.threads)
        .to_buffer()
        .read_from_tar(target, "audio_file", "audio", prefix=prefix)
        .load_audio("audio", from_memory=True, output_key="input")
        .pad_to_multiple("input", 0, 16000, 0, "input")
        .shape("input", "input_length", 0)
        .tokenize("transcript", trie, ignore_unk=True, output_key="target")
        .shape("target", "target_length", 0)
        .pad_to_size("target", 0, args.max_target_len, 0)
    )

    # sort the dataset using user ids and split into batches where each user is a batch
    perm = np.argsort([int(x['user_id']) for x in dset])
    dset = dset.perm(perm)
    user_ids = [int(x['user_id']) for x in dset]
    # print('user_ids:', user_ids)
    unique_user_ids = np.unique(user_ids)
    counter = Counter(user_ids)
    batch_sizes = [counter[int(user_id)] for user_id in unique_user_ids]            
    dset = dset.batch(batch_sizes)
    
    end = time.time()
    print('preprocessing time:', end - start)
    start = time.time()

    slices = {}
    for item in dset:
        user_id = item['user_id'][0]
        slices[user_id] = PyTorchTensorDataset(
            [torch.Tensor(item['input']), 
             torch.Tensor(item['target']), 
             torch.Tensor(item['input_length']), 
             torch.Tensor(item['target_length'])],
            user_id = user_id)

    end = time.time()
    print('slice processing time:', end - start)
    
    return slices

In [4]:
import time
print('args:', args)

start = time.time()

trie = construct_eng_char_trie_for_ctc('')
dset=load_librispeech_slices(split='dev-clean', trie=trie, args=args)

end = time.time()

print('total time:', end - start)

args: {'batch_size': 2, 'batch_policy': 'dynamic', 'max_target_len': 100, 'threads': 10, 'input_key': 'file', 'target_key': 'transcription', 'target_nopad': False, 'target_nosil': False, 'tar': True, 'federated_cohort_size': 10, 'local_num_epochs': 1}
preprocessing time: 6.562301158905029
slice processing time: 3.944700002670288
total time: 10.510547876358032


In [5]:
print(list(dset.keys()))
print(dset[1919])

[84, 174, 251, 422, 652, 777, 1272, 1462, 1673, 1919, 1988, 1993, 2035, 2078, 2086, 2277, 2412, 2428, 2803, 2902, 3000, 3081, 3170, 3536, 3576, 3752, 3853, 5338, 5536, 5694, 5895, 6241, 6295, 6313, 6319, 6345, 7850, 7976, 8297, 8842]
<pfl.data.pytorch.PyTorchTensorDataset object at 0x177ec02e0>


In [6]:
# iterate through the batches for the user id 1919
for batch in dset[1919].iter(5):
    print('batch tensor shapes:', [x.shape for x in batch])

batch tensor shapes: [torch.Size([5, 432000, 1]), torch.Size([5, 384]), torch.Size([5]), torch.Size([5])]
batch tensor shapes: [torch.Size([5, 432000, 1]), torch.Size([5, 384]), torch.Size([5]), torch.Size([5])]
batch tensor shapes: [torch.Size([5, 432000, 1]), torch.Size([5, 384]), torch.Size([5]), torch.Size([5])]
batch tensor shapes: [torch.Size([5, 432000, 1]), torch.Size([5, 384]), torch.Size([5]), torch.Size([5])]
batch tensor shapes: [torch.Size([5, 432000, 1]), torch.Size([5, 384]), torch.Size([5]), torch.Size([5])]
batch tensor shapes: [torch.Size([5, 432000, 1]), torch.Size([5, 384]), torch.Size([5]), torch.Size([5])]
batch tensor shapes: [torch.Size([5, 432000, 1]), torch.Size([5, 384]), torch.Size([5]), torch.Size([5])]
batch tensor shapes: [torch.Size([5, 432000, 1]), torch.Size([5, 384]), torch.Size([5]), torch.Size([5])]
batch tensor shapes: [torch.Size([5, 432000, 1]), torch.Size([5, 384]), torch.Size([5]), torch.Size([5])]
batch tensor shapes: [torch.Size([5, 432000, 1

In [7]:
# create the federated dataset
federated_dataset = FederatedDataset.from_slices(dset, user_sampler=MinimizeReuseUserSampler(list(dset.keys())))

In [8]:
# iterate through one cohort
for client_dataset, _ in federated_dataset.get_cohort(5):
    print([x.shape for x in client_dataset.raw_data.raw_data])

[torch.Size([65, 320000, 1]), torch.Size([65, 294]), torch.Size([65]), torch.Size([65])]
[torch.Size([59, 448000, 1]), torch.Size([59, 307]), torch.Size([59]), torch.Size([59])]
[torch.Size([75, 400000, 1]), torch.Size([75, 345]), torch.Size([75]), torch.Size([75])]
[torch.Size([36, 528000, 1]), torch.Size([36, 505]), torch.Size([36]), torch.Size([36])]
[torch.Size([71, 304000, 1]), torch.Size([71, 319]), torch.Size([71]), torch.Size([71])]


In [9]:
# iterate through another cohort
for client_dataset, _ in federated_dataset.get_cohort(5):
    print([x.shape for x in client_dataset.raw_data.raw_data])

[torch.Size([82, 304000, 1]), torch.Size([82, 313]), torch.Size([82]), torch.Size([82])]
[torch.Size([73, 480000, 1]), torch.Size([73, 362]), torch.Size([73]), torch.Size([73])]
[torch.Size([94, 272000, 1]), torch.Size([94, 251]), torch.Size([94]), torch.Size([94])]
[torch.Size([42, 320000, 1]), torch.Size([42, 321]), torch.Size([42]), torch.Size([42])]
[torch.Size([64, 432000, 1]), torch.Size([64, 384]), torch.Size([64]), torch.Size([64])]
