In [2]:
import librosa
import numpy as np
import joblib
import torch
from sklearn.preprocessing import StandardScaler
import torch.nn as nn


In [3]:

# Define the CNN model (as used previously)
class SimpleCNN(nn.Module):
    def __init__(self, input_dim, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * (input_dim // 4), 128)  # Adjust according to the input dimension
        self.fc2 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)



In [4]:
# Load the trained Random Forest model
rf_model = joblib.load('random_forest_3_sec.pkl')

In [5]:
# Load the trained CNN model
input_dim = 57  # Assuming 57 features
num_classes = 10  # Number of genre classes
cnn_model = SimpleCNN(input_dim=input_dim, num_classes=num_classes)
cnn_model.load_state_dict(torch.load('cnn_model.pth'))
cnn_model.eval()

# Load the scaler used during training
scaler = joblib.load('scaler_3_sec.pkl')
# Load label encoder
label_encoder = joblib.load('y_3_sec_encoded.pkl')


In [6]:
def extract_features_from_file(file_path):
    y, sr = librosa.load(file_path, sr=None)
    features = []

    # Extract features
    chroma_stft = librosa.feature.chroma_stft(y=y, sr=sr)
    rms = librosa.feature.rms(y=y)
    spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
    spectral_bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr)
    rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)
    zero_crossing_rate = librosa.feature.zero_crossing_rate(y)
    harmony = librosa.effects.harmonic(y)
    perceptr = librosa.effects.percussive(y)
    tempo = librosa.beat.tempo(y=y, sr=sr)[0]

    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20)

    # Aggregate statistics for each feature
    features.extend([
        chroma_stft.mean(), chroma_stft.var(),
        rms.mean(), rms.var(),
        spectral_centroid.mean(), spectral_centroid.var(),
        spectral_bandwidth.mean(), spectral_bandwidth.var(),
        rolloff.mean(), rolloff.var(),
        zero_crossing_rate.mean(), zero_crossing_rate.var(),
        harmony.mean(), harmony.var(),
        perceptr.mean(), perceptr.var(),
        tempo
    ])

    # Add mean and variance for each MFCC
    for i in range(20):
        features.append(mfcc[i].mean())
        features.append(mfcc[i].var())

    return np.array(features)

In [7]:
# Predict Genre Function
def predict_genre(file_path):
    features = extract_features_from_file(file_path)
    
    if features is None:
        return None, None

    features_scaled = scaler.transform([features])

    rf_prediction = rf_model.predict(features_scaled)
    rf_genre = label_encoder.inverse_transform(rf_prediction)[0]

    features_tensor = torch.tensor(features_scaled, dtype=torch.float32).unsqueeze(1)
    cnn_output = cnn_model(features_tensor)
    cnn_prediction = torch.argmax(cnn_output, dim=1)
    cnn_genre = label_encoder.inverse_transform(cnn_prediction.detach().numpy())[0]

    return rf_genre, cnn_genre

In [11]:
# Example usage
file_path = 'C:/Users/jimon/Music/UNLIMITED LOVE RHCP-20220327T134231Z-001/UNLIMITED LOVE RHCP/03 Aquatic Mouth Dance.wav'
rf_genre, cnn_genre = predict_genre(file_path)
if rf_genre and cnn_genre:
    print(f'Random Forest Prediction: {rf_genre}')
    print(f'CNN Prediction: {cnn_genre}')
else:
    print("Failed to predict genre.")

	This function was moved to 'librosa.feature.rhythm.tempo' in librosa version 0.10.0.
	This alias will be removed in librosa version 1.0.
  tempo = librosa.beat.tempo(y=y, sr=sr)[0]


Random Forest Prediction: hiphop
CNN Prediction: pop


