In [1]:
from speechbrain.dataio.sampler import DynamicBatchSampler
import torch
import torchaudio
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm
INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [disable_jit_profiling, allow_tf32]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []


In [2]:
import re
from g2p_en import G2p
g2p = G2p()
import numpy as np

PHONE_DEF = [
    'AA', 'AE', 'AH', 'AO', 'AW',
    'AY', 'B',  'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH',
    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z', 'ZH'
]
PHONE_DEF_SIL = PHONE_DEF + ['SIL']

def phoneToId(p):
    return PHONE_DEF_SIL.index(p)

def convert_to_phonemes(transcript):
    
    thisTranscription = transcript.strip()
    thisTranscription = re.sub(r'[^a-zA-Z\- \']', '', thisTranscription)
    thisTranscription = thisTranscription.replace('--', '').lower()
    addInterWordSymbol = True

    phonemes = []
    
    for p in g2p(thisTranscription):
        if addInterWordSymbol and p==' ':
            phonemes.append('SIL')
        p = re.sub(r'[0-9]', '', p)  # Remove stress
        if re.match(r'[A-Z]+', p):  # Only keep phonemes
            phonemes.append(p)

    #add one SIL symbol at the end so there's one at the end of each word
    if addInterWordSymbol:
        phonemes.append('SIL')
        
    seqLen = len(phonemes)
    maxSeqLen = 500
    seqClassIDs = np.zeros([maxSeqLen]).astype(np.int32)
    seqClassIDs[0:seqLen] = [phoneToId(p) + 1 for p in phonemes]
    return seqClassIDs, len(phonemes)

# Custom collate function with padding
def collate_fn(batch):
    
    waveforms = [item[0].squeeze(0) for item in batch]  # Remove channel dimension
    lengths = torch.tensor([wav.shape[0] for wav in waveforms])
    lengths = lengths / max(lengths)
    
    # Pad sequences to match longest in batch
    padded_waveforms = torch.nn.utils.rnn.pad_sequence(
        waveforms, 
        batch_first=True
    )
    
    # Process batch in a single list comprehension (avoids intermediate lists)
    processed_batch = [convert_to_phonemes(item[2]) for item in batch]

    # Unpack using numpy for tensor conversion
    transcripts, transcript_lengths = zip(*processed_batch)

    # Convert using numpy stacking for better performance
    transcripts = torch.from_numpy(np.stack(transcripts))  # For multi-dimensional arrays
    transcript_lengths = torch.as_tensor(np.array(transcript_lengths), dtype=torch.long)
    
    return padded_waveforms, transcripts, lengths, transcript_lengths

In [3]:

import torchaudio
from torch.utils.data import DataLoader
import numpy as npx

# Dynamic Batch Sampler with Bucketing
class DynamicBatchSampler:
    def __init__(self, lengths, batch_size, shuffle=True, bucket_size=4000,
                  min_bucket_size=16):
        self.lengths = lengths
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.bucket_size = bucket_size

        # Create buckets based on length ranges
        self.buckets = {}
        for idx, length in enumerate(lengths):
            
            # // returns floor(length / bucket_size)
            # so all inputs from T:T+bucket_size in length
            # are placed in the same bucket. 
            bucket_id = length // bucket_size
            if bucket_id not in self.buckets:
                self.buckets[bucket_id] = []
            self.buckets[bucket_id].append(idx)
            
        # combine small buckets
        prev_bucket_id = None
        
        for bucket_id, vals in sorted(self.buckets.items()):
            
            # if bucket is too small
            if len(vals) < min_bucket_size:
                # check if previous bucket exists
                if prev_bucket_id is not None:
                    
                    # merge small bucket into previous big bucket
                    self.buckets[prev_bucket_id].extend(vals)
                   
                # if no valid previous bucket, move ids
                # to the next available bucket
                elif prev_bucket_id is None: 
                    self.buckets[bucket_id+1].extend(vals)
                    
                # delete small bucket
                del self.buckets[bucket_id]
                    
            # if bucket is big enough, mark it as the last
            # valid bucket
            else: 
                prev_bucket_id = bucket_id
                    
                    
                

    def __iter__(self):
        
        # shuffles inputs within a bucket
        if self.shuffle:
            np.random.shuffle(batches)
            for bucket in self.buckets.values():
                np.random.shuffle(bucket)

        # divides each bucket into batches
        batches = []
        for bucket in self.buckets.values():
            for i in range(0, len(bucket), self.batch_size):
                batches.append(bucket[i:i + self.batch_size])


        return iter(batches)

    def __len__(self):
        return sum(len(bucket) // self.batch_size for bucket in self.buckets.values())


    def print_bucket_sizes(self):
        print("Number of examples in each bucket:")
        for bucket_id, bucket in self.buckets.items():
            print(f"Bucket {bucket_id}: {len(bucket)} examples")
    #


batch_sampler = DynamicBatchSampler(
    lengths=np.load('/data/LLMs/librispeech/LibriSpeech/train-clean-100/lengths.npy'),
    batch_size=64,
    shuffle=True,
    bucket_size=8000,
)

trainLoader = DataLoader(
    trainDataset,
    batch_sampler=batch_sampler,
    collate_fn=collate_fn,
    num_workers=4,
)


SyntaxError: positional argument follows keyword argument (3304482115.py, line 93)

In [54]:
batch_sampler.print_bucket_sizes()

Number of examples in each bucket:
Bucket 28: 2740 examples
Bucket 31: 3066 examples
Bucket 27: 2269 examples
Bucket 29: 3273 examples
Bucket 25: 1430 examples
Bucket 30: 3361 examples
Bucket 19: 424 examples
Bucket 20: 429 examples
Bucket 7: 353 examples
Bucket 23: 756 examples
Bucket 26: 1805 examples
Bucket 32: 1289 examples
Bucket 24: 1061 examples
Bucket 16: 354 examples
Bucket 33: 474 examples
Bucket 14: 326 examples
Bucket 22: 623 examples
Bucket 13: 312 examples
Bucket 18: 361 examples
Bucket 6: 330 examples
Bucket 9: 348 examples
Bucket 11: 293 examples
Bucket 15: 335 examples
Bucket 21: 546 examples
Bucket 5: 304 examples
Bucket 10: 327 examples
Bucket 17: 368 examples
Bucket 8: 322 examples
Bucket 4: 259 examples
Bucket 34: 53 examples
Bucket 12: 324 examples
Bucket 3: 24 examples
