In [None]:
import numpy as np
import tensorflow as tf

import magenta.music as mm
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel
from magenta.music.sequences_lib import concatenate_sequences

from utils import strip_to_melody, remove_melody
from get_model import get_model
from keras.callbacks import ModelCheckpoint

%load_ext autoreload
%autoreload 2

In [None]:
def inference(midi_input, melody_model, trio_model, config, use_original_melody=True):
    '''
        midi_input:   the read in midi file (i.e. midi_file.read())
        melody_model: Our model that takes a melody and predicts the latent space encoding
        trio_model:   The magenta trio model that we are using to decode the latent vector
        config:       Used for data conversions. 
                      Can get by e.g. config = configs.CONFIG_MAP[model_name]
                      Might be able to just use the config in trio_model? Unsure.
    '''
    # Convert the midi to a NoteSequence
    input_note_seq = mm.midi_to_sequence_proto(midi_input)
    # mm.midi_to_sequence_proto(midi_file.read())
    
    # Convert the sequence to tensors, and then strip out just the melody.
    trio_tensors  = config.data_converter.to_tensors(input_note_seq).outputs
    
    # TODO: is this filter necessary? What does it do?
#     trio_tensors  = trio_tensors = list(
#                         filter(lambda t: t.shape == (TIMESTEPS, DIM_TRIO),
#                                trio_tensors)
#                     )
    melody_tensors = np.array(list(map(lambda t: t[:, :DIM_MELODY],
                                       trio_tensors)))
    

    # Get the latent representation of just the melody using our trained model
    latent_code = melody_model.predict(melody_tensors)
    
    # Decode the latent representation of the melody into 3 parts using Trio
    # Note that this returns an array of different, related musical sections.
    # We use concatenate_sequences to combine them all into one longer piece.
    output_trio_seq = concatenate_sequences(trio_model.decode(latent_code))
    
    if not use_original_melody:
        return output_trio_seq

    # Slice in the orignal melody
    
    # Take out only the generated accompaniment
    # TODO: Does
    non_melody_seq = remove_melody(output_trio_seq)
    
    # Stitch the original melody and the new accompaniment together.
    recombined_seq = strip_to_melody(input_note_seq)
    non_melody_seq.MergeFrom(recombined_seq)
    return non_melody_seq
#     recombined_seq.MergeFrom(non_melody_seq)
#     return recombined_seq
    

# Testing Inference

## Parameter Setup 

### Trio Model

In [None]:
model_name_melody_2bar  = 'cat-mel_2bar_big'
model_name_melody_16bar = 'hierdec-mel_16bar'
model_name_trio_16bar   = 'hierdec-trio_16bar'

model_name = model_name_trio_16bar
config = configs.CONFIG_MAP[model_name]

trio_model = TrainedModel(config,
                     batch_size=16,
                     checkpoint_dir_or_path='./models/pretrained/{}.ckpt'.format(model_name))

### midi

In [None]:
midi = None
with open('./data/lmd_clean/Toto/Africa.3.mid', 'rb') as midi_file:
    midi = midi_file.read()
input_note_seq = mm.midi_to_sequence_proto(midi)

### Load the Melody Model

In [None]:
melody_model = get_model()
melody_model.load_weights("./models/checkpoints/bi_rnn_test/01-0.0682.hdf5")

## Running inference

In [None]:
generated_seq = inference(midi, melody_model, trio_model, config)

In [None]:
mm.play_sequence(input_note_seq, synth=mm.fluidsynth)

In [None]:
mm.play_sequence(generated_seq, synth=mm.fluidsynth)

In [None]:
input_note_seq