In [86]:
import os
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.nn import Transformer



In [81]:
def load_spectrogram(image_path):
    image = Image.open(image_path).convert('L')  # Convert to grayscale
    image = np.array(image) / 255.0  # Normalize to [0, 1]
    image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
    return image

def load_midi(midi_path):
    midi_array = np.load(midi_path, allow_pickle=True)
    midi_array = torch.tensor(midi_array, dtype=torch.float32)
    return midi_array

In [117]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import os

def custom_collate_fn(batch):
    # Stack spectrograms and MIDI files
    spectrograms = torch.stack([item[0] for item in batch])
    midi_files = torch.stack([item[1] for item in batch])
    return spectrograms, midi_files


class SpectrogramMIDIDataset(Dataset):
    def __init__(self, spectrogram_dir, midi_dir, years):
        self.spectrogram_paths = []
        self.midi_paths = []

        # Iterate through each year folder
        for year in years:
            spectrogram_year_dir = os.path.join(spectrogram_dir, year)
            midi_year_dir = os.path.join(midi_dir, year)

            # Get all spectrogram and MIDI files in the year folder
            spectrogram_files = sorted(os.listdir(spectrogram_year_dir))
            midi_files = sorted(os.listdir(midi_year_dir))

            # Pair spectrogram and MIDI files
            for spec_file, midi_file in zip(spectrogram_files, midi_files):
                self.spectrogram_paths.append(os.path.join(spectrogram_year_dir, spec_file))
                self.midi_paths.append(os.path.join(midi_year_dir, midi_file))

    def __len__(self):
        return len(self.spectrogram_paths)

    def __getitem__(self, idx):
        spectrogram = load_spectrogram(self.spectrogram_paths[idx])
        midi = load_midi(self.midi_paths[idx])
        return spectrogram, midi

In [119]:
# Define paths
spectrogram_dir = "../spectrograms"
midi_dir = "../midi-processed-values"

# years = ["2004", "2006", "2008", "2009", "2011", "2013", "2014", "2015", "2017", "2018"]
training_years = ["2004", "2006", "2009", "2013", "2014", "2015", "2017", "2018"]
testing_years = ["2008", "2011"]


# Create dataset
dataset = SpectrogramMIDIDataset(spectrogram_dir, midi_dir, training_years)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)