In [103]:
import sys
sys.path.append("../src")

In [104]:
import glob
import os
import subprocess
from typing import List, Tuple, Union, Any

import hydra
import joblib
import music21
import numpy as np
from music21.chord import Chord
from music21.note import Note
from music21.stream.base import Score
from rich.progress import track

In [105]:
from utils import make_sequence

In [106]:
class MusicXMLFeature(object):
    def __init__(
        self,
        xml_file: str,
        key_root: str = "C",
        num_beats: int = 4,
        num_parts_of_beat: int = 4,
        max_measure_num: int = 240,
        min_note_num: int = 36,
        max_note_num: int = 84,
    ) -> None:
        assert key_root in ["C", "D", "E", "F", "G", "A", "B"]

        self.score = self._get_score(xml_file, key_root)
        self.num_beats = num_beats
        self.num_parts_of_beat = num_parts_of_beat
        self.max_measure_num = max_measure_num
        self.min_note_num = min_note_num
        self.max_note_num = max_note_num

    def _get_score(self, xml_file: str, root: str) -> Score:
        score: Score = music21.converter.parse(
            xml_file, format="musicxml"
        )  # type: ignore
        key = score.analyze("key")
        interval = music21.interval.Interval(
            key.tonic, music21.pitch.Pitch(root)  # type: ignore
        )
        score.transpose(interval, inPlace=True)
        return score

    def get_mode(self) -> str:
        key = self.score.analyze("key")
        mode = "None" if key is None else str(key.mode)
        return mode

    def get_note_seq(self) -> List[Union[None, Note]]:
        note_seq: List[Union[None, Note]] = [None] * int(
            self.max_measure_num * self.num_beats * self.num_parts_of_beat
        )

        for measure in self.score.parts[0].getElementsByClass("Measure"):
            for note in measure.getElementsByClass("Note"):
                onset = measure.offset + note._activeSiteStoredOffset
                offset = onset + note._duration.quarterLength

                start_idx = int(onset * self.num_parts_of_beat)
                end_idx = int(offset * self.num_parts_of_beat + 1)

                num_item = int(end_idx - start_idx)
                note_seq[start_idx:end_idx] = [note] * num_item

        return note_seq

    def get_onehot_note_seq(self) -> np.ndarray:
        note_seq = self.get_note_seq()
        note_num_seq = [
            int(n.pitch.midi - self.min_note_num) if n is not None else -1
            for n in note_seq
        ]
        num_note = self.max_note_num - self.min_note_num + 1
        onehot_note_seq = np.identity(num_note)[note_num_seq]
        return onehot_note_seq

    def get_seq_notenum(self) -> np.ndarray:
        seq_note = self.get_note_seq()
        seq_notenum = np.array(
            [
                int(n.pitch.midi) - self.min_note_num + 1
                if n is not None
                else 0
                for n in seq_note
            ]
        )
        return seq_notenum

    def get_chord_seq(self) -> List[Union[None, Chord]]:
        chord_seq: List[Union[None, Chord]] = [None] * int(
            self.max_measure_num * self.num_beats * self.num_parts_of_beat
        )

        for measure in self.score.parts[0].getElementsByClass("Measure"):
            for note in measure.getElementsByClass("ChordSymbol"):
                offset = measure.offset + note.offset

                start_idx = int(offset * self.num_parts_of_beat)
                end_idx = (
                    int(
                        (measure.offset + self.num_beats)
                        * self.num_parts_of_beat
                    )
                    + 1
                )
                num_item = int(end_idx - start_idx)
                chord_seq[start_idx:end_idx] = [note] * num_item

        return chord_seq

    def get_onehot_chord_seq(self) -> np.ndarray:
        chord_seq = self.get_chord_seq()
        onehot_chord_seq = np.zeros((len(chord_seq), 12))
        for i, chord in enumerate(chord_seq):
            if chord is None:
                continue
            for note in chord._notes:
                onehot_chord_seq[i, note.pitch.midi % 12] = 1
        return onehot_chord_seq

In [107]:
files = [
    "/workspace/data/xml/Celerity.xml",
    "/workspace/data/xml/Anthropology.xml",
    "/workspace/data/xml/Diverse.xml"
]

In [108]:
for f in files:
    feat = MusicXMLFeature(f)
    print(len(feat.get_note_seq()))
    print(len(feat.get_chord_seq()))
    print()

3840
3840

3840
3840

3840
3840



In [109]:
notes = feat.get_note_seq()
onehot = feat.get_onehot_note_seq()
chords = feat.get_chord_seq()
chords_chorma = feat.get_onehot_chord_seq()

In [110]:
seq_note = make_sequence(np.array(notes), 128)
seq_chord = make_sequence(chords_chorma, 128)

seq_note.shape, seq_chord.shape

((30, 128), (30, 128, 12))

In [111]:
notenum = feat.get_seq_notenum()
notenum[:20]

array([ 0,  0, 32, 32, 32, 32, 28, 28, 25, 25, 20, 20, 20,  0, 20, 20, 30,
       30, 22, 22])

In [112]:
notes[:20]

[None,
 None,
 <music21.note.Note G>,
 <music21.note.Note G>,
 <music21.note.Note G>,
 <music21.note.Note G>,
 <music21.note.Note E->,
 <music21.note.Note E->,
 <music21.note.Note C>,
 <music21.note.Note C>,
 <music21.note.Note G>,
 <music21.note.Note G>,
 <music21.note.Note G>,
 None,
 <music21.note.Note G>,
 <music21.note.Note G>,
 <music21.note.Note F>,
 <music21.note.Note F>,
 <music21.note.Note A>,
 <music21.note.Note A>]

In [113]:
chords[:20]

[<music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Dm>,
 <music21.harmony.ChordSymbol Dm>,
 <music21.harmony.ChordSymbol Dm>,
 <music21.harmony.ChordSymbol Dm>]

---

In [187]:
class MusicXMLFeature(object):
    def __init__(
        self,
        xml_file: str,
        key_root: str = "C",
        num_beats: int = 4,
        num_parts_of_beat: int = 4,
        max_measure_num: int = 240,
        min_note_num: int = 36,
        max_note_num: int = 84,
    ) -> None:
        assert key_root in ["C", "D", "E", "F", "G", "A", "B"]

        self.score = self._get_score(xml_file, key_root)
        self.num_beats = num_beats
        self.num_parts_of_beat = num_parts_of_beat
        self.max_measure_num = max_measure_num
        self.min_note_num = min_note_num
        self.max_note_num = max_note_num

        self.notes, self.chords = self.get_notes_and_chords()

    def _get_score(self, xml_file: str, root: str) -> Score:
        score: Score = music21.converter.parse(
            xml_file, format="musicxml"
        )  # type: ignore
        key = score.analyze("key")
        interval = music21.interval.Interval(
            key.tonic, music21.pitch.Pitch(root)  # type: ignore
        )
        score.transpose(interval, inPlace=True)
        return score

    def get_mode(self) -> str:
        key = self.score.analyze("key")
        mode = "None" if key is None else str(key.mode)
        return mode

    def get_notes_and_chords(
        self,
    ) -> Tuple[List[Union[None, Note]], List[Union[None, Chord]]]:
        notes = []
        chords = []
        for measure in self.score.parts[0].getElementsByClass("Measure"):
            m_notes = [None] * self.num_beats * self.num_parts_of_beat
            for note in measure.getElementsByClass("Note"):
                onset = note._activeSiteStoredOffset
                offset = onset + note._duration.quarterLength

                start_idx = int(onset * self.num_parts_of_beat)
                end_idx = int(offset * self.num_parts_of_beat) + 1
                end_idx = end_idx if end_idx < 16 else 16

                num_item = int(end_idx - start_idx)
                m_notes[start_idx:end_idx] = [note] * num_item
            notes.extend(m_notes)

            m_chords = [None] * self.num_beats * self.num_parts_of_beat
            for chord in measure.getElementsByClass("ChordSymbol"):
                offset = chord.offset

                start_idx = int(offset * self.num_parts_of_beat)
                end_idx = int(self.num_beats * self.num_parts_of_beat) + 1
                end_idx = end_idx if end_idx < 16 else 16

                num_item = int(end_idx - start_idx)
                m_chords[start_idx:end_idx] = [chord] * num_item
            chords.extend(m_chords)

        return notes, chords

    def get_seq_notenum(self) -> np.ndarray:
        # NOTE: 0 is empty note number.
        seq_notenum = [
            int(n.pitch.midi) - self.min_note_num + 1 if n is not None else 0
            for n in self.notes
        ]
        return np.array(seq_notenum)

    def get_seq_note_onehot(self) -> np.ndarray:
        notenum = self.get_seq_notenum()
        
        num_note = self.max_note_num - self.min_note_num + 1
        seq_note_onehot = np.identity(num_note)[notenum]
        return seq_note_onehot

    def get_seq_chord_chorma(self) -> np.ndarray:
        onehot_chord_seq = np.zeros((len(self.chords), 12))
        for i, chord in enumerate(self.chords):
            if chord is not None:
                for note in chord._notes:
                    onehot_chord_seq[i, note.pitch.midi % 12] = 1
        return onehot_chord_seq



In [188]:
for f in files:
    feat = MusicXMLFeature(f)
    notes, chords = feat.get_notes_and_chords()
    print(len(notes))
    print(len(chords))
    print()

1024
1024

2096
2096

1552
1552



In [189]:
len(feat.score.parts[0].getElementsByClass("Measure")) * 16

1552

In [190]:
notes, chords = feat.get_notes_and_chords()
seq_notenum = feat.get_seq_notenum()
seq_note_onehot = feat.get_seq_note_onehot()
seq_chord_chroma = feat.get_seq_chord_chorma()

In [191]:
make_sequence(seq_chord_chroma, 128).shape

(12, 128, 12)

In [192]:
make_sequence(seq_notenum, 128).shape

(12, 128)

In [193]:
make_sequence(seq_note_onehot, 128).shape

(12, 128, 49)

In [194]:
notes[:20]

[None,
 None,
 <music21.note.Note G>,
 <music21.note.Note G>,
 <music21.note.Note G>,
 <music21.note.Note G>,
 <music21.note.Note E->,
 <music21.note.Note E->,
 <music21.note.Note C>,
 <music21.note.Note C>,
 <music21.note.Note G>,
 <music21.note.Note G>,
 <music21.note.Note G>,
 None,
 <music21.note.Note G>,
 <music21.note.Note G>,
 <music21.note.Note F>,
 <music21.note.Note F>,
 <music21.note.Note A>,
 <music21.note.Note A>]

In [195]:
chords[:20]

[<music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Cm>,
 <music21.harmony.ChordSymbol Dm>,
 <music21.harmony.ChordSymbol Dm>,
 <music21.harmony.ChordSymbol Dm>,
 <music21.harmony.ChordSymbol Dm>]

In [196]:
seq_notenum[:10]

array([ 0,  0, 32, 32, 32, 32, 28, 28, 25, 25])

In [197]:
seq_note_onehot[:10].argmax(axis=1)

array([ 0,  0, 32, 32, 32, 32, 28, 28, 25, 25])

In [198]:
seq_note_onehot.shape

(1552, 49)

In [215]:
import torch
import torch.nn

loss_fn = nn.CrossEntropyLoss()

bs = 64
inputs = torch.from_numpy(seq_note_onehot[:bs].astype(np.float64))
targets = torch.from_numpy(seq_notenum[:bs].astype(np.int64))

In [216]:
loss_fn(inputs, targets)

tensor(2.9263, dtype=torch.float64)