In [13]:
from PIL import Image
import numpy as np
import torch

def load_spectrogram(image_path):
    image = Image.open(image_path).convert('L')  # Convert to grayscale
    image = np.array(image) / 255.0  # Normalize to [0, 1]
    image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
    return image

def load_midi(midi_path):
    midi_array = np.load(midi_path, allow_pickle=True)
    midi_array = torch.tensor(midi_array, dtype=torch.float32)
    return midi_array

In [14]:
sample_midi = 'midi-processed-values/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi.npy'
midi = load_midi(sample_midi)
print(midi.shape)

torch.Size([500, 17])


In [15]:
sample_spectrogram = 'spectrograms/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.png'
image = load_spectrogram(sample_spectrogram)

print(image.shape)

torch.Size([1, 500, 1400])


In [16]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import os

def custom_collate_fn(batch):
    # Stack spectrograms and MIDI files
    spectrograms = torch.stack([item[0] for item in batch])
    midi_files = torch.stack([item[1] for item in batch])
    return spectrograms, midi_files


class SpectrogramMIDIDataset(Dataset):
    def __init__(self, spectrogram_dir, midi_dir, years):
        self.spectrogram_paths = []
        self.midi_paths = []

        # Iterate through each year folder
        for year in years:
            spectrogram_year_dir = os.path.join(spectrogram_dir, year)
            midi_year_dir = os.path.join(midi_dir, year)

            # Get all spectrogram and MIDI files in the year folder
            spectrogram_files = sorted(os.listdir(spectrogram_year_dir))
            midi_files = sorted(os.listdir(midi_year_dir))

            # Pair spectrogram and MIDI files
            for spec_file, midi_file in zip(spectrogram_files, midi_files):
                self.spectrogram_paths.append(os.path.join(spectrogram_year_dir, spec_file))
                self.midi_paths.append(os.path.join(midi_year_dir, midi_file))

    def __len__(self):
        return len(self.spectrogram_paths)

    def __getitem__(self, idx):
        spectrogram = load_spectrogram(self.spectrogram_paths[idx])
        midi = load_midi(self.midi_paths[idx])
        return spectrogram, midi

In [25]:
# Define paths
spectrogram_dir = "./spectrograms"
midi_dir = "./midi-processed-values"

# years = ["2004", "2006", "2008", "2009", "2011", "2013", "2014", "2015", "2017", "2018"]
training_years = ["2004", "2006", "2009", "2013", "2014", "2015", "2017", "2018"]
testing_years = ["2008", "2011"]


# Create dataset
dataset = SpectrogramMIDIDataset(spectrogram_dir, midi_dir, training_years)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)

In [21]:
for batch_idx, (spectrograms, midi_files) in enumerate(dataloader):
    print(f"Batch {batch_idx + 1}")
    print("Padded spectrograms shape:", spectrograms.shape)
    print("Padded MIDI files shape:", midi_files.shape)
    break  # Only check the first batch

Batch 1
Padded spectrograms shape: torch.Size([32, 1, 500, 1400])
Padded MIDI files shape: torch.Size([32, 500, 17])


In [23]:
import torch
import torch.nn as nn
class SpectrogramToMIDIModel(nn.Module):
    def __init__(self, input_channels=1, freq_bins=1400, midi_features=17):
        super().__init__()
        
        # Spectrogram encoder (1D CNN to preserve time steps)
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 16, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU(),
        )
        
        # Temporal decoder (LSTM)
        self.lstm = nn.LSTM(
            input_size=64 * freq_bins,  # 64 channels * 1400 frequency bins
            hidden_size=256,
            num_layers=2,
            batch_first=True,
            bidirectional=False
        )
        
        # Final projection to MIDI features
        self.fc = nn.Linear(256, midi_features)

    def forward(self, x):
        # x shape: [batch, 1, 500, 1400]
        x = self.encoder(x)  # [batch, 64, 500, 1400]
        
        # Flatten frequency and channel dimensions
        x = x.permute(0, 2, 1, 3)  # [batch, 500, 64, 1400]
        x = x.reshape(x.size(0), x.size(1), -1)  # [batch, 500, 64*1400]
        
        # LSTM expects [batch, seq_len, features]
        lstm_out, _ = self.lstm(x)  # [batch, 500, 256]
        
        # Project to MIDI features
        output = self.fc(lstm_out)  # [batch, 500, 17]
        return output

In [24]:
import torch.optim as optim

input_channels = 1
frequency_bins = 1400
output_dim = 17 

model = SpectrogramToMIDIModel(input_channels, frequency_bins, output_dim)


In [None]:
import torch.optim as optim

# Loss function and optimizer
criterion = nn.MSELoss()  # Mean Squared Error for regression
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Learning rate

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training loop
num_epochs = 10  # Number of epochs to train

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (spectrograms, midi_files) in enumerate(dataloader):
        spectrograms = spectrograms.to(device)
        midi_files = midi_files.to(device)

        # Ensure input shape: [batch, 1, 500, 1400]
        if spectrograms.dim() == 3:
            spectrograms = spectrograms.unsqueeze(1)

        optimizer.zero_grad()
        outputs = model(spectrograms)  # [batch, 500, 17]
        loss = criterion(outputs, midi_files)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 10 == 9:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx + 1}/{len(dataloader)}], Loss: {running_loss / 10:.4f}")
            running_loss = 0.0

    print(f"Epoch [{epoch + 1}/{num_epochs}] completed, Average Loss: {running_loss / len(dataloader):.4f}")