In [2]:
import tensorflow as tf
from tensorflow.contrib.rnn import RNNCell
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops

from func.midi_to_statematrix import *
from func.data import *
import func.multi_training

import random

import os
#import cPickle as pickle
import pickle

import signal

import numpy as np

path = os.getcwd()


In [3]:
pcs = func.multi_training.loadPieces(path  + '/music_test')

Loaded alb_esp2
Loaded alb_esp5
Loaded appass_2
Loaded appass_3
Loaded bach_846
Loaded bach_847
Loaded bach_850
Loaded beethoven_hammerklavier_1
Loaded beethoven_les_adieux_1
Loaded beethoven_les_adieux_2
Loaded beethoven_opus10_2
Loaded beethoven_opus10_3
Loaded beethoven_opus22_1
Loaded beethoven_opus22_4
Loaded beethoven_opus90_2


In [5]:
batch_width = 10 # number of sequences in a batch
batch_len = 16*8 # length of each sequence
division_len = 16 # interval between possible start locations

def loadPieces(dirpath):

    pieces = {}

    for fname in os.listdir(dirpath):
        if fname[-4:] not in ('.mid','.MID'):
            continue

        name = fname[:-4]

        outMatrix = midiToNoteStateMatrix(os.path.join(dirpath, fname))
        if len(outMatrix) < batch_len:
            continue

        pieces[name] = outMatrix
        print("Loaded {}".format(name))

    return pieces

def getPieceSegment(pieces):
    pcs=pieces.values()
    piece_output = random.choice(list(pcs))
    start = random.randrange(0,len(piece_output)-batch_len,division_len)
    
    # print "Range is {} {} {} -> {}".format(0,len(piece_output)-batch_len,division_len, start)

    seg_out = piece_output[start:start+batch_len]
    seg_in = noteStateMatrixToInputForm(seg_out)

    return seg_in, seg_out

def getPieceBatch(pieces):
    i,o = zip(*[getPieceSegment(pieces) for _ in range(batch_width)])
    return numpy.array(i), numpy.array(o)

In [6]:
def trainPiece(model,pieces,epochs,start=0):
    stopflag = [False]
    def signal_handler(signame, sf):
        stopflag[0] = True
    old_handler = signal.signal(signal.SIGINT, signal_handler)
    for i in range(start,start+epochs):
        if stopflag[0]:
            break
        error = model.update_fun(*getPieceBatch(pieces))
        if i % 100 == 0:
            print("epoch {}, error={}".format(i,error))
        if i % 500 == 0 or (i % 100 == 0 and i < 1000):
            xIpt, xOpt = map(numpy.array, getPieceSegment(pieces))
            noteStateMatrixToMidi(numpy.concatenate((numpy.expand_dims(xOpt[0], 0), model.predict_fun(batch_len, 1, xIpt[0])), axis=0),'output/sample{}'.format(i))
            pickle.dump(model.learned_config,open('output/params{}.p'.format(i), 'wb'))
    signal.signal(signal.SIGINT, old_handler)

In [7]:
song={}
song['beethoven_hammerklavier_1']=pcs['beethoven_hammerklavier_1']

In [11]:
def Model(t_layer_sizes,p_layer_sizes,xs,ys):

    t_input_size = 80

    

            #Lstm input data recquires size : batch_size,max_time (spanning back how many time steps), ect..
    
            #xs = tf.one_hot(xss, depth=1000, axis=-1)
            #xs_onehot = tf.one_hot(xs, depth=1000, axis=-1)

            # From our architecture definition, size of the notewise input

            # time network maps from notewise input size to various hidden sizes
    lstm_time=[]
    for i in t_layer_sizes:
        lstm_time.append(tf.contrib.rnn.LSTMCell(i))

    time_model=tf.contrib.rnn.MultiRNNCell(lstm_time)        
    init_state_time=time_model.zero_state(tf.shape(ys)[0],tf.float32)
    with tf.variable_scope('lstm1'):
        outputs_time,final_state_time=tf.nn.dynamic_rnn(time_model, xs, initial_state = init_state_time, dtype = tf.float32)
            #self.time_model = StackedCells( self.t_input_size, celltype=LSTM, layers = t_layer_sizes)
            #self.time_model.layers.append(PassthroughLayer())

            # pitch network takes last layer of time model and state of last note, moving upward
            # and eventually ends with a two-element sigmoid layer

    p_input_size = t_layer_sizes[-1] + 2


    lstm_pitch=[]

    for i in p_layer_sizes:
        lstm_pitch.append(tf.contrib.rnn.LSTMCell(i))
    lstm_pitch.append(tf.contrib.rnn.LSTMCell(2))


    pitch_model=tf.contrib.rnn.MultiRNNCell(lstm_pitch)

    init_state_pitch=pitch_model.zero_state(tf.shape(ys)[0],tf.float32)
    with tf.variable_scope('lstm2'):
        outputs_pitch,final_state_pitch=tf.nn.dynamic_rnn(pitch_model,outputs_time,initial_state = init_state_pitch,dtype = tf.float32)

    loss=tf.nn.sigmoid_cross_entropy_with_logits(labels=ys,logits=outputs_pitch)
    #loss=tf.reduce_mean(loss)
        
        
        
    return outputs_pitch,loss


def cross_entropy(output, input_y):
    with tf.name_scope('cross_entropy'):
        
        ce = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=ys, logits=output))

    return ce


def train_step(loss, learning_rate=1e-3):
    with tf.name_scope('train_step'):
        step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

    return step


def evaluate(output, input_y):
    with tf.name_scope('evaluate'):
        pred = tf.argmax(output, axis=1)
        error_num = tf.count_nonzero(pred - tf.cast(input_y, tf.int64), name='error_num')
        #tf.summary.scalar('LeNet_error_num', error_num)
    return error_num

def training(song,t_layer_sizes,p_layer_sizes, pre_trained_model=None):
    
    tf.reset_default_graph()
    # define the variables and parameter needed during training
    with tf.name_scope('inputs'):
        xs = tf.placeholder(tf.float32, [None,None, t_input_size])
        ys = tf.placeholder(tf.float32, [None,None, 2])
        
        
    output, loss = Model(t_layer_sizes,p_layer_sizes,xs,ys)
    
    
    iters = int(np.array(list(song.values())[0]).shape[0] / batch_len)
    print('number of batches for training: {}'.format(iters))

    step = train_step(loss)
    eve = evaluate(output, ys)

    iter_total = 0
    best_acc = 0
    #cur_model_name = 'lenet_{}'.format(int(time.time()))

    epoch=20
    
    with tf.Session() as sess:
        #merge = tf.summary.merge_all()

        #writer = tf.summary.FileWriter("log/{}".format(cur_model_name), sess.graph)
        #saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())

        # try to restore the pre_trained
        if pre_trained_model is not None:
            try:
                print("Load the model from: {}".format(pre_trained_model))
                saver.restore(sess, 'model/{}'.format(pre_trained_model))
            except Exception:
                print("Load model Failed!")
                pass

        for epc in range(epoch):
            print("epoch {} ".format(epc + 1))

            for itr in range(iters):
                iter_total += 1
                training_batch_x,training_batch_y= map(numpy.array, getPieceSegment(song))
                
                _, cur_loss = sess.run([step, loss], feed_dict={xs: training_batch_x, ys: training_batch_y})
                print(cur_loss)
                break
            break
                    # save the merge result summary
                #writer.add_summary(merge_result, iter_total)

                    # when achieve the best validation accuracy, we store the model paramters
                    

    #print("Traning ends. The best valid accuracy is {}. Model named {}.".format(best_acc, cur_model_name))

In [12]:
t_input_size=80

In [13]:
a=training(song,[300,300],[100,50], pre_trained_model=None)

number of batches for training: 33
epoch 1 
[[[ 0.69311529  0.69311827]
  [ 0.69297081  0.69302309]
  [ 0.69260359  0.69282448]
  ..., 
  [ 0.58242041  0.62922245]
  [ 0.58260369  0.62931722]
  [ 0.58278167  0.62940717]]

 [[ 0.69312429  0.69311476]
  [ 0.69300181  0.69302177]
  [ 0.69266331  0.69283658]
  ..., 
  [ 0.57414371  0.62025809]
  [ 0.5742377   0.62033588]
  [ 0.57434899  0.62042552]]

 [[ 0.69312763  0.69314182]
  [ 0.69304109  0.6931318 ]
  [ 0.69281811  0.69312292]
  ..., 
  [ 0.59133685  0.63761163]
  [ 0.5913918   0.63752311]
  [ 0.59144253  0.63744259]]

 ..., 
 [[ 0.69314843  0.69311881]
  [ 0.69312876  0.69300878]
  [ 0.69300967  0.69274551]
  ..., 
  [ 0.56003022  0.6106264 ]
  [ 0.55985695  0.61066884]
  [ 0.55972266  0.61070776]]

 [[ 0.69317102  0.69314247]
  [ 0.69322014  0.69311678]
  [ 0.69326144  0.69302922]
  ..., 
  [ 0.58002663  0.62865454]
  [ 0.58014572  0.62870789]
  [ 0.58026892  0.62876314]]

 [[ 0.69316024  0.69314599]
  [ 0.69318157  0.69312871]
  [