In [None]:
def load_segments(df, source_dir, segments_dir, cap_per_class=100, sr=32000, segment_sec=5.0, 
                 threshold_factor=0.5, hoplen=512, noise_reduce=False):
    '''
    Extracts audio segments from source files and saves them as .wav files.
    Ensures balanced distribution by cycling through all audios for each time window.
    
    Args:
        df (pd.DataFrame): DataFrame containing 'filename' and 'class_id' columns.
        source_dir (str): Directory where the original audio files are located.
        segments_dir (str): Directory to save extracted audio segments as .wav files.
        cap_per_class (int): Maximum number of segments per class_id. Defaults to 100.
        sr (int): Target sampling rate. Defaults to 32000.
        segment_sec (float): Duration of each segment in seconds. Defaults to 5.0.
        threshold_factor (float): RMS threshold factor. Defaults to 0.5.
        hoplen (int): Hop length for RMS calculation. Defaults to 512.
        noise_reduce (bool): Whether to apply noise reduction. Defaults to False.
    
    Returns:
        pd.DataFrame: DataFrame with 'filename' and 'class_id' for all saved segments.
    '''
    
    samples_per_segment = int(sr * segment_sec)
    
    # Initialize tracking structures
    active_audios = []  # List of (audio_data, class_id, original_filename, max_segments)
    class_counts = {}   # Track segments per class_id
    segment_records = []  # Final output records
    
    # Load all audio files and calculate thresholds
    print("Loading audio files and calculating thresholds...")
    for _, row in df.iterrows():
        filename = row['filename']
        class_id = row['class_id']
        audio_path = os.path.join(source_dir, filename)
        
        try:
            y, srate = util.lbrs_loading(audio_path, sr=sr, mono=True)
            threshold = util.get_rmsThreshold(y, frame_len=2048, hop_len=hoplen, thresh_factor=threshold_factor)
            max_segments = len(y) // samples_per_segment
            
            if max_segments > 0:  # Only add if audio has at least one full segment
                active_audios.append({
                    'audio_data': y,
                    'class_id': class_id,
                    'filename': filename,
                    'max_segments': max_segments,
                    'threshold': threshold,
                    'sr': srate
                })
                
                if class_id not in class_counts:
                    class_counts[class_id] = 0
                    
        except Exception as e:
            print(f"Error loading {filename}: {e}")
            continue
    
    print(f"Loaded {len(active_audios)} audio files successfully.")
    
    # Cycle through segments
    segment_index = 0
    segments_saved = 0
    
    while active_audios:
        print(f"Processing segment window {segment_index} ({segment_index * segment_sec}s - {(segment_index + 1) * segment_sec}s)")
        
        audios_to_remove = []
        
        for i, audio_info in enumerate(active_audios):
            class_id = audio_info['class_id']
            
            # Check if this class has reached its cap
            if class_counts[class_id] >= cap_per_class:
                audios_to_remove.append(i)
                continue
            
            # Check if this audio has enough data for current segment
            if segment_index >= audio_info['max_segments']:
                audios_to_remove.append(i)
                continue
            
            # Extract segment
            start_sample = segment_index * samples_per_segment
            end_sample = start_sample + samples_per_segment
            segment = audio_info['audio_data'][start_sample:end_sample]
            
            # Check RMS threshold
            seg_rms = np.mean(lbrs.feature.rms(y=segment)[0])
            if seg_rms < audio_info['threshold']:
                continue
            
            # Apply noise reduction if requested
            if noise_reduce:
                segment = util.reduce_noise_seg(segment, sr=audio_info['sr'], 
                                              filename=audio_info['filename'], class_id=class_id)
            
            # Save segment
            segment_filename = f"{audio_info['filename'].split('.')[0]}_seg{segment_index:03d}.wav"
            segment_path = os.path.join(segments_dir, segment_filename)
            
            # Ensure directory exists
            os.makedirs(segments_dir, exist_ok=True)
            
            # Save as wav file
            import soundfile as sf
            sf.write(segment_path, segment, audio_info['sr'])
            
            # Record this segment
            segment_records.append({
                'filename': segment_filename,
                'class_id': class_id
            })
            
            class_counts[class_id] += 1
            segments_saved += 1
            
            print(f"Saved {segment_filename} (Class {class_id}: {class_counts[class_id]}/{cap_per_class})")
        
        # Remove audios that are done
        for i in sorted(audios_to_remove, reverse=True):
            removed_audio = active_audios.pop(i)
            print(f"Removed {removed_audio['filename']} from processing")
        
        segment_index += 1
        
        # Break if no more audios to process
        if not active_audios:
            break
    
    print(f"\nSegment extraction complete!")
    print(f"Total segments saved: {segments_saved}")
    print("Segments per class:")
    for class_id, count in class_counts.items():
        print(f"  Class {class_id}: {count}")
    
    return pd.DataFrame(segment_records)
