<a href="https://colab.research.google.com/github/gbdionne/toneclone/blob/main/SpectrogramDatasetBuilder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import os
import numpy as np
import pandas as pd
import librosa
import librosa.display
import wave
import h5py
import scipy.signal
from scipy import stats
from scipy.io.wavfile import write

class SpectrogramDatasetBuilder:
    DSP_FEATURES = [
        'modulation_freq', 'modulation_strength',
        'spectral_centroid', 'spectral_flatness',
        'freq_rolloff', 'spectral_bandwidth', 'zcr',
        'crest_factor', 'flat_top_indicator',
        'clipping_score'
    ]

    EFFECT_LABELS = {
        "ODV": "overdrive", "DST": "distortion", "FUZ": "fuzz", "TRM": "tremolo",
        "PHZ": "phaser", "FLG": "flanger", "CHR": "chorus", "DLY": "delay", "HLL": "hall_reverb",
        "PLT": "plate_reverb", "OCT": "octaver", "FLT": "auto_filter"
    }

    def __init__(self, data_dir, output_dir, sample_length=10, overlap=5, sample_rate=32000, num_mels=128, n_fft=2048, hop_length=512):
        self.data_dir = data_dir
        self.output_dir = output_dir
        self.sample_length = sample_length
        self.overlap = overlap
        self.sample_rate = sample_rate
        self.num_mels = num_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        os.makedirs(output_dir, exist_ok=True)
        self.label_file = os.path.join(output_dir, "test.csv")

        if not os.path.exists(self.label_file):
            columns = ['key'] + self.DSP_FEATURES + list(self.EFFECT_LABELS.values())
            pd.DataFrame(columns=columns).to_csv(self.label_file, index=False)

    def extract_effect_labels(self, filename):
        labels = {name: 0 for name in self.EFFECT_LABELS.values()} # Initialize multi-hot row to all 0s
        effects_found = []

        for code, name in self.EFFECT_LABELS.items():
            if code in filename:
                effects_found.append(name)
            elif 'CLN' in filename:
                effects_found.append('clean')

        if 'clean' in effects_found: # Example of clean guitar, no effects, return all 0 labels
            return labels
        elif effects_found:
            for effect in effects_found:
                if effect != 'clean':
                    labels[effect] = 1
            return labels
        else: # No valid effects found
            raise ValueError(f"No valid effects found in filename: {filename}")

    def split_wav(self, file_path):
        with wave.open(file_path, 'rb') as wav:
            num_channels = wav.getnchannels()
            sample_rate = wav.getframerate()
            num_frames = wav.getnframes()

            audio_data = np.frombuffer(wav.readframes(num_frames), dtype=np.int16)
            if num_channels == 2:
                audio_data = audio_data.reshape(-1, 2).mean(axis=1) # Convert stereo to mono if necessary

            samples_per_segment = self.sample_rate * self.sample_length
            overlap_samples = int(self.sample_rate * (self.overlap / 100))
            step_size = samples_per_segment - overlap_samples

            segments = []
            start = 0
            while start + samples_per_segment <= len(audio_data):
                segment_data = audio_data[start:start + samples_per_segment]
                segments.append(segment_data)
                start += step_size

            return segments, sample_rate

    def extract_am_modulation_strength(self, y, sr):
        """ Extracts dominant AM modulation frequency and its relative strength. """

        # Compute amplitude envelope
        analytic_signal = scipy.signal.hilbert(y)
        amplitude_envelope = np.abs(analytic_signal)

        # Compute FFT of the amplitude envelope
        fft_vals = np.abs(np.fft.rfft(amplitude_envelope))
        fft_freqs = np.fft.rfftfreq(len(amplitude_envelope), 1 / sr)

        # Focus on modulation frequencies between 4-20 Hz (common tremolo range)
        mod_range = (fft_freqs >= 4) & (fft_freqs <= 20)

        if np.any(mod_range):
            # Find dominant modulation frequency
            dominant_mod_freq = fft_freqs[mod_range][np.argmax(fft_vals[mod_range])]

            # Energy at dominant frequency
            dominant_energy = np.max(fft_vals[mod_range])

            # Compute total modulation energy across all frequencies
            total_mod_energy = np.sum(fft_vals)

            # Normalize modulation strength (percentage of total modulation energy at dominant frequency)
            mod_strength = dominant_energy / (total_mod_energy + 1e-8)  # Avoid division by zero
        else:
            dominant_mod_freq = 0  # No clear modulation detected
            mod_strength = 0

        return dominant_mod_freq, mod_strength

    def extract_autocorrelation_delay_fft(self, y, sr, max_lag_ms=300):
        """Detects delay effects using windowed autocorrelation analysis."""

        max_lag_samples = int((max_lag_ms / 1000) * sr)  # Convert max lag to samples

        # Use FFT-based cross-correlation for speed
        autocorr = scipy.signal.fftconvolve(y, y[::-1], mode='full')
        autocorr = autocorr[len(y)-1:len(y)-1+max_lag_samples]  # Keep only positive lags

        # Normalize autocorrelation to avoid scale dependency
        autocorr /= np.max(np.abs(autocorr) + 1e-8)

        # Convert lag indices to time in milliseconds
        lag_times = np.arange(len(autocorr)) / sr * 1000  # Convert to ms

        # Find peaks with minimum prominence and distance
        peak_indices, peak_props = scipy.signal.find_peaks(
            autocorr, height=0.1, distance=int(sr * 0.02)  # 20ms min distance
        )

        if peak_indices.size > 0:
            dominant_peak_idx = peak_indices[0]  # First peak is the dominant delay time
            dominant_delay_time = lag_times[dominant_peak_idx]
            dominant_echo_strength = peak_props['peak_heights'][0]

            # Normalize echo strength
            normalized_echo_strength = dominant_echo_strength / (np.sum(autocorr) + 1e-8)

            # Count strong peaks (multiple echoes)
            echo_count = len(peak_indices)
        else:
            dominant_delay_time = 0
            normalized_echo_strength = 0
            echo_count = 0

        return dominant_delay_time, normalized_echo_strength, echo_count

    def extract_spectral_features(self, y, sr):
        """ Compute spectral centroid and flatness features. """
        # librosa expects float32
        y = y.astype(np.float32) / (np.max(np.abs(y)) + np.finfo(np.float32).eps)

        spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))
        spectral_flatness = np.mean(librosa.feature.spectral_flatness(y=y))
        return spectral_centroid, spectral_flatness

    def extract_freq_rolloff(self, y, sr):
        """ Compute frequency rolloff feature. """
        y = y.astype(np.float32) / (np.max(np.abs(y)) + np.finfo(np.float32).eps)
        freq_rolloff = np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr))
        return freq_rolloff

    def extract_spectral_bandwidth(self, y, sr):
        """ Compute spectral bandwidth feature. """
        y = y.astype(np.float32) / (np.max(np.abs(y)) + np.finfo(np.float32).eps)
        spectral_bandwidth = np.mean(librosa.feature.spectral_bandwidth(y=y, sr=sr))
        return spectral_bandwidth

    def extract_zero_crossing_rate(self, y):
        """ Compute zero crossing rate feature. """
        y = y.astype(np.float32) / (np.max(np.abs(y)) + np.finfo(np.float32).eps)
        return np.mean(librosa.feature.zero_crossing_rate(y))

    def extract_crest_factor(self, y):
        """ Compute crest factor feature (Peak-to_RMS Ratio), ignore silent samples. """
        threshold_ratio = 0.01
        # Get max absolute amplitude
        peak_amplitude = np.max(np.abs(y))

        # Set threshold based on the max amplitude
        min_valid_amplitude = threshold_ratio * peak_amplitude

        # Select samples that exceed the threshold
        valid_samples = y[np.abs(y) >= min_valid_amplitude]

        if valid_samples.size == 0:
            return np.nan  # No valid samples found

        # Compute RMS on valid samples
        rms_amplitude = np.sqrt(np.mean(valid_samples**2))

        if rms_amplitude == 0:
            return np.nan  # Avoid division by zero

        return peak_amplitude / rms_amplitude

    def extract_flat_top_indicator(self, y, threshold=0.95):
        """ Compute flat top indicator feature. """
        peak_amplitude = np.max(np.abs(y))
        if peak_amplitude == 0:
            return 0.0  # Avoid division by zero (silent signal case)

        # Count samples within threshold of peak amplitude
        flat_top_samples = np.sum(np.abs(y) >= threshold * peak_amplitude)
        total_samples = len(y)

        return flat_top_samples / total_samples

    def clipping_score(self, y, silence_threshold_ratio=0.005):
        """
        Computes a Clipping Score based on kurtosis, attempts to ignore silence
        """
        # Convert to NumPy array
        x = np.asarray(y, dtype=float)

        # Remove DC offset (center the signal)
        x = x - np.mean(x)

        # Compute RMS-based silence threshold (instead of peak-based)
        rms_amplitude = np.sqrt(np.mean(x**2))
        silence_threshold = silence_threshold_ratio * rms_amplitude

        # Define valid (non-silent) samples
        valid_samples = x[np.abs(x) >= silence_threshold]

        if valid_samples.size == 0:
            return 0.0  # If all samples are silent, return 0 (no clipping detected)

        # Compute kurtosis on non-silent samples
        k = stats.kurtosis(valid_samples, fisher=False, bias=False)

        if k == 0 or np.isnan(k):
            return 0.0  # Avoid division errors

        # Compute Clipping Score as 1/kurtosis
        return 1.0 / k

    def generate_spectrogram(self, audio_segment, sample_rate, spectrogram_type='mel'):
        audio_float = audio_segment.astype(np.float32) / (np.max(np.abs(audio_segment))  + np.finfo(np.float32).eps)
        if spectrogram_type == 'mel':
            sgrm = librosa.feature.melspectrogram(y=audio_float, sr=sample_rate, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.num_mels)
        else:
            sgrm = librosa.stft(audio_float, n_fft=self.n_fft, hop_length=self.hop_length)
        return librosa.amplitude_to_db(sgrm, ref=np.max)

    def process_data(self, spectrogram_type='mel'):
        label_buffer = []
        hdf5_path = os.path.join(self.output_dir, "test.h5")

        effect_columns = list(self.EFFECT_LABELS.values())  # Extract effect names
        all_columns = ['key'] + self.DSP_FEATURES + list(self.EFFECT_LABELS.values())

        with h5py.File(hdf5_path, 'w') as h5f:
            for file in os.listdir(self.data_dir):
                if file.endswith(".wav"):
                    print(f"Processing file: {file}")
                    file_path = os.path.join(self.data_dir, file)
                    segments, sample_rate = self.split_wav(file_path)
                    effect_labels = self.extract_effect_labels(file)  # Extract effect labels (e.g., {'distortion': 1})

                    for i, segment in enumerate(segments):
                        # sgrm_db = self.generate_spectrogram(segment, sample_rate, spectrogram_type)
                        key = f"{file}_seg_{i+1}"
                        # h5f.create_dataset(key, data=sgrm_db)

                        # Compute AM modulation features
                        dominant_mod_freq, mod_strength = self.extract_am_modulation_strength(segment, sample_rate)
                        #dominant_mod_freq, mod_strength = 0, 0

                        # Compute spectral features
                        spectral_centroid, spectral_flatness = self.extract_spectral_features(segment, sample_rate)
                        #spectral_centroid, spectral_flatness = 0, 0

                        # Compute Frequency Rolloff
                        freq_rolloff = self.extract_freq_rolloff(segment, sample_rate)

                        # Compute Spectral Bandwidth
                        spectral_bandwidth = self.extract_spectral_bandwidth(segment, sample_rate)

                        # Compute Zero Crossing Rate
                        zero_crossing_rate = self.extract_zero_crossing_rate(segment)

                        # Compute Crest Factor
                        crest_factor = self.extract_crest_factor(segment)

                        # Compute Flat Top Indicator
                        flat_top_indicator = self.extract_flat_top_indicator(segment)

                        # Compute Clipping Score
                        clipping_score = self.clipping_score(segment)

                        # Ensure all columns exist and are initialized to 0
                        label_row = {col: 0 if col in effect_columns else None for col in all_columns}
                        #print(f"Label row init: {label_row}")

                        # Explicitly update effect labels
                        for effect in effect_columns:
                            label_row[effect] = effect_labels.get(effect, 0)
                        #print(f"Label row after adding effects: {label_row}")

                        # Add computed features (ensuring no overwrites)
                        label_row.update({
                            'key': key,
                            'modulation_freq': round(float(dominant_mod_freq), 2),
                            'modulation_strength': round(float(mod_strength), 5),
                            'spectral_centroid': round(float(spectral_centroid), 2),
                            'spectral_flatness': round(float(spectral_flatness), 5),
                            'freq_rolloff': round(float(freq_rolloff), 5),
                            'spectral_bandwidth': round(float(spectral_bandwidth), 1),
                            'zcr': round(float(zero_crossing_rate), 5),
                            'crest_factor': round(float(crest_factor), 5),
                            'flat_top_indicator': round(float(flat_top_indicator), 5),
                            'clipping_score': round(float(self.clipping_score(segment)), 5)
                        })

                        #print(f"Label row after updating with key and DSP features: {label_row}")
                        label_buffer.append(label_row)

                        # Write in 100-row batches for efficiency
                        if len(label_buffer) >= 100:
                            df = pd.DataFrame(label_buffer, columns=all_columns)  # Ensure correct column order
                            df.to_csv(self.label_file, mode='a', header=False, index=False)
                            print(f"Wrote {len(label_buffer)} rows to CSV file.")
                            label_buffer.clear()

        # Final batch write to ensure all rows are saved
        if label_buffer:
            df = pd.DataFrame(label_buffer, columns=all_columns)  # Maintain column consistency
            df.to_csv(self.label_file, mode='a', header=False, index=False)
            print(f"Final batch: Wrote {len(label_buffer)} rows to CSV file.")

        print("Data processing completed. Spectrograms saved to HDF5 file.")

In [4]:
!mkdir /content/wav_files
!cp -r "/content/drive/MyDrive/Capstone 210/Data/Distortion Types" "/content/wav_files"

In [9]:
data_dir = "/content/wav_files/Distortion Types"
output_dir = "/content/drive/MyDrive/Capstone 210/Output"
data_loader = SpectrogramDatasetBuilder(data_dir, output_dir)
data_loader.process_data()

Processing file: final_test_ODV.wav


  rms_amplitude = np.sqrt(np.mean(valid_samples**2))


Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Processing file: final_test_DST.wav
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Processing file: final_test_FUZ.wav
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Wrote 100 rows to CSV file.
Final batch: Wrote 71 rows to CSV file.
Data processing completed. Spectrograms saved to HDF5 file.


In [15]:
from google.colab import runtime
runtime.unassign()