Permalink
Browse files

Initial commit

  • Loading branch information...
0 parents commit c9f295420757ebb3cdf769ec2e363b73f6ae6ebe @hexahedria committed Aug 1, 2015
Showing with 827 additions and 0 deletions.
  1. +1 −0 .gitignore
  2. +54 −0 data.py
  3. +46 −0 main.py
  4. +100 −0 midi_to_statematrix.py
  5. +355 −0 model.py
  6. +50 −0 multi_training.py
  7. +19 −0 out_to_in_op.py
  8. +30 −0 piece_training.py
  9. +99 −0 scale_training.py
  10. +73 −0 visualize.py
@@ -0,0 +1 @@
+output/*
@@ -0,0 +1,54 @@
+import itertools
+from midi_to_statematrix import upperBound, lowerBound
+
+def startSentinel():
+ def noteSentinel(note):
+ position = note
+ part_position = [position]
+
+ pitchclass = (note + lowerBound) % 12
+ part_pitchclass = [int(i == pitchclass) for i in range(12)]
+
+ return part_position + part_pitchclass + [0]*66 + [1]
+ return [noteSentinel(note) for note in range(upperBound-lowerBound)]
+
+def getOrDefault(l, i, d):
+ try:
+ return l[i]
+ except IndexError:
+ return d
+
+def buildContext(state):
+ context = [0]*12
+ for note, notestate in enumerate(state):
+ if notestate[0] == 1:
+ pitchclass = (note + lowerBound) % 12
+ context[pitchclass] += 1
+ return context
+
+def buildBeat(time):
+ return [2*x-1 for x in [time%2, (time//2)%2, (time//4)%2, (time//8)%2]]
+
+def noteInputForm(note, state, context, beat):
+ position = note
+ part_position = [position]
+
+ pitchclass = (note + lowerBound) % 12
+ part_pitchclass = [int(i == pitchclass) for i in range(12)]
+ # Concatenate the note states for the previous vicinity
+ part_prev_vicinity = list(itertools.chain.from_iterable((getOrDefault(state, note+i, [0,0]) for i in range(-12, 13))))
+
+ part_context = context[pitchclass:] + context[:pitchclass]
+
+ return part_position + part_pitchclass + part_prev_vicinity + part_context + beat + [0]
+
+def noteStateSingleToInputForm(state,time):
+ beat = buildBeat(time)
+ context = buildContext(state)
+ return [noteInputForm(note, state, context, beat) for note in range(len(state))]
+
+def noteStateMatrixToInputForm(statematrix):
+ # NOTE: May have to transpose this or transform it in some way to make Theano like it
+ #[startSentinel()] +
+ inputform = [ noteStateSingleToInputForm(state,time) for time,state in enumerate(statematrix) ]
+ return inputform
@@ -0,0 +1,46 @@
+import cPickle as pickle
+import gzip
+import numpy
+from midi_to_statematrix import *
+
+rel_modules = []
+
+import multi_training
+rel_modules.append(multi_training)
+import model
+rel_modules.append(model)
+
+def refresh():
+ for mod in rel_modules:
+ reload(mod)
+
+pcs = multi_training.loadPieces("bachmidi")
+# pickle.dump( pcs, gzip.GzipFile( "traindata.p.zip", "wb" ) )
+# pcs = pickle.load(gzip.GzipFile( "traindata.p.zip", "r"))
+
+m = model.Model([300,300],[100,50])
+
+multi_training.trainPiece(m, pcs, 10000)
+
+pickle.dump( m.learned_config, open( "output/final_learned_config.p", "wb" ) )
+
+def gen_adaptive(times):
+ xIpt, xOpt = map(lambda x: numpy.array(x, dtype='int8'), multi_training.getPieceSegment(pcs))
+ all_outputs = [xOpt[0]]
+ m.start_slow_walk(xIpt[0])
+ cons = 1
+ for time in range(multi_training.batch_len*times):
+ resdata = m.slow_walk_fun( cons )
+ nnotes = np.sum(resdata[-1][:,0])
+ if nnotes > 6:
+ if cons < 1:
+ cons = 1
+ cons += 0.01
+ elif nnotes < 2:
+ if cons > 1:
+ cons = 1
+ cons -= 0.01
+ else:
+ cons += (1 - cons)*0.3
+ all_outputs.append(resdata[-1])
+ noteStateMatrixToMidi(numpy.array(all_outputs),'output/final')
@@ -0,0 +1,100 @@
+import midi, numpy
+
+lowerBound = 36
+upperBound = 92
+
+def midiToNoteStateMatrix(midifile):
+
+ pattern = midi.read_midifile(midifile)
+
+ timeleft = [track[0].tick for track in pattern]
+
+ posns = [0 for track in pattern]
+
+ statematrix = []
+ span = upperBound-lowerBound
+ time = 0
+
+ state = [[0,0] for x in range(span)]
+ statematrix.append(state)
+ while True:
+ if time % (pattern.resolution / 4) == (pattern.resolution / 8):
+ # Crossed a note boundary. Create a new state, defaulting to holding notes
+ oldstate = state
+ state = [[oldstate[x][0],0] for x in range(span)]
+ statematrix.append(state)
+
+ for i in range(len(timeleft)):
+ while timeleft[i] == 0:
+ track = pattern[i]
+ pos = posns[i]
+
+ evt = track[pos]
+ if isinstance(evt, midi.NoteEvent):
+ try:
+ if isinstance(evt, midi.NoteOffEvent) or evt.velocity == 0:
+ state[evt.pitch-lowerBound] = [0, 0]
+ else:
+ state[evt.pitch-lowerBound] = [1, 1]
+ except IndexError:
+ print "Note {} at time {} out of bounds (ignoring)".format(evt.pitch, time)
+ elif isinstance(evt, midi.TimeSignatureEvent):
+ if evt.numerator not in (2, 4):
+ # We don't want to worry about non-4 time signatures. Bail early!
+ print "Found time signature event {}. Bailing!".format(evt)
+ return statematrix
+
+ try:
+ timeleft[i] = track[pos + 1].tick
+ posns[i] += 1
+ except IndexError:
+ timeleft[i] = None
+
+ if timeleft[i] is not None:
+ timeleft[i] -= 1
+
+ if all(t is None for t in timeleft):
+ break
+
+ time += 1
+
+ return statematrix
+
+def noteStateMatrixToMidi(statematrix, name="example"):
+ statematrix = numpy.asarray(statematrix)
+ pattern = midi.Pattern()
+ track = midi.Track()
+ pattern.append(track)
+
+ span = upperBound-lowerBound
+ tickscale = 55
+
+ lastcmdtime = 0
+ prevstate = [[0,0] for x in range(span)]
+ for time, state in enumerate(statematrix + [prevstate[:]]):
+ offNotes = []
+ onNotes = []
+ for i in range(span):
+ n = state[i]
+ p = prevstate[i]
+ if p[0] == 1:
+ if n[0] == 0:
+ offNotes.append(i)
+ elif n[1] == 1:
+ offNotes.append(i)
+ onNotes.append(i)
+ elif n[0] == 1:
+ onNotes.append(i)
+ for note in offNotes:
+ track.append(midi.NoteOffEvent(tick=(time-lastcmdtime)*tickscale, pitch=note+lowerBound))
+ lastcmdtime = time
+ for note in onNotes:
+ track.append(midi.NoteOnEvent(tick=(time-lastcmdtime)*tickscale, velocity=40, pitch=note+lowerBound))
+ lastcmdtime = time
+
+ prevstate = state
+
+ eot = midi.EndOfTrackEvent(tick=1)
+ track.append(eot)
+
+ midi.write_midifile("{}.mid".format(name), pattern)
Oops, something went wrong.

0 comments on commit c9f2954

Please sign in to comment.