In [None]:
import numpy as np
from torch import nn
import torchaudio
from torchaudio import transforms
from IPython.display import Audio
from torch.utils.data import DataLoader
import cv2
import random, math
from glob import glob
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

In [None]:
train_df, test_df = train_test_split(gtzan_df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)

print(len(train_df))
print(len(test_df))
print(len(val_df))

In [None]:
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.class_names = sorted(dataframe['Genre'].unique())
        self.class_to_index = {class_name: i for i, class_name in enumerate(self.class_names)}
        self.file_list = [(row['Path'], self.class_to_index[row['Genre']]) for index, row in dataframe.iterrows()]
        self.transform = transform
        
        self.sr = 44100
        self.duration = 5500
        self.channel = 2
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        audio_file, class_id = self.file_list[idx]
        
        aud = AudioUtil.open(audio_file)
        resample = AudioUtil.resample(aud, self.sr)
        rechannel = AudioUtil.rechannel(resample, self.channel)
        equal_dur = AudioUtil.pad_trunc(rechannel, self.duration)
        melspectrogram = AudioUtil.spectro_gram(equal_dur)

        return melspectrogram, class_id
        
def create_data_loader(audio_folder, max_batch_size=16, shuffle=True):
    audio_dataset = AudioDataset(audio_folder)
    data_loader = torch.utils.data.DataLoader(audio_dataset, batch_size=max_batch_size, shuffle=shuffle)
    return data_loader

class_names = gtzan_df['Genre'].unique()
print(class_names)

train_loader = create_data_loader(train_df)
val_loader = create_data_loader(val_df)
test_loader = create_data_loader(test_df)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def show_images_from_loader(data_loader, num_images=16):
    indices = np.random.choice(len(data_loader.dataset), num_images, replace=False)
    
    spectrograms = []
    labels = []
    for idx in indices:
        spectrogram, label = data_loader.dataset[idx]
        spectrogram_data = spectrogram[0]
        spectrograms.append(spectrogram_data)
        labels.append(label)

    rows = int(np.ceil(num_images / 4))
    fig, axes = plt.subplots(rows, 4, figsize=(15, 3 * rows))

    for i, (specgram, label) in enumerate(zip(spectrograms, labels)):
        ax = axes[i // 4, i % 4] if rows > 1 else axes[i % 4]
        plot_spectrogram(specgram, title=f"Label: {class_names[label]}", ax=ax)
    plt.tight_layout()
    plt.show()

show_images_from_loader(train_loader)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SpectrogramCNN_GRUNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SpectrogramCNN_GRUNet, self).__init__()

        self.conv1 = nn.Conv2d(2, 32, kernel_size=(3, 3), padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.drop = nn.Dropout(0.25)
        
        self.gru_input_size = 128 * 8
        self.gru1 = nn.GRU(input_size=self.gru_input_size, hidden_size=68, batch_first=True, num_layers=1)
        self.gru2 = nn.GRU(input_size=68, hidden_size=68, batch_first=True, num_layers=1)

        self.fc1 = nn.Linear(68, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.drop_fc = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        x = self.drop(x)

        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = self.drop(x)

        x = self.pool(torch.relu(self.bn3(self.conv3(x))))
        x = self.drop(x)

        b, c, h, w = x.size()
        x = x.view(b, w, c * h)

        x, _ = self.gru1(x)
        x, _ = self.gru2(x)

        x = x[:, -1, :]

        x = torch.relu(self.fc1(x))
        x = self.drop_fc(x)
        x = self.fc2(x)
        return x

model = SpectrogramCNN_GRUNet(num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
input_tensor = torch.randn(16, 1, 96, 1366).to(device)
criterion = nn.CrossEntropyLoss()

In [None]:
num_epochs = 10
best_accuracy = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    train_iterator = tqdm(train_loader, total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    
    for batch_idx, (inputs, labels) in enumerate(train_iterator):
        inputs = inputs.float().to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        
        train_iterator.set_postfix(loss=running_loss / ((batch_idx + 1) * train_loader.batch_size))
        
    epoch_loss = running_loss / len(train_loader.dataset)
    
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.float().to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}, Validation Accuracy: {val_accuracy:.2%}")

    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), 'DM_gtzan_best.pth')

print(f"Best validation accuracy: {best_accuracy:.2%}")