# 1) Install and Verify Dependencies

This section installs required packages and checks optional system tools (like `ffmpeg`) for audio decoding backends used by `pydub`.

- Core: librosa, tensorflow, scikit-learn, matplotlib, seaborn
- Utilities: pydub (fallback loader), resampy (resampling), soundfile
- Optional: ffmpeg (system binary) for broader audio format support with pydub

In [None]:
# (Optional) Install packages if missing. Uncomment if running in a fresh environment.
# %pip install -q librosa tensorflow scikit-learn matplotlib seaborn pydub resampy soundfile

import shutil
import sys
print("Python:", sys.version)
print("FFmpeg found:", shutil.which("ffmpeg") is not None)

In [None]:
# 2) Imports, GPU Check, and Deterministic Seeds
import os
import random
import numpy as np
import pandas as pd
import librosa
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# GPU memory growth (optional)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"GPUs available: {len(gpus)}; memory growth enabled")
    except Exception as e:
        print("Could not set memory growth:", e)
else:
    print("No GPU detected; running on CPU")

In [None]:
# 3) Configure Dataset Paths and Target Genres
DATASET_PATH = "data/Data/genres_original"
GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop',
          'jazz', 'metal', 'pop', 'reggae', 'rock']
SAMPLE_RATE = 22050
DURATION = 30  # seconds

# Validate dataset path
assert os.path.isdir(DATASET_PATH), f"Dataset path not found: {DATASET_PATH}"

# Validate genre folders
missing = [g for g in GENRES if not os.path.isdir(
    os.path.join(DATASET_PATH, g))]
if missing:
    print("Warning: missing genre folders:", missing)
else:
    print("All target genre folders present")

In [None]:
# 4) Explore Dataset Files
from collections import defaultdict

per_genre_paths = defaultdict(list)

for genre in GENRES:
    gpath = os.path.join(DATASET_PATH, genre)
    if not os.path.isdir(gpath):
        continue
    files = sorted([f for f in os.listdir(gpath) if f.lower().endswith(
        (".wav", ".mp3", ".flac", ".ogg"))])
    for f in files[:5]:  # show a few examples
        per_genre_paths[genre].append(os.path.join(gpath, f))

for genre, items in per_genre_paths.items():
    print(f"{genre} (showing {len(items)} files):")
    for p in items:
        print("  ", p)
    if not items:
        print("   [No audio files found]")

In [None]:
# 5) Robust Audio Loading and Mel-Spectrogram Extraction
import warnings


def extract_features(file_path, sr=SAMPLE_RATE, duration=DURATION):
    """
    Load audio with multiple fallbacks, pad/trim to fixed duration,
    and compute dB-scaled Mel-spectrogram (n_mels=128).
    Returns: np.ndarray shape (n_mels, time)
    """
    import numpy as _np
    y = None
    try:
        y, _sr = librosa.load(file_path, sr=sr, duration=duration)
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        try:
            warnings.filterwarnings('ignore')
            y, _sr = librosa.load(
                file_path, sr=sr, duration=duration, res_type='kaiser_fast')
        except Exception as e2:
            print(f"Second attempt failed: {e2}")
            try:
                from pydub import AudioSegment
                audio = AudioSegment.from_file(file_path)
                samples = _np.array(audio.get_array_of_samples())
                if audio.channels == 2:
                    samples = samples.reshape((-1, 2)).mean(axis=1)
                y = samples.astype(_np.float32) / _np.iinfo(samples.dtype).max
                if sr != audio.frame_rate:
                    import resampy
                    y = resampy.resample(y, audio.frame_rate, sr)
            except Exception as e3:
                print(f"All loading methods failed for {file_path}: {e3}")
                y = _np.zeros(sr * duration, dtype=_np.float32)

    target_len = sr * duration
    if len(y) < target_len:
        y = _np.pad(y, (0, target_len - len(y)))
    elif len(y) > target_len:
        y = y[:target_len]

    mel = librosa.feature.melspectrogram(
        y=y, sr=sr, n_fft=2048, hop_length=512, n_mels=128)
    mel_db = librosa.power_to_db(mel, ref=_np.max)
    return mel_db

In [None]:
# 6) Dataset Assembly and Label Encoding
from typing import List, Tuple, Dict


def prepare_dataset(dataset_path: str, genres: List[str], min_samples_per_class: int | None = None):
    """
    Iterate over genre folders, extract Mel-spectrograms, and encode labels.
    Returns: X (N, 128, T), y (N,), label_encoder, samples_per_genre (dict)
    """
    features = []
    labels = []
    samples_per_genre: Dict[str, int] = {g: 0 for g in genres}

    for genre in genres:
        genre_path = os.path.join(dataset_path, genre)
        if not os.path.isdir(genre_path):
            print(f"Skipping missing genre: {genre}")
            continue
        files = [f for f in os.listdir(genre_path) if f.lower().endswith(
            (".wav", ".mp3", ".flac", ".ogg"))]
        for fname in files:
            fpath = os.path.join(genre_path, fname)
            try:
                mel = extract_features(fpath)
                features.append(mel)
                labels.append(genre)
                samples_per_genre[genre] += 1
                if min_samples_per_class and samples_per_genre[genre] >= min_samples_per_class:
                    break
            except Exception as e:
                print(f"Error processing {fpath}: {e}")

    for genre, count in samples_per_genre.items():
        print(f"{genre}: {count} samples")

    X = np.array(features, dtype=np.float32)
    le = LabelEncoder()
    y = le.fit_transform(labels)
    y = np.array(y, dtype=np.int64)
    return X, y, le, samples_per_genre

In [None]:
# 7) Data Quality Check and Visualizations

def plot_mel_spectrogram(mel_spectrogram, title='Mel Spectrogram'):
    plt.figure(figsize=(10, 4))
    plt.imshow(mel_spectrogram, aspect='auto', origin='lower', cmap='viridis')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.xlabel('Time')
    plt.ylabel('Mel Frequency')
    plt.tight_layout()
    plt.show()


def check_data_quality(X: np.ndarray, threshold: float = -60.0):
    bad_indices = []
    for i, spec in enumerate(X):
        if np.mean(spec) < threshold:
            bad_indices.append(i)
            print(
                f"Warning: Sample {i} may have low quality (mean dB: {np.mean(spec):.2f})")
    if bad_indices:
        print(
            f"Found {len(bad_indices)} potentially problematic samples out of {len(X)}")
    else:
        print("All samples passed quality check.")
    return bad_indices


def check_data_balance(y: np.ndarray, label_encoder: LabelEncoder):
    class_counts = np.bincount(y)
    class_names = label_encoder.classes_

    plt.figure(figsize=(10, 6))
    plt.bar(class_names, class_counts)
    plt.title('Class Distribution')
    plt.xlabel('Genre')
    plt.ylabel('Number of Samples')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

    for name, count in zip(class_names, class_counts):
        print(f"{name}: {count} samples")

    if class_counts.min() > 0 and class_counts.max() / class_counts.min() > 1.5:
        print("\nWarning: Data imbalance detected. Consider resampling or using class weights.")

    return class_counts, class_names

In [None]:
# 8) Train/Validation/Test Split and Tensor Reshape

def create_train_test_data(X: np.ndarray, y: np.ndarray, test_size: float = 0.2, val_size: float = 0.2):
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, stratify=y, random_state=SEED
    )
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train, test_size=val_size, stratify=y_train, random_state=SEED
    )

    # Shape to (N, freq_bins, time_frames, channels)
    def _add_channel(a):
        return a[..., np.newaxis]

    X_train = _add_channel(X_train)
    X_val = _add_channel(X_val)
    X_test = _add_channel(X_test)
    return X_train, X_val, X_test, y_train, y_val, y_test

In [None]:
# 9) CRNN Model Definition
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dropout, Reshape, Dense, LSTM, Bidirectional
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.models import Model


def create_crnn_model(input_shape, num_classes, learning_rate=8e-4):
    inputs = Input(shape=input_shape)

    x = Conv2D(32, (3, 3), padding='same', activation='relu',
               kernel_regularizer=tf.keras.regularizers.l2(0.0015))(inputs)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2))(x)
    x = Dropout(0.25)(x)

    x = Conv2D(64, (3, 3), padding='same', activation='relu',
               kernel_regularizer=tf.keras.regularizers.l2(0.0015))(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2))(x)
    x = Dropout(0.25)(x)

    x = Conv2D(128, (3, 3), padding='same', activation='relu',
               kernel_regularizer=tf.keras.regularizers.l2(0.0015))(x)
    x = BatchNormalization()(x)
    x = Conv2D(128, (3, 3), padding='same', activation='relu',
               kernel_regularizer=tf.keras.regularizers.l2(0.0015))(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2))(x)
    x = Dropout(0.3)(x)

    x = Conv2D(256, (3, 3), padding='same', activation='relu',
               kernel_regularizer=tf.keras.regularizers.l2(0.002))(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((4, 4))(x)
    x = Dropout(0.4)(x)

    # Convert to (batch, time, features) for RNN using a safe reshape
    x = tf.keras.layers.Permute((2, 1, 3))(x)  # (batch, time, freq, ch)
    x = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten())(x)

    x = Bidirectional(LSTM(128, return_sequences=True,
                           recurrent_dropout=0.1,
                           recurrent_regularizer=tf.keras.regularizers.l2(0.002)))(x)
    x = Dropout(0.4)(x)

    x = Bidirectional(LSTM(128, return_sequences=False,
                           recurrent_dropout=0.1,
                           recurrent_regularizer=tf.keras.regularizers.l2(0.002)))(x)
    x = Dropout(0.4)(x)

    x = Dense(256, activation='relu',
              kernel_regularizer=tf.keras.regularizers.l2(0.002))(x)
    x = Dropout(0.5)(x)

    x = Dense(128, activation='relu',
              kernel_regularizer=tf.keras.regularizers.l2(0.002))(x)
    x = Dropout(0.5)(x)

    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs, outputs)
    model.compile(optimizer=Adam(learning_rate=learning_rate),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

In [None]:
# 10) Training Utilities: Augmentation, Callbacks, and LR Schedule

def train_crnn_model(model, X_train, y_train, X_val=None, y_val=None, X_test=None,
                     batch_size=16, epochs=100,
                     model_path='crnn_music_genre_model', class_weights=None,
                     validation_split=None):
    datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        width_shift_range=0.15,
        height_shift_range=0.15,
        zoom_range=0.15,
        rotation_range=5,
        brightness_range=[0.8, 1.2],
        fill_mode='constant',
        horizontal_flip=False
    )

    callbacks = [
        EarlyStopping(monitor='val_loss', patience=25,
                      restore_best_weights=True, verbose=1),
        ModelCheckpoint(f'{model_path}_best.keras', monitor='val_accuracy',
                        save_best_only=True, mode='max', verbose=1),
        ReduceLROnPlateau(monitor='val_loss', factor=0.15,
                          patience=5, min_lr=1e-7, verbose=1),
        tf.keras.callbacks.TensorBoard(
            log_dir=f'./logs/{model_path}', histogram_freq=1, update_freq='epoch')
    ]

    # Trim spectrogram time dimension to a consistent width
    max_T = 259
    X_train = X_train[:, :, :max_T, :]
    if X_val is not None:
        X_val = X_val[:, :, :max_T, :]
    if X_test is not None:
        X_test = X_test[:, :, :max_T, :]

    if X_val is not None and y_val is not None:
        history = model.fit(
            datagen.flow(X_train, y_train, batch_size=batch_size),
            steps_per_epoch=max(1, len(X_train) // batch_size),
            epochs=epochs,
            validation_data=(X_val, y_val),
            callbacks=callbacks,
            class_weight=class_weights,
            verbose=1
        )
    else:
        # Use Keras internal validation split from training data
        if not validation_split:
            validation_split = 0.2
        history = model.fit(
            datagen.flow(X_train, y_train, batch_size=batch_size),
            steps_per_epoch=max(1, len(X_train) // batch_size),
            epochs=epochs,
            validation_split=validation_split,
            callbacks=callbacks,
            class_weight=class_weights,
            verbose=1
        )

    model.save(f'{model_path}_final.keras')
    np.save(f'{model_path}_history.npy', history.history)

    # Plot training history
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history.history.get('accuracy', []))
    plt.plot(history.history.get('val_accuracy', []))
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend(['Train', 'Validation'])

    plt.subplot(1, 2, 2)
    plt.plot(history.history.get('loss', []))
    plt.plot(history.history.get('val_loss', []))
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(['Train', 'Validation'])
    plt.tight_layout()
    plt.savefig(f'{model_path}_training_history.png')
    plt.show()

    return history, X_train, X_val, X_test

In [None]:
# 11) Class Imbalance Handling (Weights and Oversampling)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.utils import resample


def compute_class_weights(y: np.ndarray):
    classes = np.unique(y)
    weights = compute_class_weight(
        class_weight='balanced', classes=classes, y=y)
    return {int(c): float(w) for c, w in zip(classes, weights)}


def oversample_minority_classes(X: np.ndarray, y: np.ndarray, target_class_count: int | None = None):
    if target_class_count is None:
        counts = np.bincount(y)
        target_class_count = int(counts.max())

    X_resampled = []
    y_resampled = []
    for class_id in np.unique(y):
        X_class = X[y == class_id]
        y_class = y[y == class_id]
        if len(X_class) < target_class_count:
            X_r, y_r = resample(
                X_class, y_class, replace=True,
                n_samples=target_class_count, random_state=SEED
            )
        else:
            X_r, y_r = X_class, y_class
        X_resampled.append(X_r)
        y_resampled.append(y_r)

    X_out = np.concatenate(X_resampled, axis=0)
    y_out = np.concatenate(y_resampled, axis=0)
    return X_out, y_out

In [None]:
# 12) Model Training and History Plots
# Prepare data
X, y, label_encoder, samples_per_genre = prepare_dataset(DATASET_PATH, GENRES)
print("Features shape:", X.shape, "Labels shape:", y.shape)

# Quality & balance
_ = check_data_quality(X, threshold=-60)
_ = check_data_balance(y, label_encoder)

# Split
X_train, X_val, X_test, y_train, y_val, y_test = create_train_test_data(X, y)
print("Train:", X_train.shape, "Val:", X_val.shape, "Test:", X_test.shape)

# Handle imbalance
class_weights = compute_class_weights(y)
print("Class weights:", class_weights)

X_train_bal, y_train_bal = oversample_minority_classes(X_train, y_train)
print("Balanced train shape:", X_train_bal.shape)

# Build model
input_shape = (X_train_bal.shape[1], X_train_bal.shape[2], 1)
num_classes = len(np.unique(y))
model = create_crnn_model(input_shape, num_classes, learning_rate=8e-4)
model.summary()

# Train
history, X_train_proc, X_val_proc, X_test_proc = train_crnn_model(
    model,
    X_train_bal, y_train_bal,
    X_val, y_val,
    X_test,
    batch_size=16,
    epochs=150,
    model_path='crnn_music_genre_model',
    class_weights=class_weights
)

In [None]:
# 13) Evaluation: Metrics and Confusion Matrix
from sklearn.metrics import classification_report, confusion_matrix


def evaluate_model(model, X_test, y_test, label_encoder):
    import seaborn as sns
    y_pred = model.predict(X_test)
    y_pred_cls = np.argmax(y_pred, axis=1)

    class_names = label_encoder.classes_
    print("Classification Report:")
    print(classification_report(y_test, y_pred_cls, target_names=class_names))

    cm = confusion_matrix(y_test, y_pred_cls)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45)
    plt.yticks(rotation=45)
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.show()

    mis = np.where(y_pred_cls != y_test)[0]
    if len(mis) > 0:
        print(f"\nMisclassified examples ({min(5, len(mis))} shown):")
        for i in range(min(5, len(mis))):
            idx = mis[i]
            true_label = label_encoder.inverse_transform([y_test[idx]])[0]
            pred_label = label_encoder.inverse_transform([y_pred_cls[idx]])[0]
            conf = float(np.max(y_pred[idx]) * 100)
            print(
                f"Index {idx}: predicted {pred_label} ({conf:.2f}%) but true {true_label}")


# Run evaluation
_ = evaluate_model(model, X_test_proc, y_test, label_encoder)

In [None]:
# 14) Save, Load, and Run Inference on New Audio
from tensorflow.keras.models import load_model

MODEL_BEST = 'crnn_music_genre_model_best.keras'

# Example: load best model (if exists) and run inference on a file
if os.path.isfile(MODEL_BEST):
    best_model = load_model(MODEL_BEST)
else:
    print(f"Best model not found at {MODEL_BEST}; using current trained model")
    best_model = model


def preprocess_audio_for_inference(filepath: str, max_T: int = 259):
    mel = extract_features(filepath)
    mel = mel[:, :max_T]
    mel = mel[np.newaxis, ..., np.newaxis]  # (1, 128, T, 1)
    return mel


# Provide your own audio path from the repo, e.g., an MP3 in audio/
example_audio = 'audio/OMG - NewJeans.mp3'
if os.path.isfile(example_audio):
    sample = preprocess_audio_for_inference(example_audio)
    probs = best_model.predict(sample)[0]
    pred_idx = int(np.argmax(probs))
    pred_label = label_encoder.inverse_transform([pred_idx])[0]
    print(
        f"Predicted: {pred_label} (confidence {float(np.max(probs))*100:.2f}%)")
else:
    print(f"Example audio not found at {example_audio}; skip demo.")