Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldjohnson committed Aug 1, 2015
0 parents commit c9f2954
Show file tree
Hide file tree
Showing 10 changed files with 827 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
output/*
54 changes: 54 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import itertools
from midi_to_statematrix import upperBound, lowerBound

def startSentinel():
def noteSentinel(note):
position = note
part_position = [position]

pitchclass = (note + lowerBound) % 12
part_pitchclass = [int(i == pitchclass) for i in range(12)]

return part_position + part_pitchclass + [0]*66 + [1]
return [noteSentinel(note) for note in range(upperBound-lowerBound)]

def getOrDefault(l, i, d):
try:
return l[i]
except IndexError:
return d

def buildContext(state):
context = [0]*12
for note, notestate in enumerate(state):
if notestate[0] == 1:
pitchclass = (note + lowerBound) % 12
context[pitchclass] += 1
return context

def buildBeat(time):
return [2*x-1 for x in [time%2, (time//2)%2, (time//4)%2, (time//8)%2]]

def noteInputForm(note, state, context, beat):
position = note
part_position = [position]

pitchclass = (note + lowerBound) % 12
part_pitchclass = [int(i == pitchclass) for i in range(12)]
# Concatenate the note states for the previous vicinity
part_prev_vicinity = list(itertools.chain.from_iterable((getOrDefault(state, note+i, [0,0]) for i in range(-12, 13))))

part_context = context[pitchclass:] + context[:pitchclass]

return part_position + part_pitchclass + part_prev_vicinity + part_context + beat + [0]

def noteStateSingleToInputForm(state,time):
beat = buildBeat(time)
context = buildContext(state)
return [noteInputForm(note, state, context, beat) for note in range(len(state))]

def noteStateMatrixToInputForm(statematrix):
# NOTE: May have to transpose this or transform it in some way to make Theano like it
#[startSentinel()] +
inputform = [ noteStateSingleToInputForm(state,time) for time,state in enumerate(statematrix) ]
return inputform
46 changes: 46 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import cPickle as pickle
import gzip
import numpy
from midi_to_statematrix import *

rel_modules = []

import multi_training
rel_modules.append(multi_training)
import model
rel_modules.append(model)

def refresh():
for mod in rel_modules:
reload(mod)

pcs = multi_training.loadPieces("bachmidi")
# pickle.dump( pcs, gzip.GzipFile( "traindata.p.zip", "wb" ) )
# pcs = pickle.load(gzip.GzipFile( "traindata.p.zip", "r"))

m = model.Model([300,300],[100,50])

multi_training.trainPiece(m, pcs, 10000)

pickle.dump( m.learned_config, open( "output/final_learned_config.p", "wb" ) )

def gen_adaptive(times):
xIpt, xOpt = map(lambda x: numpy.array(x, dtype='int8'), multi_training.getPieceSegment(pcs))
all_outputs = [xOpt[0]]
m.start_slow_walk(xIpt[0])
cons = 1
for time in range(multi_training.batch_len*times):
resdata = m.slow_walk_fun( cons )
nnotes = np.sum(resdata[-1][:,0])
if nnotes > 6:
if cons < 1:
cons = 1
cons += 0.01
elif nnotes < 2:
if cons > 1:
cons = 1
cons -= 0.01
else:
cons += (1 - cons)*0.3
all_outputs.append(resdata[-1])
noteStateMatrixToMidi(numpy.array(all_outputs),'output/final')
100 changes: 100 additions & 0 deletions midi_to_statematrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import midi, numpy

lowerBound = 36
upperBound = 92

def midiToNoteStateMatrix(midifile):

pattern = midi.read_midifile(midifile)

timeleft = [track[0].tick for track in pattern]

posns = [0 for track in pattern]

statematrix = []
span = upperBound-lowerBound
time = 0

state = [[0,0] for x in range(span)]
statematrix.append(state)
while True:
if time % (pattern.resolution / 4) == (pattern.resolution / 8):
# Crossed a note boundary. Create a new state, defaulting to holding notes
oldstate = state
state = [[oldstate[x][0],0] for x in range(span)]
statematrix.append(state)

for i in range(len(timeleft)):
while timeleft[i] == 0:
track = pattern[i]
pos = posns[i]

evt = track[pos]
if isinstance(evt, midi.NoteEvent):
try:
if isinstance(evt, midi.NoteOffEvent) or evt.velocity == 0:
state[evt.pitch-lowerBound] = [0, 0]
else:
state[evt.pitch-lowerBound] = [1, 1]
except IndexError:
print "Note {} at time {} out of bounds (ignoring)".format(evt.pitch, time)
elif isinstance(evt, midi.TimeSignatureEvent):
if evt.numerator not in (2, 4):
# We don't want to worry about non-4 time signatures. Bail early!
print "Found time signature event {}. Bailing!".format(evt)
return statematrix

try:
timeleft[i] = track[pos + 1].tick
posns[i] += 1
except IndexError:
timeleft[i] = None

if timeleft[i] is not None:
timeleft[i] -= 1

if all(t is None for t in timeleft):
break

time += 1

return statematrix

def noteStateMatrixToMidi(statematrix, name="example"):
statematrix = numpy.asarray(statematrix)
pattern = midi.Pattern()
track = midi.Track()
pattern.append(track)

span = upperBound-lowerBound
tickscale = 55

lastcmdtime = 0
prevstate = [[0,0] for x in range(span)]
for time, state in enumerate(statematrix + [prevstate[:]]):
offNotes = []
onNotes = []
for i in range(span):
n = state[i]
p = prevstate[i]
if p[0] == 1:
if n[0] == 0:
offNotes.append(i)
elif n[1] == 1:
offNotes.append(i)
onNotes.append(i)
elif n[0] == 1:
onNotes.append(i)
for note in offNotes:
track.append(midi.NoteOffEvent(tick=(time-lastcmdtime)*tickscale, pitch=note+lowerBound))
lastcmdtime = time
for note in onNotes:
track.append(midi.NoteOnEvent(tick=(time-lastcmdtime)*tickscale, velocity=40, pitch=note+lowerBound))
lastcmdtime = time

prevstate = state

eot = midi.EndOfTrackEvent(tick=1)
track.append(eot)

midi.write_midifile("{}.mid".format(name), pattern)
Loading

0 comments on commit c9f2954

Please sign in to comment.