# MUSIC GAN

## INITIALIZATION

### Imports

In [17]:
import music21
import os
import pickle as pkl
import keras
import tensorflow as tf
from typing import List, Tuple, Optional

### Functions

In [18]:
def parse_midi_files(
    file_list: List[str],
    parser: music21.converter.Converter,
    seq_len: int,
    parsed_data_path: Optional[str] = None
) -> Tuple[List[str], List[str]]:
    """
    Parses a list of MIDI files and extracts note sequences and their corresponding durations.

    Args:
        file_list (List[str]): A list of file paths to MIDI files to be parsed.
        parser (music21.converter.Converter): A music21 parser object to read and process the MIDI files.
        seq_len (int): The length of the note and duration sequences to be generated.
        parsed_data_path (Optional[str]): Path to save the parsed note and duration sequences (if provided).

    Returns:
        Tuple[List[str], List[str]]: A tuple containing two lists:
            - notes_list: A list of note sequences of length `seq_len`.
            - duration_list: A list of corresponding duration sequences of length `seq_len`.
    """
    notes_list = []  # List to store sequences of notes
    duration_list = []  # List to store sequences of durations
    notes = []  # Temporary list to hold notes for a single file
    durations = []  # Temporary list to hold durations for a single file

    # Loop through each MIDI file in the file list
    for i, file in enumerate(file_list):
        print(i + 1, f'Parsing {file}')
        score = parser.parse(file).chordify()  # Convert the score to chords

        # Add a start token to the sequence
        notes.append('START')
        durations.append('0.0')

        # Iterate over the flattened elements of the score
        for element in score.flat:
            note_name = None
            duration_name = None

            # Determine the type of musical element and extract the note and duration
            if isinstance(element, music21.key.Key):
                note_name = f"{element.tonic.name}:{element.mode}"  # Extract key signature
                duration_name = '0.0'
            elif isinstance(element, music21.meter.TimeSignature):
                note_name = f"{element.ratioString}TS"  # Extract time signature
                duration_name = '0.0'
            elif isinstance(element, music21.chord.Chord):
                note_name = element.pitches[-1].nameWithOctave  # Use the highest pitch in the chord
                duration_name = str(element.duration.quarterLength)
            elif isinstance(element, music21.note.Rest):
                note_name = str(element.name)  # Extract rest
                duration_name = str(element.duration.quarterLength)
            elif isinstance(element, music21.note.Note):
                note_name = element.nameWithOctave  # Extract note with octave
                duration_name = str(element.duration.quarterLength)

            # Append note and duration if both were successfully extracted
            if note_name and duration_name:
                notes.append(note_name)
                durations.append(duration_name)

        print(f'{len(notes)} notes parsed')  # Log the number of parsed notes

    # Generate sequences of notes and durations of length `seq_len`
    print(f'Building sequences of length {seq_len}')
    for i in range(len(notes) - seq_len):
        notes_list.append(' '.join(notes[i: i + seq_len]))
        duration_list.append(' '.join(durations[i: i + seq_len]))

    # Save the parsed sequences to files if `parsed_data_path` is provided
    if parsed_data_path:
        with open(os.path.join(parsed_data_path, 'notes.pkl'), 'wb') as f:
            pkl.dump(notes_list, f)
        with open(os.path.join(parsed_data_path, 'durations.pkl'), 'wb') as f:
            pkl.dump(duration_list, f)

    return notes_list, duration_list

def load_parsed_files(parsed_data_path: str) -> Tuple[List[str], List[str]]:
    """
    Loads the parsed note and duration sequences from pickle files.

    Args:
        parsed_data_path (str): The directory path where the parsed files ('notes' and 'durations') are stored.

    Returns:
        Tuple[List[str], List[str]]: A tuple containing:
            - notes (List[str]): A list of note sequences.
            - durations (List[str]): A list of corresponding duration sequences.
    """
    # Load the note sequences from the pickle file
    with open(os.path.join(parsed_data_path, 'notes'), 'rb') as f:
        notes = pkl.load(f)

    # Load the duration sequences from the pickle file
    with open(os.path.join(parsed_data_path, 'durations'), 'rb') as f:
        durations = pkl.load(f)

    return notes, durations

def get_midi_notes(sample_note: str, sample_duration: str) -> music21.note.NotRest:
    """
    Converts a sample note and its duration into a corresponding Music21 note, chord, rest, or key signature object.

    Args:
        sample_note (str): The name of the note or musical element (e.g., 'C4', 'rest', 'C4.E4.G4' for a chord, or '4/4TS').
        sample_duration (str): The duration of the note in terms of quarter length, represented as a string (e.g., '0.5', '1', '1.5').

    Returns:
        new_note (music21.note.Note, music21.chord.Chord, music21.key.Key, or music21.meter.TimeSignature): 
        A Music21 object corresponding to the provided note and duration.
    """
    new_note = None

    # Handle Time Signature (e.g., '4/4TS')
    if 'TS' in sample_note:
        new_note = music21.meter.TimeSignature(sample_note.split('TS')[0])

    # Handle Key Signature (e.g., 'C:major', 'A:minor')
    elif 'major' in sample_note or 'minor' in sample_note:
        tonic, mode = sample_note.split(':')
        new_note = music21.key.Key(tonic, mode)

    # Handle Rest
    elif sample_note == 'rest':
        new_note = music21.note.Rest()
        new_note.duration = music21.duration.Duration(float(Fraction(sample_duration)))
        new_note.storedInstrument = music21.instrument.Violoncello()

    # Handle Chord (e.g., 'C4.E4.G4')
    elif '.' in sample_note:
        notes_in_chord = sample_note.split('.')
        chord_notes = []

        for current_note in notes_in_chord:
            n = music21.note.Note(current_note)
            n.duration = music21.duration.Duration(float(Fraction(sample_duration)))
            n.storedInstrument = music21.instrument.Violoncello()
            chord_notes.append(n)

        new_note = music21.chord.Chord(chord_notes)

    # Handle Single Note (e.g., 'C4')
    elif sample_note != 'START':
        new_note = music21.note.Note(sample_note)
        new_note.duration = music21.duration.Duration(float(Fraction(sample_duration)))
        new_note.storedInstrument = music21.instrument.Violoncello()

    return new_note

### Classes

In [None]:
class SinePositionEncoding(keras.layers.Layer):
    """
    Sinusoidal positional encoding layer.

    This layer computes positional encodings using a mixture of sine and cosine
    functions with geometrically increasing wavelengths, as described in the paper
    'Attention Is All You Need' by Vaswani et al. (2017).

    Args:
        max_wavelength (int): The maximum angular wavelength of the sine and cosine
                              functions. Default is 10,000.

    Input:
        inputs (tf.Tensor): A 3D tensor of shape [batch_size, sequence_length, hidden_size],
                            where hidden_size is the feature dimension.

    Output:
        tf.Tensor: A positional encoding tensor with the same shape as the input.

    Example:
        ```python
        seq_len = 100
        vocab_size = 1000
        embedding_dim = 32

        inputs = keras.Input((seq_len,), dtype=tf.float32)
        embedding = keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim)(inputs)
        positional_encoding = SinePositionEncoding()(embedding)
        outputs = embedding + positional_encoding
        ```
    """

    def __init__(self, max_wavelength: int = 10000, **kwargs):
        """
        Initialize the SinePositionEncoding layer.

        Args:
            max_wavelength (int): The maximum angular wavelength of the sine and cosine
                                  curves used for positional encoding.
            **kwargs: Additional keyword arguments passed to the base `keras.layers.Layer` class.
        """
        super().__init__(**kwargs)
        self.max_wavelength = max_wavelength

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        """
        Compute the sinusoidal positional encodings for the given input tensor.

        Args:
            inputs (tf.Tensor): A 3D tensor of shape [batch_size, sequence_length, hidden_size].

        Returns:
            tf.Tensor: A 3D tensor of shape [batch_size, sequence_length, hidden_size]
                       containing the sinusoidal positional encodings.
        """
        # Get the input shape dynamically
        input_shape = tf.shape(inputs)

        # Sequence length is the second-to-last dimension of the input
        seq_length = input_shape[-2]
        # Hidden size (embedding dimension) is the last dimension of the input
        hidden_size = input_shape[-1]

        # Create a tensor of positions from 0 to seq_length - 1, cast to the layer's dtype
        position = tf.cast(tf.range(seq_length), self.compute_dtype)

        # Compute the minimum frequency for the sinusoidal functions
        min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype)

        # Compute the scaling timescales for each hidden dimension
        # Shape: [hidden_size]
        timescales = tf.pow(
            min_freq,
            tf.cast(2 * (tf.range(hidden_size) // 2), self.compute_dtype)
            / tf.cast(hidden_size, self.compute_dtype),
        )

        # Compute the angles (position * timescales) for sine and cosine functions
        # Shape: [seq_length, hidden_size]
        angles = tf.expand_dims(position, 1) * tf.expand_dims(timescales, 0)

        # Create masks for sine (even indices) and cosine (odd indices)
        cos_mask = tf.cast(tf.range(hidden_size) % 2, self.compute_dtype)  # [0, 1, 0, 1, ...]
        sin_mask = 1 - cos_mask  # [1, 0, 1, 0, ...]

        # Compute positional encodings by applying sine to even indices and cosine to odd indices
        # Shape: [seq_length, hidden_size]
        positional_encodings = (
            tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask
        )

        # Broadcast the positional encodings to match the input shape
        return tf.broadcast_to(positional_encodings, input_shape)

    def get_config(self) -> dict:
        """
        Return the configuration of the layer for serialization.

        Returns:
            dict: A dictionary containing the layer's configuration.
        """
        config = super().get_config()
        config.update({"max_wavelength": self.max_wavelength})
        return config
        

### Dataset

In [None]:
# Create directory if it doesn't exist
output_dir = "data/bach-cello"
os.makedirs(output_dir, exist_ok=True)

# List of URLs to download
urls = [
    "http://www.jsbach.net/midi/cs1-1pre.mid",
    "http://www.jsbach.net/midi/cs1-2all.mid",
    "http://www.jsbach.net/midi/cs1-3cou.mid",
    "http://www.jsbach.net/midi/cs1-4sar.mid",
    "http://www.jsbach.net/midi/cs1-5men.mid",
    "http://www.jsbach.net/midi/cs1-6gig.mid",
    "http://www.jsbach.net/midi/cs2-1pre.mid",
    "http://www.jsbach.net/midi/cs2-2all.mid",
    "http://www.jsbach.net/midi/cs2-3cou.mid",
    "http://www.jsbach.net/midi/cs2-4sar.mid",
    "http://www.jsbach.net/midi/cs2-5men.mid",
    "http://www.jsbach.net/midi/cs2-6gig.mid",
    "http://www.jsbach.net/midi/cs3-1pre.mid",
    "http://www.jsbach.net/midi/cs3-2all.mid",
    "http://www.jsbach.net/midi/cs3-3cou.mid",
    "http://www.jsbach.net/midi/cs3-4sar.mid",
    "http://www.jsbach.net/midi/cs3-5bou.mid",
    "http://www.jsbach.net/midi/cs3-6gig.mid",
    "http://www.jsbach.net/midi/cs4-1pre.mid",
    "http://www.jsbach.net/midi/cs4-2all.mid",
    "http://www.jsbach.net/midi/cs4-3cou.mid",
    "http://www.jsbach.net/midi/cs4-4sar.mid",
    "http://www.jsbach.net/midi/cs4-5bou.mid",
    "http://www.jsbach.net/midi/cs4-6gig.mid",
    "http://www.jsbach.net/midi/cs5-1pre.mid",
    "http://www.jsbach.net/midi/cs5-2all.mid",
    "http://www.jsbach.net/midi/cs5-3cou.mid",
    "http://www.jsbach.net/midi/cs5-4sar.mid",
    "http://www.jsbach.net/midi/cs5-5gav.mid",
    "http://www.jsbach.net/midi/cs5-6gig.mid",
    "http://www.jsbach.net/midi/cs6-1pre.mid",
    "http://www.jsbach.net/midi/cs6-2all.mid",
    "http://www.jsbach.net/midi/cs6-3cou.mid",
    "http://www.jsbach.net/midi/cs6-4sar.mid",
    "http://www.jsbach.net/midi/cs6-5gav.mid",
    "http://www.jsbach.net/midi/cs6-6gig.mid"
]

# Function to download a file
def download_file(url, output_dir):
    file_name = os.path.join(output_dir, os.path.basename(url))
    print(f"Downloading {file_name}...")
    response = requests.get(url)
    if response.status_code == 200:
        with open(file_name, 'wb') as f:
            f.write(response.content)
        print(f"✅ Downloaded: {file_name}")
    else:
        print(f"❌ Failed to download: {file_name} (Status: {response.status_code})")

# Download each file in the list
for url in urls:
    download_file(url, output_dir)