In [1]:
import logging
from typing import Optional, List
import partitura as pt
import numpy as np
import os

os.chdir(r"C:\Users\cunn2\OneDrive\DSML\Project\thesis-repo")

from sms.src.log import configure_logging
from sms.src.synthetic_data.utils import midi_to_note_array

logger = logging.getLogger(__name__)
configure_logging()

ModuleNotFoundError: No module named 'sms.src.synthetic_data.utils'

In [2]:
def midi_to_all_bars(
    midi_path: str,
    rest_pitch: float = -1,
    remove_rests: bool = True
) -> List[np.ndarray]:
    """
    Extracts all bars from a MIDI file and returns them as a list of note arrays.

    Args:
        midi_path (str): Path to the MIDI file.
        rest_pitch (float): Pitch value for rests.
        remove_rests (bool): Whether to remove rests from the note arrays.

    Returns:
        List[np.ndarray]: List of arrays, each representing a bar with columns [duration_beat, pitch].
    """
    score = pt.load_score(midi_path)

    if not score.parts:
        raise ValueError("No parts found in the score.")
    
    part = score.parts[0]
    note_arr = part.note_array()
    
    last_note = note_arr[-1]
    last_note_end = last_note['onset_beat'] + last_note['duration_beat']
    total_bars = int(np.ceil(last_note_end / 4))  # Assuming 4 beats per bar
    print(f"Total bars: {total_bars}")
    all_bars = []
    for bar in range(total_bars):
        bar_notes = midi_to_note_array(
            midi_path=midi_path,
            num_bars=1,
            start_bar=bar,
            rest_pitch=rest_pitch,
            remove_rests=remove_rests
        )
        all_bars.append(bar_notes)

    return all_bars

path = r"C:\Users\cunn2\OneDrive\DSML\Project\thesis-repo\data\synthetic_dataset\monophonic_midis\maestro\MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_mono.mid"
all_bars = midi_to_all_bars(path)


  return f(*args, **kwargs)
  part = create_part(
  part = create_part(
  part = create_part(
  part = create_part(
  part = create_part(
  part = create_part(


Total bars: 90


[2024-10-02 16:09:05] [INFO ] Extracted notes: [(2.  , 1.25, 2.  , 1.25, 440, 275, 42, 1, 'n0')
 (3.25, 0.75, 3.25, 0.75, 715, 165, 49, 1, 'n1')]. Returning None.
[2024-10-02 16:09:08] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:09:08] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:09:13] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:09:16] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:09:18] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:09:19] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:09:19] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:09:21] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:09:23] [INFO ] Adjusted last note/rest

In [6]:
def midi_to_all_bars_efficient(
    midi_path: str,
    rest_pitch: float = -1,
    remove_rests: bool = True
) -> List[np.ndarray]:
    """
    Efficiently extracts all bars from a MIDI file and returns them as a list of note arrays.

    Args:
        midi_path (str): Path to the MIDI file.
        rest_pitch (float): Pitch value for rests.
        remove_rests (bool): Whether to remove rests from the note arrays.

    Returns:
        List[np.ndarray]: List of arrays, each representing a bar with columns [duration_beat, pitch].
    """
    score = pt.load_score(midi_path)

    if not score.parts:
        raise ValueError("No parts found in the score.")
    
    part = score.parts[0]
    note_arr = part.note_array()
    
    last_note = note_arr[-1]
    last_note_end = last_note['onset_beat'] + last_note['duration_beat']
    total_bars = int(np.ceil(last_note_end / 4))  # Assuming 4 beats per bar

    all_bars = []
    beats_per_bar = 4

    for bar in range(total_bars):
        start_beats = bar * beats_per_bar
        end_beats = start_beats + beats_per_bar

        # Filter notes for this bar
        bar_notes = note_arr[
            (note_arr['onset_beat'] < end_beats) &
            (note_arr['onset_beat'] + note_arr['duration_beat'] > start_beats)
        ]

        if len(bar_notes) <= 3:
            logger.info(f"Bar {bar} has 3 or fewer notes. Skipping.")
            continue

        # Sort notes by onset_beat
        bar_notes = bar_notes[np.argsort(bar_notes['onset_beat'])]

        # Process notes for this bar
        duration_pitch = []
        previous_end = start_beats

        for note in bar_notes:
            note_onset = note['onset_beat']
            note_duration = note['duration_beat']
            note_pitch = note['pitch']

            adjusted_onset = max(note_onset, start_beats)
            actual_duration = min(note_onset + note_duration, end_beats) - adjusted_onset

            if adjusted_onset > previous_end:
                rest_duration = adjusted_onset - previous_end
                if remove_rests:
                    actual_duration += rest_duration
                else:
                    duration_pitch.append([rest_duration, rest_pitch])

            duration_pitch.append([actual_duration, note_pitch])
            previous_end = adjusted_onset + actual_duration

        # Handle end rest
        if previous_end < end_beats:
            rest_duration = end_beats - previous_end
            if remove_rests:
                duration_pitch[-1][0] += rest_duration
            else:
                duration_pitch.append([rest_duration, rest_pitch])

        # Convert to numpy array and round durations
        bar_array = np.array(duration_pitch, dtype=float)
        bar_array[:, 0] = np.round(bar_array[:, 0], decimals=3)

        # Validate total duration
        total_duration = np.sum(bar_array[:, 0])
        if not np.isclose(total_duration, beats_per_bar, atol=1e-8):
            logger.warning(f"Bar {bar}: Total duration {total_duration} does not match expected duration {beats_per_bar}")
            duration_diff = beats_per_bar - total_duration
            bar_array[-1, 0] += duration_diff
            logger.info(f"Adjusted last note/rest duration by {duration_diff} to match expected duration")

        all_bars.append(bar_array)

    return all_bars

path = r"C:\Users\cunn2\OneDrive\DSML\Project\thesis-repo\data\synthetic_dataset\monophonic_midis\maestro\MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_mono.mid"
all_bars = midi_to_all_bars_efficient(path)

  return f(*args, **kwargs)
  part = create_part(
  part = create_part(
  part = create_part(
  part = create_part(
  part = create_part(
  part = create_part(
[2024-10-02 16:13:52] [INFO ] Bar 0 has 3 or fewer notes. Skipping.
[2024-10-02 16:13:52] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:13:52] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:13:52] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:13:52] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:13:52] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:13:52] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:13:52] [INFO ] Adjusted last note/rest duration by 0.25 to match expected duration
[2024-10-02 16:13:52] [INFO ] Adjusted last note/rest duration by 0.25 to match expect

In [8]:
all_bars

def test_bar_durations(bar_list: List[np.ndarray], expected_duration: float = 4.0, tolerance: float = 1e-6) -> bool:
    """
    Test if the sum of durations in each bar is equal to the expected duration.

    Args:
        bar_list (List[np.ndarray]): List of numpy arrays representing bars.
        expected_duration (float): Expected total duration of each bar (default is 4.0).
        tolerance (float): Tolerance for floating point comparison (default is 1e-6).

    Returns:
        bool: True if all bars have the correct total duration, False otherwise.
    """
    for i, bar in enumerate(bar_list):
        total_duration = np.sum(bar[:, 0])
        if not np.isclose(total_duration, expected_duration, atol=tolerance):
            print(f"Bar {i} has incorrect duration: {total_duration} (expected {expected_duration})")
            return False
    return True

test_bar_durations(all_bars)


True

In [14]:
from sms.src.synthetic_data.note_arr_mod import NoteArrayModifier, NoteArrayModifierSettings

settings = NoteArrayModifierSettings(
    transposition_semitone_range=(-4, 4),
    notes_to_pitch_shift=1,
    note_pitch_shift_range=(-4, 4),
    notes_to_scale=1,
    note_duration_scale_options=(0.5, 1.5, 2),
    notes_to_delete=1,
    notes_to_insert=1,
    insert_note_duration_options=(0.25, 0.5),
    insert_note_relative_pitch_range=(-4, 4)
)

aug_dict = {
    "use_transposition": True,
    "use_shift_selected_notes_pitch": True,
    "use_change_note_durations": True,
    "use_delete_notes": True,
    "use_insert_notes": True
}

mod = NoteArrayModifier(settings=settings)
for bar in all_bars:
    mod(bar, aug_dict)



[2024-10-02 16:30:45] [DEBUG] Transposing non-rest notes by 3 semitones.
[2024-10-02 16:30:45] [DEBUG] Shifting note at index 4 by 2 semitones.
[2024-10-02 16:30:45] [DEBUG] Scaling duration of note at index 0 by a factor of 1.5.
[2024-10-02 16:30:45] [DEBUG] Inserting note at index 0 with duration 0.25 and relative pitch -4.
[2024-10-02 16:30:45] [DEBUG] Deleting notes at indices [0].
[2024-10-02 16:30:45] [DEBUG] Truncated note 8 by 0.125 to maintain total duration.
[2024-10-02 16:30:45] [DEBUG] Transposing non-rest notes by -1 semitones.
[2024-10-02 16:30:45] [DEBUG] Shifting note at index 9 by -1 semitones.
[2024-10-02 16:30:45] [DEBUG] Scaling duration of note at index 4 by a factor of 1.5.
[2024-10-02 16:30:45] [DEBUG] Inserting note at index 1 with duration 0.5 and relative pitch -3.
[2024-10-02 16:30:45] [DEBUG] Deleting notes at indices [2].
[2024-10-02 16:30:45] [DEBUG] Removed note 9 with duration 0.25 to adjust total duration.
[2024-10-02 16:30:45] [DEBUG] Transposing non-r

In [3]:
import torch

ds = torch.load(r"data\exp2\all_chunks.pt")


  ds = torch.load(r"data\exp2\all_chunks.pt")


In [9]:
import sys

import os
import torch
import logging
from typing import Dict, List
import numpy as np
import sys

from sms.src.synthetic_data.midi_to_note_arrays import midi_to_all_bars_efficient
from sms.src.synthetic_data.note_arr_mod import NoteArrayModifier, NoteArrayModifierSettings
from sms.src.log import configure_logging
from sms.defaults import MAESTRO_PATH, MTC_PATH, MAESTRO_SEGMENTS_PATH, MTC_SEGMENTS_PATH

def augment_all_note_arrays(
        input_file: str, 
        output_file: str, 
        num_augmentations: int, 
        total_songs: int
        ) -> None:
    """
    Augments note arrays from the input file and saves the result to the output file.

    Args:
        input_file (str): Path to the input file containing note arrays.
        output_file (str): Path to save the augmented note arrays.
        num_augmentations (int): Number of times to augment each chunk.
        total_songs (int): Total number of augmented songs to generate.
    """
    settings = NoteArrayModifierSettings(
        transposition_semitone_range=(-4, 4),
        notes_to_pitch_shift=1,
        note_pitch_shift_range=(-4, 4),
        notes_to_scale=1,
        note_duration_scale_options=(0.5, 1.5, 2),
        notes_to_delete=1,
        notes_to_insert=1,
        insert_note_duration_options=(0.25, 0.5),
        insert_note_relative_pitch_range=(-4, 4)
    )

    aug_dict = {
        "use_transposition": True,
        "use_shift_selected_notes_pitch": True,
        "use_change_note_durations": True,
        "use_delete_notes": True,
        "use_insert_notes": True
    }

    modifier = NoteArrayModifier(settings=settings)

    input_chunks = torch.load(input_file)
    augmented_chunks = {}
    
    song_names = list(input_chunks.keys())
    augmented_count = 0

    # while augmented_count < total_songs:
    #     song_name = np.random.choice(song_names)
    #     chunk = np.random.choice(input_chunks[song_name])
    #     for _ in range(num_augmentations):
    #         if augmented_count >= total_songs:
    #             break
    #         augmented_chunk = modifier.modify_note_array(chunk, **aug_dict)
    #         aug_song_name = f"{song_name}_aug_{augmented_count}"
    #         augmented_chunks[aug_song_name] = augmented_chunk
    #         augmented_count += 1

    for song_name, chunks in input_chunks.items():
        if not isinstance(chunks, (list, np.ndarray)):
            chunks = [chunks]  # Convert single chunk to list
        
        for i, chunk in enumerate(chunks):
            for j in range(num_augmentations):
                if augmented_count >= total_songs:
                    break
                
                augmented_chunk = modifier(chunk, aug_dict)
                aug_song_name = f"{song_name}_chunk{i}_aug{j}"
                augmented_chunks[aug_song_name] = augmented_chunk
                augmented_count += 1
            
            if augmented_count >= total_songs:
                break
        
        if augmented_count >= total_songs:
            break

    torch.save(augmented_chunks, output_file)
    logger.info(f"Saved {len(augmented_chunks)} augmented chunks to {output_file}")

    # Estimate RAM usage
    ram_usage = sys.getsizeof(augmented_chunks) / (1024 * 1024)  # Convert to MB
    logger.info(f"Estimated RAM usage: {ram_usage:.2f} MB")


augment_all_note_arrays(
    r"data\exp2\all_chunks.pt",
    r"data\exp2\augmented_chunks.pt",
    4,
    1_000_000
)


  input_chunks = torch.load(input_file)
[2024-10-02 23:30:50] [DEBUG] Transposing non-rest notes by -1 semitones.
[2024-10-02 23:30:50] [DEBUG] Shifting note at index 2 by 3 semitones.
[2024-10-02 23:30:50] [DEBUG] Scaling duration of note at index 3 by a factor of 1.5.
[2024-10-02 23:30:50] [DEBUG] Inserting note at index 1 with duration 0.25 and relative pitch 2.
[2024-10-02 23:30:50] [DEBUG] Deleting notes at indices [1].
[2024-10-02 23:30:50] [DEBUG] Truncated note 4 by 0.125 to maintain total duration.
[2024-10-02 23:30:50] [DEBUG] Transposing non-rest notes by 0 semitones.
[2024-10-02 23:30:50] [DEBUG] Shifting note at index 4 by 0 semitones.
[2024-10-02 23:30:50] [DEBUG] Scaling duration of note at index 3 by a factor of 0.5.
[2024-10-02 23:30:50] [DEBUG] Inserting note at index 2 with duration 0.25 and relative pitch -4.
[2024-10-02 23:30:50] [DEBUG] Deleting notes at indices [3].
[2024-10-02 23:30:50] [DEBUG] Elongated last note by 0.625 to maintain total duration.
[2024-10-02

KeyboardInterrupt: 

In [5]:
import torch
import os

os.chdir(r"C:\Users\cunn2\OneDrive\DSML\Project\thesis-repo")

ds = torch.load(r"augmented_chunks.pt")



  ds = torch.load(r"augmented_chunks.pt")


In [42]:
import torch
import os

os.chdir(r"C:\Users\cunn2\OneDrive\DSML\Project\thesis-repo")

selected_keys = torch.load(r"data\exp2\augmented_embeddings\selected_keys.pt")
subset_keys = torch.load(r"data\exp2\augmented_embeddings\subset_keys.pt")

for size, subset_key_list in subset_keys.items():
    selected_key_list = selected_keys[size]
    # Convert lists to sets for efficient set operations
    subset_key_set = set(subset_key_list)
    selected_key_set = set(selected_key_list)
    
    # Perform set difference
    diff_keys = subset_key_set - selected_key_set
    
    print(f"For size {size}:")
    print(f"  Number of keys in subset: {len(subset_key_set)}")
    print(f"  Number of keys in selected: {len(selected_key_set)}")
    print(f"  Number of keys in subset but not in selected: {len(diff_keys)}")
    
    if len(diff_keys) > 0:
        print("  First 5 keys in difference (or all if less than 5):")
        for key in list(diff_keys)[:5]:
            print(f"    {key}")
    print()  # Add a blank line for readability between sizes


  selected_keys = torch.load(r"data\exp2\augmented_embeddings\selected_keys.pt")
  subset_keys = torch.load(r"data\exp2\augmented_embeddings\subset_keys.pt")


For size 5k:
  Number of keys in subset: 5000
  Number of keys in selected: 100
  Number of keys in subset but not in selected: 4900
  First 5 keys in difference (or all if less than 5):
    NLB137936_01.mid_chunk9_aug1
    MIDI-Unprocessed_08_R2_2009_01_ORIG_MID--AUDIO_08_R2_2009_08_R2_2009_04_WAV_mono.mid_chunk54_aug1
    NLB074154_01.mid_chunk2_aug0
    MIDI-Unprocessed_14_R1_2008_01-05_ORIG_MID--AUDIO_14_R1_2008_wav--1_mono.mid_chunk81_aug1
    MIDI-UNPROCESSED_06-08_R1_2014_MID--AUDIO_06_R1_2014_wav--2_mono.mid_chunk8_aug2

For size 10k:
  Number of keys in subset: 10000
  Number of keys in selected: 200
  Number of keys in subset but not in selected: 9800
  First 5 keys in difference (or all if less than 5):
    NLB075172_01.mid_chunk6_aug0
    NLB140245_01.mid_chunk4_aug2
    NLB135607_01.mid_chunk0_aug2
    MIDI-UNPROCESSED_19-20_R1_2014_MID--AUDIO_20_R1_2014_wav--2_mono.mid_chunk16_aug2
    MIDI-Unprocessed_049_PIANO049_MID--AUDIO-split_07-06-17_Piano-e_2-06_wav--2_mono.mid_ch

In [44]:
pr_5k = torch.load(r"data\exp2\augmented_embeddings\transformer_pr_1_pretrain_aug_5k_embeddings.pt")

  pr_5k = torch.load(r"data\exp2\augmented_embeddings\transformer_pr_1_pretrain_aug_5k_embeddings.pt")


In [45]:
pr_5k_keys = list(pr_5k.keys())


In [46]:
len(set(selected_keys['5k']))


100

In [47]:
diff = set(pr_5k_keys) - set(selected_keys['5k'])

len(diff)

0

In [37]:
augs = torch.load(r"data\exp2\augmented_embeddings\precomputed_augmentations.pt")
aug_5k_keys = list(augs['10k'].keys())


  augs = torch.load(r"data\exp2\augmented_embeddings\precomputed_augmentations.pt")


In [41]:
len(aug_5k_keys)

diff = set(aug_5k_keys) - set(selected_keys['10k'])

len(diff)



0