# Imports and midi loading

Now we have a trainable discriminator - it's time to build the environment

TODO:
    - discriminator into env
        - reverse postprocessing to set range for actions
        - use discriminator prediction as reward
        - train discriminator after trajectory
        - make it more efficient by making discriminator stateful and always feed a single time step
        
    - models => lstm
   

In [1]:
## Imports and data loading

%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt

from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.layers import TimeDistributed
from tensorflow.keras import metrics

from musicrl.midi2vec import MidiVectorMapper
from musicrl.render import *
from musicrl.data import RandomMidiDataGenerator

import pretty_midi
from glob import glob

import pprint
pprint = pprint.PrettyPrinter(indent=4).pprint

pprint(discriminator.layers[0]._stateful)
pprint(discriminator.trainable)


REAL = 1
GEN = 0

In [2]:
filepaths = list(glob('maestro-v2.0.0/2008/**.midi'))
real_midis = [pretty_midi.PrettyMIDI(i) for i in filepaths]
mapper = MidiVectorMapper(real_midis)

In [3]:
mapper = MidiVectorMapper(real_midis)
real_seq = mapper.midi2vec(real_midis[1])
real_seq.shape

(60867, 5)

# Load the discriminator

In [34]:
discriminator = load_model("models/seq_lstm.h5")

In [38]:
print(discriminator.inputs)
discriminator.summary()

[<tf.Tensor 'lstm_3_input:0' shape=(None, None, 5) dtype=float32>]
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_3 (LSTM)                (None, None, 128)         68608     
_________________________________________________________________
time_distributed_6 (TimeDist (None, None, 128)         16512     
_________________________________________________________________
time_distributed_7 (TimeDist (None, None, 1)           129       
Total params: 85,249
Trainable params: 85,249
Non-trainable params: 0
_________________________________________________________________


In [37]:
mapper.dims

5

# Environment

In [109]:
import gym
import pretty_midi


class SeqEnvironment(gym.Env):
    """We ignore control change events for now
    """
    def __init__(
        self,
        discriminator,
        mapper,
        change_rate=0.01,
        batch_size=32
    ):
        super().__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        # self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
        # Example for using image as input:
        # self.observation_space = spaces.Box(low=0, high=255,
        #                                    shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)
        self.discriminator = discriminator
        self.mapper = mapper
        self.observation_shape = [mapper.dims]
        self.change_rate = change_rate
        self.batch_size = batch_size
        self.reset()
        
    def step(self, actions):
        assert len(actions)==self.batch_size, f"Expected batch_size of {self.batch_size}"
        self.current_seqs.append(actions)
        bs = len(actions)
        self.observations  = self.observations \
            + np.random.normal(0, self.change_rate, size=[bs] + self.observation_shape)
        self.done = self.done | self.mapper.is_done(actions)
        reward = self.discriminator.predict(
            np.transpose(self.current_seqs, [1, 0, 2])
        )[:,-1,:]
        return np.array(self.observations), reward, self.done, None
    
    def reset(self):
        self.current_seqs = []
        self.done = np.array([False]*self.batch_size)
        self.observations = np.random.normal(0, 1, size=[self.batch_size] + self.observation_shape)
        return self.observations  # reward, done, info can't be included

    def render(self, mode='human'):
        pass
    
    def close (self):
        pass

# Training Loop

In [111]:
from musicrl.agent import *
from musicrl.models import *
from tqdm import *
import pdb

batch_size = 4
agent = DDPG(mapper.dims,mapper.dims,act_range=3)
env = SeqEnvironment(discriminator, mapper, batch_size=batch_size)


# First, gather experience
config = {
    "nb_episodes" : 1
}

print(range(config["nb_episodes"]))

tqdm_e = tqdm(range(config["nb_episodes"]), desc='Score', leave=True, unit=" episodes")
for e in tqdm_e:

    # Reset episode
    cumul_reward, done = np.zeros((batch_size, 1)), 0
    old_state = env.reset()
    actions, states, rewards = [], [], []
    
    step = 0
    while not np.mean(done)>0.5:
        print("Step", step)
        step += 1
        env.render()
        # Actor picks an action (following the deterministic policy)   
        
        actions = agent.policy_action(old_state)
        states, rewards, done, _ = env.step(actions) #new_states -> bs of new_state
        q_values = agent.critic.target_model.predict([states, agent.actor.target_model.predict(states)])
        # Compute critic target

        critic_target = agent.bellman(rewards, q_values, dones)
    
        # Train both networks on sampled batch, update target networks
        agent.update_models(states, actions, critic_target)
        # Update current state
        cumul_reward += rewards
    postprocess_and_synthesize(np.array([seq[0] for seq in env.current_seqs]))

    # Display score
    tqdm_e.set_description("Score: " + str(cumul_reward))
    tqdm_e.refresh()

Score:   0%|          | 0/1 [00:00<?, ? episodes/s]

range(0, 1)
Step 0
Step 1
Step 2
Step 3
Step 4
Step 5
Step 6
Step 7
Step 8
Step 9
Step 10
Step 11
Step 12
Step 13
Step 14
Step 15
Step 16
Step 17
Step 18
Step 19
Step 20
Step 21
Step 22
Step 23
Step 24
Step 25
Step 26
Step 27
Step 28
Step 29
Step 30
Step 31
Step 32
Step 33
Step 34
Step 35
Step 36
Step 37
Step 38
Step 39
Step 40
Step 41
Step 42
Step 43
Step 44
Step 45
Step 46
Step 47
Step 48
Step 49
Step 50
Step 51
Step 52
Step 53
Step 54
Step 55
Step 56
Step 57
Step 58
Step 59
Step 60
Step 61
Step 62
Step 63
Step 64
Step 65
Step 66
Step 67
Step 68
Step 69
Step 70
Step 71
Step 72
Step 73
Step 74
Step 75
Step 76
Step 77
Step 78
Step 79
Step 80
Step 81
Step 82
Step 83
Step 84
Step 85
Step 86
Step 87
Step 88
Step 89
Step 90
Step 91
Step 92
Step 93
Step 94
Step 95
Step 96
Step 97
Step 98
Step 99
Step 100
Step 101
Step 102
Step 103
Step 104
Step 105
Step 106
Step 107
Step 108
Step 109
Step 110
Step 111
Step 112
Step 113
Step 114
Step 115
Step 116
Step 117
Step 118
Step 119
Step 120
Step 121


Step 922
Step 923
Step 924
Step 925
Step 926
Step 927
Step 928
Step 929
Step 930
Step 931
Step 932
Step 933
Step 934
Step 935
Step 936
Step 937
Step 938
Step 939
Step 940
Step 941
Step 942
Step 943
Step 944
Step 945
Step 946
Step 947
Step 948
Step 949
Step 950
Step 951
Step 952
Step 953
Step 954
Step 955
Step 956
Step 957
Step 958
Step 959
Step 960
Step 961
Step 962
Step 963
Step 964
Step 965
Step 966
Step 967
Step 968
Step 969
Step 970
Step 971
Step 972
Step 973
Step 974
Step 975
Step 976
Step 977
Step 978
Step 979
Step 980
Step 981
Step 982
Step 983
Step 984
Step 985
Step 986
Step 987
Step 988
Step 989
Step 990
Step 991
Step 992
Step 993
Step 994
Step 995
Step 996
Step 997
Step 998
Step 999
Step 1000
Step 1001
Step 1002
Step 1003
Step 1004
Step 1005
Step 1006
Step 1007
Step 1008
Step 1009
Step 1010
Step 1011


Step 1012
Step 1013
Step 1014
Step 1015
Step 1016
Step 1017
Step 1018
Step 1019
Step 1020
Step 1021
Step 1022
Step 1023
Step 1024
Step 1025


Step 1026
Step 1027
Step 1028
Step 1029
Step 1030
Step 1031
Step 1032
Step 1033
Step 1034
Step 1035
Step 1036
Step 1037
Step 1038
Step 1039


Step 1040
Step 1041
Step 1042
Step 1043
Step 1044
Step 1045
Step 1046
Step 1047
Step 1048
Step 1049
Step 1050
Step 1051
Step 1052
Step 1053


Step 1054
Step 1055
Step 1056
Step 1057
Step 1058
Step 1059
Step 1060
Step 1061
Step 1062
Step 1063
Step 1064
Step 1065
Step 1066
Step 1067


Step 1068
Step 1069
Step 1070
Step 1071
Step 1072
Step 1073
Step 1074
Step 1075
Step 1076
Step 1077
Step 1078
Step 1079
Step 1080
Step 1081


Step 1082
Step 1083
Step 1084
Step 1085
Step 1086


Score:   0%|          | 0/1 [04:45<?, ? episodes/s]


KeyboardInterrupt: 

In [78]:
np.array(env.current_seqs).shape

(996, 4, 5)

# Listen to it

In [107]:
from musicrl.midi2vec import PostProcessor

def postprocess_and_synthesize(gen_seq):
    postprocess = PostProcessor([mapper.midi2vec(real_midi) for real_midi in real_midis[:5]])
    gen_seq[:,0] += 0.4
    gen_seq = postprocess(gen_seq)
    gen_midi = mapper.vec2midi(gen_seq)
    listen_to(gen_midi)

In [112]:
for i in range(env.batch_size):
    postprocess_and_synthesize(np.array([seq[i] for seq in env.current_seqs]))

In [101]:
gen_midi.instruments[0].notes

[]

In [92]:
mapper.__dict__

{'time_per_tick': 0.00130208,
 'dims': 5,
 'column_meaning': ['is_note',
  'note_pitch',
  'note_velocity',
  'note_duration',
  'is_end'],
 'no_sound': array([0., 0., 0., 0., 0.])}

In [106]:
0.00130208*len(gen_seq)

1.29687168