In [None]:
#!pip install magenta

In [None]:
import numpy as np
import magenta.music as mm
import magenta.music.midi_io as midi_io
from magenta.models.score2perf import score2perf
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils import decoding
from tensor2tensor.data_generators import text_encoder

In [None]:
model_name = 'transformer'
hparams_set = 'transformer_tpu'
ckpt_path = './melody_conditioned_model_16.ckpt/melody_conditioned_model_16.ckpt'

In [None]:
class MelodyToPianoPerformanceProblem(score2perf.AbsoluteMelody2PerfProblem):
    @property
    def add_eos_symbol(self):
        return True


In [None]:
problem = MelodyToPianoPerformanceProblem()
melody_conditioned_encoders = problem.get_feature_encoders()

In [None]:
hparams = trainer_lib.create_hparams(hparams_set=hparams_set)
trainer_lib.add_problem_hparams(hparams, problem)
hparams.num_hidden_layers = 16
hparams.sampling_method = 'random'

In [None]:
decode_hparams = decoding.decode_hparams()
decode_hparams.alpha = 0.0
decode_hparams.beam_size = 1

In [None]:
run_config = trainer_lib.create_run_config(hparams)
estimator = trainer_lib.create_estimator(model_name, hparams, run_config,
                                         decode_hparams=decode_hparams)

In [None]:
inputs = []
decode_length = 0

In [None]:
# Create input generator.
def input_generator():
    global inputs
    while True:
        yield {
            'inputs': np.array([[inputs]], dtype=np.int32),
            'targets': np.zeros([1, 0], dtype=np.int32),
            'decode_length': np.array(decode_length, dtype=np.int32)
        }


In [None]:
input_fn = decoding.make_input_fn_from_generator(input_generator())
melody_conditioned_samples = estimator.predict(input_fn, checkpoint_path=ckpt_path)
next(melody_conditioned_samples)

In [None]:
melody_ns = mm.midi_file_to_note_sequence("chopin.mid")
melody_instrument = mm.infer_melody_for_sequence(melody_ns)
notes = [note for note in melody_ns.notes
         if note.instrument == melody_instrument]
del melody_ns.notes[:]
melody_ns.notes.extend(sorted(notes, key=lambda note: note.start_time))
for i in range(len(melody_ns.notes) - 1):
    melody_ns.notes[i].end_time = melody_ns.notes[i + 1].start_time
inputs = melody_conditioned_encoders['inputs'].encode_note_sequence(melody_ns)

In [None]:
# Decode a list of IDs.
def decode(ids, encoder):
    ids = list(ids)
    if text_encoder.EOS_ID in ids:
        ids = ids[:ids.index(text_encoder.EOS_ID)]
    return encoder.decode(ids)

In [None]:
decode_length = 4096
sample_ids = next(melody_conditioned_samples)['outputs']
# To note sequence
midi_filename = decode(sample_ids, encoder=melody_conditioned_encoders['targets'])
accompaniment_ns = mm.midi_file_to_note_sequence(midi_filename)
mm.sequence_proto_to_midi_file(accompaniment_ns, "output.mid")

mm.plot_sequence(accompaniment_ns)
mm.play_sequence(accompaniment_ns)