In [21]:
import magenta
import os

In [68]:
import ast
import os
import time

# internal imports

import tensorflow as tf
import magenta

from magenta.models.drums_rnn import drums_rnn_config_flags
from magenta.models.drums_rnn import drums_rnn_model
from magenta.models.drums_rnn import drums_rnn_sequence_generator

from magenta.protobuf import generator_pb2
from magenta.protobuf import music_pb2

In [82]:
import pretty_midi
# Load MIDI file into PrettyMIDI object

# Notes of magenta
# {36, 38, 42, 45, 46, 48, 49, 50, 51}

# 36 C1 Bass Drum 1
# 38 D1 Acoustic Snare
# 42 F#1 Closed Hi Hat
# 45 A1 Low Tom
# 46 Bb1 Open Hi-Hat
# 48 C2 Hi Mid Tom
# 49 C#2 Crash Cymbal 1
# 50 D2 High Tom
# 51 Eb2 Ride Cymbal 1

humanNotes = [36, 45, 46]
machineNotes = [38, 48, 51]
mapping = {}

mapping[42] = 51
mapping[49] = -1
mapping[50] = 48


def binNotes(notes, length = 0.125):
    beats = notes[-1].start / length
    bins = [[] for i in range( int(beats + 1) )]
    for note in notes: 
        beat = int(note.start / length)
        midi = remap(note.pitch)
        if midi != -1: bins[beat].append(midi)
    return bins

def uniqueNotes(bins):
    return set(sum(bins, []))

def fragment(bins, i, length=12):
    i = (i % (len(bins) / length))
    return bins[length*i: length*(i + 1)]


def remap(pitch):
    if pitch in humanNotes: return pitch
    elif pitch in machineNotes: return pitch
    return mapping[pitch]

def binNotes(notes, length = 0.125):
    beats = notes[-1].start / length
    bins = [[] for i in range( int(beats + 1) )]
    for note in notes: 
        beat = int(note.start / length)
        midi = remap(note.pitch)
        if midi != -1: bins[beat].append(midi)
    return bins

def getDistance(fragment):
    distance = 0
    for i in range(len(primer_drums)):
        s = frozenset(fragment[i])
        p = primer_drums[i]
        distance += len(s.union(p).difference(p))
    return distance

In [83]:
bundle_file = os.path.expanduser("drum_kit_rnn.mag")
bundle = magenta.music.read_bundle_file(bundle_file)


config_id = bundle.generator_details.id
config = drums_rnn_model.default_configs[config_id]

beam_size = 1
branch_factor = 1

config.hparams.batch_size = min(
      config.hparams.batch_size, beam_size * branch_factor)

In [84]:
generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator(
  model=drums_rnn_model.DrumsRnnModel(config),
  details=config.details,
  steps_per_quarter=config.steps_per_quarter,
  bundle=bundle)



In [190]:
qpm = 120
num_steps = 12*10
temperature = 0.5
branch_factor = 1
beam_size = 1
steps_per_iteration = 1
primer = "[(36,45), (), (36,), (), (36,), (36,), (), (36,), (36,46,), (45,), (36,46,), ()]"

primer_drums = magenta.music.DrumTrack(
    [frozenset(pitches)
     for pitches in ast.literal_eval(primer_drums)])

primer_sequence = primer_drums.to_sequence(qpm=qpm)
seconds_per_step = 60.0 / qpm / generator.steps_per_quarter
total_seconds = num_steps * seconds_per_step
generator_options = generator_pb2.GeneratorOptions()

input_sequence = primer_sequence
last_end_time = (max(n.end_time for n in primer_sequence.notes)
                 if primer_sequence.notes else 0)

generate_section = generator_options.generate_sections.add(
    start_time=last_end_time + seconds_per_step,
    end_time=total_seconds)

if generate_section.start_time >= generate_section.end_time:
    tf.logging.fatal(
      'Priming sequence is longer than the total number of steps '
      'requested: Priming sequence length: %s, Generation length '
      'requested: %s',
      generate_section.start_time, total_seconds)
else:
    generator_options.args['temperature'].float_value = temperature
    generator_options.args['beam_size'].int_value = beam_size
    generator_options.args['branch_factor'].int_value = branch_factor
    generator_options.args['steps_per_iteration'].int_value = steps_per_iteration

    generated_sequence = generator.generate(input_sequence, generator_options)
    generated_midi = magenta.music.sequence_proto_to_pretty_midi(generated_sequence)
    instrument = generated_midi.instruments[1]

instrument.notes
bins = binNotes(instrument.notes)
for i in range(12):
    print getDistance(fragment(bins, i))

INFO:tensorflow:Beam search yields sequence with log-likelihood: -58.464188 
0
3
1
4
0
3
1
4
0
3
0
3


In [192]:
print primer_drums

print fragment(bins, 1)

frozenset([])
[[], [], [], [], [36], [], [36], [], [36], [36], [], [36]]


0
6
5
7
2
7
3
7
2
7
0
6


In [165]:
fragment(bins, 1)

[[], [], [], [], [36], [], [36], [], [36], [36], [], [36]]