In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import wave
from scipy.io import wavfile
from scipy.signal import spectrogram
from pydub import AudioSegment
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras.layers import SeparableConv2D, MaxPooling2D, Flatten, Dense, Dropout, Input
from keras.callbacks import EarlyStopping
from keras.utils import to_categorical
import tensorflow as tf

tf.config.optimizer.set_jit(True)  

GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop', 
          'jazz', 'metal', 'pop', 'reggae', 'rock']
SAMPLE_RATE = 22050
FIXED_TIME_FRAMES = 200  

def is_valid_wav(file_path):
    try:
        with wave.open(file_path, 'rb') as wav_file:
            return True  
    except wave.Error:
        return False  

def convert_to_standard_wav(file_path):
    if not is_valid_wav(file_path):
        print(f"Skipping {file_path} (corrupt or invalid WAV format)")
        return None

    try:
        audio = AudioSegment.from_file(file_path)
        audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1).set_sample_width(2)  
        new_path = file_path.replace(".wav", "_fixed.wav")
        audio.export(new_path, format="wav")
        return new_path
    except Exception as e:
        print(f"Skipping {file_path} due to conversion error: {e}")
        return None

def extract_features(file_path):
    try:
        sr, signal = wavfile.read(file_path) 
        if signal.ndim > 1:
            signal = np.mean(signal, axis=1)  

        freqs, times, Sxx = spectrogram(signal, sr, nperseg=512)
        Sxx = np.log1p(Sxx) 

        if Sxx.shape[1] < FIXED_TIME_FRAMES:  
            pad_width = FIXED_TIME_FRAMES - Sxx.shape[1]
            Sxx = np.pad(Sxx, ((0, 0), (0, pad_width)), mode='constant')
        else:  
            Sxx = Sxx[:, :FIXED_TIME_FRAMES]

        return Sxx  
    except Exception as e:
        print(f"Skipping {file_path} due to error: {e}")
        return None

def load_data(data_dir):
    features, labels = [], []

    for genre in GENRES:
        genre_dir = os.path.join(data_dir, genre)
        for file in os.listdir(genre_dir):
            if file.endswith('.wav'):
                file_path = os.path.join(genre_dir, file)

                if not file.endswith("_fixed.wav"):
                    file_path = convert_to_standard_wav(file_path)
                    if file_path is None:
                        continue  

                feature = extract_features(file_path)
                if feature is not None:
                    features.append(feature)
                    labels.append(GENRES.index(genre))

    return np.array(features), np.array(labels)

data_dir = r'D:/Projects/DeepLearning/Data/genres_original'
X, y = load_data(data_dir)

if len(X) == 0:
    raise ValueError("No valid data was found. Ensure your dataset contains valid WAV files.")

X = X[..., np.newaxis]  
y = to_categorical(y)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(f"Input Shape: {X_train.shape}")

model = Sequential([
    Input(shape=(X.shape[1], X.shape[2], 1)),  # Explicit Input layer
    SeparableConv2D(32, (3, 3), activation='relu', padding='same'),
    MaxPooling2D(pool_size=(2, 2)),
    SeparableConv2D(64, (3, 3), activation='relu', padding='same'),
    MaxPooling2D(pool_size=(2, 2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.4),
    Dense(len(GENRES), activation='softmax')
])

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

early_stopping = EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)

history = model.fit(X_train, y_train, validation_data=(X_test, y_test), 
                    epochs=50, batch_size=64, callbacks=[early_stopping])


loss, accuracy = model.evaluate(X_test, y_test)
print(f'Test Accuracy: {accuracy * 100:.2f}%')

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(loc='upper left')
plt.show()

plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(loc='upper left')
plt.show()

Input Shape: (1704, 257, 200, 1)
Epoch 1/50
[1m27/27[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 619ms/step - accuracy: 0.2001 - loss: 4.3895 - val_accuracy: 0.3286 - val_loss: 1.9463
Epoch 2/50
[1m27/27[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 582ms/step - accuracy: 0.2848 - loss: 1.9923 - val_accuracy: 0.4695 - val_loss: 1.7448
Epoch 3/50
[1m27/27[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 562ms/step - accuracy: 0.3845 - loss: 1.7204 - val_accuracy: 0.5587 - val_loss: 1.4485
Epoch 4/50
[1m27/27[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 546ms/step - accuracy: 0.4476 - loss: 1.5102 - val_accuracy: 0.6455 - val_loss: 1.3147
Epoch 5/50
[1m27/27[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 568ms/step - accuracy: 0.5677 - loss: 1.1976 - val_accuracy: 0.7230 - val_loss: 1.0599
Epoch 6/50
[1m27/27[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 543ms/step - accuracy: 0.6306 - loss: 1.0059 - val_accuracy: 0.7864 - val_lo

In [None]:
model.save("music_genre_classifier.h5")

