## Homework 3: Symbolic Music Generation Using Markov Chains

**Before starting the homework:**

Please run `pip install miditok` to install the [MiDiTok](https://github.com/Natooz/MidiTok) package, which simplifies MIDI file processing by making note and beat extraction more straightforward.

You’re also welcome to experiment with other MIDI processing libraries such as [mido](https://github.com/mido/mido), [pretty_midi](https://github.com/craffel/pretty-midi) and [miditoolkit](https://github.com/YatingMusic/miditoolkit). However, with these libraries, you’ll need to handle MIDI quantization yourself, for example, converting note-on/note-off events into beat positions and durations.

In [1]:
# install some packages
# !pip install miditok
# !pip install symusic

In [2]:
# import required packages
import random
random.seed(42)
from glob import glob
from collections import defaultdict

import numpy as np
from numpy.random import choice

from symusic import Score
from miditok import REMI, TokenizerConfig
from midiutil import MIDIFile

  from .autonotebook import tqdm as notebook_tqdm


### Load music dataset
We use a subset of [PDMX dataset](https://zenodo.org/records/14984509) for this homework. 

Please download the data through XXXXX and unzip.

All pieces are monophonic music (i.e. one melody line) in time signature 4/4.

In [3]:
midi_files = glob('nesmdb_midi/train/*.mid')
len(midi_files)

4502

### Train a tokenizer with the REMI method in MidiTok

In [4]:
config = TokenizerConfig(num_velocities=1, use_chords=False, use_programs=False)
tokenizer = REMI(config)
tokenizer.train(vocab_size=1000, files_paths=midi_files)

### Use the trained tokenizer to get tokens for each midi file
In REMI representation, each note will be represented with four tokens: `Position, Pitch, Velocity, Duration`, e.g. `('Position_28', 'Pitch_74', 'Velocity_127', 'Duration_0.4.8')`, and `Bar_None` token indicates the beginning of a new bar.

In [5]:
midi = Score(midi_files[0])
tokens = tokenizer(midi)[0].tokens
tokens[:10]

['Bar_None',
 'Position_0',
 'Pitch_65',
 'Velocity_127',
 'Duration_0.6.8',
 'Position_6',
 'Pitch_69',
 'Velocity_127',
 'Duration_0.2.8',
 'Position_9']

1. Write a function to extract note pitch events from a midi file; extract all note pitch events from the dataset and output a dictionary that maps note pitch events to the number of times they occur in the files. (e.g. {60: 120, 61: 58, …}).

`note_extraction()`
- **Input**: a midi file

- **Output**: a list of note pitch events

`note_frequency()`
- **Input**: all midi files `midi_files`

- **Output**: a dictionary that maps note pitch events to the number of times they occur, e.g {60: 120, 61: 58, …}

In [6]:
def note_extraction(midi_file):
    # Q1a: Your code goes here
    note_events = []
    midi = Score(midi_file)
    tokens = tokenizer(midi)[0].tokens
    for token in tokens:
        if 'Pitch' in token:
            note = int(token.split('_')[1])
            note_events.append(note)
    return note_events

In [7]:
def note_frequency(midi_files):
    # Q1b: Your code goes here
    note_counts = defaultdict(int)
    for midi_file in midi_files:
        note_events = note_extraction(midi_file)
        for note in note_events:
            note_counts[note] += 1
    return note_counts

2. Write a function to normalize the above dictionary to produce probability scores. (e.g. {60: 0.13, 61: 0.065, …})

`note_unigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: a dictionary that maps note pitch events to probabilities they occur in the dataset, e.g. {60: 0.13, 61: 0.06, …}

In [8]:
def note_unigram_probability(midi_files):
    note_counts = note_frequency(midi_files)
    
    # Q2: Your code goes here
    unigramProbabilities = {}
    counts = sum(list(note_counts.values()))
    for n in note_counts:
        unigramProbabilities[n] = note_counts[n] / counts
    return unigramProbabilities

3. Generate a table of pairwise probabilities containing p(next_note | previous_note) for the dataset; write a function that randomly generates the next note based on the previous note based on this distribution.

`note_bigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `bigramTransitions`: key - previous_note, value - a list of next_note, e.g. {60:[62, 64, ..], 62:[60, 64, ..], ...}

  - `bigramTransitionProbabilities`: key - previous_note, value - a list of probabilities for next_note in the same order of `bigramTransitions`, e.g. {60:[0.3, 0.4, ..], 62:[0.2, 0.1, ..], ...}

`sample_next_note()`
- **Input**: a note

- **Output**: next note sampled from pairwise probabilities

In [9]:
def note_bigram_probability(midi_files):
    # Q3a: Your code goes here
    bigrams = defaultdict(int)
    
    for file in midi_files:
        note_events = note_extraction(file)
        for (note1, note2) in zip(note_events[:-1], note_events[1:]):
            bigrams[(note1, note2)] += 1
            
    bigramTransitions = defaultdict(list)
    bigramTransitionProbabilities = defaultdict(list)

    for b1,b2 in bigrams:
        bigramTransitions[b1].append(b2)
        bigramTransitionProbabilities[b1].append(bigrams[(b1,b2)])
        
    for k in bigramTransitionProbabilities:
        Z = sum(bigramTransitionProbabilities[k])
        bigramTransitionProbabilities[k] = [x / Z for x in bigramTransitionProbabilities[k]]
        
    return bigramTransitions, bigramTransitionProbabilities

In [10]:
def sample_next_note(note):
    # Q3b: Your code goes here
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    next_note = choice(bigramTransitions[note], 1, p=bigramTransitionProbabilities[note])[0]
    return next_note

4. Write a function to calculate the perplexity of your model on a midi file.

    The perplexity of a model is defined as 

    $\quad \text{exp}(-\frac{1}{N} \sum_{i=1}^N \text{log}(p(w_i|w_{i-1})))$

    where $p(w_1|w_0) = p(w_1)$, $p(w_i|w_{i-1}) (i>1)$ refers to the pairwise probability p(next_note | previous_note).

`note_bigram_perplexity()`
- **Input**: a midi file

- **Output**: perplexity value

In [11]:
def note_bigram_perplexity(midi_file):
    unigramProbabilities = note_unigram_probability(midi_files)
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    
    # Q4: Your code goes here
    note_events = note_extraction(midi_file)
    perplexities = [unigramProbabilities[note_events[0]]]
    for (note1, note2) in zip(note_events[:-1], note_events[1:]):
        index = bigramTransitions[note1].index(note2)
        prob = bigramTransitionProbabilities[note1][index]
        perplexities.append(prob)

    assert len(perplexities) == len(note_events)
    perplexity = np.exp(-np.sum(np.log(perplexities)) / len(note_events))
    return perplexity

5. Implement a second-order Markov chain, i.e., one which estimates p(next_note | next_previous_note, previous_note); write a function to compute the perplexity of this new model on a midi file. 

    The perplexity of this model is defined as 

    $\quad \text{exp}(-\frac{1}{N} \sum_{i=1}^N \text{log}(p(w_i|w_{i-2}, w_{i-1})))$

    where $p(w_1|w_{-1}, w_0) = p(w_1)$, $p(w_2|w_0, w_1) = p(w_2|w_1)$, $p(w_i|w_{i-2}, w_{i-1}) (i>2)$ refers to the probability p(next_note | next_previous_note, previous_note).


`note_trigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `trigramTransitions`: key - (next_previous_note, previous_note), value - a list of next_note, e.g. {(60, 62):[64, 66, ..], (60, 64):[60, 64, ..], ...}

  - `trigramTransitionProbabilities`: key - (next_previous_note, previous_note), value - a list of probabilities for next_note in the same order of `trigramTransitions`, e.g. {(60, 62):[0.2, 0.2, ..], (60, 64):[0.4, 0.1, ..], ...}

`note_trigram_perplexity()`
- **Input**: a midi file

- **Output**: perplexity value

In [12]:
def note_trigram_probability(midi_files):
    # Q5a: Your code goes here
    trigrams = defaultdict(int)
    for file in midi_files:
        note_events = note_extraction(file)
        for (note1, note2, note3) in zip(note_events[:-2], note_events[1:-1], note_events[2:]):
            trigrams[(note1, note2, note3)] += 1
            
    trigramTransitions = defaultdict(list)
    trigramTransitionProbabilities = defaultdict(list)

    for t1,t2,t3 in trigrams:
        trigramTransitions[(t1,t2)].append(t3)
        trigramTransitionProbabilities[(t1,t2)].append(trigrams[(t1,t2,t3)])
        
    for k in trigramTransitionProbabilities:
        Z = sum(trigramTransitionProbabilities[k])
        trigramTransitionProbabilities[k] = [x / Z for x in trigramTransitionProbabilities[k]]
        
    return trigramTransitions, trigramTransitionProbabilities

In [13]:
def note_trigram_perplexity(midi_file):
    unigramProbabilities = note_unigram_probability(midi_files)
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    trigramTransitions, trigramTransitionProbabilities = note_trigram_probability(midi_files)
    
    # Q5b: Your code goes here
    note_events = note_extraction(midi_file)
    perplexities = [unigramProbabilities[note_events[0]]]
    index = bigramTransitions[note_events[0]].index(note_events[1])
    prob = bigramTransitionProbabilities[note_events[0]][index]
    perplexities.append(prob)
    
    for (note1, note2, note3) in zip(note_events[:-2], note_events[1:-1], note_events[2:]):
        index = trigramTransitions[(note1, note2)].index(note3)
        prob = trigramTransitionProbabilities[(note1, note2)][index]
        perplexities.append(prob)

    assert len(perplexities) == len(note_events)
    perplexity = np.exp(-np.sum(np.log(perplexities)) / len(note_events))
    return perplexity

6. Our model currently doesn’t have any knowledge of beats. Write a function that extracts beat lengths and outputs a list of [(beat position; beat length)] values.

    Recall that each note will be encoded as `Position, Pitch, Velocity, Duration` using REMI. Please keep the `Position` value for beat position, and convert `Duration` to beat length using provided lookup table `duration2length` (see below).

    For example, for a note represented by four tokens `('Position_24', 'Pitch_72', 'Velocity_127', 'Duration_0.4.8')`, the extracted (beat position; beat length) value is `(24, 4)`.

    As a result, we will obtain a list like [(0,8),(8,16),(24,4),(28,4),(0,4)...], where the next beat position is the previous beat position + the beat length. As we divide each bar into 32 positions by default, when reaching the end of a bar (i.e. 28 + 4 = 32 in the case of (28, 4)), the beat position reset to 0.

In [14]:
duration2length = {
    '0.2.8': 2,  # sixteenth note, 0.25 beat in 4/4 time signature
    '0.4.8': 4,  # eighth note, 0.5 beat in 4/4 time signature
    '1.0.8': 8,  # quarter note, 1 beat in 4/4 time signature
    '2.0.8': 16, # half note, 2 beats in 4/4 time signature
    '4.0.4': 32, # whole note, 4 beats in 4/4 time signature
}

`beat_extraction()`
- **Input**: a midi file

- **Output**: a list of (beat position; beat length) values

In [15]:
def beat_extraction(midi_file):
    # Q6: Your code goes here
    midi = Score(midi_file)
    tokens = tokenizer(midi)[0].tokens
    beats = []
    
    for i in range(len(tokens)):
        if 'Position' in tokens[i] and 'Duration' in tokens[i+3]:
            position = int(tokens[i].split('_')[1])
            encoded_length = tokens[i+3].split('_')[1]
            length = int(encoded_length.split('.')[0]) * 8 + int(encoded_length.split('.')[1]) # TODO: is this correct?
            # length = duration2length[tokens[i+3].split('_')[1]]
            beats.append((position, length))
    return beats

7. Implement a Markov chain that computes p(beat_length | previous_beat_length) based on the above function.

`beat_bigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `bigramBeatTransitions`: key - previous_beat_length, value - a list of beat_length, e.g. {4:[8, 2, ..], 8:[8, 4, ..], ...}

  - `bigramBeatTransitionProbabilities`: key - previous_beat_length, value - a list of probabilities for beat_length in the same order of `bigramBeatTransitions`, e.g. {4:[0.3, 0.2, ..], 8:[0.4, 0.4, ..], ...}

In [16]:
def beat_bigram_probability(midi_files):
    # Q7: Your code goes here
    bigramBeat = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for (beat1, beat2) in zip(beats[:-1], beats[1:]):
            bigramBeat[(beat1[1], beat2[1])] += 1
            
    bigramBeatTransitions = defaultdict(list)
    bigramBeatTransitionProbabilities = defaultdict(list)

    for b1,b2 in bigramBeat:
        bigramBeatTransitions[b1].append(b2)
        bigramBeatTransitionProbabilities[b1].append(bigramBeat[(b1,b2)])
        
    for k in bigramBeatTransitionProbabilities:
        Z = sum(bigramBeatTransitionProbabilities[k])
        bigramBeatTransitionProbabilities[k] = [x / Z for x in bigramBeatTransitionProbabilities[k]]
        
    return bigramBeatTransitions, bigramBeatTransitionProbabilities

8. Implement a function to compute p(beat length | beat position), and compute the perplexity of your models from Q7 and Q8. For both models, we only consider the probabilities of predicting the sequence of **beat length**.

`beat_pos_bigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `bigramBeatPosTransitions`: key - beat_position, value - a list of beat_length

  - `bigramBeatPosTransitionProbabilities`: key - beat_position, value - a list of probabilities for beat_length in the same order of `bigramBeatPosTransitions`

`beat_bigram_perplexity()`
- **Input**: a midi file

- **Output**: two perplexity values correspond to the models in Q7 and Q8, respectively

In [17]:
def beat_pos_bigram_probability(midi_files):
    # Q8a: Your code goes here
    bigramBeatPos = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for beat in beats:
            bigramBeatPos[(beat[0], beat[1])] += 1
            
    bigramBeatPosTransitions = defaultdict(list)
    bigramBeatPosTransitionProbabilities = defaultdict(list)

    for b1,b2 in bigramBeatPos:
        bigramBeatPosTransitions[b1].append(b2)
        bigramBeatPosTransitionProbabilities[b1].append(bigramBeatPos[(b1,b2)])
        
    for k in bigramBeatPosTransitionProbabilities:
        Z = sum(bigramBeatPosTransitionProbabilities[k])
        bigramBeatPosTransitionProbabilities[k] = [x / Z for x in bigramBeatPosTransitionProbabilities[k]]
        
    return bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities

In [18]:
def beat_bigram_perplexity(midi_file):
    bigramBeatTransitions, bigramBeatTransitionProbabilities = beat_bigram_probability(midi_files)
    bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)
    # Q8b: Your code goes here
    # Hint: one more probability function needs to be computed
    unigramBeat = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for beat in beats:
            unigramBeat[beat[1]] += 1
    unigramBeatProbabilities = {}
    counts = sum(list(unigramBeat.values()))
    for n in unigramBeat:
        unigramBeatProbabilities[n] = unigramBeat[n] / counts
        
    beat_events = beat_extraction(midi_file)
    beats = [b[1] for b in beat_events]

    # perplexity for Q7
    perplexities = [unigramBeatProbabilities[beats[0]]]
    for (beat1, beat2) in zip(beats[:-1], beats[1:]):
        index = bigramBeatTransitions[beat1].index(beat2)
        prob = bigramBeatTransitionProbabilities[beat1][index]
        perplexities.append(prob)
    assert len(perplexities) == len(beats)
    perplexity_Q7 = np.exp(-np.sum(np.log(perplexities)) / len(beats))
    
    # perplexity for Q8
    perplexities = []
    for (beat_position, beat_length) in beat_events:
        index = bigramBeatPosTransitions[beat_position].index(beat_length)
        prob = bigramBeatPosTransitionProbabilities[beat_position][index]
        perplexities.append(prob)
    assert len(perplexities) == len(beat_events)
    perplexity_Q8 = np.exp(-np.sum(np.log(perplexities)) / len(beats))
    
    return perplexity_Q7, perplexity_Q8

9. Implement a Markov chain that computes p(beat_length | previous_beat_length, beat_position), and report its perplexity. 

`beat_trigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `trigramBeatTransitions`: key - (previous_beat_length, beat_position), value - a list of beat_length

  - `trigramBeatTransitionProbabilities`: key: (previous_beat_length, beat_position), value: a list of probabilities for beat_length in the same order of `trigramsBeatTransition`

`beat_trigram_perplexity()`
- **Input**: a midi file

- **Output**: perplexity value

In [19]:
def beat_trigram_probability(midi_files):
    # Q9a: Your code goes here
    trigramBeat = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for (beat1, beat2) in zip(beats[:-1], beats[1:]):
            trigramBeat[(beat1[1], beat2[0], beat2[1])] += 1
            
    trigramBeatTransitions = defaultdict(list)
    trigramBeatTransitionProbabilities = defaultdict(list)

    for t1,t2,t3 in trigramBeat:
        trigramBeatTransitions[(t1,t2)].append(t3)
        trigramBeatTransitionProbabilities[(t1,t2)].append(trigramBeat[(t1,t2,t3)])
        
    for k in trigramBeatTransitionProbabilities:
        Z = sum(trigramBeatTransitionProbabilities[k])
        trigramBeatTransitionProbabilities[k] = [x / Z for x in trigramBeatTransitionProbabilities[k]]
        
    return trigramBeatTransitions, trigramBeatTransitionProbabilities

In [20]:
def beat_trigram_perplexity(midi_file):
    bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)
    trigramBeatTransitions, trigramBeatTransitionProbabilities = beat_trigram_probability(midi_files)
    # Q9b: Your code goes here
    beats = beat_extraction(midi_file)

    perplexities = []
    index = bigramBeatPosTransitions[beats[0][0]].index(beats[0][1])
    prob = bigramBeatPosTransitionProbabilities[beats[0][0]][index]
    perplexities.append(prob)

    for (beat1, beat2) in zip(beats[:-1], beats[1:]):
        index = trigramBeatTransitions[(beat1[1], beat2[0])].index(beat2[1])
        prob = trigramBeatTransitionProbabilities[(beat1[1], beat2[0])][index]
        perplexities.append(prob)

    assert len(perplexities) == len(beats)
    perplexity = np.exp(-np.sum(np.log(perplexities)) / len(beats))
    return perplexity

In [40]:
def note_quadgram_probability(midi_files):
    quadgrams = defaultdict(int)
    for file in midi_files:
        note_events = note_extraction(file)
        for (note1, note2, note3, note4) in zip(note_events[:-3], note_events[1:-2], note_events[2:-1], note_events[3:]):
            quadgrams[(note1, note2, note3, note4)] += 1
            
    quadgramTransitions = defaultdict(list)
    quadgramTransitionProbabilities = defaultdict(list)

    for t1,t2,t3,t4 in quadgrams:
        quadgramTransitions[(t1,t2,t3)].append(t4)
        quadgramTransitionProbabilities[(t1,t2,t3)].append(quadgrams[(t1,t2,t3,t4)])
        
    for k in quadgramTransitionProbabilities:
        Z = sum(quadgramTransitionProbabilities[k])
        quadgramTransitionProbabilities[k] = [x / Z for x in quadgramTransitionProbabilities[k]]
        
    return quadgramTransitions, quadgramTransitionProbabilities

10. Use the model from Q5 to generate 500 notes, and the model from Q8 to generate beat lengths for each note. Save the generated music as a midi file (see code from workbook1) as q10.mid. Remember to reset the beat position to 0 when reaching the end of a bar.

`music_generate`
- **Input**: target length, e.g. 500

- **Output**: a midi file q10.mid

Note: the duration of one beat in MIDIUtil is 1, while in MidiTok is 8. Divide beat length by 8 if you use methods in MIDIUtil to save midi files.

In [42]:
def music_generate(length, unigramProbabilities, bigramTransitions, bigramTransitionProbabilities, trigramTransitions, trigramTransitionProbabilities, quadgramTransitions, quadgramTransitionProbabilities, bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities, trigramBeatPosTransitions, trigramBeatPosTransitionProbabilities):
    # Your code goes here ...
    sampled_notes = []
    
    first_note = random.choices(list(unigramProbabilities.keys()), weights=list(unigramProbabilities.values()), k=1)[0]
    second_note = random.choices(bigramTransitions[first_note], weights=bigramTransitionProbabilities[first_note], k=1)[0]
    third_note = random.choices(trigramTransitions[(first_note, second_note)], weights=trigramTransitionProbabilities[(first_note, second_note)], k=1)[0]
    sampled_notes = [first_note, second_note, third_note]
    
    for i in range(length - 3):
        prev_notes = tuple(sampled_notes[-3:])
        next_note = random.choices(quadgramTransitions[prev_notes], weights=quadgramTransitionProbabilities[prev_notes], k=1)[0]
        sampled_notes.append(next_note)
    
    # sample beats
    first_beat_length = random.choices(bigramBeatPosTransitions[0], weights=bigramBeatPosTransitionProbabilities[0], k=1)[0]
    sampled_beats = [first_beat_length]
    beat_position = first_beat_length % 32
    
    for i in range(length):
        beat_length = random.choices(trigramBeatPosTransitions[(sampled_beats[-1], beat_position)], weights=trigramBeatPosTransitionProbabilities[(sampled_beats[-1], beat_position)], k=1)[0]
        sampled_beats.append(beat_length)
        beat_position += beat_length
        
        if beat_position >= 32:
            beat_position = beat_position % 32
    
    # save the generated music as a midi file
    midi = MIDIFile(1) # Create a MIDI file that consists of 1 track
    track = 0 # Set track number
    time = 0 # Where is the event placed (at the beginning)
    tempo = 120 # The tempo (beats per minute)
    midi.addTempo(track, time, tempo) # Add tempo information
    
    current_time = 0
    for i in range(length):
        pitch = sampled_notes[i]
        duration = sampled_beats[i] / 8
        midi.addNote(track, 0, pitch, current_time, duration, 100)
        current_time += duration

    with open("q10.mid", "wb") as f:
        midi.writeFile(f) # write MIDI file

In [41]:
# sample notes
unigramProbabilities = note_unigram_probability(midi_files)
bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
trigramTransitions, trigramTransitionProbabilities = note_trigram_probability(midi_files)
quadgramTransitions, quadgramTransitionProbabilities = note_quadgram_probability(midi_files)
# sample beats
bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)
trigramBeatPosTransitions, trigramBeatPosTransitionProbabilities = beat_trigram_probability(midi_files)

In [43]:
music_generate(200, unigramProbabilities, bigramTransitions, bigramTransitionProbabilities, trigramTransitions, trigramTransitionProbabilities, quadgramTransitions, quadgramTransitionProbabilities, bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities, trigramBeatPosTransitions, trigramBeatPosTransitionProbabilities)

In [44]:
import pretty_midi
import IPython.display as ipd

# Load MIDI file
midi_data = pretty_midi.PrettyMIDI('q10.mid')

# Convert to audio
audio_data = midi_data.synthesize()

# Play (works well in Jupyter notebooks)
ipd.Audio(audio_data, rate=44100)