<a href="https://colab.research.google.com/github/gbdionne/toneclone/blob/main/spectrogramCNN_alt7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!mkdir /content/final_datasets

!cp "/content/drive/MyDrive/Capstone 210/Data/Final Datasets/final_real.h5" "/content/final_datasets/final_real.h5"
!cp "/content/drive/MyDrive/Capstone 210/Data/Final Datasets/final_real.csv" "/content/final_datasets/final_real.csv"

!cp "/content/drive/MyDrive/Capstone 210/Data/Final Datasets/test_extra_TRM_DLY.h5" "/content/final_datasets/test_extra_TRM_DLY.h5"
!cp "/content/drive/MyDrive/Capstone 210/Data/Final Datasets/test_extra_TRM_DLY.csv" "/content/final_datasets/test_extra_TRM_DLY.csv"

!cp "/content/drive/MyDrive/Capstone 210/Data/Final Datasets/validate_extra_TRM_DLY.h5" "/content/final_datasets/validate_extra_TRM_DLY.h5"
!cp "/content/drive/MyDrive/Capstone 210/Data/Final Datasets/validate_extra_TRM_DLY.csv" "/content/final_datasets/validate_extra_TRM_DLY.csv"

!cp "/content/drive/MyDrive/Capstone 210/Data/Final Datasets/train_extra_TRM_DLY.h5" "/content/final_datasets/train_extra_TRM_DLY.h5"
!cp "/content/drive/MyDrive/Capstone 210/Data/Final Datasets/train_extra_TRM_DLY.csv" "/content/final_datasets/train_extra_TRM_DLY.csv"

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import h5py
import pandas as pd
from sklearn.metrics import classification_report
import torchaudio.transforms as T
import random
import warnings

warnings.filterwarnings("ignore", message="Downcasting object dtype arrays on .fillna")

class SpectrogramDataset(Dataset):
    """
    Custom dataset for spectrogram data with data augmentation.
    Includes:
    - Random Gaussian noise
    - Pitch shifting using torch.roll() with zero-padding (prevents wrapping)
    """

    def __init__(self, hdf5_file, csv_file, augment=True, noise_level=0.03, pitch_shift_range=(-0.5, 0.5)):
        """
        Args:
            hdf5_file (str): Path to the HDF5 file containing spectrograms.
            csv_file (str): Path to CSV file with labels.
            augment (bool): Whether to apply data augmentation.
            noise_level (float): Standard deviation of Gaussian noise to add.
            pitch_shift_range (tuple): Min/max semitones for pitch shifting.
        """
        self.hdf5_file_path = hdf5_file
        self.labels = pd.read_csv(csv_file)

        # Manually define only important columns
        self.label_map = [
            'overdrive', 'distortion', 'fuzz', 'tremolo', 'phaser',
            'flanger', 'chorus', 'delay', 'hall_reverb', 'plate_reverb',
            'octaver', 'auto_filter'
        ]

        # Drop all non-label columns
        self.labels = self.labels[['key'] + self.label_map]

        self.hdf5_file = None  # Open HDF5 file once per worker

        self.augment = augment
        self.noise_level = noise_level
        self.pitch_shift_range = pitch_shift_range

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Open HDF5 file per worker to avoid threading issues
        if self.hdf5_file is None:
            self.hdf5_file = h5py.File(self.hdf5_file_path, "r", swmr=True)

        # Retrieve spectrogram
        key = self.labels.iloc[idx]['key']
        spectrogram = torch.tensor(self.hdf5_file[key][()], dtype=torch.float32).unsqueeze(0)

        # Retrieve labels
        label_values = self.labels.iloc[idx][1:].fillna(0).astype(float).values
        label = torch.tensor(label_values, dtype=torch.float32)

        # Data augmentation
        if self.augment:
            spectrogram = self.add_noise(spectrogram)
            spectrogram = self.pitch_shift(spectrogram)

        return spectrogram, label

    def add_noise(self, spectrogram):
        """Adds Gaussian noise where noise level is randomly chosen between 0 and self.noise_level."""
        noise_level = random.uniform(0, self.noise_level)  # Random noise per sample
        noise = torch.randn_like(spectrogram) * noise_level  # Scale noise
        return spectrogram + noise

    def pitch_shift(self, spectrogram):
        """Shifts spectrogram frequency bins using torch.roll() with zero padding."""
        semitone_shift = random.uniform(*self.pitch_shift_range)  # Random shift between min/max
        shift_bins = int(semitone_shift / 12 * spectrogram.shape[-2])  # Convert semitone shift to frequency bins

        # Apply frequency bin shift using torch.roll() with zero-padding
        shifted = torch.roll(spectrogram, shifts=shift_bins, dims=-2)  # Shift along frequency axis

        if shift_bins > 0:  # Shift up (higher pitch)
            shifted[..., :shift_bins, :] = 0  # Zero-pad low frequencies
        elif shift_bins < 0:  # Shift down (lower pitch)
            shifted[..., shift_bins:, :] = 0  # Zero-pad high frequencies

        return shifted

    def __del__(self):
        if self.hdf5_file is not None:
            self.hdf5_file.close()

In [4]:
class spectrogramCNN(nn.Module):
    def __init__(self, num_classes):
        super(spectrogramCNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(512)

        # Global average pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Fully connected layers
        self.fc1 = nn.Linear(512, 256)
        self.dropout = nn.Dropout(0.1)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)  # Max pooling

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.bn5(self.conv5(x)))
        x = F.max_pool2d(x, 2)

        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)

        x = F.relu(self.fc1(x))
        x = self.dropout(x) # Dropout
        x = self.fc2(x)
        return x

In [5]:
# Initialize dataset from HD5F and csv file

h5_train_path = '/content/final_datasets/train_extra_TRM_DLY.h5'
csv_train_path = '/content/final_datasets/train_extra_TRM_DLY.csv'

h5_val_path = '/content/final_datasets/validate_extra_TRM_DLY.h5'
csv_val_path = '/content/final_datasets/validate_extra_TRM_DLY.csv'

model_save_path = "/content/drive/MyDrive/Capstone 210/Models/final_multi_effects_alt7.mod"

train_dataset = SpectrogramDataset(h5_train_path, csv_train_path)
val_dataset = SpectrogramDataset(h5_val_path, csv_val_path)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=12, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=6, pin_memory=True)

num_classes = len(train_dataset.label_map)

model = spectrogramCNN(num_classes).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8577)  # 0.0001 → 0.00001 over 15 epochs
#optimizer = torch.optim.SGD(model.parameters(), lr=0.0005, momentum=0.9, weight_decay=1e-4)

# Training loop
num_epochs = 15
print_freq = 50
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (spectrograms, labels) in enumerate(train_loader):
        spectrograms, labels = spectrograms.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(spectrograms)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (batch_idx + 1) % print_freq == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")

    # Update learning rate
    scheduler.step()
    print(f"Updated Learning Rate: {scheduler.get_last_lr()}")

    # Validation step
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for spectrograms, labels in val_loader:
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            outputs = model(spectrograms)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Compute accuracy
            predicted = (torch.sigmoid(outputs) > 0.5).float()  # Convert logits to binary predictions

            # Store for metric computation
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_loader)

    # Convert lists to numpy arrays for metric calculations
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Compute metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average="macro", zero_division=0)
    recall = recall_score(all_labels, all_preds, average="macro", zero_division=0)
    f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)

    # Print classification report
    class_names = train_dataset.label_map
    print(classification_report(all_labels, all_preds, target_names=class_names))

    print(f"\nValidation Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}\n")

    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

Epoch [1/15], Batch [50/4502], Loss: 0.3647
Epoch [1/15], Batch [100/4502], Loss: 0.2850
Epoch [1/15], Batch [150/4502], Loss: 0.2740
Epoch [1/15], Batch [200/4502], Loss: 0.2295
Epoch [1/15], Batch [250/4502], Loss: 0.1917
Epoch [1/15], Batch [300/4502], Loss: 0.1696
Epoch [1/15], Batch [350/4502], Loss: 0.2025
Epoch [1/15], Batch [400/4502], Loss: 0.1460
Epoch [1/15], Batch [450/4502], Loss: 0.1250
Epoch [1/15], Batch [500/4502], Loss: 0.1440
Epoch [1/15], Batch [550/4502], Loss: 0.1223
Epoch [1/15], Batch [600/4502], Loss: 0.1519
Epoch [1/15], Batch [650/4502], Loss: 0.1155
Epoch [1/15], Batch [700/4502], Loss: 0.1433
Epoch [1/15], Batch [750/4502], Loss: 0.0894
Epoch [1/15], Batch [800/4502], Loss: 0.1006
Epoch [1/15], Batch [850/4502], Loss: 0.1094
Epoch [1/15], Batch [900/4502], Loss: 0.0981
Epoch [1/15], Batch [950/4502], Loss: 0.1130
Epoch [1/15], Batch [1000/4502], Loss: 0.0726
Epoch [1/15], Batch [1050/4502], Loss: 0.0951
Epoch [1/15], Batch [1100/4502], Loss: 0.0603
Epoch [1

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [2/15], Batch [50/4502], Loss: 0.0258
Epoch [2/15], Batch [100/4502], Loss: 0.0150
Epoch [2/15], Batch [150/4502], Loss: 0.0261
Epoch [2/15], Batch [200/4502], Loss: 0.0185
Epoch [2/15], Batch [250/4502], Loss: 0.0517
Epoch [2/15], Batch [300/4502], Loss: 0.0313
Epoch [2/15], Batch [350/4502], Loss: 0.0118
Epoch [2/15], Batch [400/4502], Loss: 0.0234
Epoch [2/15], Batch [450/4502], Loss: 0.0243
Epoch [2/15], Batch [500/4502], Loss: 0.0263
Epoch [2/15], Batch [550/4502], Loss: 0.0380
Epoch [2/15], Batch [600/4502], Loss: 0.0244
Epoch [2/15], Batch [650/4502], Loss: 0.0363
Epoch [2/15], Batch [700/4502], Loss: 0.0211
Epoch [2/15], Batch [750/4502], Loss: 0.0250
Epoch [2/15], Batch [800/4502], Loss: 0.0141
Epoch [2/15], Batch [850/4502], Loss: 0.0080
Epoch [2/15], Batch [900/4502], Loss: 0.0415
Epoch [2/15], Batch [950/4502], Loss: 0.0143
Epoch [2/15], Batch [1000/4502], Loss: 0.0115
Epoch [2/15], Batch [1050/4502], Loss: 0.0173
Epoch [2/15], Batch [1100/4502], Loss: 0.0280
Epoch [2

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [3/15], Batch [50/4502], Loss: 0.0146
Epoch [3/15], Batch [100/4502], Loss: 0.0162
Epoch [3/15], Batch [150/4502], Loss: 0.0249
Epoch [3/15], Batch [200/4502], Loss: 0.0146
Epoch [3/15], Batch [250/4502], Loss: 0.0188
Epoch [3/15], Batch [300/4502], Loss: 0.0129
Epoch [3/15], Batch [350/4502], Loss: 0.0308
Epoch [3/15], Batch [400/4502], Loss: 0.0231
Epoch [3/15], Batch [450/4502], Loss: 0.0168
Epoch [3/15], Batch [500/4502], Loss: 0.0066
Epoch [3/15], Batch [550/4502], Loss: 0.0075
Epoch [3/15], Batch [600/4502], Loss: 0.0194
Epoch [3/15], Batch [650/4502], Loss: 0.0410
Epoch [3/15], Batch [700/4502], Loss: 0.0048
Epoch [3/15], Batch [750/4502], Loss: 0.0114
Epoch [3/15], Batch [800/4502], Loss: 0.0167
Epoch [3/15], Batch [850/4502], Loss: 0.0096
Epoch [3/15], Batch [900/4502], Loss: 0.0051
Epoch [3/15], Batch [950/4502], Loss: 0.0039
Epoch [3/15], Batch [1000/4502], Loss: 0.0279
Epoch [3/15], Batch [1050/4502], Loss: 0.0098
Epoch [3/15], Batch [1100/4502], Loss: 0.0213
Epoch [3

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [4/15], Batch [50/4502], Loss: 0.0121
Epoch [4/15], Batch [100/4502], Loss: 0.0099
Epoch [4/15], Batch [150/4502], Loss: 0.0100
Epoch [4/15], Batch [200/4502], Loss: 0.0188
Epoch [4/15], Batch [250/4502], Loss: 0.0040
Epoch [4/15], Batch [300/4502], Loss: 0.0086
Epoch [4/15], Batch [350/4502], Loss: 0.0020
Epoch [4/15], Batch [400/4502], Loss: 0.0202
Epoch [4/15], Batch [450/4502], Loss: 0.0048
Epoch [4/15], Batch [500/4502], Loss: 0.0086
Epoch [4/15], Batch [550/4502], Loss: 0.0061
Epoch [4/15], Batch [600/4502], Loss: 0.0235
Epoch [4/15], Batch [650/4502], Loss: 0.0167
Epoch [4/15], Batch [700/4502], Loss: 0.0093
Epoch [4/15], Batch [750/4502], Loss: 0.0029
Epoch [4/15], Batch [800/4502], Loss: 0.0037
Epoch [4/15], Batch [850/4502], Loss: 0.0049
Epoch [4/15], Batch [900/4502], Loss: 0.0036
Epoch [4/15], Batch [950/4502], Loss: 0.0141
Epoch [4/15], Batch [1000/4502], Loss: 0.0272
Epoch [4/15], Batch [1050/4502], Loss: 0.0118
Epoch [4/15], Batch [1100/4502], Loss: 0.0048
Epoch [4

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [5/15], Batch [50/4502], Loss: 0.0085
Epoch [5/15], Batch [100/4502], Loss: 0.0131
Epoch [5/15], Batch [150/4502], Loss: 0.0081
Epoch [5/15], Batch [200/4502], Loss: 0.0064
Epoch [5/15], Batch [250/4502], Loss: 0.0025
Epoch [5/15], Batch [300/4502], Loss: 0.0080
Epoch [5/15], Batch [350/4502], Loss: 0.0063
Epoch [5/15], Batch [400/4502], Loss: 0.0023
Epoch [5/15], Batch [450/4502], Loss: 0.0226
Epoch [5/15], Batch [500/4502], Loss: 0.0028
Epoch [5/15], Batch [550/4502], Loss: 0.0097
Epoch [5/15], Batch [600/4502], Loss: 0.0141
Epoch [5/15], Batch [650/4502], Loss: 0.0069
Epoch [5/15], Batch [700/4502], Loss: 0.0392
Epoch [5/15], Batch [750/4502], Loss: 0.0101
Epoch [5/15], Batch [800/4502], Loss: 0.0075
Epoch [5/15], Batch [850/4502], Loss: 0.0026
Epoch [5/15], Batch [900/4502], Loss: 0.0020
Epoch [5/15], Batch [950/4502], Loss: 0.0024
Epoch [5/15], Batch [1000/4502], Loss: 0.0093
Epoch [5/15], Batch [1050/4502], Loss: 0.0033
Epoch [5/15], Batch [1100/4502], Loss: 0.0051
Epoch [5

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [6/15], Batch [50/4502], Loss: 0.0148
Epoch [6/15], Batch [100/4502], Loss: 0.0025
Epoch [6/15], Batch [150/4502], Loss: 0.0004
Epoch [6/15], Batch [200/4502], Loss: 0.0064
Epoch [6/15], Batch [250/4502], Loss: 0.0019
Epoch [6/15], Batch [300/4502], Loss: 0.0013
Epoch [6/15], Batch [350/4502], Loss: 0.0057
Epoch [6/15], Batch [400/4502], Loss: 0.0116
Epoch [6/15], Batch [450/4502], Loss: 0.0034
Epoch [6/15], Batch [500/4502], Loss: 0.0011
Epoch [6/15], Batch [550/4502], Loss: 0.0008
Epoch [6/15], Batch [600/4502], Loss: 0.0018
Epoch [6/15], Batch [650/4502], Loss: 0.0035
Epoch [6/15], Batch [700/4502], Loss: 0.0015
Epoch [6/15], Batch [750/4502], Loss: 0.0142
Epoch [6/15], Batch [800/4502], Loss: 0.0248
Epoch [6/15], Batch [850/4502], Loss: 0.0142
Epoch [6/15], Batch [900/4502], Loss: 0.0010
Epoch [6/15], Batch [950/4502], Loss: 0.0056
Epoch [6/15], Batch [1000/4502], Loss: 0.0023
Epoch [6/15], Batch [1050/4502], Loss: 0.0196
Epoch [6/15], Batch [1100/4502], Loss: 0.0043
Epoch [6

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [7/15], Batch [50/4502], Loss: 0.0052
Epoch [7/15], Batch [100/4502], Loss: 0.0039
Epoch [7/15], Batch [150/4502], Loss: 0.0012
Epoch [7/15], Batch [200/4502], Loss: 0.0085
Epoch [7/15], Batch [250/4502], Loss: 0.0151
Epoch [7/15], Batch [300/4502], Loss: 0.0035
Epoch [7/15], Batch [350/4502], Loss: 0.0069
Epoch [7/15], Batch [400/4502], Loss: 0.0024
Epoch [7/15], Batch [450/4502], Loss: 0.0017
Epoch [7/15], Batch [500/4502], Loss: 0.0095
Epoch [7/15], Batch [550/4502], Loss: 0.0001
Epoch [7/15], Batch [600/4502], Loss: 0.0178
Epoch [7/15], Batch [650/4502], Loss: 0.0003
Epoch [7/15], Batch [700/4502], Loss: 0.0040
Epoch [7/15], Batch [750/4502], Loss: 0.0003
Epoch [7/15], Batch [800/4502], Loss: 0.0023
Epoch [7/15], Batch [850/4502], Loss: 0.0023
Epoch [7/15], Batch [900/4502], Loss: 0.0043
Epoch [7/15], Batch [950/4502], Loss: 0.0055
Epoch [7/15], Batch [1000/4502], Loss: 0.0094
Epoch [7/15], Batch [1050/4502], Loss: 0.0021
Epoch [7/15], Batch [1100/4502], Loss: 0.0133
Epoch [7

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [8/15], Batch [50/4502], Loss: 0.0007
Epoch [8/15], Batch [100/4502], Loss: 0.0044
Epoch [8/15], Batch [150/4502], Loss: 0.0022
Epoch [8/15], Batch [200/4502], Loss: 0.0131
Epoch [8/15], Batch [250/4502], Loss: 0.0093
Epoch [8/15], Batch [300/4502], Loss: 0.0024
Epoch [8/15], Batch [350/4502], Loss: 0.0103
Epoch [8/15], Batch [400/4502], Loss: 0.0020
Epoch [8/15], Batch [450/4502], Loss: 0.0040
Epoch [8/15], Batch [500/4502], Loss: 0.0049
Epoch [8/15], Batch [550/4502], Loss: 0.0002
Epoch [8/15], Batch [600/4502], Loss: 0.0045
Epoch [8/15], Batch [650/4502], Loss: 0.0094
Epoch [8/15], Batch [700/4502], Loss: 0.0004
Epoch [8/15], Batch [750/4502], Loss: 0.0017
Epoch [8/15], Batch [800/4502], Loss: 0.0002
Epoch [8/15], Batch [850/4502], Loss: 0.0119
Epoch [8/15], Batch [900/4502], Loss: 0.0060
Epoch [8/15], Batch [950/4502], Loss: 0.0056
Epoch [8/15], Batch [1000/4502], Loss: 0.0050
Epoch [8/15], Batch [1050/4502], Loss: 0.0114
Epoch [8/15], Batch [1100/4502], Loss: 0.0020
Epoch [8

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [9/15], Batch [50/4502], Loss: 0.0013
Epoch [9/15], Batch [100/4502], Loss: 0.0053
Epoch [9/15], Batch [150/4502], Loss: 0.0027
Epoch [9/15], Batch [200/4502], Loss: 0.0172
Epoch [9/15], Batch [250/4502], Loss: 0.0188
Epoch [9/15], Batch [300/4502], Loss: 0.0026
Epoch [9/15], Batch [350/4502], Loss: 0.0011
Epoch [9/15], Batch [400/4502], Loss: 0.0007
Epoch [9/15], Batch [450/4502], Loss: 0.0016
Epoch [9/15], Batch [500/4502], Loss: 0.0097
Epoch [9/15], Batch [550/4502], Loss: 0.0049
Epoch [9/15], Batch [600/4502], Loss: 0.0037
Epoch [9/15], Batch [650/4502], Loss: 0.0017
Epoch [9/15], Batch [700/4502], Loss: 0.0050
Epoch [9/15], Batch [750/4502], Loss: 0.0024
Epoch [9/15], Batch [800/4502], Loss: 0.0007
Epoch [9/15], Batch [850/4502], Loss: 0.0071
Epoch [9/15], Batch [900/4502], Loss: 0.0125
Epoch [9/15], Batch [950/4502], Loss: 0.0026
Epoch [9/15], Batch [1000/4502], Loss: 0.0006
Epoch [9/15], Batch [1050/4502], Loss: 0.0014
Epoch [9/15], Batch [1100/4502], Loss: 0.0021
Epoch [9

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [10/15], Batch [50/4502], Loss: 0.0090
Epoch [10/15], Batch [100/4502], Loss: 0.0007
Epoch [10/15], Batch [150/4502], Loss: 0.0094
Epoch [10/15], Batch [200/4502], Loss: 0.0007
Epoch [10/15], Batch [250/4502], Loss: 0.0040
Epoch [10/15], Batch [300/4502], Loss: 0.0015
Epoch [10/15], Batch [350/4502], Loss: 0.0006
Epoch [10/15], Batch [400/4502], Loss: 0.0019
Epoch [10/15], Batch [450/4502], Loss: 0.0003
Epoch [10/15], Batch [500/4502], Loss: 0.0053
Epoch [10/15], Batch [550/4502], Loss: 0.0018
Epoch [10/15], Batch [600/4502], Loss: 0.0006
Epoch [10/15], Batch [650/4502], Loss: 0.0015
Epoch [10/15], Batch [700/4502], Loss: 0.0017
Epoch [10/15], Batch [750/4502], Loss: 0.0001
Epoch [10/15], Batch [800/4502], Loss: 0.0038
Epoch [10/15], Batch [850/4502], Loss: 0.0013
Epoch [10/15], Batch [900/4502], Loss: 0.0012
Epoch [10/15], Batch [950/4502], Loss: 0.0004
Epoch [10/15], Batch [1000/4502], Loss: 0.0194
Epoch [10/15], Batch [1050/4502], Loss: 0.0010
Epoch [10/15], Batch [1100/4502],

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [11/15], Batch [50/4502], Loss: 0.0131
Epoch [11/15], Batch [100/4502], Loss: 0.0030
Epoch [11/15], Batch [150/4502], Loss: 0.0008
Epoch [11/15], Batch [200/4502], Loss: 0.0006
Epoch [11/15], Batch [250/4502], Loss: 0.0001
Epoch [11/15], Batch [300/4502], Loss: 0.0114
Epoch [11/15], Batch [350/4502], Loss: 0.0046
Epoch [11/15], Batch [400/4502], Loss: 0.0095
Epoch [11/15], Batch [450/4502], Loss: 0.0010
Epoch [11/15], Batch [500/4502], Loss: 0.0006
Epoch [11/15], Batch [550/4502], Loss: 0.0012
Epoch [11/15], Batch [600/4502], Loss: 0.0135
Epoch [11/15], Batch [650/4502], Loss: 0.0011
Epoch [11/15], Batch [700/4502], Loss: 0.0002
Epoch [11/15], Batch [750/4502], Loss: 0.0020
Epoch [11/15], Batch [800/4502], Loss: 0.0054
Epoch [11/15], Batch [850/4502], Loss: 0.0050
Epoch [11/15], Batch [900/4502], Loss: 0.0006
Epoch [11/15], Batch [950/4502], Loss: 0.0036
Epoch [11/15], Batch [1000/4502], Loss: 0.0005
Epoch [11/15], Batch [1050/4502], Loss: 0.0007
Epoch [11/15], Batch [1100/4502],

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [12/15], Batch [50/4502], Loss: 0.0049
Epoch [12/15], Batch [100/4502], Loss: 0.0051
Epoch [12/15], Batch [150/4502], Loss: 0.0010
Epoch [12/15], Batch [200/4502], Loss: 0.0008
Epoch [12/15], Batch [250/4502], Loss: 0.0006
Epoch [12/15], Batch [300/4502], Loss: 0.0047
Epoch [12/15], Batch [350/4502], Loss: 0.0027
Epoch [12/15], Batch [400/4502], Loss: 0.0146
Epoch [12/15], Batch [450/4502], Loss: 0.0011
Epoch [12/15], Batch [500/4502], Loss: 0.0006
Epoch [12/15], Batch [550/4502], Loss: 0.0012
Epoch [12/15], Batch [600/4502], Loss: 0.0017
Epoch [12/15], Batch [650/4502], Loss: 0.0110
Epoch [12/15], Batch [700/4502], Loss: 0.0006
Epoch [12/15], Batch [750/4502], Loss: 0.0015
Epoch [12/15], Batch [800/4502], Loss: 0.0045
Epoch [12/15], Batch [850/4502], Loss: 0.0007
Epoch [12/15], Batch [900/4502], Loss: 0.0006
Epoch [12/15], Batch [950/4502], Loss: 0.0039
Epoch [12/15], Batch [1000/4502], Loss: 0.0029
Epoch [12/15], Batch [1050/4502], Loss: 0.0014
Epoch [12/15], Batch [1100/4502],

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [13/15], Batch [50/4502], Loss: 0.0155
Epoch [13/15], Batch [100/4502], Loss: 0.0020
Epoch [13/15], Batch [150/4502], Loss: 0.0052
Epoch [13/15], Batch [200/4502], Loss: 0.0036
Epoch [13/15], Batch [250/4502], Loss: 0.0031
Epoch [13/15], Batch [300/4502], Loss: 0.0055
Epoch [13/15], Batch [350/4502], Loss: 0.0055
Epoch [13/15], Batch [400/4502], Loss: 0.0015
Epoch [13/15], Batch [450/4502], Loss: 0.0001
Epoch [13/15], Batch [500/4502], Loss: 0.0035
Epoch [13/15], Batch [550/4502], Loss: 0.0012
Epoch [13/15], Batch [600/4502], Loss: 0.0048
Epoch [13/15], Batch [650/4502], Loss: 0.0006
Epoch [13/15], Batch [700/4502], Loss: 0.0073
Epoch [13/15], Batch [750/4502], Loss: 0.0121
Epoch [13/15], Batch [800/4502], Loss: 0.0095
Epoch [13/15], Batch [850/4502], Loss: 0.0002
Epoch [13/15], Batch [900/4502], Loss: 0.0002
Epoch [13/15], Batch [950/4502], Loss: 0.0001
Epoch [13/15], Batch [1000/4502], Loss: 0.0002
Epoch [13/15], Batch [1050/4502], Loss: 0.0081
Epoch [13/15], Batch [1100/4502],

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [14/15], Batch [50/4502], Loss: 0.0039
Epoch [14/15], Batch [100/4502], Loss: 0.0029
Epoch [14/15], Batch [150/4502], Loss: 0.0038
Epoch [14/15], Batch [200/4502], Loss: 0.0154
Epoch [14/15], Batch [250/4502], Loss: 0.0003
Epoch [14/15], Batch [300/4502], Loss: 0.0004
Epoch [14/15], Batch [350/4502], Loss: 0.0010
Epoch [14/15], Batch [400/4502], Loss: 0.0154
Epoch [14/15], Batch [450/4502], Loss: 0.0014
Epoch [14/15], Batch [500/4502], Loss: 0.0101
Epoch [14/15], Batch [550/4502], Loss: 0.0005
Epoch [14/15], Batch [600/4502], Loss: 0.0028
Epoch [14/15], Batch [650/4502], Loss: 0.0004
Epoch [14/15], Batch [700/4502], Loss: 0.0108
Epoch [14/15], Batch [750/4502], Loss: 0.0015
Epoch [14/15], Batch [800/4502], Loss: 0.0221
Epoch [14/15], Batch [850/4502], Loss: 0.0126
Epoch [14/15], Batch [900/4502], Loss: 0.0001
Epoch [14/15], Batch [950/4502], Loss: 0.0008
Epoch [14/15], Batch [1000/4502], Loss: 0.0043
Epoch [14/15], Batch [1050/4502], Loss: 0.0033
Epoch [14/15], Batch [1100/4502],

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch [15/15], Batch [50/4502], Loss: 0.0028
Epoch [15/15], Batch [100/4502], Loss: 0.0002
Epoch [15/15], Batch [150/4502], Loss: 0.0010
Epoch [15/15], Batch [200/4502], Loss: 0.0067
Epoch [15/15], Batch [250/4502], Loss: 0.0029
Epoch [15/15], Batch [300/4502], Loss: 0.0050
Epoch [15/15], Batch [350/4502], Loss: 0.0007
Epoch [15/15], Batch [400/4502], Loss: 0.0008
Epoch [15/15], Batch [450/4502], Loss: 0.0015
Epoch [15/15], Batch [500/4502], Loss: 0.0001
Epoch [15/15], Batch [550/4502], Loss: 0.0042
Epoch [15/15], Batch [600/4502], Loss: 0.0101
Epoch [15/15], Batch [650/4502], Loss: 0.0036
Epoch [15/15], Batch [700/4502], Loss: 0.0006
Epoch [15/15], Batch [750/4502], Loss: 0.0002
Epoch [15/15], Batch [800/4502], Loss: 0.0005
Epoch [15/15], Batch [850/4502], Loss: 0.0295
Epoch [15/15], Batch [900/4502], Loss: 0.0002
Epoch [15/15], Batch [950/4502], Loss: 0.0301
Epoch [15/15], Batch [1000/4502], Loss: 0.0024
Epoch [15/15], Batch [1050/4502], Loss: 0.0006
Epoch [15/15], Batch [1100/4502],

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load test dataset
h5_test_path = "/content/final_datasets/test_extra_TRM_DLY.h5"
csv_test_path = "/content/final_datasets/test_extra_TRM_DLY.csv"

model_load_path = "/content/drive/MyDrive/Capstone 210/Models/final_multi_effects_alt7.mod"

test_dataset = SpectrogramDataset(h5_test_path, csv_test_path)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=12, pin_memory=True)

num_classes = len(test_dataset.label_map)

# Load a saved model for test dataset metrics
model = spectrogramCNN(num_classes).to(device)
model.load_state_dict(torch.load(model_load_path, map_location=device))
model.eval()
print("Model loaded successfully.")

print("\nEvaluating with external test dataset...")

model.eval()
criterion = nn.BCEWithLogitsLoss()
test_loss = 0.0
test_preds, test_labels = [], []

with torch.no_grad():
    for spectrograms, labels in test_loader:
        spectrograms, labels = spectrograms.to(device), labels.to(device)
        outputs = model(spectrograms)
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        # Convert logits to binary predictions
        predicted = (torch.sigmoid(outputs) > 0.5).float()

        test_preds.extend(predicted.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

test_loss /= len(test_loader)

# Compute test metrics
test_preds = np.array(test_preds)
test_labels = np.array(test_labels)
test_accuracy = accuracy_score(test_labels, test_preds)
test_precision = precision_score(test_labels, test_preds, average="macro", zero_division=0)
test_recall = recall_score(test_labels, test_preds, average="macro", zero_division=0)
test_f1 = f1_score(test_labels, test_preds, average="macro", zero_division=0)

print(f"\nTest Loss: {test_loss:.4f}, Accuracy: {test_accuracy:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1-score: {test_f1:.4f}\n")

# Print classification report
class_names = test_dataset.label_map
print(classification_report(test_labels, test_preds, target_names=class_names))

  model.load_state_dict(torch.load(model_load_path, map_location=device))


Model loaded successfully.

Evaluating with external test dataset...

Test Loss: 0.0050, Accuracy: 0.9835, Precision: 0.9936, Recall: 0.9932, F1-score: 0.9934

              precision    recall  f1-score   support

   overdrive       1.00      0.99      1.00      3028
  distortion       1.00      0.99      1.00      4544
        fuzz       1.00      1.00      1.00      5300
     tremolo       1.00      1.00      1.00      4542
      phaser       1.00      1.00      1.00      4542
     flanger       1.00      0.99      0.99      3028
      chorus       1.00      1.00      1.00      5300
       delay       0.99      1.00      0.99      8328
 hall_reverb       0.97      0.98      0.98      3788
plate_reverb       0.98      0.99      0.99      3028
     octaver       0.99      0.99      0.99      2271
 auto_filter       1.00      0.99      1.00      3785

   micro avg       0.99      0.99      0.99     51484
   macro avg       0.99      0.99      0.99     51484
weighted avg       0.99     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load test dataset
# h5_test_path = "/content/drive/MyDrive/Capstone 210/Data/Final Datasets/final_real.h5"
# csv_test_path = "/content/drive/MyDrive/Capstone 210/Data/Final Datasets/final_real.csv"

h5_test_path = "/content/final_datasets/final_real.h5"
csv_test_path = "/content/final_datasets/final_real.csv"

model_load_path = "/content/drive/MyDrive/Capstone 210/Models/final_multi_effects_alt7.mod"

test_dataset = SpectrogramDataset(h5_test_path, csv_test_path)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=12, pin_memory=True)

num_classes = len(test_dataset.label_map)

# Load a saved model for test dataset metrics
model = spectrogramCNN(num_classes).to(device)
model.load_state_dict(torch.load(model_load_path, map_location=device))
model.eval()
print("Model loaded successfully.")

print("\nEvaluating with external test dataset...")

model.eval()
criterion = nn.BCEWithLogitsLoss()
test_loss = 0.0
test_preds, test_labels = [], []

with torch.no_grad():
    for spectrograms, labels in test_loader:
        spectrograms, labels = spectrograms.to(device), labels.to(device)
        outputs = model(spectrograms)
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        # Convert logits to binary predictions
        predicted = (torch.sigmoid(outputs) > 0.5).float()

        test_preds.extend(predicted.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

test_loss /= len(test_loader)

# Compute test metrics
test_preds = np.array(test_preds)
test_labels = np.array(test_labels)
test_accuracy = accuracy_score(test_labels, test_preds)
test_precision = precision_score(test_labels, test_preds, average="macro", zero_division=0)
test_recall = recall_score(test_labels, test_preds, average="macro", zero_division=0)
test_f1 = f1_score(test_labels, test_preds, average="macro", zero_division=0)

print(f"\nTest Loss: {test_loss:.4f}, Accuracy: {test_accuracy:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1-score: {test_f1:.4f}\n")

# Print classification report
class_names = test_dataset.label_map
print(classification_report(test_labels, test_preds, target_names=class_names))

  model.load_state_dict(torch.load(model_load_path, map_location=device))


Model loaded successfully.

Evaluating with external test dataset...

Test Loss: 0.0951, Accuracy: 0.7859, Precision: 0.9025, Recall: 0.9216, F1-score: 0.9048

              precision    recall  f1-score   support

   overdrive       0.84      0.79      0.81      3432
  distortion       0.99      0.97      0.98      5148
        fuzz       0.99      0.96      0.97      6006
     tremolo       0.87      0.99      0.92      4290
      phaser       0.99      0.96      0.98      5148
     flanger       0.99      0.73      0.84      3432
      chorus       0.95      0.95      0.95      6006
       delay       0.93      0.93      0.93      7722
 hall_reverb       0.89      0.98      0.93      5148
plate_reverb       0.96      0.87      0.91      3432
     octaver       0.57      0.99      0.73      2574
 auto_filter       0.87      0.94      0.90      4290

   micro avg       0.91      0.93      0.92     56628
   macro avg       0.90      0.92      0.90     56628
weighted avg       0.92     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [9]:
from google.colab import runtime
runtime.unassign()