# Environment

Now we have a trainable discriminator - it's time to build the environment

## Imports and Setup

In [None]:
## Imports and data loading

%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt

from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.layers import TimeDistributed
from tensorflow.keras import metrics

import gym

from musicrl.midi2vec import MidiVectorMapper, PostProcessor
from musicrl.render import *
from musicrl.random_generator import resemble_midi, repair_generated_seq
from musicrl.data import RandomMidiDataGenerator
from musicrl import mel_lstm

import pretty_midi
from glob import glob


REAL = 1
GEN = 0

na = None # new axis

In [None]:
filepaths = list(glob('maestro-v2.0.0/2008/**.midi'))
real_midis = [pretty_midi.PrettyMIDI(i) for i in filepaths]
mapper = MidiVectorMapper(real_midis)

In [None]:
mapper = MidiVectorMapper(real_midis)
real_seq = mapper.midi2vec(real_midis[1])
real_seq.shape

In [None]:
mapper.dims

## Environment

In [None]:
# import gym
import pretty_midi


class MelEnvironment(gym.Env):
    """Environment to train generating midi data in a self defined
    vector space. The midi vector representation is defined via the
    mapper object.
    The waveform for the single instrument is then preprocessed for the
    discriminator, and at each time step, the discriminators final prediction
    serves as reward.
    The preprocessed waveform, i.e. the mel spectrogram, also serves as
    observation. The number of time frames that are used for the observation
    are defined by the constant `self.N_TIMESTEPS`.
    One session is understood as one song.
    
    Always works on batches - i.e. multiple songs/trajectories in parallel
    
    Gets:
        discriminator: keras.Model: np.array(preprocessed) -> np.array(#time_steps, 1)
        preprocess: function: np.array(#actions): waveform -> np.array(preprocessed) : spectrogram
        mapper: musicrl.midi2vec.MidiVectorMapper
        N_TIMESTEPS: int: number of timesteps used to generate the observation
        MAX_NUM_ACTIONS: int: number of actions after which to end a trajectory
    """
    def __init__(self, discriminator, preprocess_wav, mapper, N_TIMESTEPS=100, MAX_NUM_ACTIONS=10000):
        super().__init__()
        # N_TIMESTEPS is used to define the observation:
        # This many timeframes of the spectrogram are fed
        # back to the generator
        self.N_TIMESTEPS = N_TIMESTEPS
        self.MAX_NUM_ACTIONS = MAX_NUM_ACTIONS
        # Define action and observation space
        # They must be gym.spaces objects
        self.action_space = gym.spaces.Box(0, np.inf, shape=(mapper.dims,))
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf,
                                           shape=(self.N_TIMESTEPS, 128), dtype=np.float32)
        self.discriminator = discriminator
        self.preprocess_wav = preprocess_wav
        self.mapper = mapper
        self.fr = 44100
        self.rewards = []
        self.current_seq = []
        self.current_midi = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=300)
        self.current_midi.instruments.append(pretty_midi.Instrument(program=0))
        self.current_observation = np.zeros((self.N_TIMESTEPS, 128))
     
    def _update_wav(self, action):
        """Appends the action to the current_seq and synthesizes the sound
        if an actual note was played
        
        Gets:
            action: np.array of shape (mapper.dims)
        Returns:
            updated: Boolean, True iff a note was played
        """
        self.current_seq.append(action)
        event = mapper.action2note(action, start=len(self.current_seq)*mapper.time_per_tick)
        if isinstance(event, pretty_midi.Note):
            if len(self.current_midi.instruments[0].notes) == 1:
                # It is the first note, so we synthesize
                self.current_midi.instruments[0].notes.append(event)
                self.current_midi.instruments[0].synthesize(self.fr)
            else:
                self.current_midi.instruments[0].append_and_synthesize(event)
            return True
        else:
            return False
        
    def step(self, action):
        if not self._update_wav(action):
            return self.current_observation, 0, action[4]>0.5, None
        preprocessed = self.preprocess_wav(self.current_wav, self.fr)[na]
        prediction = self.discriminator.predict_on_batch(preprocessed)
        observation = np.zeros((self.N_TIMESTEPS, 128))
        observation[-min(self.N_TIMESTEPS, len(preprocessed[0])):] = preprocessed[0, -self.N_TIMESTEPS:]
        self.current_observation = observation
        self.current_prediction = prediction
        reward = prediction[0, -1, 0]
        self.rewards.append(reward)
        # TODO: add a end token to mapper (issue #1)
        done = len(self.current_seq) >= self.MAX_NUM_ACTIONS
        return observation, reward, done, None

    def reset(self):
        self.current_seq = []
        self.current_midi = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=300)
        self.current_midi.instruments.append(pretty_midi.Instrument(program=0))
        self.current_observation = np.zeros((self.N_TIMESTEPS, 128))
        self.rewards = []
        return self.current_observation
    
    @property
    def current_wav(self):
        return self.current_midi.instruments[0].synthesized

    def render(self, mode='human'):
        plot_spectro(self.current_observation.T, "Current observation")
    
    def close (self):
        pass
        

Let's use the environment by pretending to take actions of a real midi sequence, just to check that everything works as expected

In [None]:
discriminator = load_model("models/mel_lstm.h5")

env = MelEnvironment(discriminator, mel_lstm.preprocess_wav, mapper, 1000)

observations = []
for i, action in enumerate(real_seq):
    observation, _, _, _ = env.step(action)
    observations.append(observation)
    if i+1 % 10000 == 0:
        env.render()
        plt.show()
        break
        
    
# display(Audio(env.current_wav, rate=44100))

In [None]:
env.render()

In [None]:
Audio(env.current_wav, rate=44100)

In [None]:
again = mapper.vec2midi(env.current_seq)
listen_to(again)

In [None]:
listen_to(real_midis[1])

In [None]:
plt.plot(env.rewards)
plt.show()

In [None]:
real_seq = mapper.midi2vec(real_midis[1])

# Generator

In order to implement the actor, we need a generator that takes observations and outputs a state. Before we build stuff for the reinforcement learning training loop, we will try to find an architecture that takes realistic inputs and generates something playable.

I have implemented the `midi2vec.postprocess` such that it transforms the output of an untrained LSTM into something that is not just silence - otherwise, the actor will never play a keyboard, never get any reward and never learn anything. A drama that we need to avoid!

In [None]:

model = Sequential()

model.add(LSTM(128,
        return_sequences=False,
        batch_input_shape=(10000, 1, 128000),
        stateful=True))
model.add(Dense(128, activation='sigmoid'))
model.add(Dense(mapper.dims, activation='relu'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=[metrics.binary_accuracy])

postprocess = PostProcessor([mapper.midi2vec(real_midi) for real_midi in real_midis[:5]])

observations = np.array(observations[:10000])
states = observations.reshape(10000, 1, -1)
gen_seq = model.predict(states)
gen_seq = postprocess(gen_seq)
gen_midi = mapper.vec2midi(gen_seq)
listen_to(gen_midi)

## Random generator on new `MidiVectorMapper`

This is a small leftover from developing the postprocessor - and a demonstration that it is now not so hard anymore to generate something. Gaussian noise almost does the trick.

In [None]:
noise = np.random.normal(0.2, 1, size=(15000, 5))
noise[:,4] = 0
rand_seq = postprocess(noise)
rand_midi = mapper.vec2midi(rand_seq)
listen_to(rand_midi)