In [1]:
pip install torch librosa numpy

Collecting librosa
  Downloading librosa-0.10.2.post1-py3-none-any.whl.metadata (8.6 kB)
Collecting audioread>=2.1.9 (from librosa)
  Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting scipy>=1.2.0 (from librosa)
  Downloading scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl.metadata (60 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.8/60.8 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scikit-learn>=0.20.0 (from librosa)
  Downloading scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl.metadata (13 kB)
Collecting joblib>=0.14 (from librosa)
  Downloading joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting numba>=0.51.0 (from librosa)
  Downloading numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (2.7 kB)
Collecting soundfile>=0.12.1 (from librosa)
  Downloading soundfile-0.12.1-py2.py3-none-macosx_11_0_arm64.whl.metadata (14 kB)
Collecting pooch>=1.1 (from librosa)
  Downloading pooch-1.8.2-py3-no

In [1]:
import os
import librosa
import audioread
import scipy.signal
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:
# 1. Convert audio to Mel spectrogram using librosa
def audio_to_melspectrogram(file_path, sr=22050, n_mels=64):
    try:
        # Use audioread to read the audio file
        with audioread.audio_open(file_path) as input_file:
            total_samples = input_file.samplerate * input_file.duration
            y = np.zeros(int(total_samples))

            i = 0
            for buf in input_file:
                buf_array = np.frombuffer(buf, dtype=np.int16) / 32768.0  # Convert buffer to float
                y[i:i+len(buf_array)] = buf_array
                i += len(buf_array)
        
        # If the sample rate is different from what we need, resample the audio
        if input_file.samplerate != sr:
            y = librosa.resample(y, input_file.samplerate, sr)

        # Generate Mel Spectrogram
        spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
        spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max)
        return spectrogram_db
    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return None  # In case of error, return None

In [3]:
# 2. Custom Dataset class to handle GTZAN data
class GenreDataset(Dataset):
    def __init__(self, file_paths, labels):
        self.file_paths = file_paths
        self.labels = labels

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        spectrogram = audio_to_melspectrogram(self.file_paths[idx])
        if spectrogram is None:
            # Return None if the file is corrupted
            return None
        
        spectrogram = torch.tensor(spectrogram, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        spectrogram = F.interpolate(spectrogram.unsqueeze(0), size=(64, 64)).squeeze(0)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return spectrogram, label

In [4]:
# 3. CNN Model for Genre Classification
class GenreClassificationCNN(nn.Module):
    def __init__(self, num_classes=10):  # GTZAN has 10 genres
        super(GenreClassificationCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)  # Can adjust based on input size
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 64 * 8 * 8)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [5]:
def collate_fn(batch):
    # Filter out corrupted files (None values)
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    return torch.utils.data.dataloader.default_collate(batch)

In [6]:
# 4. Prepare GTZAN Dataset - Mapping genres to labels and loading the data
def create_dataloader(audio_dir, batch_size=32):
    genre_to_label = {
        "blues": 0,
        "classical": 1,
        "country": 2,
        "disco": 3,
        "hiphop": 4,
        "jazz": 5,
        "metal": 6,
        "pop": 7,
        "reggae": 8,
        "rock": 9
    }
    file_paths, labels = [], []
    
    # Traverse the GTZAN directories and gather file paths and labels
    for genre in genre_to_label:
        genre_folder = os.path.join(audio_dir, genre)
        for file_name in os.listdir(genre_folder):
            if file_name.endswith(".wav"):  # Only consider wav files
                file_paths.append(os.path.join(genre_folder, file_name))
                labels.append(genre_to_label[genre])

    # Create Dataset and DataLoader
    dataset = GenreDataset(file_paths, labels)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn)

In [7]:
# 5. Training loop for the model
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            # Skip empty batches
            if data is None:
                continue
            
            inputs, labels = data

            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Epoch [{epoch + 1}], Batch [{i + 1}], Loss: {running_loss / 100:.4f}')
                running_loss = 0.0

    print('Finished Training')

In [8]:
# Main function
if __name__ == "__main__":
    # 6. Path to the GTZAN dataset folder
    audio_dir = "/Users/duncanroepke/Downloads/Data/genres_original"  # This point to the folder where "blues", "classical", etc. are located

    # 7. Create DataLoader from GTZAN dataset
    train_loader = create_dataloader(audio_dir, batch_size=32)

    # 8. Initialize model, loss function, and optimizer
    num_classes = 10  # GTZAN has 10 genres
    model = GenreClassificationCNN(num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 9. Train the model
    train_model(model, train_loader, criterion, optimizer, num_epochs=10)

    # 10. Save the trained model
    torch.save(model.state_dict(), "genre_classification_cnn.pth")

Finished Training
