In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install tensorflow==2.16.1

Collecting tensorflow==2.16.1
  Downloading tensorflow-2.16.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.3 kB)
Collecting ml-dtypes~=0.3.1 (from tensorflow==2.16.1)
  Downloading ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 (from tensorflow==2.16.1)
  Downloading protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl.metadata (541 bytes)
Collecting tensorboard<2.17,>=2.16 (from tensorflow==2.16.1)
  Downloading tensorboard-2.16.2-py3-none-any.whl.metadata (1.6 kB)
Collecting numpy<2.0.0,>=1.26.0 (from tensorflow==2.16.1)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m


## Player setup

In [None]:
from IPython.display import Audio, display
import os

def play_audio(file_path, show_controls=True):
    """Create an inline audio player in Jupyter"""
    if os.path.exists(file_path):
        print(f"🎵 Playing: {os.path.basename(file_path)}")
        display(Audio(file_path, autoplay=False))
    else:
        print(f"❌ File not found: {file_path}")

In [None]:
from pathlib import Path
import librosa

folder_path = Path('/content/drive/MyDrive/<insert_path_here>')

track_1 = ""
track_2 = ""
track_3 = ""

full_path_1 = folder_path / track_1
full_path_2 = folder_path / track_2
full_path_3 = folder_path / track_3

In [None]:
play_audio(full_path_2)

# Remove Silence

In [None]:
from pydub import AudioSegment
from pydub.silence import detect_nonsilent
from functools import reduce

def strip_silence(audio_path):
    """Removes silent parts from an audio file."""
    sound = AudioSegment.from_file(audio_path)
    nonsilent_ranges = detect_nonsilent(
        sound, min_silence_len=500, silence_thresh=-50)
    stripped = reduce(lambda acc, val: acc + sound[val[0]:val[1]],
                      nonsilent_ranges, AudioSegment.empty())
    stripped.export(audio_path, format='mp3')


# Audio Feature Extraction

In [None]:
"""
Audio processing functionality for chorus detection.
"""

import os
import numpy as np
from typing import List, Tuple
import librosa
from sklearn.preprocessing import StandardScaler


# Constants
SR = 12000
HOP_LENGTH = 128
MAX_FRAMES = 300
MAX_METERS = 201
N_FEATURES = 15


class AudioFeature:
    """Class for extracting and processing audio features."""

    def __init__(self, audio_path, sr=SR, hop_length=HOP_LENGTH):
        self.audio_path = audio_path
        self.sr = sr
        self.hop_length = hop_length
        self.y = None
        self.y_harm = self.y_perc = None
        self.beats = None
        self.chromagram = self.chroma_acts = None
        self.combined_features = None
        self.key = self.mode = None
        self.mel_acts = self.melspectrogram = None
        self.meter_grid = None
        self.mfccs = self.mfcc_acts = None
        self.n_frames = None
        self.onset_env = None
        self.rms = None
        self.spectrogram = None
        self.tempo = None
        self.tempogram = self.tempogram_acts = None
        self.time_signature = 4

    def detect_key(self, chroma_vals: np.ndarray) -> Tuple[str, str]:
        """Detect the key and mode (major or minor) of the audio segment."""
        note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
        major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
        minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])

        # Normalize profiles
        major_profile /= np.linalg.norm(major_profile)
        minor_profile /= np.linalg.norm(minor_profile)

        # Calculate correlations for all possible keys
        major_correlations = [np.corrcoef(chroma_vals, np.roll(major_profile, i))[0, 1] for i in range(12)]
        minor_correlations = [np.corrcoef(chroma_vals, np.roll(minor_profile, i))[0, 1] for i in range(12)]

        # Find best match
        max_major_idx = np.argmax(major_correlations)
        max_minor_idx = np.argmax(minor_correlations)

        self.mode = 'major' if major_correlations[max_major_idx] > minor_correlations[max_minor_idx] else 'minor'
        self.key = note_names[max_major_idx if self.mode == 'major' else max_minor_idx]
        return self.key, self.mode

    def calculate_ki_chroma(self, waveform: np.ndarray, sr: int, hop_length: int) -> np.ndarray:
        """Calculate a normalized, key-invariant chromagram."""
        chromagram = librosa.feature.chroma_cqt(y=waveform, sr=sr, hop_length=hop_length, bins_per_octave=24)
        chromagram = (chromagram - chromagram.min()) / (chromagram.max() - chromagram.min())

        chroma_vals = np.sum(chromagram, axis=1)
        key, mode = self.detect_key(chroma_vals)

        key_idx = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'].index(key)
        shift_amount = -key_idx if mode == 'major' else -(key_idx + 3) % 12

        return librosa.util.normalize(np.roll(chromagram, shift_amount, axis=0), axis=1)


    def extract_features_from_audio(self):
        """
        Extract features from audio file using the same process as training.
        """
        print(f"Processing: {self.audio_path}")

        # Load audio and separate harmonic/percussive components
        self.y, self.sr = librosa.load(self.audio_path, sr=self.sr)
        print(f"Loaded audio: {len(self.y)/self.sr:.1f} seconds")

        self.y_harm, self.y_perc = librosa.effects.hpss(self.y)

        # Extract spectrogram and RMS
        self.spectrogram, _ = librosa.magphase(librosa.stft(self.y, hop_length=self.hop_length))
        self.rms = librosa.feature.rms(S=self.spectrogram, hop_length=self.hop_length).astype(np.float32)

        # Extract mel spectrogram and its components
        self.melspectrogram = librosa.feature.melspectrogram(
            y=self.y, sr=self.sr, n_mels=128, hop_length=self.hop_length).astype(np.float32)
        self.mel_acts = librosa.decompose.decompose(
            self.melspectrogram, n_components=3, sort=True)[1].astype(np.float32)

        # Extract chromagram and its components
        self.chromagram = self.calculate_ki_chroma(self.y_harm, self.sr, self.hop_length).astype(np.float32)
        self.chroma_acts = librosa.decompose.decompose(
            self.chromagram, n_components=4, sort=True)[1].astype(np.float32)

        # Extract onset envelope and tempogram
        self.onset_env = librosa.onset.onset_strength(y=self.y_perc, sr=self.sr, hop_length=self.hop_length)
        self.tempogram = np.clip(librosa.feature.tempogram(
            onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length), 0, None)
        self.tempogram_acts = librosa.decompose.decompose(self.tempogram, n_components=3, sort=True)[1]

        # Extract MFCCs and components
        self.mfccs = librosa.feature.mfcc(y=self.y, sr=self.sr, n_mfcc=20, hop_length=self.hop_length)
        self.mfccs += abs(np.min(self.mfccs))
        self.mfcc_acts = librosa.decompose.decompose(self.mfccs, n_components=4, sort=True)[1].astype(np.float32)

        # Combine features with weighted normalization
        features = [self.rms, self.mel_acts, self.chroma_acts, self.tempogram_acts, self.mfcc_acts]
        feature_names = ['rms', 'mel_acts', 'chroma_acts', 'tempogram_acts', 'mfcc_acts']

        # Calculate weights for each feature type
        dims = {name: feature.shape[0] for feature, name in zip(features, feature_names)}
        total_inv_dim = sum(1 / dim for dim in dims.values())
        weights = {name: 1 / (dims[name] * total_inv_dim) for name in feature_names}

        # Standardize and weight features
        std_weighted_features = [
            StandardScaler().fit_transform(feature.T).T * weights[name]
            for feature, name in zip(features, feature_names)
        ]

        self.combined_features = np.concatenate(std_weighted_features, axis=0).T.astype(np.float32)
        self.n_frames = len(self.combined_features)

    def create_meter_grid(self):
        """Create a grid based on the meter of the song, using tempo and beats."""
        self.tempo, self.beats = librosa.beat.beat_track(
            onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length)

        # Adjust tempo to reasonable range
        if self.tempo < 75:
            self.tempo *= 2
        elif self.tempo > 160:
            self.tempo /= 2

        self.meter_grid = self._create_meter_grid()
        return self.meter_grid

    def _create_meter_grid(self) -> np.ndarray:
        """Helper function to create a meter grid for the song."""
        seconds_per_beat = 60 / self.tempo
        beat_interval = int(librosa.time_to_frames(
            seconds_per_beat, sr=self.sr, hop_length=self.hop_length))

        # Find best matching start beat
        if len(self.beats) >= 3:
            best_match = max(
                (1 - abs(np.mean(self.beats[i:i+3]) - beat_interval) / beat_interval, self.beats[i])
                for i in range(len(self.beats) - 2)
            )[1]
            anchor_frame = best_match if best_match > 0.95 else self.beats[0]
        else:
            anchor_frame = self.beats[0] if len(self.beats) > 0 else 0

        first_beat_time = librosa.frames_to_time(anchor_frame, sr=self.sr, hop_length=self.hop_length)

        # Calculate beat times forward and backward
        time_duration = librosa.frames_to_time(self.n_frames, sr=self.sr, hop_length=self.hop_length)
        beat_times_forward = np.arange(first_beat_time, time_duration, seconds_per_beat)
        beat_times_backward = np.arange(first_beat_time - seconds_per_beat, -seconds_per_beat, -seconds_per_beat)


        # Create beat times in both directions
        beat_grid = np.concatenate((np.array([0.0]), beat_times_backward[::-1], beat_times_forward))
        meter_indices = np.arange(0, len(beat_grid), self.time_signature)
        meter_grid = beat_grid[meter_indices]

        # Ensure grid starts at 0
        if meter_grid[0] != 0.0:
            meter_grid = np.insert(meter_grid, 0, 0.0)

        # Convert to frames and add final frame
        meter_grid_frames = librosa.time_to_frames(meter_grid, sr=self.sr, hop_length=self.hop_length)
        meter_grid_frames = np.append(meter_grid_frames, self.n_frames)

        return meter_grid_frames


def segment_data_meters(data: np.ndarray, meter_grid: List[int]) -> List[np.ndarray]:
    """Segment input data into chunks based on a meter grid."""
    return [data[meter_grid[i]:meter_grid[i+1]] for i in range(len(meter_grid) - 1)]


def positional_encoding(position: int, d_model: int) -> np.ndarray:
    """Add positional encoding to input data."""
    pe = np.zeros(d_model)
    for i in range(0, d_model, 2):
        pe[i] = np.sin(position / (10000 ** (i / d_model)))
        if i + 1 < d_model:
            pe[i + 1] = np.cos(position / (10000 ** (i / d_model)))
    return pe


def apply_hierarchical_positional_encoding(segments: List[np.ndarray]) -> List[np.ndarray]:
    """Apply positional encoding to a list of segments."""
    encoded_segments = []
    for meter_idx, meter_segment in enumerate(segments):
        meter_encoded = np.zeros_like(meter_segment)
        for frame_idx, frame in enumerate(meter_segment):
            frame_pos_encoding = positional_encoding(frame_idx, frame.shape[0]) * 0.1
            meter_pos_encoding = positional_encoding(meter_idx, frame.shape[0]) * 0.2
            meter_encoded[frame_idx] = frame + frame_pos_encoding + meter_pos_encoding
        encoded_segments.append(meter_encoded)
    return encoded_segments


def pad_song(encoded_segments: List[np.ndarray], max_frames: int = MAX_FRAMES,
             max_meters: int = MAX_METERS, n_features: int = N_FEATURES) -> np.ndarray:
    """
    Pad a list of encoded segments to create a uniform 3D array.

    Parameters:
    - encoded_segments (list): List of encoded data segments
    - max_frames (int): Maximum number of frames per segment
    - max_meters (int): Maximum number of meters
    - n_features (int): Number of features per frame

    Returns:
    - np.ndarray: Padded 3D array of shape (max_meters, max_frames, n_features)
    """
    padded_song = np.zeros((max_meters, max_frames, n_features))

    for i, segment in enumerate(encoded_segments):
        if i >= max_meters:
            break  # Only consider up to max_meters segments

        segment_frames = segment.shape[0]
        if segment_frames <= max_frames:
            # If segment fits, copy it directly
            padded_song[i, :segment_frames, :] = segment
        else:
            # If segment is too long, sample frames evenly
            indices = np.linspace(0, segment_frames - 1, max_frames, dtype=int)
            padded_song[i, :, :] = segment[indices, :]

    return padded_song


def process_audio(audio_path, trim_silence=True, sr=SR, hop_length=HOP_LENGTH):
    """Process an audio file for chorus detection."""
    try:
        # Optionally strip silence
        if trim_silence:
            strip_silence(audio_path)

        # Extract audio features
        audio_features = AudioFeature(audio_path, sr=sr, hop_length=hop_length)
        audio_features.extract_features_from_audio()
        meter_grid = audio_features.create_meter_grid()

        # Segment and pad the data
        feature_segments = segment_data_meters(audio_features.combined_features, meter_grid)
        encoded_segments = apply_hierarchical_positional_encoding(feature_segments)
        padded_song = pad_song(encoded_segments)

        # Add batch dimension for model
        padded_song = np.expand_dims(padded_song, axis=0)
        return padded_song, audio_features
    except Exception as e:
        print(f"Error processing audio: {e}")
        return None, None

In [None]:
padded_song, audio_features = process_audio(full_path_2)

In [None]:
print(padded_song)
print(audio_features)

# Load CRNN Model and its helper methods

In [None]:
crnn_model = "best_model_V3.h5"
MODEL_PATH = folder_path / crnn_model

In [None]:
"""
Model functionality for chorus detection.
"""

import os
import numpy as np
import tensorflow as tf
import librosa


def create_crnn_model(max_frames_per_meter=300, max_meters=201, n_features=15):
    """
    Recreate the exact CRNN model architecture from the repo
    """
    # Frame-level feature extractor (CNN part)
    frame_input = tf.keras.layers.Input(shape=(max_frames_per_meter, n_features))
    conv1 = tf.keras.layers.Conv1D(filters=128, kernel_size=3, activation='relu', padding='same')(frame_input)
    pool1 = tf.keras.layers.MaxPooling1D(pool_size=2, padding='same')(conv1)
    conv2 = tf.keras.layers.Conv1D(filters=256, kernel_size=3, activation='relu', padding='same')(pool1)
    pool2 = tf.keras.layers.MaxPooling1D(pool_size=2, padding='same')(conv2)  # Fixed: was pool2, should be conv2
    conv3 = tf.keras.layers.Conv1D(filters=256, kernel_size=3, activation='relu', padding='same')(pool2)
    pool3 = tf.keras.layers.MaxPooling1D(pool_size=2, padding='same')(conv3)
    frame_features = tf.keras.layers.Flatten()(pool3)
    frame_feature_model = tf.keras.Model(inputs=frame_input, outputs=frame_features)

    # Full model with LSTM
    meter_input = tf.keras.layers.Input(shape=(max_meters, max_frames_per_meter, n_features))
    time_distributed = tf.keras.layers.TimeDistributed(frame_feature_model)(meter_input)
    masking_layer = tf.keras.layers.Masking(mask_value=0.0)(time_distributed)
    lstm_out = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(256, return_sequences=True))(masking_layer)
    output = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(1, activation='sigmoid'))(lstm_out)

    model = tf.keras.Model(inputs=meter_input, outputs=output)
    return model

# Update your load function:
def load_CRNN_model(model_path: str = MODEL_PATH) -> tf.keras.Model:
    """Load a pre-trained CRNN model from the specified path."""
    try:
        # Create the model architecture
        model = create_crnn_model()

        # Load just the weights
        model.load_weights(model_path)
        print("Model loaded successfully!")
        return model

    except Exception as e:
        print(f"Error loading model: {e}")
        return None


def smooth_predictions(data: np.ndarray) -> np.ndarray:
    """Apply smoothing to model predictions to reduce jitter."""
    # First pass: Moving average
    window_size = 3
    smoothed = np.zeros_like(data)
    for i in range(len(data)):
        window_start = max(0, i - window_size // 2)
        window_end = min(len(data), i + window_size // 2 + 1)
        smoothed[i] = np.mean(data[window_start:window_end])

    # Second pass: Eliminate short segments
    min_segment_length = 2
    current_segment_length = 1
    current_value = smoothed[0] > 0.5
    binary_smoothed = np.zeros_like(smoothed, dtype=int)
    binary_smoothed[0] = int(current_value)

    for i in range(1, len(smoothed)):
        new_value = smoothed[i] > 0.5
        if new_value == current_value:
            current_segment_length += 1
        else:
            # If segment is too short, revert to previous value
            if current_segment_length < min_segment_length:
                for j in range(i - current_segment_length, i):
                    binary_smoothed[j] = int(new_value)
            current_value = new_value
            current_segment_length = 1
        binary_smoothed[i] = int(current_value)

    # Third pass: Fix final segment if too short
    if current_segment_length < min_segment_length:
        for j in range(len(smoothed) - current_segment_length, len(smoothed)):
            binary_smoothed[j] = int(not current_value)

    return binary_smoothed


def make_predictions(model, processed_audio, audio_features):
    """Make chorus predictions using the loaded model."""
    # Generate predictions
    raw_predictions = model.predict(processed_audio).squeeze()

    # Limit predictions to actual meters
    n_meters = min(len(audio_features.meter_grid) - 1, len(raw_predictions))
    predictions = raw_predictions[:n_meters]

    # Apply smoothing
    smoothed_predictions = smooth_predictions(predictions)

    # Calculate time values for display
    meter_grid_times = librosa.frames_to_time(
        audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)

    # Find chorus segments
    chorus_indices = np.where(smoothed_predictions == 1)[0]
    chorus_start_times = []
    chorus_end_times = []

    if len(chorus_indices) > 0:
        # Group consecutive indices
        groups = []
        current_group = [chorus_indices[0]]

        for i in range(1, len(chorus_indices)):
            if chorus_indices[i] == chorus_indices[i-1] + 1:
                current_group.append(chorus_indices[i])
            else:
                groups.append(current_group)
                current_group = [chorus_indices[i]]
        groups.append(current_group)

        # Display chorus segments
        print("\nDetected chorus sections:")
        for i, group in enumerate(groups):
            start_time = meter_grid_times[group[0]]
            end_time = meter_grid_times[group[-1] + 1]
            chorus_start_times.append(start_time)
            chorus_end_times.append(end_time)

            start_min, start_sec = divmod(start_time, 60)
            end_min, end_sec = divmod(end_time, 60)

            print(f"Chorus {i+1}: {int(start_min)}:{start_sec:05.2f} - {int(end_min)}:{end_sec:05.2f}")
    else:
        print("No choruses detected in this audio file.")

    return smoothed_predictions, chorus_start_times, chorus_end_times

In [None]:
print(MODEL_PATH)
padded_song, audio_features = process_audio(full_path_2)
model = load_CRNN_model()
print(model)

# Test predictions

In [None]:
smoothed_predictions, chorus_start_times, chorus_end_times  = make_predictions(model, padded_song, audio_features)
print(smooth_predictions, chorus_start_times, chorus_end_times)

# Change last layer

In [None]:
model.summary() # Original Model

In [None]:
  def modify_model_for_multiclass(model, class_names):
      """Modified version that handles custom loss for padding."""
      num_classes = len(class_names)

      # Freeze all layers except the last one
      for layer in model.layers:
          layer.trainable = False

      # Remove the last TimeDistributed Dense layer
      x = model.layers[-2].output

      # Add dropout before the output layer -> didnt help on small dataset. maybe try it again when we have more training data.
      # x = tf.keras.layers.TimeDistributed(
      #     tf.keras.layers.Dropout(0.3)
      # )(x)

      # Add new multiclass output layer
      new_output = tf.keras.layers.TimeDistributed(
          tf.keras.layers.Dense(num_classes, activation='softmax')
      )(x)

      # Custom loss function that ignores padding
      def masked_categorical_crossentropy(y_true, y_pred):
          # Create a mask from the true labels (assuming padding is represented by all zeros or -1s)
          mask = tf.reduce_sum(tf.cast(tf.not_equal(y_true, -1.0), tf.float32), axis=-1)
          mask = tf.cast(tf.not_equal(mask, 0), tf.float32) # Mask is 1 where there is data, 0 where padded

          # Calculate categorical crossentropy
          cce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)

          # Apply mask
          masked_cce = cce * mask

          # Return average loss (only over non-padded elements)
          return tf.reduce_sum(masked_cce) / (tf.reduce_sum(mask) + tf.keras.backend.epsilon()) # Add epsilon for numerical stability

      def masked_accuracy(y_true, y_pred):
          # Create a mask from the true labels (assuming padding is represented by all zeros or -1s)
          mask = tf.reduce_sum(tf.cast(tf.not_equal(y_true, -1.0), tf.float32), axis=-1)
          mask = tf.cast(tf.not_equal(mask, 0), tf.float32) # Mask is 1 where there is data, 0 where padded

          # Get predictions and true labels (ignoring padding)
          y_pred_classes = tf.argmax(y_pred, axis=-1)
          y_true_classes = tf.argmax(y_true, axis=-1)

          # Apply mask to true and predicted classes
          y_true_masked = y_true_classes * tf.cast(mask, tf.int64)
          y_pred_masked = y_pred_classes * tf.cast(mask, tf.int64)

          # Calculate accuracy only on non-padded elements
          correct = tf.cast(tf.equal(y_pred_masked, y_true_masked), tf.float32) * mask

          return tf.reduce_sum(correct) / (tf.reduce_sum(mask) + tf.keras.backend.epsilon()) # Add epsilon for numerical stability

      # Create and compile new model
      new_model = tf.keras.Model(inputs=model.input, outputs=new_output)
      new_model.compile(
          optimizer='adam',
          loss=masked_categorical_crossentropy,
          metrics=[masked_accuracy]
      )

      return new_model

In [None]:
new_model = modify_model_for_multiclass(model, ["O", "A", "B","C"])

In [None]:
new_model.summary() # New model with 4 classes on output

# Setup Data for fine tuning

In [None]:
def create_pitch_shifted_audio_files_with_mapping(original_files, pitch_shifts=[-2, -1, 1, 2]):
    """Create pitch-shifted audio files and return both files and song_id mapping"""
    import librosa
    import soundfile as sf

    augmented_files = []
    song_id_mapping = {}
    file_index = 0

    # Add original files first
    for i, audio_path in enumerate(original_files):
        augmented_files.append(audio_path)
        song_id_mapping[file_index] = str(i + 1)  # "1", "2", "3", etc.
        file_index += 1

    # Add pitch-shifted files
    for shift in pitch_shifts:
        for i, audio_path in enumerate(original_files):
            # Create new filename
            stem = audio_path.stem
            new_name = f"{stem}_pitch{shift:+d}.wav"
            new_path = audio_path.parent / new_name

            # Skip if file already exists
            if not new_path.exists():
                # Load, pitch shift, and save
                y, sr = librosa.load(audio_path, sr=None)
                y_shifted = librosa.effects.pitch_shift(y, sr=sr, n_steps=shift)
                sf.write(new_path, y_shifted, sr)
                print(f"Created: {new_path}")
            else:
                print(f"Skipping (already exists): {new_path}")

            augmented_files.append(new_path)
            song_id_mapping[file_index] = f"{i + 1}_pitch{shift:+d}"  # "1_pitch-2", etc.
            file_index += 1

    return augmented_files, song_id_mapping

# Then, augment the LABEL data
def augment_label_data_with_time_adjustment(original_data, pitch_shifts=[-2, -1, 1, 2]):
    augmented_data = original_data.copy()

    for shift in pitch_shifts:
        time_stretch_ratio = 2 ** (-shift / 12.0)

        for entry in original_data:
            new_entry = entry.copy()
            new_entry['SongID'] = f"{entry['SongID']}_pitch{shift:+d}"
            new_entry['start_time'] = entry['start_time'] * time_stretch_ratio
            new_entry['end_time'] = entry['end_time'] * time_stretch_ratio
            augmented_data.append(new_entry)

    return augmented_data

In [None]:
def create_time_stretched_audio_files_with_mapping(original_files, stretch_factors=[0.8, 0.9, 1.1, 1.2]):
    """Create time-stretched audio files (tempo change, no pitch change)"""
    import librosa
    import soundfile as sf

    augmented_files = []
    song_id_mapping = {}
    file_index = 0

    # Add original files first
    for i, audio_path in enumerate(original_files):
        augmented_files.append(audio_path)
        song_id_mapping[file_index] = str(i + 1)  # "1", "2", "3", etc.
        file_index += 1

    # Add time-stretched files
    for factor in stretch_factors:
        for i, audio_path in enumerate(original_files):
            # Create new filename
            stem = audio_path.stem
            new_name = f"{stem}_tempo{factor:.1f}x.wav"
            new_path = audio_path.parent / new_name

            # Skip if file already exists
            if not new_path.exists():
                # Load, time stretch, and save
                y, sr = librosa.load(audio_path, sr=None)
                y_stretched = librosa.effects.time_stretch(y, rate=factor)
                sf.write(new_path, y_stretched, sr)
                print(f"Created: {new_path}")
            else:
                print(f"Skipping (already exists): {new_path}")

            augmented_files.append(new_path)
            song_id_mapping[file_index] = f"{i + 1}_tempo{factor:.1f}x"  # "1_tempo0.8x", etc.
            file_index += 1

    return augmented_files, song_id_mapping

def augment_label_data_with_tempo_adjustment(original_data, stretch_factors=[0.8, 0.9, 1.1, 1.2]):
    """Augment label data for tempo-stretched audio"""
    augmented_data = original_data.copy()

    for factor in stretch_factors:
        # Time stretch ratio is inverse of tempo factor
        # factor=0.8 means slower (80% speed) = longer duration (1/0.8 = 1.25x longer)
        time_ratio = 1.0 / factor

        for entry in original_data:
            new_entry = entry.copy()
            new_entry['SongID'] = f"{entry['SongID']}_tempo{factor:.1f}x"
            new_entry['start_time'] = entry['start_time'] * time_ratio
            new_entry['end_time'] = entry['end_time'] * time_ratio
            augmented_data.append(new_entry)

    return augmented_data

In [None]:
import pandas as pd

# O is Other
# A is low energy
# B is high energy
# C is Breakdown

folder_path = Path('/content/drive/MyDrive/<insert_your_path_here>')

class_names = ['O', 'A', 'B', 'C']

track_1 = ""
track_2 = ""
track_3 = ""
track_4 = ""
track_5 = ""
track_6 = ""
track_7 = ""
track_8 = ""
track_9 = ""
track_10 = ""
track_11 = ""
track_12 = ""
track_13 = ""
track_14 = ""
track_15 = ""
track_16 = ""
track_17 = ""
track_18 = ""
track_19 = ""
track_20 = ""
track_21 = ""
track_22 = ""
track_23 = ""
track_24 = ""
track_25 = ""


full_path_1 = folder_path / track_1
full_path_2 = folder_path / track_2
full_path_3 = folder_path / track_3
full_path_4 = folder_path / track_4
full_path_5 = folder_path / track_5
full_path_6 = folder_path / track_6
full_path_7 = folder_path / track_7
full_path_8 = folder_path / track_8
full_path_9 = folder_path / track_9
full_path_10 = folder_path / track_10
full_path_11 = folder_path / track_11
full_path_12 = folder_path / track_12
full_path_13 = folder_path / track_13
full_path_14 = folder_path / track_14
full_path_15 = folder_path / track_15
full_path_16 = folder_path / track_16
full_path_17 = folder_path / track_17
full_path_18 = folder_path / track_18
full_path_19 = folder_path / track_19
full_path_20 = folder_path / track_20
full_path_21 = folder_path / track_21
full_path_22 = folder_path / track_22
full_path_23 = folder_path / track_23
full_path_24 = folder_path / track_23
full_path_25 = folder_path / track_23

tracks_for_finetuning = [
    full_path_3,
    full_path_2,
    full_path_1,
    full_path_4,
    full_path_5,
    full_path_6,
    full_path_7,
    full_path_8,
    full_path_9,
    full_path_10,
    full_path_11,
    full_path_12,
    full_path_13,
    full_path_14,
    full_path_15,
    full_path_16,
    full_path_17,
    full_path_18,
    full_path_19,
    full_path_20,
    full_path_21,
    full_path_22,
    full_path_23,
    full_path_24,
    full_path_25,
]

data = [
    {'SongID': 1, 'start_time': 0, 'end_time': 57, 'label': 'O'},
    {'SongID': 1, 'start_time': 57, 'end_time': 109, 'label': 'B'},
    {'SongID': 1, 'start_time': 109, 'end_time': 116, 'label': 'C'},
    {'SongID': 1, 'start_time': 116, 'end_time': 131, 'label': 'B'},
    {'SongID': 1, 'start_time': 131, 'end_time': 174, 'label': 'A'},
    {'SongID': 1, 'start_time': 174, 'end_time': 205, 'label': 'C'},
    {'SongID': 1, 'start_time': 205, 'end_time': 241, 'label': 'A'},
    {'SongID': 1, 'start_time': 241, 'end_time': 277, 'label': 'A'},
    {'SongID': 1, 'start_time': 277, 'end_time': 295, 'label': 'O'},

    # ... Add more here

]

# 1. Create pitch-shifted audio files
# augmented_audio_files, song_id_mapping = create_pitch_shifted_audio_files_with_mapping(
#     tracks_for_finetuning, [-2, 5]
# )
# time stretching without pitch shifting
augmented_audio_files, song_id_mapping = create_time_stretched_audio_files_with_mapping(
    tracks_for_finetuning, [0.8, 0.9, 1.1, 1.2]
)
print(augmented_audio_files)

# # 2. Augment label data
# augmented_data = augment_label_data_with_time_adjustment(data, [-2, 5])
# print(augmented_data)

# for time stretching
augmented_data = augment_label_data_with_tempo_adjustment(data, [0.8, 0.9, 1.1, 1.2])


df = pd.DataFrame(augmented_data)
csv_path = '/content/test_labels_1song.csv'
df.to_csv(csv_path, index=False)


# Process CSV and extract features

In [None]:
def create_multiclass_labels_from_csv(csv_path, song_id, meter_grid_frames, class_names, sr=12000,
  hop_length=128):
      """
      Create multiclass labels for a song based on CSV annotations and meter grid.

      Parameters:
      - csv_path: Path to your labeled CSV
      - song_id: ID of the song to process
      - meter_grid_frames: Meter grid frames from feature extraction
      - class_names: List of class names (e.g., ['verse', 'chorus', 'bridge', 'outro', 'intro'])
      - sr: Sample rate
      - hop_length: Hop length

      Returns:
      - aligned_labels: Integer array where each value corresponds to class index
      """
      # Create mapping from string labels to integers
      label_to_int = {label: i for i, label in enumerate(class_names)}
      print(f"Label mapping: {label_to_int}")

      # Load CSV and filter for this song
      df = pd.read_csv(csv_path)
      song_data = df[df['SongID'] == song_id].copy()

      if song_data.empty:
          print(f"No data found for song ID {song_id}")
          return None

      # Get total frames for this song
      total_frames = int(librosa.time_to_frames(
          song_data['end_time'].max(), sr=sr, hop_length=hop_length))

      # Create label sequence for the entire song (default to first class, e.g., 'verse')
      label_sequence = np.zeros(total_frames, dtype=int)

      # Fill in labeled sections
      for _, row in song_data.iterrows():
          start_frame = int(librosa.time_to_frames(row['start_time'], sr=sr, hop_length=hop_length))
          end_frame = int(librosa.time_to_frames(row['end_time'], sr=sr, hop_length=hop_length))

          if row['label'] in label_to_int:
              class_idx = label_to_int[row['label']]
              label_sequence[start_frame:end_frame] = class_idx
          else:
              print(f"Warning: Unknown label '{row['label']}' - using default class 0")

      # Align labels to meter grid
      aligned_labels = []
      for i in range(len(meter_grid_frames) - 1):
          start_meter = int(meter_grid_frames[i])
          end_meter = int(meter_grid_frames[i + 1])

          end_meter = min(end_meter, len(label_sequence))

          if start_meter < len(label_sequence):
              # Use majority vote for each meter
              meter_section = label_sequence[start_meter:end_meter]
              if len(meter_section) > 0:
                  # Get most frequent label in this meter
                  most_common_label = np.bincount(meter_section).argmax()
                  aligned_labels.append(most_common_label)
              else:
                  aligned_labels.append(0)  # Default to class 0
          else:
              aligned_labels.append(0)

      return np.array(aligned_labels)

def process_song_multiclass(audio_path, song_id, csv_path, class_names):
    """Process song for multiclass classification."""
    # Extract features using the AudioFeature class
    audio_features = AudioFeature(audio_path)
    audio_features.extract_features_from_audio()
    meter_grid = audio_features.create_meter_grid()

    # Segment and encode the data
    feature_segments = segment_data_meters(audio_features.combined_features, meter_grid)
    encoded_segments = apply_hierarchical_positional_encoding(feature_segments)


    # Create multiclass labels
    labels = create_multiclass_labels_from_csv(csv_path, song_id, meter_grid, class_names)

    if labels is None:
        return None, None

    return encoded_segments, labels

In [None]:
def prepare_multiclass_data(encoded_segments_list, labels_list, class_names, max_frames=500, max_meters=201):
    """Prepare data for multiclass training with one-hot encoding."""
    def pad_segments(segments, max_frames, max_meters, n_features=15):
        """Pad segments to fixed dimensions."""
        padded = np.zeros((max_meters, max_frames, n_features))
        for i, segment in enumerate(segments):
            if i >= max_meters:
                break
            segment_frames = segment.shape[0]
            if segment_frames <= max_frames:
                padded[i, :segment_frames, :] = segment
            else:
                # Sample frames if segment is too long
                indices = np.linspace(0, segment_frames - 1, max_frames, dtype=int)
                padded[i, :, :] = segment[indices, :]
        return padded

    X = []
    y = []
    num_classes = len(class_names)

    for segments, labels in zip(encoded_segments_list, labels_list):
        padded_song = pad_segments(segments, max_frames, max_meters)
        X.append(padded_song)

        # Truncate labels if longer than max_meters before padding
        truncated_labels = labels[:max_meters]

        # One-hot encode the labels and pad with a class of -1 for padding
        padded_labels = np.full((max_meters, num_classes), -1.0, dtype=np.float32) # Use -1.0 for padding
        if len(truncated_labels) > 0:
            one_hot_labels = tf.keras.utils.to_categorical(truncated_labels, num_classes=num_classes)
            padded_labels[:len(truncated_labels), :] = one_hot_labels

        y.append(padded_labels)

    return np.array(X), np.array(y)

# Setup our X and y

In [None]:
all_segments = []
all_labels = []

for i, audio_file in enumerate(augmented_audio_files):
    song_id = song_id_mapping[i]
    segments, labels = process_song_multiclass(audio_file, song_id, csv_path, class_names)
    if segments is not None:
        all_segments.append(segments)
        all_labels.append(labels)

# Prepare data
# X, y = prepare_multiclass_data(all_segments, all_labels, class_names)

In [None]:
print(f"Number of songs processed: {len(all_segments)}")
print(f"Number of labels processed: {len(all_labels)}")
# print(X, y)

In [None]:
from sklearn.model_selection import train_test_split

# Split data at the song level
train_segments, val_segments, train_labels, val_labels = train_test_split(
    all_segments, all_labels, test_size=0.25, random_state=42
)

print(f"Train songs: {len(train_segments)}, Validation songs: {len(val_segments)}")

# Prepare (pad) the training and validation data separately
print("Padding training data...")
X_train, y_train = prepare_multiclass_data(
    train_segments, train_labels, class_names, max_frames=MAX_FRAMES, max_meters=MAX_METERS
)
print(f"Padded Training shapes: X={X_train.shape}, y={y_train.shape}")

print("Padding validation data...")
X_val, y_val = prepare_multiclass_data(
    val_segments, val_labels, class_names, max_frames=MAX_FRAMES, max_meters=MAX_METERS
)
print(f"Padded Validation shapes: X={X_val.shape}, y={y_val.shape}")

In [None]:
# def log_training_run(model, X_train, y_train, X_val, y_val, epochs, batch_size, comment=""):
#     """
#     Trains the model and logs the training history with a comment to a CSV file.
#     """
#     print(f"Starting training run: {comment}")
#     history = model.fit(
#         X_train, y_train,
#         epochs=epochs,
#         batch_size=batch_size,
#         verbose=1,
#         validation_data=(X_val, y_val)
#     )

#     print("\nHistory keys after training:")
#     print(history.history.keys()) # Print keys to debug

#     # Prepare log entry
#     log_entry = {
#         'timestamp': pd.Timestamp.now(),
#         'comment': comment,
#         'epochs': epochs,
#         'batch_size': batch_size,
#         'train_loss': history.history['loss'][-1],
#         'train_accuracy': history.history['masked_accuracy'][-1], # Use the correct key names
#         'val_loss': history.history['val_loss'][-1],
#         'val_accuracy': history.history['val_masked_accuracy'][-1] # Use the correct key names
#     }

#     # Create DataFrame and save to CSV
#     log_df = pd.DataFrame([log_entry])

#     # Append to CSV, create file with header if it doesn't exist
#     if not os.path.exists(log_file_path):
#         log_df.to_csv(log_file_path, index=False, header=True)
#     else:
#         log_df.to_csv(log_file_path, index=False, header=False, mode='a')


#     print(f"Training run logged to {log_file_path}")

#     return history

import pandas as pd
import os

def log_training_run(model, X_train, y_train, X_val, y_val, epochs, batch_size,
                     comment="", class_weights=None):
    """
    Trains the model and logs the training history with a comment to a CSV file.

    Parameters:
    - class_weights: Dict with class weights {0: weight, 1: weight, ...} or None for no weighting
    """

    # Add class weights info to comment if provided
    if class_weights is not None:
        weights_str = ", ".join([f"{k}:{v}" for k, v in class_weights.items()])
        comment = f"{comment} [weights: {weights_str}]"

    print(f"Starting training run: {comment}")

    # Build fit parameters
    fit_params = {
        'x': X_train,
        'y': y_train,
        'epochs': epochs,
        'batch_size': batch_size,
        'verbose': 1,
        'validation_data': (X_val, y_val)
    }

    # Add class weights if provided
    if class_weights is not None:
        fit_params['class_weight'] = class_weights
        print(f"Using class weights: {class_weights}")

    # Train the model
    history = model.fit(**fit_params)

    print("\nHistory keys after training:")
    print(history.history.keys()) # Print keys to debug

    # Prepare log entry
    log_entry = {
        'timestamp': pd.Timestamp.now(),
        'comment': comment,
        'epochs': epochs,
        'batch_size': batch_size,
        'class_weights': str(class_weights) if class_weights else 'None',
        'train_loss': history.history['loss'][-1],
        'train_accuracy': history.history['masked_accuracy'][-1], # Use the correct key names
        'val_loss': history.history['val_loss'][-1],
        'val_accuracy': history.history['val_masked_accuracy'][-1] # Use the correct key names
    }

    # Create DataFrame and save to CSV
    log_df = pd.DataFrame([log_entry])

    # Append to CSV, create file with header if it doesn't exist
    if not os.path.exists(log_file_path):
        log_df.to_csv(log_file_path, index=False, header=True)
    else:
        log_df.to_csv(log_file_path, index=False, header=False, mode='a')

    print(f"Training run logged to {log_file_path}")

    return history


# Predefined class weight strategies for your use case
CLASS_WEIGHT_STRATEGIES = {
  'boost_A_strategy' = {
      0: 1.5,   # O (Intro/Outro) - boost recall (40.7% is low)
      1: 2.0,   # A (Medium Energy) - boost recall (19.6% is terrible)
      2: 0.7,   # B (High Energy) - REDUCE to improve precision
      3: 1.0    # C (Breakdown) - normal (decent performance)
  }

}

# Usage examples:
print("Available class weight strategies:")
for strategy_name, weights in CLASS_WEIGHT_STRATEGIES.items():
    print(f"\n{strategy_name}:")
    class_names = ['O(Other)', 'A(High)', 'B(Breakdown)', 'C(Low)']
    for class_idx, weight in weights.items():
        print(f"  {class_names[class_idx]}: {weight}")

print("\n" + "="*50)
print("USAGE EXAMPLES:")
print("="*50)


In [None]:
# from datetime import datetime
# date = datetime.now()
log_file_path = f'/content/training_logs_17-8-25.csv'
# log_training_run(new_model, X_train, y_train, X_val, y_val, 15, 1, comment="25 songs labelled + time stretching *4 - 4 classes - finetuned")

# Try more aggressive approach
log_training_run(
    new_model, X_train, y_train, X_val, y_val, 15, 1,
    comment="25 songs + time stretch *4 - aggressive high energy",
    class_weights=CLASS_WEIGHT_STRATEGIES['boost_A_strategy']
)

In [None]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report

def evaluate_per_song_and_overall(model, X_val, y_val, class_names):
    y_pred = model.predict(X_val)

    # Overall metrics (all songs combined)
    overall_true, overall_pred = [], []

    # Per-song metrics
    for song_idx in range(X_val.shape[0]):
        song_true, song_pred = [], []

        for time_idx in range(y_val.shape[1]):
            if np.sum(y_val[song_idx, time_idx]) > 0:
                true_class = np.argmax(y_val[song_idx, time_idx])
                pred_class = np.argmax(y_pred[song_idx, time_idx])

                song_true.append(true_class)
                song_pred.append(pred_class)
                overall_true.append(true_class)
                overall_pred.append(pred_class)

        print(f"Song {song_idx + 1} accuracy: {accuracy_score(song_true, song_pred):.3f}")

    # Overall report
    return classification_report(overall_true, overall_pred, target_names=class_names, output_dict=True)

# look at percision and recall for our classes

In [None]:
# Then add this evaluation step
print("\nPer-class performance analysis:")
class_report = evaluate_per_song_and_overall(new_model, X_val, y_val, class_names)

for class_name in class_names:
    metrics = class_report[class_name]
    print(f"{class_name}: Precision={metrics['precision']:.3f}, Recall={metrics['recall']:.3f}, F1={metrics['f1-score']:.3f}")

# percision

# Make new prediction

In [None]:
def predict_song(audio_path, model, class_names):
    """
    Predict arrangement sections for a single song using new_model.

    Parameters:
    - audio_path: str, path to the audio file
    - model: tf.keras.Model, trained model (e.g. new_model)
    - class_names: list of str, class labels used in training

    Returns:
    - predicted_classes: NumPy array of predicted class indices per meter
    - meter_grid: NumPy array of meter grid frames
    """
    # Preprocess the song (same as training pipeline)
    padded_song, audio_features = process_audio(audio_path)
    if padded_song is None:
        print("❌ Failed to process song")
        return None, None

    # Run the model
    preds = model.predict(padded_song).squeeze(axis=0)  # shape: (meters, classes)

    # Convert to class indices
    predicted_classes = np.argmax(preds, axis=-1)

    return predicted_classes, audio_features.meter_grid

In [None]:
class_names = [ "O", "A", "B", "C"]  # same as you used in new_model

new_path = Path('/content/drive/MyDrive/<insert your path here>')

new_song = ""

full_path_wip = new_path / new_song


pred_labels, meter_grid = predict_song(full_path_wip, new_model, class_names)

print(pred_labels)

# Visualize Prediction

In [None]:
"""
Visualization utilities for multiclass energy level detection.
"""

import os
import numpy as np
import librosa
from matplotlib import pyplot as plt


def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
    """Draw meter grid lines on the plot."""
    for time in meter_grid_times:
        ax.axvline(x=time, color='grey', linestyle='--', linewidth=1, alpha=0.6)


def plot_multiclass_predictions(audio_features, multiclass_predictions, class_names,
                                title=None, save_path=None):
    """
    Plot the audio waveform and overlay the predicted energy level classifications.

    Parameters:
    - audio_features: AudioFeature object containing audio data
    - multiclass_predictions: Array of class predictions (0=other, 1=high, 2=low, 3=breakdown)
    - class_names: List of class names (e.g., ['other', 'high', 'low', 'breakdown'])
    - title: Optional title for the plot (default: based on audio filename)
    - save_path: Optional path to save the plot image (default: don't save)

    Returns:
    - fig: The matplotlib figure object
    """

    plt.close('all')  # Add this line
    meter_grid_times = librosa.frames_to_time(
        audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
    fig, ax = plt.subplots(figsize=(15, 4), dpi=96)
    ax.clear()

    meter_grid_times = librosa.frames_to_time(
        audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
    fig, ax = plt.subplots(figsize=(15, 4), dpi=96)

    # Display waveform components
    librosa.display.waveshow(audio_features.y_harm, sr=audio_features.sr,
                              alpha=0.8, ax=ax, color='deepskyblue', label='Harmonic')
    librosa.display.waveshow(audio_features.y_perc, sr=audio_features.sr,
                              alpha=0.7, ax=ax, color='plum', label='Percussive')
    plot_meter_lines(ax, meter_grid_times)

    # Define colors for each class
    class_colors = {
        0: 'lightgray',    # other
        1: 'red',          # high energy
        2: 'blue',         # low energy
        3: 'orange'        # breakdown
    }

    # Track which classes we've added to legend
    legend_added = {class_idx: False for class_idx in range(len(class_names))}

    # Highlight sections by energy level
    for i, prediction in enumerate(multiclass_predictions):
        if i < len(meter_grid_times) - 1:
            start_time = meter_grid_times[i]
            end_time = meter_grid_times[i + 1]

            color = class_colors.get(prediction, 'black')
            class_name = class_names[prediction] if prediction < len(class_names) else 'unknown'

            # Add to legend only once per class
            label = class_name.title() if not legend_added[prediction] else None
            legend_added[prediction] = True

            ax.axvspan(start_time, end_time, color=color, alpha=0.4, label=label)

    # Configure plot appearance
    ax.set_xlim([0, len(audio_features.y) / audio_features.sr])
    ax.set_ylabel('Amplitude')

    # Set plot title
    if title:
        ax.set_title(title)
    else:
        audio_file_name = os.path.basename(audio_features.audio_path)
        ax.set_title(f'Energy Level Predictions for {os.path.splitext(audio_file_name)[0]}')

    # Add legend
    ax.legend(loc='upper right')

    # Set time-based x-axis labels
    duration = len(audio_features.y) / audio_features.sr
    xticks = np.arange(0, duration, 30)  # Every 30 seconds for electronic music
    xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks]
    ax.set_xticks(xticks)
    ax.set_xticklabels(xlabels)
    ax.set_xlabel('Time (mm:ss)')

    plt.tight_layout()

    # Save if path is provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()
    return fig


def plot_energy_timeline(audio_features, multiclass_predictions, class_names,
                        title=None, save_path=None):
    """
    Plot energy levels as a colored timeline/bar chart.

    Parameters:
    - audio_features: AudioFeature object containing audio data
    - multiclass_predictions: Array of class predictions
    - class_names: List of class names
    - title: Optional title for the plot
    - save_path: Optional path to save the plot image

    Returns:
    - fig: The matplotlib figure object
    """
    duration = len(audio_features.y) / audio_features.sr
    meter_grid_times = librosa.frames_to_time(
        audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)

    fig, ax = plt.subplots(figsize=(15, 3), dpi=96)

    # Define colors and y-positions for each class
    class_colors = {
        0: 'lightgray',    # other
        1: 'red',          # high energy
        2: 'blue',         # low energy
        3: 'orange'        # breakdown
    }

    class_y_positions = {
        0: 0,    # other (bottom)
        1: 3,    # high energy (top)
        2: 1,    # low energy
        3: 2     # breakdown (middle-high)
    }

    # Plot bars for each meter
    for i, prediction in enumerate(multiclass_predictions):
        if i < len(meter_grid_times) - 1:
            start_time = meter_grid_times[i]
            end_time = meter_grid_times[i + 1]
            width = end_time - start_time

            color = class_colors.get(prediction, 'black')
            y_pos = class_y_positions.get(prediction, 0)

            ax.barh(y_pos, width, left=start_time, height=0.8,
                    color=color, alpha=0.8, edgecolor='white', linewidth=0.5)

    # Configure plot appearance
    ax.set_xlim([0, duration])
    ax.set_ylim([-0.5, 3.5])
    ax.set_yticks(list(class_y_positions.values()))
    ax.set_yticklabels([class_names[i].title() for i in sorted(class_y_positions.keys())])

    # Set plot title
    if title:
        ax.set_title(title)
    else:
        audio_file_name = os.path.basename(audio_features.audio_path)
        ax.set_title(f'Energy Level Timeline for {os.path.splitext(audio_file_name)[0]}')

    # Set time-based x-axis labels
    xticks = np.arange(0, duration, 30)
    xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks]
    ax.set_xticks(xticks)
    ax.set_xticklabels(xlabels)
    ax.set_xlabel('Time (mm:ss)')
    ax.set_ylabel('Energy Level')

    # Add grid
    ax.grid(True, axis='x', alpha=0.3)

    plt.tight_layout()

    # Save if path is provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()
    return fig


def plot_energy_distribution(multiclass_predictions, class_names, title=None):
    """
    Plot a pie chart showing the distribution of energy levels.

    Parameters:
    - multiclass_predictions: Array of class predictions
    - class_names: List of class names
    - title: Optional title for the plot

    Returns:
    - fig: The matplotlib figure object
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Count occurrences of each class
    unique, counts = np.unique(multiclass_predictions, return_counts=True)

    # Colors matching the timeline
    colors = ['lightgray', 'red', 'blue', 'orange']

    # Pie chart
    ax1.pie(counts, labels=[class_names[i].title() for i in unique],
            colors=[colors[i] for i in unique], autopct='%1.1f%%', startangle=90)
    ax1.set_title('Energy Level Distribution')

    # Bar chart
    ax2.bar([class_names[i].title() for i in unique], counts,
            color=[colors[i] for i in unique], alpha=0.8)
    ax2.set_title('Energy Level Counts')
    ax2.set_ylabel('Number of Meters')
    ax2.tick_params(axis='x', rotation=45)

    if title:
        fig.suptitle(title)

    plt.tight_layout()
    plt.show()
    return fig


# Usage example:
def visualize_predictions_complete(audio_features, predicted_classes, class_names):
    """
    Create a complete visualization suite for energy level predictions.
    """
    # Only use actual meters (not padded)
    actual_meters = min(len(audio_features.meter_grid) - 1, len(predicted_classes))
    predicted_classes = predicted_classes[:actual_meters]

    print(f"Creating visualizations for {actual_meters} meters...")

    # Plot 1: Waveform with energy level overlay
    plot_multiclass_predictions(audio_features, predicted_classes, class_names)

    # Plot 2: Energy timeline
    plot_energy_timeline(audio_features, predicted_classes, class_names)

    # Plot 3: Distribution charts
    plot_energy_distribution(predicted_classes, class_names)

    return predicted_classes

In [None]:
_, audio_features = process_audio(full_path_wip)

visualize_predictions_complete(audio_features, pred_labels, class_names)

In [None]:

new_song_2 = ""

full_path_wip_2 = new_path / new_song_2

In [None]:

new_song_2 = ""

full_path_wip_2 = new_path / new_song_2

pred_labels, meter_grid = predict_song(full_path_wip_2, new_model, class_names)

_, audio_features = process_audio(full_path_wip_2)

visualize_predictions_complete(audio_features, pred_labels, class_names)

In [None]:

new_song_3 = ""

full_path_wip_3 = new_path / new_song_3

print(full_path_wip_3)

pred_labels, meter_grid = predict_song(str(full_path_wip_3), new_model, class_names)

_, audio_features = process_audio(full_path_wip_3)

visualize_predictions_complete(audio_features, pred_labels, class_names)

In [None]:
play_audio(full_path_wip_3)

# Pickle dataset - boilerplate that needs adjusting

In [None]:
# import tensorflow as tf
# import numpy as np
# import glob
# import os

# # --- Placeholder for your custom audio processing functions ---
# # It's assumed these functions are defined elsewhere and work correctly.
# # For this script to be runnable, you must provide their actual implementations.

# SR = 22050  # Example Sample Rate
# HOP_LENGTH = 512 # Example Hop Length

# def strip_silence(audio_path):
#     """Placeholder: Implement your silence trimming logic here."""
#     # print(f"Stripping silence from {audio_path}...")
#     pass

# class AudioFeature:
#     """Placeholder: Implement your AudioFeature class."""
#     def __init__(self, audio_path, sr, hop_length):
#         self.path = audio_path
#         self.sr = sr
#         self.hop_length = hop_length
#         self.combined_features = np.random.rand(128, 500).astype(np.float32) # Example shape

#     def extract_features_from_audio(self):
#         # print("Extracting features...")
#         pass

#     def create_meter_grid(self):
#         # print("Creating meter grid...")
#         return np.random.rand(500).astype(np.float32) # Example shape

# def segment_data_meters(features, grid):
#     """Placeholder: Implement your segmentation logic."""
#     # print("Segmenting data...")
#     return np.random.rand(10, 128, 50).astype(np.float32) # Example shape

# def apply_hierarchical_positional_encoding(segments):
#     """Placeholder: Implement your encoding logic."""
#     # print("Applying encoding...")
#     return segments + np.random.rand(*segments.shape).astype(np.float32)

# def pad_song(encoded_segments):
#     """Placeholder: Implement your padding logic."""
#     # print("Padding song...")
#     # Example: pad to a fixed number of segments, e.g., 12
#     padded = np.zeros((12, 128, 50), dtype=np.float32)
#     num_segments = min(12, encoded_segments.shape[0])
#     padded[:num_segments, :, :] = encoded_segments[:num_segments, :, :]
#     return padded

# # --- End of Placeholders ---


# def process_audio(audio_path, trim_silence=True, sr=SR, hop_length=HOP_LENGTH):
#     """
#     Processes a single audio file using your custom pipeline.
#     This function now takes a byte string path from tf.py_function.
#     """
#     # Decode the byte string path to a regular Python string
#     audio_path = audio_path.decode('utf-8')

#     try:
#         # 1. Optionally strip silence
#         if trim_silence:
#             strip_silence(audio_path)

#         # 2. Extract audio features
#         audio_features = AudioFeature(audio_path, sr=sr, hop_length=hop_length)
#         audio_features.extract_features_from_audio()
#         meter_grid = audio_features.create_meter_grid()

#         # 3. Segment and pad the data
#         feature_segments = segment_data_meters(audio_features.combined_features, meter_grid)
#         encoded_segments = apply_hierarchical_positional_encoding(feature_segments)
#         padded_song = pad_song(encoded_segments)

#         # Note: The original returned np.expand_dims(padded_song, axis=0)
#         # We remove the batch dimension here because the tf.data pipeline will add it later.
#         return padded_song
#     except Exception as e:
#         print(f"Error processing audio at path {audio_path}: {e}")
#         # Return a zero-array or handle error appropriately for your model
#         # The shape must match the expected output shape.
#         return np.zeros((12, 128, 50), dtype=np.float32)


# def load_and_process_item(file_path):
#     """
#     A wrapper function to load audio and extract the label.
#     This function will be mapped over the dataset of file paths.
#     """
#     # 1. Extract Label from filename
#     # The file_path is a tf.string tensor, so we use TensorFlow string operations
#     parts = tf.strings.split(file_path, os.path.sep)
#     filename = parts[-1]
#     label_str = tf.strings.split(filename, "_")[-1]
#     label_str = tf.strings.split(label_str, ".")[0] # Assumes .wav or similar
#     label = tf.strings.to_number(label_str, out_type=tf.int32)

#     # 2. Process audio using your custom Python function
#     # We use tf.py_function to wrap the Python code.
#     # We must define the return type (Tout) for TensorFlow to build the graph.
#     processed_audio = tf.py_function(
#         func=process_audio,
#         inp=[file_path],
#         Tout=tf.float32
#     )

#     # Set the shape of the output tensor, which is required after tf.py_function
#     # Update this shape to match the exact output of your `pad_song` function.
#     processed_audio.set_shape([12, 128, 50])

#     return processed_audio, label


# def create_preprocessed_dataset(audio_dir, output_path, file_extension="*.wav"):
#     """
#     Creates and saves a preprocessed TensorFlow dataset.

#     Args:
#         audio_dir (str): Path to the folder containing audio files.
#         output_path (str): Path to save the processed tf.data.Dataset.
#         file_extension (str): The file extension pattern to search for (e.g., "*.wav", "*.mp3").
#     """
#     print(f"🔍 Searching for audio files in: {audio_dir}")
#     file_paths = glob.glob(os.path.join(audio_dir, file_extension))

#     if not file_paths:
#         print("⚠️ No audio files found. Please check the `audio_dir` and `file_extension`.")
#         return

#     print(f"✅ Found {len(file_paths)} files. Creating dataset...")

#     # 1. Create a dataset from the list of file paths
#     path_ds = tf.data.Dataset.from_tensor_slices(file_paths)

#     # 2. Use .map() to apply the processing function to each file
#     # `num_parallel_calls=tf.data.AUTOTUNE` processes multiple files in parallel for speed.
#     audio_ds = path_ds.map(load_and_process_item, num_parallel_calls=tf.data.AUTOTUNE)

#     # 3. Save the processed dataset to disk
#     print(f"💾 Saving dataset to: {output_path}")
#     # The new, stable API for saving datasets
#     audio_ds.save(output_path)
#     print("✨ Dataset creation complete!")


# if __name__ == "__main__":
#     # --- Configuration ---
#     # Define the path to your folder with raw audio files
#     SAMPLES_FOLDER = full_path_wip_2
#     # Define where you want to save the processed dataset
#     OUTPUT_DATASET_PATH = '/content/drive/MyDrive/DSR-AI-MENTOR'

#     # Create the dataset
#     create_preprocessed_dataset(SAMPLES_FOLDER, OUTPUT_DATASET_PATH, file_extension="*.wav")

#     # --- Example of how to load and use the dataset ---
#     print("\n--- Verifying the saved dataset ---")
#     if os.path.exists(OUTPUT_DATASET_PATH):
#         loaded_dataset = tf.data.Dataset.load(OUTPUT_DATASET_PATH)

#         # Take one sample to inspect its shape and type
#         for audio_data, label in loaded_dataset.take(1):
#             print(f"Sample audio data shape: {audio_data.shape}")
#             print(f"Sample audio data type: {audio_data.dtype}")
#             print(f"Sample label: {label.numpy()}")
#             print(f"Sample label type: {label.dtype}")

#         # You can now use this loaded_dataset for training
#         # For example:
#         # loaded_dataset = loaded_dataset.shuffle(1024).batch(32).prefetch(tf.data.AUTOTUNE)
#         # model.fit(loaded_dataset, epochs=10)
#     else:
#         print(f"Could not find the saved dataset at {OUTPUT_DATASET_PATH}")


