In [1]:
# Install required package
!pip install pretty_midi

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.7/5.6 MB[0m [31m19.9 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━[0m [32m3.6/5.6 MB[0m [31m51.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m5.6/5.6 MB[0m [31m58.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m39.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m

In [2]:
import os
import numpy as np
import pretty_midi
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
from google.colab import drive
import random

# Temporarily depending on performance- NEW
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

In [3]:
# Mount Google Drive
drive.mount('/content/drive')

# Define paths
base_path = '/content/drive/MyDrive/NEW PROJECT LIST/DL COURSE SUMMER/Group Project/Composer_Dataset/NN_midi_files_extended'
train_path = os.path.join(base_path, 'train')
test_path = os.path.join(base_path, 'test')
composers = ['bach', 'bartok', 'chopin', 'mozart']
composer_to_idx = {composer: idx for idx, composer in enumerate(composers)}


Mounted at /content/drive


In [4]:
# Custom Dataset
class MidiDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        midi_file = self.file_paths[idx]
        label = self.labels[idx]

        # Load and preprocess MIDI
        try:
            midi_data = pretty_midi.PrettyMIDI(midi_file)
            piano_roll = midi_data.get_piano_roll(fs=100)  # Shape: (128, time_steps)

            # Transpose to (time_steps, 128) if needed
            piano_roll = piano_roll.T  # Shape: (time_steps, 128)

            # Normalize
            piano_roll = piano_roll / 127.0

            # Pad or truncate to fixed time length (1000)
            target_length = 1000
            if piano_roll.shape[0] < target_length:
                pad_width = ((0, target_length - piano_roll.shape[0]), (0, 0))
                piano_roll = np.pad(piano_roll, pad_width, mode='constant')
            else:
                piano_roll = piano_roll[:target_length, :]

            # Ensure pitch dimension is exactly 128
            if piano_roll.shape[1] != 128:
                if piano_roll.shape[1] < 128:
                    pad_width = ((0, 0), (0, 128 - piano_roll.shape[1]))
                    piano_roll = np.pad(piano_roll, pad_width, mode='constant')
                else:
                    piano_roll = piano_roll[:, :128]

            # Add channel dimension: (1, time_steps, pitches)
            piano_roll = np.expand_dims(piano_roll, axis=0)  # Shape: (1, 1000, 128)

            # Apply data augmentation
            if self.transform:
                piano_roll = self.transform(piano_roll)

            return torch.FloatTensor(piano_roll), torch.LongTensor([label])[0]
        except Exception as e:
            print(f"Error processing {midi_file}: {e}")
            return torch.zeros((1, 1000, 128)), torch.LongTensor([label])[0]

In [5]:
# Data augmentation functions
def time_shift(piano_roll, max_shift=75):  # Increased max_shift
    shift = random.randint(-max_shift, max_shift)
    shifted = np.roll(piano_roll, shift, axis=1)  # Shift along time axis
    if shift > 0:
        shifted[:, :shift, :] = 0  # Zero-pad the start
    elif shift < 0:
        shifted[:, shift:, :] = 0  # Zero-pad the end
    return shifted

def pitch_shift(piano_roll, max_shift=10):  # Increased max_shift
    shift = random.randint(-max_shift, max_shift)
    shifted = np.roll(piano_roll, shift, axis=2)  # Shift along pitch axis
    if shift > 0:
        shifted[:, :, :shift] = 0  # Zero-pad the start
    elif shift < 0:
        shifted[:, :, shift:] = 0  # Zero-pad the end
    return shifted

def add_noise(piano_roll, noise_factor=0.1):  # Increased noise_factor
    noise = np.random.normal(0, noise_factor, piano_roll.shape)
    return np.clip(piano_roll + noise, 0, 1)

def tempo_variation(piano_roll, factor=0.2):  # New transformation
    scale = 1 + random.uniform(-factor, factor)
    time_steps = piano_roll.shape[1]
    new_time_steps = int(time_steps * scale)
    if new_time_steps < 1:
        new_time_steps = 1
    # Interpolate along the time axis for each pitch
    rescaled = np.zeros_like(piano_roll)  # Shape: (1, 1000, 128)
    for i in range(piano_roll.shape[2]):  # Iterate over pitches
        interpolated = np.interp(
            np.linspace(0, time_steps, new_time_steps),
            np.arange(time_steps),
            piano_roll[0, :, i]
        )
        # Resize to original time_steps (1000) using interpolation or truncation
        if new_time_steps > time_steps:
            rescaled[0, :, i] = np.interp(
                np.linspace(0, new_time_steps, time_steps),
                np.arange(new_time_steps),
                interpolated
            )
        else:
            rescaled[0, :new_time_steps, i] = interpolated
            rescaled[0, new_time_steps:, i] = 0  # Zero-pad if shorter
    return rescaled

def augment_data(piano_roll):
    if random.random() > 0.3:  # Increased application frequency
        piano_roll = time_shift(piano_roll)
    if random.random() > 0.3:
        piano_roll = pitch_shift(piano_roll)
    if random.random() > 0.3:
        piano_roll = add_noise(piano_roll)
    if random.random() > 0.5:
        piano_roll = tempo_variation(piano_roll)
    return piano_roll

In [6]:
# Load MIDI files
def load_midi_files(data_path):
    file_paths = []
    labels = []

    for composer in composers:
        composer_path = os.path.join(data_path, composer)
        for file in os.listdir(composer_path):
            if file.endswith('.mid') or file.endswith('.midi'):
                file_paths.append(os.path.join(composer_path, file))
                labels.append(composer_to_idx[composer])

    return file_paths, labels

In [7]:
# CNN Model
class ComposerCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(ComposerCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(5, 5), stride=1, padding=2),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=(3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Calculate the size of flattened features
        self.flatten_size = 256 * (1000 // 8) * (128 // 8)

        self.fc_layers = nn.Sequential(
            nn.Linear(self.flatten_size, 1024),
            nn.ReLU(),
            nn.Dropout(0.7),  # Increased dropout
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

In [8]:
# Load data
train_files, train_labels = load_midi_files(train_path)
test_files, test_labels = load_midi_files(test_path)

In [9]:
# Split training data into train and validation sets
train_files, val_files, train_labels, val_labels = train_test_split(train_files, train_labels, test_size=0.2, random_state=42)

In [10]:
# Create datasets
train_dataset = MidiDataset(train_files, train_labels, transform=augment_data)
val_dataset = MidiDataset(val_files, val_labels)
test_dataset = MidiDataset(test_files, test_labels)

In [11]:
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)  # Reduced batch size
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


In [12]:
# Initialize model, loss, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ComposerCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Reduced learning rate
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # Learning rate scheduler

In [13]:
# Training loop with early stopping
num_epochs = 50
best_val_loss = float('inf')
patience, trials = 20, 0  # was 10, but this was underperforming so I adjusted it to 20.

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    val_loss = val_loss / len(val_loader)
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {running_loss/len(train_loader):.4f}, Val Loss: {val_loss:.4f}')

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        trials = 0
    else:
        trials += 1
        if trials >= patience:
            print(f'Early stopping triggered after epoch {epoch+1}')
            break

    scheduler.step()

Epoch 1/50, Train Loss: 21.0012, Val Loss: 9.8513
Epoch 2/50, Train Loss: 17.9280, Val Loss: 16.8986
Epoch 3/50, Train Loss: 13.2961, Val Loss: 6.4396
Epoch 4/50, Train Loss: 13.9754, Val Loss: 15.4359
Epoch 5/50, Train Loss: 12.5262, Val Loss: 18.7953
Epoch 6/50, Train Loss: 11.8264, Val Loss: 29.6960
Epoch 7/50, Train Loss: 9.3842, Val Loss: 22.7654
Epoch 8/50, Train Loss: 11.7012, Val Loss: 36.0533
Epoch 9/50, Train Loss: 8.8249, Val Loss: 24.9890
Epoch 10/50, Train Loss: 6.1273, Val Loss: 21.4801
Epoch 11/50, Train Loss: 6.2106, Val Loss: 26.9299
Epoch 12/50, Train Loss: 7.2173, Val Loss: 37.7424
Epoch 13/50, Train Loss: 5.0696, Val Loss: 33.1504
Epoch 14/50, Train Loss: 3.6833, Val Loss: 43.2616
Epoch 15/50, Train Loss: 3.9319, Val Loss: 47.8091
Epoch 16/50, Train Loss: 2.7175, Val Loss: 47.5252
Epoch 17/50, Train Loss: 3.4549, Val Loss: 44.2628
Epoch 18/50, Train Loss: 3.2459, Val Loss: 42.0771
Epoch 19/50, Train Loss: 3.2057, Val Loss: 32.6244
Epoch 20/50, Train Loss: 2.8620, Va

In [14]:
# Evaluation
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

In [15]:
# Generate and print classification report
print('\nClassification Report:\n')
print(classification_report(all_labels, all_preds, target_names=composers))


Classification Report:

              precision    recall  f1-score   support

        bach       1.00      1.00      1.00         4
      bartok       0.67      1.00      0.80         4
      chopin       0.50      0.50      0.50         4
      mozart       0.50      0.25      0.33         4

    accuracy                           0.69        16
   macro avg       0.67      0.69      0.66        16
weighted avg       0.67      0.69      0.66        16

