In [1]:
import pandas as pd
import muspy
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


In [2]:

class MIDIDataset(Dataset):
    def __init__(self, data_path, label_path, transform=None, max_seq_length=1000):
        self.data = pd.read_csv(label_path)  # Load labels from CSV
        self.data_path = data_path
        self.transform = transform
        self.max_seq_length = max_seq_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_id = self.data['ID'][idx]  # Fetch file ID/name from CSV
        label = self.data['4Q'][idx]  # Fetch label
        
        file_path = os.path.join(self.data_path, f"{file_id}.mid")
        midi_data = muspy.read(file_path)  # Use MusPy to read MIDI files
        
        # Process MIDI data to extract required events
        note_on_events = []
        note_durations = []
        velocities = []
        tempos = []
        bars = []
        beats = []
        
        for track in midi_data.tracks:
            for event in track:
                if isinstance(event, muspy.Note):
                    note_on_events.append(event.pitch)
                    note_durations.append(event.duration)
                    velocities.append(event.velocity)
                elif isinstance(event, muspy.Tempo):
                    tempos.append(event.qpm)
                elif isinstance(event, muspy.TimeSignature):
                    bars.append(event.numerator)
                    beats.append(event.denominator)
        
        # Example: Extract number of note events
        num_note_events = len(note_on_events)
        
        sample = {
            'file_path': file_path,
            'label': label,
            'num_note_events': num_note_events,
            'note_on_events': note_on_events,
            'note_durations': note_durations,
            'velocities': velocities,
            'tempos': tempos,
            'bars': bars,
            'beats': beats
        }


        # Pad sequences to a fixed length
        pad_value = 0  # You can choose a different value if needed
        pad = lambda x, length: pad_sequence([torch.tensor(x)] * length, batch_first=True, padding_value=pad_value)
        
        note_on_events = pad(sample['note_on_events'], self.max_seq_length)
        note_durations = pad(sample['note_durations'], self.max_seq_length)
        velocities = pad(sample['velocities'], self.max_seq_length)
        tempos = pad(sample['tempos'], self.max_seq_length)
        bars = pad(sample['bars'], self.max_seq_length)
        beats = pad(sample['beats'], self.max_seq_length)

        sample = {
            'file_path': file_path,
            'label': label,
            'num_note_events': num_note_events,
            'note_on_events': note_on_events,
            'note_durations': note_durations,
            'velocities': velocities,
            'tempos': tempos,
            'bars': bars,
            'beats': beats
            # Add other extracted events here
        }
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample


In [3]:
# Define paths
data_path = '/workspaces/dl/Final Project/fork/data/emopia/EMOPIA_2.2/midis'
label_path = '/workspaces/dl/Final Project/fork/data/emopia/EMOPIA_2.2/label_new.csv'

# Initialize dataset and dataloader
dataset = MIDIDataset(data_path, label_path)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [4]:
batch_size = 32

In [5]:
def collate_fn(batch):
    # Get the maximum sequence length in this batch
    max_seq_length = max(len(item['note_on_events']) for item in batch)

    # Prepare sequences for padding
    padded_note_on_events = []
    # Add other sequences like note_durations, velocities, etc.

    # Pad or truncate sequences to the maximum length
    for item in batch:
        # Pad sequences to match the maximum length
        padded_note_on_events.append(
            torch.nn.functional.pad(torch.tensor(item['note_on_events']), (0, max_seq_length - len(item['note_on_events'])))
        )
        # Similarly, pad or truncate other sequences like note_durations, velocities, etc.

    # Use pad_sequence to create batch tensors
    padded_note_on_events = pad_sequence(padded_note_on_events, batch_first=True)
    # Similarly, use pad_sequence for other sequences

    return {
        'note_on_events': padded_note_on_events,
        # Add other padded sequences to the returned dictionary
    }

# Create your DataLoader with the collate_fn
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)


In [None]:
# show sample from dataloader
sample = next(iter(dataloader))

In [51]:
# store file name as a csv
file_names = []
for i in range(len(os.listdir(data_path))):
    file_names.append(os.listdir(data_path)[i][:-4])

In [52]:
# create dataframe
df = pd.DataFrame(file_names, columns=['ID'])

In [53]:
df

Unnamed: 0,ID
0,Q1_9v2WSpn4FCw_10
1,Q2_dtS02mrDMsM_1
2,Q3_3ZnxqCZ7qGg_0
3,Q4_vpTguZtJAFA_2
4,Q3_Ie5koh4qvJc_5
...,...
1066,Q1_miqLU2739dk_1
1067,Q1_0vLPYiPN7qY_1
1068,Q1_V3Y9L4UOcpk_2
1069,Q3_gvWDOIiocuE_0


In [54]:
label_df = pd.read_csv(label_path)

In [55]:
# find the file names in label.csv that are not in the df
label_df[~label_df['ID'].isin(df['ID'])]

Unnamed: 0,ID,4Q,annotator
58,Q1_9v2WSpn4FCw_5,1,A
59,Q1_9v2WSpn4FCw_6,1,A
92,Q1_JT1XJnVmABo_2,1,A
208,Q1_ldCQ6N9G6Mk_2,1,A
305,Q2_9v2WSpn4FCw_1,2,D
306,Q2_9v2WSpn4FCw_2,2,D
876,Q4_JT1XJnVmABo_1,4,C


In [49]:
missing_files

[]