In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| default_exp datasets

# datasets
> Routines for loading/handling datasets

Many of these routines are dupes or mods from "audio-diffusion" repo by Zach Evans w/ contributions by Scott Hawley https://github.com/zqevans/audio-diffusion/blob/main/diffusion/utils.py

In [None]:
#|hide
from nbdev.showdoc import *

In [None]:
#|export
from __future__ import annotations  # for type hints, in LAION code samples
import numpy as np 
import torch
import torch.nn as nn
import torchaudio
from torchaudio import transforms as T
from torchvision import transforms as VT
import random
import os
import json
import tqdm
from multiprocessing import Pool, cpu_count
from functools import partial
from aeiou.core import load_audio, get_audio_filenames, is_silence
from fastcore.utils import *
import webdataset as wds
import subprocess

## Augmentation routines

Not all of these are used.  Code copied from https://github.com/zqevans/audio-diffusion/blob/main/diffusion/utils.py

In [None]:
#|export
class PadCrop(nn.Module):
    def __init__(self, 
        n_samples,           # length of chunk to extract from longer signal
        randomize=True,      # draw cropped chunk from a random position in audio file
        redraw_silence=True, # a chunk containing silence will be replaced with a new one
        silence_thresh=-60,  # threshold in dB below which we declare to be silence
        max_redraws=2        # when redrawing silences, don't do it more than this many
        ):
        super().__init__()
        store_attr()     # sets self.___ vars automatically
    
    def draw_chunk(self, signal):
        "here's the part that actually draws a cropped/padded chunk of audio from signal"
        n, s = signal.shape
        start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
        end = start + self.n_samples
        chunk = signal.new_zeros([n, self.n_samples])
        chunk[:, :min(s, self.n_samples)] = signal[:, start:end]
        return chunk
    
    def __call__(self, signal):
        "when part of the pipline, this will grab a padded/cropped chunk from signal"
        chunk = self.draw_chunk(signal)
        num_redraws = 0
        while self.redraw_silence and is_silence(chunk, thresh=self.silence_thresh) and (num_redraws < self.max_redraws):
            #print(f"    PadCrop: Got silence.  Redrawing. Try {num_redraws+1} of {self.max_redraws}")
            chunk, num_redraws = self.draw_chunk(signal), num_redraws+1
        return chunk

In [None]:
#|export    
class PhaseFlipper(nn.Module):
    "she was PHAAAAAAA-AAAASE FLIPPER, a random invert yeah"
    def __init__(self, 
        p=0.5  # probability that phase flip will be applied
        ):
        super().__init__()
        self.p = p
    def __call__(self, signal):
        return -signal if (random.random() < self.p) else signal

In [None]:
#|export  
class FillTheNoise(nn.Module):
    "randomly adds a bit of noise, just to spice things up"
    def __init__(self, 
        p=0.33       # probability that noise will be added
        ):
        super().__init__()
        self.p = p
    def __call__(self, signal):
        return signal + 0.25*random.random()*(2*torch.rand_like(signal)-1) if (random.random() < self.p) else signal

In [None]:
#|export    
class RandPool(nn.Module):
    def __init__(self, p=0.2):
        self.p, self.maxkern = p, 100
    def __call__(self, signal):
        if (random.random() < self.p):
            ksize = int(random.random()*self.maxkern)
            avger = nn.AvgPool1d(kernel_size=ksize, stride=1, padding=1)
            return avger(signal)
        else:
            return signal

In [None]:
#|export
class NormInputs(nn.Module):
    "Normalize inputs to [-1,1]. Useful for quiet inputs"
    def __init__(self, 
        do_norm=True    # controllable parameter for turning normalization on/off
        ):
        super().__init__()
        self.do_norm = do_norm
        self.eps = 1e-2
    def __call__(self, signal):
        return signal if (not self.do_norm) else signal/(torch.amax(signal,-1)[0] + self.eps)

In [None]:
#|export    
class Mono(nn.Module):
    "convert audio to mono"
    def __call__(self, signal):
        return torch.mean(signal, dim=0) if len(signal.shape) > 1 else signal

In [None]:
#|export
class Stereo(nn.Module):
    "convert audio to stereo"
    def __call__(self, signal):
        signal_shape = signal.shape
        # Check if it's mono
        if len(signal.shape) == 1: # s -> 2, s
            signal = signal.unsqueeze(0).repeat(2, 1)
        elif len(signal_shape) == 2:
            if signal.shape[0] == 1: #1, s -> 2, s
                signal = signal.repeat(2, 1)
            elif signal.shape[0] > 2: #?, s -> 2,s
                signal = signal[:2, :]    
        return signal

In [None]:
#|export    
class RandomGain(nn.Module):
    "apply a random gain to audio"
    def __init__(self, min_gain, max_gain):
        super().__init__()
        self.min_gain = min_gain
        self.max_gain = max_gain

    def __call__(self, signal):
        gain = random.uniform(self.min_gain, self.max_gain)
        signal = signal * gain
        return signal

## WebDataset support


### Background Info
cf. https://github.com/webdataset/webdataset

> WebDataset makes it easy to write I/O pipelines for large datasets. Datasets can be stored locally or in the cloud.

They use the word "shards" but never define what "shard" means.  I (S.H.) surmise they mean the groups of data files which are gathered into a series of `.tar` files -- the `.tar` files are the shards? 

cf. Video Tutorial: ["Loading Training Data with WebDataset"](https://www.youtube.com/watch?v=mTv_ePYeBhs).

The recommended usage for AWS S3 can be seen in [this GitHub Issue comment by tmbdev](https://github.com/webdataset/webdataset/issues/21#issuecomment-706008342):

```Python
url = "pipe:s3cmd get s3://bucket/dataset-{000000..000999}.tar -"
dataset = wds.Dataset(url)...
```
That URL is expecting a contiguously-numbered range of .tar files. So if the file numbers are contiguous (no gaps), then we'll have an easy time. 

### General utility: `get_s3_contents()`

In [None]:
#|export
def get_s3_contents(dataset_path, s3_url_prefix='s3://s-laion-audio/webdataset_tar', filter=''):
    "Gets a list of names of files or subdirectories on an s3 path"
    run_ls = subprocess.run(['aws','s3','ls',f'{s3_url_prefix}/{dataset_path}/'], capture_output=True)
    result = subprocess.run(['awk','{print $NF}'],input=run_ls.stdout, capture_output=True)
    contents = result.stdout.decode('utf-8').strip().replace('/','').split('\n')
    contents = [x for x in contents if x] # list of non-empty strings
    return [x for x in contents if filter in x] # return filtered list

Let's test that on the FSD50K dataset:

In [None]:
#| eval: false
get_s3_contents('FSD50K')

['test', 'train', 'valid']

In [None]:
#| eval: false
get_s3_contents('FSD50K/test')

['0.tar',
 '1.tar',
 '10.tar',
 '11.tar',
 '12.tar',
 '13.tar',
 '14.tar',
 '15.tar',
 '16.tar',
 '17.tar',
 '18.tar',
 '19.tar',
 '2.tar',
 '3.tar',
 '4.tar',
 '5.tar',
 '6.tar',
 '7.tar',
 '8.tar',
 '9.tar',
 'sizes.json']

And let's try filtering for only tar files: 

In [None]:
#| eval: false
tar_names = get_s3_contents('FSD50K/test', filter='tar')
tar_names

['0.tar',
 '1.tar',
 '10.tar',
 '11.tar',
 '12.tar',
 '13.tar',
 '14.tar',
 '15.tar',
 '16.tar',
 '17.tar',
 '18.tar',
 '19.tar',
 '2.tar',
 '3.tar',
 '4.tar',
 '5.tar',
 '6.tar',
 '7.tar',
 '8.tar',
 '9.tar']

### For contiguous file-number lists...

Maybe the range of tar numbers is contigous. If so, let's have something to output that range:

In [None]:
#|export 
def get_contiguous_range(
    tar_names, # list of tar file names, although the .tar part is actually optional
    ):
    "given a string of tar file names, return a string of their range if the numbers are contiguous. Otherwise return empty string"
    if len(tar_names) == 0:  return ''
    elif len(tar_names) == 1: return tar_names[-1]
    just_nums = [x.replace('.tar','') for x in tar_names]
    just_nums.sort(key=int) # sorts numerically but meaningfully preserves leading zeros in strings
    nums_arr = np.asarray(just_nums,  dtype=int)
    is_contiguous =  np.abs( (nums_arr - np.roll(nums_arr,1)) [1:] ).max() == 1
    if is_contiguous:   # {000000..000999}
        return '{' + f'{just_nums[0]}..{just_nums[-1]}' +'}'
    else:
        print("get_contiguous_range: File numbers not continuous")  # have to do more work
        return '' # empty string will signify no dice; signal for more work to be done

In [None]:
#| eval: false
cont_range = get_contiguous_range(tar_names)
cont_range

'{0..19}'

Test if leading zeros are preserved:

In [None]:
#| eval: false
get_contiguous_range(['0000'+x for x in tar_names])  

'{00000..000019}'

Test zero-element and single element versions:

In [None]:
print(get_contiguous_range([]))
print(get_contiguous_range([1]))


1


And show that '.tar' is optional:

In [None]:
get_contiguous_range(['01','02','3']) 

'{01..3}'

....So, if a contiguous range of tar file names is available in a WebDataset directory, then we can just use the native WebDataset creation utilities and can ignore all the other %$#*& that's about to follow below. 

Let's test the simple version first:

In [None]:
#| eval: false
s3_url_prefix='s3://s-laion-audio/webdataset_tar/'
url = f"pipe:aws s3 cp {s3_url_prefix}FSD50K/test/{cont_range}.tar -"  # 'aws get' is not a thing. 'aws cp' is
print(url)
dataset = wds.WebDataset(url)

pipe:aws s3 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/{0..19}.tar -


Hooray, it didn't crash! 

Try dataloader-ing that:

In [None]:
#| eval: false

## NOTE TO SELF: DON'T RUN THIS ON STABILITY CLUSTER HEADNODE
if 'this next part fails' == 'darn it':
    loader = wds.WebLoader(dataset, num_workers=4, batch_size=8)
    #loader = loader.batched(12)
    batch = next(iter(loader))
    batch[0].shape, batch[1].shape

### Non-contiguously-numbered lists of tar files...
Because you could do a test-train-val split by moving the tar files around.
this is what all the extra code is for.

A lot of the code predating this was written by LAION who require that the `.json` file(s) for the webdataset(s) be downloaded first. So, let's write a utility for that: 

In [None]:
#|export
def download_webdataset_json(
    datasetnames,              # list of names of valid AudioDataset datasets / paths
    dataset_split={},          # keys are dataset names, values are lists of subdirs
    src_prefix='s3://s-laion-audio/webdataset_tar', # parent location where the dataset lives
    dst_prefix='./json_files', # local path to save the json
    force=False,            # Force new download even if local copy exists
    ):
    "Downloads the json info of webdataset (sub-)file sizes"
    for dataset_name in datasetnames:
        splits = dataset_split if dataset_split!={} else get_s3_contents(dataset_name)
        for split in splits:
            if not os.path.exists(f"./json_files/{dataset_name}/{split}"): # make sure local dir to hold json exists
                os.makedirs(f"./json_files/{dataset_name}/{split}")
            dst = f"{dst_prefix}/{dataset_name}/{split}/sizes.json"
            if force or not os.path.exists(dst):
                os.system(        # TODO: replace os.system with subprocess.run
                    f"aws s3 cp {src_prefix}/{dataset_name}/{split}/sizes.json {dst_prefix}/{dataset_name}/{split}/sizes.json"
                )
            #else: print("Already got it")

test get_webdataset_json:

In [None]:
#| eval: false

from types import SimpleNamespace
args = SimpleNamespace(remotedata=True, datasetnames=['FSD50K'],
                       dataset_type="webdataset",
                       dataset_proportion=1, datasetpath='IDK')
download_webdataset_json(args.datasetnames, force=True)

download: s3://s-laion-audio/webdataset_tar/FSD50K/test/sizes.json to json_files/FSD50K/test/sizes.json
download: s3://s-laion-audio/webdataset_tar/FSD50K/train/sizes.json to json_files/FSD50K/train/sizes.json
download: s3://s-laion-audio/webdataset_tar/FSD50K/valid/sizes.json to json_files/FSD50K/valid/sizes.json


For non-contiguous files, we need a list of urls to every single tar file individually.  That's what this next code from LAION's CLAP repo does:

In [None]:
def get_tar_path_s3(base_s3_path:str, 
    train_valid_test:list[str], 
    dataset_names:list[str]=[''], 
    cache_path:str='', 
    recache:bool=False,
    ):
    "Code from LAOIN CLAP may not keep. This spits out a list of aws cli calls to download every tar file"
    if os.path.isfile(cache_path) and not recache:
        with open(cache_path) as f:
            print("Loading Cache")
            return json.load(f)

    # create cmd for collecting url spesific dataset, 
    # if `dataset_names` is not given it will search the full base_s3_path
    cmds = [f'aws s3 ls s3://{os.path.join(base_s3_path, name, "")} --recursive | grep /.*.tar' for name in dataset_names]
    # urls are collected
    urls = [os.popen(cmd).read() for cmd in cmds]
    # cleaning the urls to conform with webdataset
    final_urls = [i.split(' ')[-1] for url in urls for i in url.split('\n')]
    final_urls = [f'pipe:aws s3 --cli-connect-timeout 0 cp s3://{os.path.join(base_s3_path, *i.split("/")[1:])} -' for i in final_urls]
    # Spliting url by state e.g. train, test and valud
    final_urls = {state:[url for url in final_urls if state in url] for state in train_valid_test}

    if cache_path:
        with open(cache_path, 'w') as f:
            json.dump(final_urls, f)

    return final_urls

Let's grab every tar file in the entire FSD50K dataset:

In [None]:
#| eval: false
urls = get_tar_path_s3('s-laion-audio/webdataset_tar',['test', 'valid'], dataset_names=['FSD50K'])
print("urls =",urls)

urls = {'test': ['pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/0.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/1.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/10.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/11.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/12.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/13.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/14.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/15.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/16.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/17.tar -', 'pipe:aws s3 --

Another version that acheives the same effect: 

In [None]:
def get_tar_path_from_dataset_name(
    dataset_names, dataset_types, islocal,  dataset_path, proportion=1, ):
    """
    From LAOIN
    Get tar path from dataset name and type
    """
    if islocal:
        output = []
        for n in dataset_names:
            for s in dataset_types:
                tmp = []
                sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" #  TODO:!!!
                if not os.path.exists(sizefilepath_):
                    continue
                sizes = json.load(open(sizefilepath_, "r"))
                for k in sizes.keys():
                    tmp.append(
                        f"{dataset_path}/{n}/{s}/{k}"
                    )
                if proportion!=1:
                    tmp = random.sample(tmp, int(proportion * len(tmp)))
                output.append(tmp)
        return sum(output, [])
    else:

        output = []
        for n in dataset_names:
            for s in dataset_types:
                tmp = []
                sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
                if not os.path.exists(sizefilepath_):
                    continue
                sizes = json.load(open(sizefilepath_, "r"))
                for k in sizes.keys():
                    tmp.append(
                        f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
                    )
                    # TODO: add dataset_path to remote dataset in the future.
                if proportion!=1:
                    tmp = random.sample(tmp, int(proportion * len(tmp)))
                output.append(tmp)
                print("output= ",output)
        return sum(output, [])


Test ^that:

In [None]:
#| eval: false

train_data_tar_path = get_tar_path_from_dataset_name(
    ['FSD50K'],
    ['test','valid'],
    islocal=False,
    proportion=1.0,
    dataset_path='/fsx/shawley/data/webdataset',
)
train_data_tar_path

output=  [['pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/0.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/1.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/2.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/3.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/4.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/5.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/6.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/7.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/8.tar -', 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/9.tar -', 'pipe:aws s3 --cli-connect-ti

['pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/0.tar -',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/1.tar -',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/2.tar -',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/3.tar -',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/4.tar -',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/5.tar -',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/6.tar -',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/7.tar -',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/8.tar -',
 'pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/FSD50K/test/9.tar -',
 'pipe:aws s3 --cli-connect-ti

And now a massive data-pipelining example from LAION that will definitely get modified for this repo: 

In [None]:
#|export

# taken from LAION CLAP repo, https://github.com/LAION-AI/CLAP/blob/d2d5dae8ea8f1ee02ac40242418a36d1d567943a/src/training/data.py

from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.utils.data.distributed import DistributedSampler
from dataclasses import dataclass

@dataclass
class DataInfo:
    dataloader: DataLoader
    sampler: DistributedSampler
        

def get_wds_dataset(
    args,
    model_cfg,
    is_train,
    audio_ext="flac",
    text_ext="json",
    max_len=480000,
    proportion=1.0,
    sizefilepath_=None,
    is_local=None,
):
    """
    Get a dataset for wdsdataloader.
    """
    if is_local is None and (not args.remotedata is None):
        is_local = not args.remotedata

    input_shards = args.train_data if is_train else args.val_data
    assert input_shards is not None

    if not sizefilepath_ is None:
        sizefilepath = sizefilepath_
    else:
        sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")

    if proportion != 1.0:
        num_samples, num_shards, input_shards, _ = sample_prop(
            sizefilepath, input_shards, proportion, is_local=is_local
        )
    else:
        num_samples, num_shards = get_dataset_size(
            input_shards, sizefilepath_=sizefilepath_, is_local=is_local
        )

    if not num_samples:
        if is_train:
            num_samples = args.train_num_samples
            if not num_samples:
                raise RuntimeError(
                    "Currently, number of dataset samples must be specified for training dataset. "
                    "Please specify via `--train-num-samples` if no dataset length info present."
                )
        else:
            num_samples = (
                args.val_num_samples or 0
            )  # eval will just exhaust the iterator if not specified

    pipeline = [wds.SimpleShardList(input_shards)]    # re. Pipeline: cf https://github.com/webdataset/webdataset#pipeline-interface
    # at this point we have an iterator over all the shards
    if is_train or args.parallel_eval:
        pipeline.extend(
            [
                wds.detshuffle(
                    bufsize=_SHARD_SHUFFLE_SIZE,
                    initial=_SHARD_SHUFFLE_INITIAL,
                    seed=args.seed,
                ),
                wds.split_by_node,
                wds.split_by_worker,
                # at this point, we have an iterator over the shards assigned to each worker at each node
                wds.tarfile_to_samples(handler=log_and_continue),
                wds.shuffle(
                    bufsize=_SAMPLE_SHUFFLE_SIZE,
                    initial=_SAMPLE_SHUFFLE_INITIAL,
                    rng=random.Random(args.seed),
                ),
                # wds.repeatedly,  # FIXME determine if this is beneficial
            ]
        )
    else:
        pipeline.extend(
            [
                wds.split_by_worker,
                # at this point, we have an iterator over the shards assigned to each worker
                wds.tarfile_to_samples(handler=log_and_continue),
            ]
        )
    pipeline.append(
        wds.map(
            partial(
                preprocess,
                audio_ext=audio_ext,
                text_ext=text_ext,
                max_len=max_len,
                class_index_dict=copy.deepcopy(args.class_index_dict),
                data_filling=args.data_filling,
            )
        ),
    )

    pipeline.append(
        wds.batched(
            args.batch_size,
            partial=not (is_train or args.parallel_eval),
            collation_fn=collate_fn,
        )
    )

    dataset = wds.DataPipeline(*pipeline) # Instantiate list as Pipeline
    
    if is_train or args.parallel_eval:
        # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
        # (yusong): See comments below.
        # roll over and repeat a few samples to get same number of full batches on each node
        global_batch_size = args.batch_size * args.world_size
        num_batches = math.ceil(num_samples / global_batch_size)
        num_workers = max(1, args.workers)
        num_worker_batches = math.ceil(
            num_batches / num_workers
        )  # per dataloader worker
        num_batches = num_worker_batches * num_workers
        num_samples = num_batches * global_batch_size
        dataset = dataset.with_epoch(
            num_worker_batches
        )  # each worker is iterating over this
    else:
        # last batches are partial, eval is done on single (master) node
        num_batches = math.ceil(num_samples / args.batch_size)

    kwargs = {}
    if args.horovod:  # multi-node training on summit
        kwargs["multiprocessing_context"] = "forkserver"

    dataloader = wds.WebLoader(
        dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs
    )

    # FIXME not clear which approach is better, with_epoch before vs after dataloader?
    # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
    # if is_train:
    #     # roll over and repeat a few samples to get same number of full batches on each node
    #     global_batch_size = args.batch_size * args.world_size
    #     num_batches = math.ceil(num_samples / global_batch_size)
    #     num_workers = max(1, args.workers)
    #     num_batches = math.ceil(num_batches / num_workers) * num_workers
    #     num_samples = num_batches * global_batch_size
    #     dataloader = dataloader.with_epoch(num_batches)
    # else:
    #     # last batches are partial, eval is done on single (master) node
    #     num_batches = math.ceil(num_samples / args.batch_size)

    # add meta-data to dataloader instance for convenience
    dataloader.num_batches = num_batches
    dataloader.num_samples = num_samples

    return DataInfo(dataloader, None)


.....yeah no tests for that yet 

In [None]:
#| eval: false

# LAION cut n paste
if 'doesnt work yet' == 'stand by':
    train_data = get_wds_dataset(    args,
        model_cfg,
        is_train,
        audio_ext="flac",
        text_ext="json",
        max_len=480000,
        proportion=1.0,
        sizefilepath_=None,
        is_local=False,)

# AudioDataset class

The flagship class!

In [None]:
#|export
class AudioDataset(torch.utils.data.Dataset):
    """
    Reads from a tree of directories and serves up cropped bits from any and all audio files
    found therein. For efficiency, best if you "chunk" these files via chunkadelic
    modified from https://github.com/drscotthawley/audio-diffusion/blob/main/dataset/dataset.py
    """
    def __init__(self, 
        paths,             # list of strings of directory (/tree) names to draw audio files from
        sample_rate=48000, # audio sample rate in Hz
        sample_size=65536, # how many audio samples in each "chunk"
        random_crop=True,  # take chunks from random positions within files
        load_frac=1.0,     # fraction of total dataset to load
        cache_training_data=False,  # True = pre-load whole dataset into memory (not fully supported)
        num_gpus=8,        # used only when `cache_training_data=True`, to avoid duplicates,
        redraw_silence=True, # a chunk containing silence will be replaced with a new one
        silence_thresh=-60,  # threshold in dB below which we declare to be silence
        max_redraws=2,        # when redrawing silences, don't do it more than this many
        augs='Stereo(), PhaseFlipper()', # list of augmentation transforms **after PadCrop**, as a string
        verbose=False,       # whether to print notices of reasampling or not
        ):
        super().__init__()
    
        print("augs =",augs)
        # base_augs are always applied
        base_augs = 'PadCrop(sample_size, randomize=random_crop, redraw_silence=redraw_silence, silence_thresh=silence_thresh, max_redraws=max_redraws)'
        self.augs = eval(f'torch.nn.Sequential( {base_augs}, {augs} )')  
        self.silence_thresh = silence_thresh
        self.redraw_silence = redraw_silence
        self.max_redraws = max_redraws
        self.sr = sample_rate
        self.cache_training_data = cache_training_data
        self.verbose = verbose

        self.filenames = get_audio_filenames(paths)
        print(f"AudioDataset:{len(self.filenames)} files found.")
        self.n_files = int(len(self.filenames)*load_frac)
        self.filenames = self.filenames[0:self.n_files]
        if cache_training_data: self.preload_files()

        self.convert_tensor = VT.ToTensor()

    def load_file_ind(self, file_list,i): # used when caching training data
        return load_audio(file_list[i], sr=self.sr, verbose=self.verbose).cpu()

    def get_data_range(self): # for parallel runs, only grab part of the data -- OBVIATED BY CHUNKING.
        start, stop = 0, len(self.filenames)
        try:
            local_rank = int(os.environ["LOCAL_RANK"])
            world_size = int(os.environ["WORLD_SIZE"])
            interval = stop//world_size
            start, stop = local_rank*interval, (local_rank+1)*interval
            return start, stop
        except KeyError as e: # we're on GPU 0 and the others haven't been initialized yet
            start, stop = 0, len(self.filenames)//self.num_gpus
            return start, stop

    def preload_files(self):
        print(f"Caching {self.n_files} input audio files:")
        wrapper = partial(self.load_file_ind, self.filenames)
        start, stop = self.get_data_range()
        with Pool(processes=cpu_count()) as p:   # //8 to avoid FS bottleneck and/or too many processes (b/c * num_gpus)
            self.audio_files = list(tqdm.tqdm(p.imap(wrapper, range(start,stop)), total=stop-start))

    def __len__(self):
        return len(self.filenames)
    
    
    def get_next_chunk(self, 
        idx     # the index of the file within the list of files
        ):
        "The heart of this whole dataset routine"
        audio_filename = self.filenames[idx]
        try:
            if self.cache_training_data:
                audio = self.audio_files[idx] # .copy()
            else:
                audio = load_audio(audio_filename, sr=self.sr, verbose=self.verbose)

            #Run augmentations on this sample (including random crop)
            if self.augs is not None:
                audio = self.augs(audio)
                
            audio = audio.clamp(-1, 1)
            return audio
        
        except Exception as e:
          print(f'Error loading file {audio_filename}: {e}')
          return None
        
        
        
    def __getitem__(self, 
        idx     # the index of the file within the list of files
        ):
        audio = self.get_next_chunk(idx)
                
        # even with PadCrop set to reject silences, it could be that the whole file is silence; 
        num_redraws = 0 
        while (audio is None) or (self.redraw_silence and is_silence(audio, thresh=self.silence_thresh) \
            and (num_redraws < self.max_redraws)):
            #print(f"AudioDataset.__getitem__: Got None or silence (torch.max = {torch.max(audio)})  Redrawing. Attempt {num_redraws+1} of {self.max_redraws}")
            next_idx = random.randint(0,len(self.filenames)-1)     # pick some other file at random
            audio, num_redraws = self.get_next_chunk(next_idx), num_redraws+1
               
        return self[random.randrange(len(self))] if (audio is None) else audio

Quick check to catch minor errors:

In [None]:
dataset = AudioDataset('examples/', augs='Stereo(), PhaseFlipper(), FillTheNoise(), NormInputs()')
signal = dataset.__getitem__(0)
print("signal.shape =",signal.shape)

print("\nStereo -------------")
dataset2 = AudioDataset('examples/', augs='Stereo(), PhaseFlipper()')
signal2 = dataset2.__getitem__(0)
print("signal2.shape =",signal2.shape)

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()