In [228]:
import os, csv

In [271]:
def arrayParser(arr):
    # CSV usage: cast string list to python list
    smooth_stage_1 = arr.replace('[', '').replace(']', '').split(',')
    smooth_stage_2 = map(lambda unit: float(unit), smooth_stage_1)
    return smooth_stage_2

In [272]:
class LVQNeuron:
    def __init__(self, name):
        self.name = name
        self.weights = []
    
    def setWeights(self, weights):
        weights = weights.replace('[', '').replace(']', '').split(',')
        
        for weight in weights:
            self.weights.append(float(weight))

    def __len__(self):
        return len(self.weights)

In [318]:
class LVQNet:
    def __init__(self, inCount, outCount):
        self.inputs   = inCount
        self.outputs  = outCount
        self.alpha    = 0.1
        self.csvCount = 0  # Limit the csvs input to number of output neurons
        self.neurons  = {} # Numbered index map (to outputs)
              
        for n in range(outCount):
            curr_neuron = LVQNeuron(n)
            self.neurons[n] = curr_neuron
            
    def __len__(self):
        return len(self.neurons)
    
    def getWeights(self, neuronNo):
        return self.neurons[neuronNo].weights
        
    # STEP 0
    def enterCSV(self, filepath): 
        if self.csvCount >= self.outputs:
            print "Reached limit of neurons" # (TODO) Throw error
            return
            
        with open(filepath, 'r') as f:
            read = csv.reader(f, delimiter=',')
            row = read.next()
            curr_neuron = self.neurons[self.csvCount]
            curr_neuron.setWeights(row[1])
            self.csvCount += 1
            
            print "Successfully added neuron from CSV", filepath
            return

    # STEP 3.1
    def edist(self, inputs, weights):
        # Euclidean Distance helper function 
        euclideanDistance = 0
        
        if len(inputs) != len(weights):
            print len(inputs), "different length than", len(weights) # (TODO) Error
            return
        
        for i in range(len(inputs)):
            nth = inputs[i] - weights[i]
            nth = nth ** 2
            euclideanDistance += nth
             
        return euclideanDistance ** (0.5)
    
    # STEP 3.2 -- (TODO) Test
    def minDist(self, inputVector):
        scores = [] # Euclidean Distances
        
        for neuron in self.neurons:
            wunit = self.neurons[neuron].weights
            scores.append(self.edist(inputVector, wunit))
        
        minNeuronIndex = scores.index(min(scores))
        return minNeuronIndex
    
    # STEP 4 -- In progress ...  
    def calibrate(self, neuron, guessNo, inputVector):
        guess   = dataset.lookupInstrument(guessNo)
        weights = self.getWeights(guessNo) 
        
        addfunc = lambda oldWeight, vec: oldWeight + self.alpha * (vec - oldWeight)
        subfunc = lambda oldWeight, vec: oldWeight - self.alpha * (vec - oldWeight)

        if neuron == guess: # replace this with LVQData Map lookup
            print "Right Guess"
            return map(addfunc, weights, inputVector) # assign weights as new weights
        else:
            print "Wrong Guess"
            return map(subfunc, weights, inputVector) # assign weights as new weights

In [319]:
class LVQData:
    def __init__(self):
        self.data          = [] # list of tuples 
        self.instrumentMap = {} # map integers with instruments (labels)
        self.instrumentNum = 0  # current integer instrument (to neuron)
    
    def loadCSV(self, filepath, label):
        # Will skip first line of each CSV since LVQ initializes using the first lines
        with open(filepath, 'r') as f:
            read = csv.reader(f, delimiter=',')
            read.next()
            
            for row in read:
                # Tuple with STFT bins and then the label
                data_struct = (arrayParser(row[1]), label)
                self.data.append(data_struct)
        
            self.instrumentMap[self.instrumentNum] = label
            self.instrumentNum += 1
            
        return self.data
    
    def getVector(self, index):
        return self.data[index][0]
    
    def getVectorLabel(self, index):
        return self.data[index][1]
    
    def lookupInstrument(self, index):
        return self.instrumentMap[index]

In [322]:
if __name__ == '__main__':
    ### Driver: Outline of the API / Algorithm in use 
    
    # Create Network with in and out neuron parameters
    koho = LVQNet(1025, 2)
    
    # Enter data (1-1 CSV to Output Neurons) Initializes the neurons with first onset
    koho.enterCSV('./data/snareFrames.csv')
    koho.enterCSV('./data/kickDrumFrames.csv')
    
    # Instantiate LVQ Training Data Structure and load rest of CSVs with labels
    dataset = LVQData()
    dataset.loadCSV('./data/snareFrames.csv', 'snare')
    dataset.loadCSV('./data/kickDrumFrames.csv', 'kick-drum')
    
    # (TODO) Put the koho minDists in loop / logic
    guess = koho.minDist(dataset.getVector(2))  # using a specific sample frame (snare)
    guess2 = koho.minDist(dataset.getVector(8)) # using another sample frame (kick)
 
    newWeights = koho.calibrate(dataset.getVectorLabel(8), guess2, dataset.getVector(8))
    print newWeights

Successfully added neuron from CSV ./data/snareFrames.csv
Successfully added neuron from CSV ./data/kickDrumFrames.csv
Right Guess
[0.02385396493001224, 0.023288319632411004, 3.0494112730026246, 16.155873680114745, 0.341811615601182, 208.28802337646485, 778.9589782714844, 605.64931640625, 174.16514358520507, 196.72050247192382, 91.85171051025391, 100.65526275634765, 45.71753768920898, 157.35950775146483, 211.898592376709, 99.4897575378418, 384.4477600097656, 308.5795394897461, 60.56137962341309, 47.116682624816896, 22.00955948829651, 7.6986008644104, 5.7671294689178465, 24.75441446304321, 30.672997379302977, 39.31307201385498, 54.13671188354492, 37.75927543640137, 8.347852993011475, 11.784475898742675, 70.1565731048584, 55.176285934448245, 51.51765089035034, 52.66171226501465, 55.65469741821289, 30.457315540313722, 6.9520911693573, 13.066671562194824, 10.93045802116394, 4.695454263687134, 1.7397945761680602, 3.921765661239624, 10.850006484985352, 31.00029811859131, 28.797776985168458, 