In [None]:
import pretty_midi
import numpy as np
import matplotlib.pyplot as plt


# Define the vocabulary
EVENTS = ['note_on', 'note_off', 'velocity', 'time_shift']
class Event:
    def __init__(self, type, value, time, vel_pitch = None):
        self.type = type
        self.value = value
        self.time = time
        if vel_pitch is not None:
            self.vel_pitch = vel_pitch
        else:
            self.vel_pitch = value
    def __lt__(self, other):
        return self.time < other.time or(self.time == other.time and self.vel_pitch < other.vel_pitch)
    def __repr__(self) -> str:
        return f"Event({self.type}, {self.value}, {self.time})"
    def __str__(self) -> str:
        return f"Event({self.type}, {self.value}, {self.time})"
    
def get_pitch_histogram(file_path, timestamp, duration=2.0):
    # Load the MIDI file
    midi_data = pretty_midi.PrettyMIDI(file_path)
    
    # Initialize an empty list to collect pitches
    pitches = []
    
    # Iterate through the notes and collect pitches within the time window
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            if timestamp <= note.start < timestamp + duration:
                pitches.append(note.pitch)
    
    # Create a histogram of the collected pitches
    pitch_histogram, bins = np.histogram(pitches, bins=np.arange(21, 109))
    
    # Plot the histogram
    plt.bar(bins[:-1], pitch_histogram, width=1.0, edgecolor='black')
    plt.xlabel('Pitch')
    plt.ylabel('Count')
    plt.title(f'Pitch Histogram from {timestamp} to {timestamp + duration} seconds')
    plt.show()
    
    return pitch_histogram

        
    
    
# Function to encode MIDI file
def encode_midi(file_path):
    # Load the MIDI file
    midi_data = pretty_midi.PrettyMIDI(file_path)
    
    # Initialize the encoded events
    encoded_events = []
    # events are stored as tuples (type, value, time),
    # where at first we only add note on note off and velocity events
    
    for instrument in midi_data.instruments:
        print(instrument.program)
        print(instrument.notes[0].start)
        for note in instrument.notes:
            if note.pitch >= 21 and note.pitch <= 108:
                velocity = max(0, min(127, int(note.velocity)))
                encoded_events.append(Event('velocity', velocity//4, note.start, note.pitch))
                encoded_events.append(Event('note_on', note.pitch, note.start))
                encoded_events.append(Event('note_off', note.pitch, note.end))
        
                
                
    # Sort the encoded events by time
    encoded_events.sort()
    all_events = []
    # Add the time offset event by iterating through the encoded events
    # and adding difference in time between the current and previous event
    for i in range(len(encoded_events)):
        if i == 0:
            all_events.append(encoded_events[i])
        else:
            time_offset = encoded_events[i].time - encoded_events[i-1].time
            if time_offset != 0:
                all_events.append(Event('time_shift', time_offset, encoded_events[i].time))
            all_events.append(encoded_events[i])
    return all_events
# Example usage

file_path = 'alb_se2.mid'
encoded_events = encode_midi(file_path)
print(encoded_events)

timestamp = 0.0
duration = 2.0
pitch_histogram = get_pitch_histogram(file_path, timestamp, duration)
print(pitch_histogram)
