## Homework 3: Symbolic Music Generation Using Markov Chains

**Before starting the homework:**

Please run `pip install miditok` to install the [MiDiTok](https://github.com/Natooz/MidiTok) package, which simplifies MIDI file processing by making note and beat extraction more straightforward.

You’re also welcome to experiment with other MIDI processing libraries such as [mido](https://github.com/mido/mido), [pretty_midi](https://github.com/craffel/pretty-midi) and [miditoolkit](https://github.com/YatingMusic/miditoolkit). However, with these libraries, you’ll need to handle MIDI quantization yourself, for example, converting note-on/note-off events into beat positions and durations.

In [40]:
# install some packages
# !pip install miditok
# !pip install symusic

In [1]:
# import required packages
import random
random.seed(42)
from glob import glob
from collections import defaultdict

import numpy as np
from numpy.random import choice

from symusic import Score
from miditok import REMI, TokenizerConfig
from midiutil import MIDIFile

  from .autonotebook import tqdm as notebook_tqdm


### Load music dataset
We use a subset of [PDMX dataset](https://zenodo.org/records/14984509) for this homework. 

Please download the data through XXXXX and unzip.

All pieces are monophonic music (i.e. one melody line) in time signature 4/4.

In [2]:
midi_files = glob('nesmdb_midi/train/*.mid')
len(midi_files)

4502

In [3]:
import re
from typing import List, Set

def extract_nes_game_names(midi_files: List[str]) -> List[str]:
    """
    Extract unique NES game names from a list of MIDI file paths.
    
    Args:
        midi_files: List of MIDI file paths
        
    Returns:
        List of unique game names sorted alphabetically
    """
    game_names = set()
    
    for file_path in midi_files:
        # Pattern to match: number_GameName_trackNumbers_trackName.mid
        # Examples: 
        # - 014_Argus_05_06LandingSuccess.mid
        # - 378_WaiWaiWorld2_SOS__ParsleyJou_27_28InvitationtotheWorldofDemons1.mid
        
        # More flexible approach: find the pattern of number_ at start, then _number_number near end
        match = re.search(r'\\(\d+)_(.+?)_(\d+)_(\d+)', file_path)
        
        if match:
            game_name = match.group(2)  # Everything between first number_ and _number_number
            game_names.add(game_name)
    
    return sorted(list(game_names))

game_names = extract_nes_game_names(midi_files)
    
print("Extracted NES Game Names:")
print("=" * 30)
for i, game in enumerate(game_names, 1):
    print(f"{i:2d}. {game}")

print(f"\nTotal games found: {len(game_names)}")

Extracted NES Game Names:
 1. 10_YardFight
 2. 1942
 3. 720_
 4. 98in1
 5. Abadox_TheDeadlyInnerWar
 6. Adam_amp_Eve
 7. AfterBurner
 8. AfterBurnerII
 9. AighinanoYogen_BalubalouknoDensetsuYori
10. AlienSyndrome
11. Aliens_Alien2
12. Argus
13. ArmWrestling
14. ArumananoKiseki
15. Athena
16. AtlantisnoNazo
17. BabelnoTou
18. BalloonFight
19. Baseball
20. BatmanReturns
21. Batman_ReturnofTheJoker
22. Batman_TheVideoGame
23. BattleCity
24. BinaryLand
25. BioMiracleBokutteUpa
26. BioSenshiDan_IncreasertonoTatakai
27. Blackjack
28. BlasterMaster
29. Bomberman
30. BombermanII
31. BuraiFighter
32. CaptainTsubasaVol_II_SuperStriker
33. Castelian
34. CastleofDragon
35. Castlevania
36. CastlevaniaIII_Dracula_sCurse
37. CastlevaniaII_Simon_sQuest
38. Chack_nPop
39. Challenger
40. ChaosWorld
41. ChesterField_EpisodeIIAnkokuShinenoChousen
42. ChoujinSentaiJetman
43. CircusCaper
44. CircusCharlie
45. CluCluLand
46. Contra
47. ContraForce
48. CrisisForce
49. DarkLord
50. DeadlyTowers
51. Deathbots
5

In [4]:
def extract_nes_game_name(midi_file_path: str) -> str:
    """
    Extract NES game name from a single MIDI file path.
    
    Args:
        midi_file_path: Path to a single MIDI file
        
    Returns:
        Game name as a string, or empty string if no match found
    """
    # Pattern to match: number_GameName_trackNumbers_trackName.mid
    # Examples: 
    # - 014_Argus_05_06LandingSuccess.mid
    # - 378_WaiWaiWorld2_SOS__ParsleyJou_27_28InvitationtotheWorldofDemons1.mid
    
    match = re.search(r'\\(\d+)_(.+?)_(\d+)_(\d+)', midi_file_path)
    
    if match:
        game_name = match.group(2)
        return game_name
    
    return ""

print(extract_nes_game_name(midi_files[4150]))

WaiWaiWorld2_SOS__ParsleyJou


### Train a tokenizer with the REMI method in MidiTok

In [5]:
config = TokenizerConfig(num_velocities=1, use_chords=False, use_programs=False)
tokenizer = REMI(config)
tokenizer.train(vocab_size=1000, files_paths=midi_files)

### Use the trained tokenizer to get tokens for each midi file
In REMI representation, each note will be represented with four tokens: `Position, Pitch, Velocity, Duration`, e.g. `('Position_28', 'Pitch_74', 'Velocity_127', 'Duration_0.4.8')`, and `Bar_None` token indicates the beginning of a new bar.

In [6]:
midi = Score(midi_files[0])
tokens = tokenizer(midi)[0].tokens
tokens[:10]

['Bar_None',
 'Position_0',
 'Pitch_65',
 'Velocity_127',
 'Duration_0.6.8',
 'Position_6',
 'Pitch_69',
 'Velocity_127',
 'Duration_0.2.8',
 'Position_9']

1. Write a function to extract note pitch events from a midi file; extract all note pitch events from the dataset and output a dictionary that maps note pitch events to the number of times they occur in the files. (e.g. {60: 120, 61: 58, …}).

`note_extraction()`
- **Input**: a midi file

- **Output**: a list of note pitch events

`note_frequency()`
- **Input**: all midi files `midi_files`

- **Output**: a dictionary that maps note pitch events to the number of times they occur, e.g {60: 120, 61: 58, …}

In [7]:
def note_extraction(midi_file):
    # Q1a: Your code goes here
    note_events = []
    midi = Score(midi_file)
    tokens = tokenizer(midi)[0].tokens
    for token in tokens:
        if 'Pitch' in token:
            note = int(token.split('_')[1])
            note_events.append(note)
    return note_events

In [8]:
def note_frequency(midi_files):
    # Q1b: Your code goes here
    note_counts = defaultdict(int)
    for midi_file in midi_files:
        note_events = note_extraction(midi_file)
        game = extract_nes_game_name(midi_file)
        for note in note_events:
            note_counts[(note, game)] += 1
    return note_counts

2. Write a function to normalize the above dictionary to produce probability scores. (e.g. {60: 0.13, 61: 0.065, …})

`note_unigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: a dictionary that maps note pitch events to probabilities they occur in the dataset, e.g. {60: 0.13, 61: 0.06, …}

In [9]:
def note_unigram_probability(midi_files):
    note_counts = note_frequency(midi_files)
    
    game_totals = {}
    for (note, game) in note_counts:
        if game not in game_totals:
            game_totals[game] = 0
        game_totals[game] += note_counts[(note, game)]
    
    # Calculate probabilities relative to each game's total
    unigramProbabilities = {}
    for (note, game) in note_counts:
        unigramProbabilities[(note, game)] = note_counts[(note, game)] / game_totals[game]
    return unigramProbabilities

3. Generate a table of pairwise probabilities containing p(next_note | previous_note) for the dataset; write a function that randomly generates the next note based on the previous note based on this distribution.

`note_bigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `bigramTransitions`: key - previous_note, value - a list of next_note, e.g. {60:[62, 64, ..], 62:[60, 64, ..], ...}

  - `bigramTransitionProbabilities`: key - previous_note, value - a list of probabilities for next_note in the same order of `bigramTransitions`, e.g. {60:[0.3, 0.4, ..], 62:[0.2, 0.1, ..], ...}

`sample_next_note()`
- **Input**: a note

- **Output**: next note sampled from pairwise probabilities

In [10]:
def note_bigram_probability(midi_files):
    # Q3a: Your code goes here
    bigrams = defaultdict(int)
    
    for file in midi_files:
        note_events = note_extraction(file)
        game = extract_nes_game_name(file)
        for (note1, note2) in zip(note_events[:-1], note_events[1:]):
            bigrams[(note1, note2, game)] += 1
            
    bigramTransitions = defaultdict(list)
    bigramTransitionProbabilities = defaultdict(list)

    for (note1, note2, game) in bigrams:
        bigramTransitions[(note1, game)].append((note2, game))
        bigramTransitionProbabilities[(note1, game)].append(bigrams[(note1, note2, game)])
        
    for k in bigramTransitionProbabilities:
        Z = sum(bigramTransitionProbabilities[k])
        bigramTransitionProbabilities[k] = [x / Z for x in bigramTransitionProbabilities[k]]
        
    return bigramTransitions, bigramTransitionProbabilities

In [11]:
def sample_next_note(note):
    # Q3b: Your code goes here
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    next_note = choice(bigramTransitions[note], 1, p=bigramTransitionProbabilities[note])[0]
    return next_note

4. Write a function to calculate the perplexity of your model on a midi file.

    The perplexity of a model is defined as 

    $\quad \text{exp}(-\frac{1}{N} \sum_{i=1}^N \text{log}(p(w_i|w_{i-1})))$

    where $p(w_1|w_0) = p(w_1)$, $p(w_i|w_{i-1}) (i>1)$ refers to the pairwise probability p(next_note | previous_note).

`note_bigram_perplexity()`
- **Input**: a midi file

- **Output**: perplexity value

In [12]:
def note_bigram_perplexity(midi_file):
    unigramProbabilities = note_unigram_probability(midi_files)
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    
    
    # Q4: Your code goes here
    note_events = note_extraction(midi_file)
    perplexities = [unigramProbabilities[note_events[0]]]
    for (note1, note2) in zip(note_events[:-1], note_events[1:]):
        index = bigramTransitions[note1].index(note2)
        prob = bigramTransitionProbabilities[note1][index]
        perplexities.append(prob)

    assert len(perplexities) == len(note_events)
    perplexity = np.exp(-np.sum(np.log(perplexities)) / len(note_events))
    return perplexity

5. Implement a second-order Markov chain, i.e., one which estimates p(next_note | next_previous_note, previous_note); write a function to compute the perplexity of this new model on a midi file. 

    The perplexity of this model is defined as 

    $\quad \text{exp}(-\frac{1}{N} \sum_{i=1}^N \text{log}(p(w_i|w_{i-2}, w_{i-1})))$

    where $p(w_1|w_{-1}, w_0) = p(w_1)$, $p(w_2|w_0, w_1) = p(w_2|w_1)$, $p(w_i|w_{i-2}, w_{i-1}) (i>2)$ refers to the probability p(next_note | next_previous_note, previous_note).


`note_trigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `trigramTransitions`: key - (next_previous_note, previous_note), value - a list of next_note, e.g. {(60, 62):[64, 66, ..], (60, 64):[60, 64, ..], ...}

  - `trigramTransitionProbabilities`: key - (next_previous_note, previous_note), value - a list of probabilities for next_note in the same order of `trigramTransitions`, e.g. {(60, 62):[0.2, 0.2, ..], (60, 64):[0.4, 0.1, ..], ...}

`note_trigram_perplexity()`
- **Input**: a midi file

- **Output**: perplexity value

In [13]:
def note_trigram_probability(midi_files):
    # Q5a: Your code goes here
    trigrams = defaultdict(int)
    for file in midi_files:
        note_events = note_extraction(file)
        game = extract_nes_game_name(file)
        for (note1, note2, note3) in zip(note_events[:-2], note_events[1:-1], note_events[2:]):
            trigrams[(note1, note2, note3, game)] += 1
            
    trigramTransitions = defaultdict(list)
    trigramTransitionProbabilities = defaultdict(list)

    for (note1, note2, note3, game) in trigrams:
        trigramTransitions[(note1, note2, game)].append((note3, game))
        trigramTransitionProbabilities[(note1, note2, game)].append(trigrams[(note1, note2, note3, game)])
        
    for k in trigramTransitionProbabilities:
        Z = sum(trigramTransitionProbabilities[k])
        trigramTransitionProbabilities[k] = [x / Z for x in trigramTransitionProbabilities[k]]
        
    return trigramTransitions, trigramTransitionProbabilities

In [14]:
def note_trigram_perplexity(midi_file):
    unigramProbabilities = note_unigram_probability(midi_files)
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    trigramTransitions, trigramTransitionProbabilities = note_trigram_probability(midi_files)
    
    # Q5b: Your code goes here
    note_events = note_extraction(midi_file)
    perplexities = [unigramProbabilities[note_events[0]]]
    index = bigramTransitions[note_events[0]].index(note_events[1])
    prob = bigramTransitionProbabilities[note_events[0]][index]
    perplexities.append(prob)
    
    for (note1, note2, note3) in zip(note_events[:-2], note_events[1:-1], note_events[2:]):
        index = trigramTransitions[(note1, note2)].index(note3)
        prob = trigramTransitionProbabilities[(note1, note2)][index]
        perplexities.append(prob)

    assert len(perplexities) == len(note_events)
    perplexity = np.exp(-np.sum(np.log(perplexities)) / len(note_events))
    return perplexity

6. Our model currently doesn’t have any knowledge of beats. Write a function that extracts beat lengths and outputs a list of [(beat position; beat length)] values.

    Recall that each note will be encoded as `Position, Pitch, Velocity, Duration` using REMI. Please keep the `Position` value for beat position, and convert `Duration` to beat length using provided lookup table `duration2length` (see below).

    For example, for a note represented by four tokens `('Position_24', 'Pitch_72', 'Velocity_127', 'Duration_0.4.8')`, the extracted (beat position; beat length) value is `(24, 4)`.

    As a result, we will obtain a list like [(0,8),(8,16),(24,4),(28,4),(0,4)...], where the next beat position is the previous beat position + the beat length. As we divide each bar into 32 positions by default, when reaching the end of a bar (i.e. 28 + 4 = 32 in the case of (28, 4)), the beat position reset to 0.

In [15]:
duration2length = {
    '0.2.8': 2,  # sixteenth note, 0.25 beat in 4/4 time signature
    '0.4.8': 4,  # eighth note, 0.5 beat in 4/4 time signature
    '1.0.8': 8,  # quarter note, 1 beat in 4/4 time signature
    '2.0.8': 16, # half note, 2 beats in 4/4 time signature
    '4.0.4': 32, # whole note, 4 beats in 4/4 time signature
}

`beat_extraction()`
- **Input**: a midi file

- **Output**: a list of (beat position; beat length) values

In [16]:
def beat_extraction(midi_file):
    midi = Score(midi_file)
    tokens = tokenizer(midi)[0].tokens
    game = extract_nes_game_name(midi_file)
    beats = []
    
    for i in range(len(tokens)):
        if 'Position' in tokens[i] and 'Duration' in tokens[i+3]:
            position = int(tokens[i].split('_')[1])
            encoded_length = tokens[i+3].split('_')[1]
            length = int(encoded_length.split('.')[0]) * 8 + int(encoded_length.split('.')[1]) # TODO: is this correct?
            length = length % 32
            # length = duration2length[tokens[i+3].split('_')[1]]
            beats.append((position, length, game))
    return beats

7. Implement a Markov chain that computes p(beat_length | previous_beat_length) based on the above function.

`beat_bigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `bigramBeatTransitions`: key - previous_beat_length, value - a list of beat_length, e.g. {4:[8, 2, ..], 8:[8, 4, ..], ...}

  - `bigramBeatTransitionProbabilities`: key - previous_beat_length, value - a list of probabilities for beat_length in the same order of `bigramBeatTransitions`, e.g. {4:[0.3, 0.2, ..], 8:[0.4, 0.4, ..], ...}

In [17]:
def beat_bigram_probability(midi_files):
    # Q7: Your code goes here
    bigramBeat = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for (beat1, beat2) in zip(beats[:-1], beats[1:]):
            bigramBeat[(beat1[1], beat2[1], beat1[2])] += 1
            
    bigramBeatTransitions = defaultdict(list)
    bigramBeatTransitionProbabilities = defaultdict(list)

    for (length1, length2, game) in bigramBeat:
        bigramBeatTransitions[(length1, game)].append((length2, game))
        bigramBeatTransitionProbabilities[(length1, game)].append(bigramBeat[(length1, length2, game)])
        
    for k in bigramBeatTransitionProbabilities:
        Z = sum(bigramBeatTransitionProbabilities[k])
        bigramBeatTransitionProbabilities[k] = [x / Z for x in bigramBeatTransitionProbabilities[k]]
        
    return bigramBeatTransitions, bigramBeatTransitionProbabilities

8. Implement a function to compute p(beat length | beat position), and compute the perplexity of your models from Q7 and Q8. For both models, we only consider the probabilities of predicting the sequence of **beat length**.

`beat_pos_bigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `bigramBeatPosTransitions`: key - beat_position, value - a list of beat_length

  - `bigramBeatPosTransitionProbabilities`: key - beat_position, value - a list of probabilities for beat_length in the same order of `bigramBeatPosTransitions`

`beat_bigram_perplexity()`
- **Input**: a midi file

- **Output**: two perplexity values correspond to the models in Q7 and Q8, respectively

In [18]:
def beat_pos_bigram_probability(midi_files):
    # Q8a: Your code goes here
    bigramBeatPos = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for beat in beats:
            bigramBeatPos[(beat[0], beat[1], beat[2])] += 1
            
    bigramBeatPosTransitions = defaultdict(list)
    bigramBeatPosTransitionProbabilities = defaultdict(list)

    for (position, length, game) in bigramBeatPos:
        bigramBeatPosTransitions[(position, game)].append((length, game))
        bigramBeatPosTransitionProbabilities[(position, game)].append(bigramBeatPos[(position, length, game)])
        
    for k in bigramBeatPosTransitionProbabilities:
        Z = sum(bigramBeatPosTransitionProbabilities[k])
        bigramBeatPosTransitionProbabilities[k] = [x / Z for x in bigramBeatPosTransitionProbabilities[k]]
        
    return bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities

In [19]:
def beat_bigram_perplexity(midi_file):
    bigramBeatTransitions, bigramBeatTransitionProbabilities = beat_bigram_probability(midi_files)
    bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)
    # Q8b: Your code goes here
    # Hint: one more probability function needs to be computed
    unigramBeat = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for beat in beats:
            unigramBeat[beat[1]] += 1
    unigramBeatProbabilities = {}
    counts = sum(list(unigramBeat.values()))
    for n in unigramBeat:
        unigramBeatProbabilities[n] = unigramBeat[n] / counts
        
    beat_events = beat_extraction(midi_file)
    beats = [b[1] for b in beat_events]

    # perplexity for Q7
    perplexities = [unigramBeatProbabilities[beats[0]]]
    for (beat1, beat2) in zip(beats[:-1], beats[1:]):
        index = bigramBeatTransitions[beat1].index(beat2)
        prob = bigramBeatTransitionProbabilities[beat1][index]
        perplexities.append(prob)
    assert len(perplexities) == len(beats)
    perplexity_Q7 = np.exp(-np.sum(np.log(perplexities)) / len(beats))
    
    # perplexity for Q8
    perplexities = []
    for (beat_position, beat_length) in beat_events:
        index = bigramBeatPosTransitions[beat_position].index(beat_length)
        prob = bigramBeatPosTransitionProbabilities[beat_position][index]
        perplexities.append(prob)
    assert len(perplexities) == len(beat_events)
    perplexity_Q8 = np.exp(-np.sum(np.log(perplexities)) / len(beats))
    
    return perplexity_Q7, perplexity_Q8

9. Implement a Markov chain that computes p(beat_length | previous_beat_length, beat_position), and report its perplexity. 

`beat_trigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `trigramBeatTransitions`: key - (previous_beat_length, beat_position), value - a list of beat_length

  - `trigramBeatTransitionProbabilities`: key: (previous_beat_length, beat_position), value: a list of probabilities for beat_length in the same order of `trigramsBeatTransition`

`beat_trigram_perplexity()`
- **Input**: a midi file

- **Output**: perplexity value

In [20]:
def beat_trigram_probability(midi_files):
    # Q9a: Your code goes here
    trigramBeat = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for (beat1, beat2) in zip(beats[:-1], beats[1:]):
            trigramBeat[(beat1[1], beat2[0], beat2[1], beat1[2])] += 1
            
    trigramBeatTransitions = defaultdict(list)
    trigramBeatTransitionProbabilities = defaultdict(list)

    for (length1, position2, length2, game) in trigramBeat:
        trigramBeatTransitions[(length1, position2, game)].append((length2, game))
        trigramBeatTransitionProbabilities[(length1, position2, game)].append(trigramBeat[(length1, position2, length2, game)])
        
    for k in trigramBeatTransitionProbabilities:
        Z = sum(trigramBeatTransitionProbabilities[k])
        trigramBeatTransitionProbabilities[k] = [x / Z for x in trigramBeatTransitionProbabilities[k]]
        
    return trigramBeatTransitions, trigramBeatTransitionProbabilities

In [45]:
def beat_trigram_perplexity(midi_file, game, bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities, trigramBeatTransitions, trigramBeatTransitionProbabilities):
    # bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)
    # trigramBeatTransitions, trigramBeatTransitionProbabilities = beat_trigram_probability(midi_files)
    # Q9b: Your code goes here
    beats = beat_extraction(midi_file)
    beat_positions, beat_lengths, _ = zip(*beats)
    
    first_beat_index = bigramBeatPosTransitions[(beat_positions[0], game)].index((beat_lengths[0], game))
    val = np.log(bigramBeatPosTransitionProbabilities[(beat_positions[0], game)][first_beat_index])
    for i in range(1, len(beat_positions)):
        index = trigramBeatTransitions[(beat_lengths[i-1],beat_positions[i],game)].index((beat_lengths[i], game))
        val += np.log(trigramBeatTransitionProbabilities[(beat_lengths[i-1],beat_positions[i],game)][index])
    
    return np.exp((-1 / len(beat_positions)) * val)

In [22]:
def note_quadgram_probability(midi_files):
    quadgrams = defaultdict(int)
    for file in midi_files:
        note_events = note_extraction(file)
        game = extract_nes_game_name(file)
        for (note1, note2, note3, note4) in zip(note_events[:-3], note_events[1:-2], note_events[2:-1], note_events[3:]):
            quadgrams[(note1, note2, note3, note4, game)] += 1
            
    quadgramTransitions = defaultdict(list)
    quadgramTransitionProbabilities = defaultdict(list)

    for (note1, note2, note3, note4, game) in quadgrams:
        quadgramTransitions[(note1, note2, note3, game)].append((note4, game))
        quadgramTransitionProbabilities[(note1, note2, note3, game)].append(quadgrams[(note1, note2, note3, note4, game)])
        
    for k in quadgramTransitionProbabilities:
        Z = sum(quadgramTransitionProbabilities[k])
        quadgramTransitionProbabilities[k] = [x / Z for x in quadgramTransitionProbabilities[k]]
        
    return quadgramTransitions, quadgramTransitionProbabilities

In [35]:
def note_quadgram_perplexity(midi_file, game, unigramProbabilities, bigramTransitions, bigramTransitionProbabilities, trigramTransitions, trigramTransitionProbabilities, quadgramTransitions, quadgramTransitionProbabilities):
    # unigramProbabilities = note_unigram_probability(midi_files)
    # bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    # trigramTransitions, trigramTransitionProbabilities = note_trigram_probability(midi_files)
    # quadgramTransitions, quadgramTransitionProbabilities = note_quadgram_probability(midi_files)
    
    note_events = note_extraction(midi_file)
    perplexities = [unigramProbabilities[(note_events[0], game)]]
    index = bigramTransitions[(note_events[0], game)].index((note_events[1], game))
    prob = bigramTransitionProbabilities[(note_events[0], game)][index]
    perplexities.append(prob)
    index2 = trigramTransitions[(note_events[0], note_events[1], game)].index((note_events[2], game))
    prob2 = trigramTransitionProbabilities[(note_events[0], note_events[1], game)][index2]
    perplexities.append(prob2)
    
    for (note1, note2, note3, note4) in zip(note_events[:-3], note_events[1:-2], note_events[2:-1], note_events[3:]):
        index = quadgramTransitions[(note1, note2, note3, game)].index((note4, game))
        prob = quadgramTransitionProbabilities[(note1, note2, note3, game)][index]
        perplexities.append(prob)

    assert len(perplexities) == len(note_events)
    perplexity = np.exp(-np.sum(np.log(perplexities)) / len(note_events))
    return perplexity

10. Use the model from Q5 to generate 500 notes, and the model from Q8 to generate beat lengths for each note. Save the generated music as a midi file (see code from workbook1) as q10.mid. Remember to reset the beat position to 0 when reaching the end of a bar.

`music_generate`
- **Input**: target length, e.g. 500

- **Output**: a midi file q10.mid

Note: the duration of one beat in MIDIUtil is 1, while in MidiTok is 8. Divide beat length by 8 if you use methods in MIDIUtil to save midi files.

In [24]:
def music_generate(length, file_name, game, unigramProbabilities, bigramTransitions, bigramTransitionProbabilities, trigramTransitions, trigramTransitionProbabilities, quadgramTransitions, quadgramTransitionProbabilities, bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities, trigramBeatPosTransitions, trigramBeatPosTransitionProbabilities):
    # Your code goes here ...
    sampled_notes = []
    
    # Filter unigram probabilities for the specified game
    game_unigrams = {k: v for k, v in unigramProbabilities.items() if k[1] == game}
    
    first_note_tuple = random.choices(list(game_unigrams.keys()), weights=list(game_unigrams.values()), k=1)[0]
    first_note = first_note_tuple[0]  # Extract just the note value
    
    second_note_tuple = random.choices(bigramTransitions[first_note_tuple], weights=bigramTransitionProbabilities[first_note_tuple], k=1)[0]
    second_note = second_note_tuple[0]  # Extract just the note value
    
    third_note_tuple = random.choices(trigramTransitions[(first_note, second_note, game)], weights=trigramTransitionProbabilities[(first_note, second_note, game)], k=1)[0]
    third_note = third_note_tuple[0]  # Extract just the note value
    
    sampled_notes = [first_note, second_note, third_note]
    
    for i in range(length - 3):
        # We need to reconstruct the tuples for looking up transitions
        prev_notes = tuple(sampled_notes[-3:]) + (game,)
        if len(quadgramTransitions[prev_notes]) == 0 or len(quadgramTransitionProbabilities[prev_notes]) == 0:
            length = i
            break
        next_note = random.choices(quadgramTransitions[prev_notes], weights=quadgramTransitionProbabilities[prev_notes], k=1)[0]
        sampled_notes.append(next_note[0])
    
    # sample beats - filter for the specified game
    # game_beat_pos_0 = [(k, v) for k, v in zip(bigramBeatPosTransitions[(0, game)], bigramBeatPosTransitionProbabilities[(0, game)])]
    first_beat_length = random.choices(bigramBeatPosTransitions[(0, game)], weights=bigramBeatPosTransitionProbabilities[(0, game)], k=1)[0][0]
    sampled_beats = [first_beat_length]
    beat_position = first_beat_length % 32
    
    for i in range(length):
        beat_length_options = trigramBeatPosTransitions[(sampled_beats[-1], beat_position, game)]
        beat_length_weights = trigramBeatPosTransitionProbabilities[(sampled_beats[-1], beat_position, game)]
        if len(beat_length_options) == 0 or len(beat_length_weights) == 0:
            length = i
            break
        beat_length = random.choices(beat_length_options, weights=beat_length_weights, k=1)[0][0]
        sampled_beats.append(beat_length)
        beat_position += beat_length
        
        if beat_position >= 32:
            beat_position = beat_position % 32
    
    # save the generated music as a midi file
    midi = MIDIFile(1) # Create a MIDI file that consists of 1 track
    track = 0 # Set track number
    time = 0 # Where is the event placed (at the beginning)
    tempo = 120 # The tempo (beats per minute)
    midi.addTempo(track, time, tempo) # Add tempo information
    
    current_time = 0
    for i in range(length):
        pitch = sampled_notes[i]  # Now it's already just the note value (int)
        duration = sampled_beats[i] / 8
        midi.addNote(track, 0, pitch, current_time, duration, 100)
        current_time += duration

    with open(file_name, "wb") as f:
        midi.writeFile(f) # write MIDI file
    
    print(length)

In [25]:
# sample notes
unigramProbabilities = note_unigram_probability(midi_files)
bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)

In [28]:
trigramTransitions, trigramTransitionProbabilities = note_trigram_probability(midi_files)

In [29]:
quadgramTransitions, quadgramTransitionProbabilities = note_quadgram_probability(midi_files)

In [30]:
# sample beats
bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)

In [31]:
trigramBeatPosTransitions, trigramBeatPosTransitionProbabilities = beat_trigram_probability(midi_files)

In [179]:
from itertools import islice

first_x_items = dict(islice(trigramBeatPosTransitionProbabilities.items(), 100))
print(first_x_items)

{(6, 6, '10_YardFight'): [0.5, 0.5], (2, 9, '10_YardFight'): [1.0], (2, 11, '10_YardFight'): [1.0], (2, 13, '10_YardFight'): [1.0], (6, 19, '10_YardFight'): [1.0], (2, 21, '10_YardFight'): [1.0], (2, 23, '10_YardFight'): [1.0], (2, 26, '10_YardFight'): [1.0], (6, 13, '10_YardFight'): [1.0], (1, 1, '1942'): [1.0], (1, 2, '1942'): [1.0], (1, 3, '1942'): [1.0], (1, 4, '1942'): [1.0], (1, 5, '1942'): [1.0], (1, 6, '1942'): [1.0], (1, 7, '1942'): [1.0], (1, 8, '1942'): [1.0], (1, 9, '1942'): [1.0], (1, 10, '1942'): [1.0], (1, 11, '1942'): [1.0], (1, 12, '1942'): [1.0], (1, 13, '1942'): [1.0], (1, 14, '1942'): [1.0], (1, 15, '1942'): [1.0], (1, 16, '1942'): [1.0], (1, 17, '1942'): [1.0], (1, 18, '1942'): [1.0], (1, 19, '1942'): [1.0], (1, 20, '1942'): [1.0], (1, 21, '1942'): [1.0], (1, 22, '1942'): [1.0], (1, 23, '1942'): [1.0], (1, 24, '1942'): [1.0], (1, 25, '1942'): [1.0], (1, 26, '1942'): [1.0], (1, 27, '1942'): [1.0], (1, 28, '1942'): [1.0], (1, 29, '1942'): [1.0], (1, 30, '1942'): [1.0

In [None]:
import pretty_midi
from IPython.display import Audio, display

# DragonWarriorIV
for i in range(3):
    music_generate(100, f"markov_con_samples/DragonWarriorIV/markov_con_DWIV{i}.mid", "DragonWarriorIV", unigramProbabilities, bigramTransitions, bigramTransitionProbabilities, trigramTransitions, trigramTransitionProbabilities, quadgramTransitions, quadgramTransitionProbabilities, bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities, trigramBeatPosTransitions, trigramBeatPosTransitionProbabilities)
    midi_data = pretty_midi.PrettyMIDI(f"markov_con_samples/DragonWarriorIV/markov_con_DWIV{i}.mid")
    for instrument in midi_data.instruments:
        if instrument.is_drum:
            continue
        instrument.program = pretty_midi.instrument_name_to_program('Lead 1 (square)')
    # Save the modified MIDI file
    midi_data.write(f'markov_con_samples/DragonWarriorIV/markov_con_DWIV{i}.mid')

100
56
100


In [None]:
# FinalFantasyIII
for i in range(3):
    music_generate(100, f"markov_con_samples/FinalFantasyIII/markov_con_FFIII{i}.mid", "FinalFantasyIII", unigramProbabilities, bigramTransitions, bigramTransitionProbabilities, trigramTransitions, trigramTransitionProbabilities, quadgramTransitions, quadgramTransitionProbabilities, bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities, trigramBeatPosTransitions, trigramBeatPosTransitionProbabilities)
    midi_data = pretty_midi.PrettyMIDI(f"markov_con_samples/FinalFantasyIII/markov_con_FFIII{i}.mid")
    for instrument in midi_data.instruments:
        if instrument.is_drum:
            continue
        instrument.program = pretty_midi.instrument_name_to_program('Lead 1 (square)')
    # Save the modified MIDI file
    midi_data.write(f'markov_con_samples/FinalFantasyIII/markov_con_FFIII{i}.mid')

58
100
31


In [None]:
# GanbareGoemonGaiden2_TenkanoZaih_
for i in range(3):
    music_generate(100, f"markov_con_samples/GanbareGoemonGaiden2_TenkanoZaih_/markov_con_GGG2{i}.mid", "GanbareGoemonGaiden2_TenkanoZaih_", unigramProbabilities, bigramTransitions, bigramTransitionProbabilities, trigramTransitions, trigramTransitionProbabilities, quadgramTransitions, quadgramTransitionProbabilities, bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities, trigramBeatPosTransitions, trigramBeatPosTransitionProbabilities)
    midi_data = pretty_midi.PrettyMIDI(f"markov_con_samples/GanbareGoemonGaiden2_TenkanoZaih_/markov_con_GGG2{i}.mid")
    for instrument in midi_data.instruments:
        if instrument.is_drum:
            continue
        instrument.program = pretty_midi.instrument_name_to_program('Lead 1 (square)')
    # Save the modified MIDI file
    midi_data.write(f'markov_con_samples/GanbareGoemonGaiden2_TenkanoZaih_/markov_con_GGG2{i}.mid')

100
100
100


In [254]:
# Kirby_sAdventure
for i in range(3):
    music_generate(100, f"markov_con_samples/Kirby_sAdventure/markov_con_KA{i}.mid", "Kirby_sAdventure", unigramProbabilities, bigramTransitions, bigramTransitionProbabilities, trigramTransitions, trigramTransitionProbabilities, quadgramTransitions, quadgramTransitionProbabilities, bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities, trigramBeatPosTransitions, trigramBeatPosTransitionProbabilities)
    midi_data = pretty_midi.PrettyMIDI(f"markov_con_samples/Kirby_sAdventure/markov_con_KA{i}.mid")
    for instrument in midi_data.instruments:
        if instrument.is_drum:
            continue
        instrument.program = pretty_midi.instrument_name_to_program('Lead 1 (square)')
    # Save the modified MIDI file
    midi_data.write(f'markov_con_samples/Kirby_sAdventure/markov_con_KA{i}.mid')

84
89
44


In [255]:
# WaiWaiWorld2_SOS__ParsleyJou
for i in range(3):
    music_generate(100, f"markov_con_samples/WaiWaiWorld2_SOS__ParsleyJou/markov_con_WWW{i}.mid", "WaiWaiWorld2_SOS__ParsleyJou", unigramProbabilities, bigramTransitions, bigramTransitionProbabilities, trigramTransitions, trigramTransitionProbabilities, quadgramTransitions, quadgramTransitionProbabilities, bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities, trigramBeatPosTransitions, trigramBeatPosTransitionProbabilities)
    midi_data = pretty_midi.PrettyMIDI(f"markov_con_samples/WaiWaiWorld2_SOS__ParsleyJou/markov_con_WWW{i}.mid")
    for instrument in midi_data.instruments:
        if instrument.is_drum:
            continue
        instrument.program = pretty_midi.instrument_name_to_program('Lead 1 (square)')
    # Save the modified MIDI file
    midi_data.write(f'markov_con_samples/WaiWaiWorld2_SOS__ParsleyJou/markov_con_WWW{i}.mid')

100
69
65


In [260]:
for i in range(3):
    music_generate(100, f"markov_con_samples/SuperMarioBros_/markov_con_MARIO{i}.mid", "SuperMarioBros_", unigramProbabilities, bigramTransitions, bigramTransitionProbabilities, trigramTransitions, trigramTransitionProbabilities, quadgramTransitions, quadgramTransitionProbabilities, bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities, trigramBeatPosTransitions, trigramBeatPosTransitionProbabilities)
    midi_data = pretty_midi.PrettyMIDI(f"markov_con_samples/SuperMarioBros_/markov_con_MARIO{i}.mid")
    for instrument in midi_data.instruments:
        if instrument.is_drum:
            continue
        instrument.program = pretty_midi.instrument_name_to_program('Lead 1 (square)')
    # Save the modified MIDI file
    midi_data.write(f'markov_con_samples/SuperMarioBros_/markov_con_MARIO{i}.mid')

54
83
96


In [53]:
note_perplexities = []
beat_perplexities = []
# DragonWarriorIV
for i in range(3):
    note_perplexities.append(float(note_quadgram_perplexity(f"markov_con_samples/DragonWarriorIV/markov_con_DWIV{i}.mid", "DragonWarriorIV",
                                                            unigramProbabilities,
                                                            bigramTransitions,
                                                            bigramTransitionProbabilities,
                                                            trigramTransitions,
                                                            trigramTransitionProbabilities,
                                                            quadgramTransitions,
                                                            quadgramTransitionProbabilities)))
    beat_perplexities.append(float(beat_trigram_perplexity(f"markov_con_samples/DragonWarriorIV/markov_con_DWIV{i}.mid", "DragonWarriorIV",
                                                            bigramBeatPosTransitions,
                                                            bigramBeatPosTransitionProbabilities,
                                                            trigramBeatPosTransitions,
                                                            trigramBeatPosTransitionProbabilities)))
# FinalFantasyIII
for i in range(3):
    note_perplexities.append(float(note_quadgram_perplexity(f"markov_con_samples/FinalFantasyIII/markov_con_FFIII{i}.mid", "FinalFantasyIII",
                                                            unigramProbabilities,
                                                            bigramTransitions,
                                                            bigramTransitionProbabilities,
                                                            trigramTransitions,
                                                            trigramTransitionProbabilities,
                                                            quadgramTransitions,
                                                            quadgramTransitionProbabilities)))
    beat_perplexities.append(float(beat_trigram_perplexity(f"markov_con_samples/FinalFantasyIII/markov_con_FFIII{i}.mid", "FinalFantasyIII",
                                                            bigramBeatPosTransitions,
                                                            bigramBeatPosTransitionProbabilities,
                                                            trigramBeatPosTransitions,
                                                            trigramBeatPosTransitionProbabilities)))
# GanbareGoemonGaiden2_TenkanoZaih_
for i in range(3):
    note_perplexities.append(float(note_quadgram_perplexity(f"markov_con_samples/GanbareGoemonGaiden2_TenkanoZaih_/markov_con_GGG2{i}.mid", "GanbareGoemonGaiden2_TenkanoZaih_",
                                                            unigramProbabilities,
                                                            bigramTransitions,
                                                            bigramTransitionProbabilities,
                                                            trigramTransitions,
                                                            trigramTransitionProbabilities,
                                                            quadgramTransitions,
                                                            quadgramTransitionProbabilities)))
    beat_perplexities.append(float(beat_trigram_perplexity(f"markov_con_samples/GanbareGoemonGaiden2_TenkanoZaih_/markov_con_GGG2{i}.mid", "GanbareGoemonGaiden2_TenkanoZaih_",
                                                            bigramBeatPosTransitions,
                                                            bigramBeatPosTransitionProbabilities,
                                                            trigramBeatPosTransitions,
                                                            trigramBeatPosTransitionProbabilities)))
# Kirby_sAdventure
for i in range(3):
    note_perplexities.append(float(note_quadgram_perplexity(f"markov_con_samples/Kirby_sAdventure/markov_con_KA{i}.mid", "Kirby_sAdventure",
                                                            unigramProbabilities,
                                                            bigramTransitions,
                                                            bigramTransitionProbabilities,
                                                            trigramTransitions,
                                                            trigramTransitionProbabilities,
                                                            quadgramTransitions,
                                                            quadgramTransitionProbabilities)))
    beat_perplexities.append(float(beat_trigram_perplexity(f"markov_con_samples/Kirby_sAdventure/markov_con_KA{i}.mid", "Kirby_sAdventure",
                                                            bigramBeatPosTransitions,
                                                            bigramBeatPosTransitionProbabilities,
                                                            trigramBeatPosTransitions,
                                                            trigramBeatPosTransitionProbabilities)))
# WaiWaiWorld2_SOS__ParsleyJou
for i in range(3):
    note_perplexities.append(float(note_quadgram_perplexity(f"markov_con_samples/WaiWaiWorld2_SOS__ParsleyJou/markov_con_WWW{i}.mid", "WaiWaiWorld2_SOS__ParsleyJou",
                                                            unigramProbabilities,
                                                            bigramTransitions,
                                                            bigramTransitionProbabilities,
                                                            trigramTransitions,
                                                            trigramTransitionProbabilities,
                                                            quadgramTransitions,
                                                            quadgramTransitionProbabilities)))
    beat_perplexities.append(float(beat_trigram_perplexity(f"markov_con_samples/WaiWaiWorld2_SOS__ParsleyJou/markov_con_WWW{i}.mid", "WaiWaiWorld2_SOS__ParsleyJou",
                                                            bigramBeatPosTransitions,
                                                            bigramBeatPosTransitionProbabilities,
                                                            trigramBeatPosTransitions,
                                                            trigramBeatPosTransitionProbabilities)))

print(note_perplexities)
print(np.average(note_perplexities))
print(beat_perplexities)
print(np.average(beat_perplexities))

[1.546581111444354, 1.6420355092378545, 1.4833347216587007, 1.5239957046701205, 1.5244651145469572, 2.258497363924173, 1.5807951470366408, 1.7540791644513178, 1.4340261581648557, 1.7795987429674354, 1.691940617124893, 2.539552597769951, 1.7276863134813987, 2.312454898827371, 1.1469088924791164]
1.7297301371856764
[3.065628998325484, 3.440727073096312, 2.6155855664324754, 2.4887146004567056, 2.5906426232362665, 2.865517495542322, 2.6268585419593062, 2.6132582864530747, 2.4136805377567936, 1.2687810974867195, 1.3206599121502416, 1.043213111314071, 2.4253505229519554, 2.775566665593575, 2.2543527407876716]
2.3872358515695313
