# Environment

Now we have a trainable discriminator - it's time to build the environment

In [25]:
## Imports and data loading

%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt

from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.layers import TimeDistributed
from tensorflow.keras import metrics

from musicrl.midi2vec import MidiVectorMapper
from musicrl.render import *
from musicrl.random_generator import generate_random_midi, resemble_midi
from musicrl.data import RandomMidiDataGenerator

import pretty_midi
from glob import glob


REAL = 1
GEN = 0

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
filepaths = list(glob('maestro-v2.0.0/2008/**.midi'))
real_midis = [pretty_midi.PrettyMIDI(i) for i in filepaths]
mapper = MidiVectorMapper(real_midis)

In [None]:
TODO
- vec2note
- append

In [4]:
mapper = MidiVectorMapper(real_midis)
real_seq = mapper.midi2vec(real_midis[1])
real_seq.shape

(2939, 9)

In [27]:
notes = []
for event in real_seq:
    if isinstance(mapper.action2note(event), pretty_midi.Note):
        notes.append(event)
        
[i[0] for i in notes]

[0.9895833333333333,
 1.0520833333333333,
 1.12109375,
 1.1744791666666665,
 1.234375,
 1.3072916666666665,
 1.4075520833333333,
 1.45703125,
 1.5377604166666665,
 1.5950520833333333,
 1.6705729166666665,
 1.7486979166666665,
 1.8098958333333333,
 1.8815104166666665,
 1.9244791666666665,
 1.9934895833333333,
 2.0768229166666665,
 2.140625,
 2.2096354166666665,
 2.28515625,
 2.3580729166666665,
 2.41015625,
 2.484375,
 2.548177083333333,
 2.6354166666666665,
 2.7265625,
 2.7682291666666665,
 2.8515625,
 2.9088541666666665,
 2.984375,
 3.05859375,
 3.140625,
 3.2018229166666665,
 3.2604166666666665,
 3.34765625,
 3.415364583333333,
 3.4518229166666665,
 3.58984375,
 3.68359375,
 3.765625,
 3.8346354166666665,
 3.891927083333333,
 3.984375,
 4.049479166666666,
 4.135416666666666,
 4.192708333333333,
 4.276041666666666,
 4.319010416666666,
 4.412760416666666,
 4.479166666666666,
 4.546875,
 4.5859375,
 4.6953125,
 4.701822916666666,
 4.79296875,
 4.881510416666666,
 4.923177083333333,
 5.0

In [32]:
import gym
import pretty_midi


class MelEnvironment(gym.Env):
    """We ignore control change events for now
    """
    def __init__(self, discriminator, mapper):
        super().__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        # self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
        # Example for using image as input:
        # self.observation_space = spaces.Box(low=0, high=255,
        #                                    shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)
        self.discriminator = discriminator
        self.mapper = mapper
        self.fr = 44100
        self.current_seq = []
        self.current_midi = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=300)
        self.current_midi.instruments.append(pretty_midi.Instrument(program=0))
        
    def step(self, action):
        self.current_seq.append(action)
        event = mapper.action2note(action)
        if isinstance(event, pretty_midi.Note):
            if len(self.midi.instruments[0].notes) == 1:
                # It is the first note, so we synthesize
                self.midi.instruments[0].synthesize(self.fr)
            else:
                self.midi.instruments[0].notes.append_and_synthesize(event)
                
        return observation, reward, done, info

    def reset(self):
        self.current_seq = []
        self.current_midi = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=300)
        self.current_midi.instruments.append(pretty_midi.Instrument(program=0))
        return observation  # reward, done, info can't be included

    def render(self, mode='human'):
        pass
    
    def close (self):
        pass

In [31]:
fr = 44100
seq = [notes[0]]
midi = mapper.vec2midi(np.array(seq))
wav = midi.synthesize(fr)

for action in notes[1:]:
    note = mapper.action2note(action)
    midi.instruments[0].append_and_synthesize(note)
    
display(Audio(data=midi.instruments[0].synthesized, rate=fr))
