Permalink
Browse files

training scripts improvements

  • Loading branch information...
1 parent 0e45725 commit 18ea436076495f0e0e0c819fac5df472afcb17ea @harthur committed Mar 22, 2013
View
@@ -5,7 +5,6 @@ var fs = require("fs"),
features = require("../features");
exports.collectData = collectData;
-exports.getDir = getDir;
exports.extractSamples = extractSamples;
/*
@@ -18,7 +17,7 @@ exports.extractSamples = extractSamples;
* file: 'test.jpg'
* }
*/
-function collectData(pos, neg, samples, limit, params) {
+function collectData(pos, neg, samples, limit, params, type) {
// number of samples to extract from each negative, 0 for whole image
samples = samples || 0;
params = params || {};
@@ -35,9 +34,27 @@ function collectData(pos, neg, samples, limit, params) {
data.sort(function() {
return 1 - 2 * Math.round(Math.random());
});
+
+ if (type == "svm") {
+ return SVMData(data);
+ }
return data;
}
+function SVMData(data) {
+ var inputs = new Array(data.length);
+ var labels = new Int8Array(data.length);
+
+ for (var i = 0; i < data.length; i++) {
+ inputs[i] = data[i].input;
+ labels[i] = data[i].output[0] || -1;
+ }
+ return {
+ inputs: inputs,
+ labels: labels
+ };
+}
+
function getDir(dir, isCat, samples, limit, params) {
var files = fs.readdirSync(dir);
@@ -32,31 +32,63 @@ var opts = nomnom.options({
var combos = [
{
HOG: {
- cellSize: 4,
+ cellSize: 3,
+ blockSize: 2,
+ blockStride: 2,
+ bins: 6,
+ norm: "L2"
+ },
+ nn: {
+ hiddenLayers: [10, 10],
+ binaryThresh: 0.8
+ },
+ train: {
+ errorThresh: 0.008
+ }
+},
+{
+ HOG: {
+ cellSize: 3,
+ blockSize: 2,
+ blockStride: 1,
+ bins: 6,
+ norm: "L2"
+ },
+ nn: {
+ hiddenLayers: [10, 10],
+ binaryThresh: 0.8
+ },
+ train: {
+ errorThresh: 0.008
+ }
+},
+{
+ HOG: {
+ cellSize: 2,
blockSize: 3,
blockStride: 3,
bins: 6,
norm: "L2"
},
nn: {
hiddenLayers: [10, 10],
- binaryThresh: 0.99
+ binaryThresh: 0.8
},
train: {
errorThresh: 0.008
}
},
{
HOG: {
- cellSize: 3,
- blockSize: 4,
- blockStride: 4,
+ cellSize: 4,
+ blockSize: 2,
+ blockStride: 1,
bins: 6,
norm: "L2"
},
nn: {
hiddenLayers: [10, 10],
- binaryThresh: 0.99
+ binaryThresh: 0.8
},
train: {
errorThresh: 0.008
@@ -25,6 +25,10 @@ var opts = nomnom.options({
sample: {
flag: true,
help: "sub-sample the negative images"
+ },
+ threshold: {
+ default: 0.99,
+ help: "threshold for classifying as a positive"
}
}).colors().parse();
@@ -37,7 +41,10 @@ function testNetwork() {
console.log("feature size", data[0].input.length);
var json = require(opts.json);
- var network = new brain.NeuralNetwork({binaryThresh: 0.99}).fromJSON(json);
+ var network = new brain.NeuralNetwork({
+ binaryThresh: opts.threshold
+ }).fromJSON(json);
+
var stats = network.test(data);
console.log("error: " + stats.error);
@@ -37,29 +37,38 @@ var combos = [
cellSize: 4,
blockSize: 2,
blockStride: 1,
- bins: 6,
+ bins: 7,
norm: "L2"
},
svm: {
numpasses: 5,
- kernel: 'rbf',
- rbfsigma: 0.5,
- C: 5
+ C: 0.001,
}
},
{
HOG: {
cellSize: 4,
blockSize: 2,
blockStride: 1,
- bins: 6,
+ bins: 9,
norm: "L2"
},
svm: {
numpasses: 5,
- kernel: 'rbf',
- rbfsigma: 2,
- C: 5
+ C: 0.001,
+ }
+},
+{
+ HOG: {
+ cellSize: 4,
+ blockSize: 2,
+ blockStride: 1,
+ bins: 4,
+ norm: "L2"
+ },
+ svm: {
+ numpasses: 5,
+ C: 0.001,
}
}
];
@@ -126,8 +135,8 @@ function getPrintout(tests) {
function testPartition(trainSet, testSet, params) {
var SVM = new svm.SVM();
- var inputs = [];
- var labels = [];
+ var inputs = new Array(trainSet.length);
+ var labels = new Int8Array(trainSet.length);
for (var i = 0; i < trainSet.length; i++) {
inputs[i] = trainSet[i].input;
@@ -165,6 +174,7 @@ function testPartition(trainSet, testSet, params) {
var stats = {
trainTime : beginTest - beginTrain,
+ trainTimePerIter: (beginTest - beginTrain) / trainingStats.iters,
testTime : endTest - beginTest,
iterations: trainingStats.iters,
falsePos: falsePos,
@@ -184,7 +194,8 @@ function crossValidate(data, params) {
var avgs = {
trainTime : 0,
testTime : 0,
- iterations: 0
+ iterations: 0,
+ trainTimePerIter: 0
};
var stats = {
View
@@ -47,12 +47,12 @@ var params = {
cellSize: 4,
blockSize: 2,
blockStride: 1,
- bins: 6,
+ bins: 7,
norm: "L2"
},
svm: {
numpasses: 3,
- C: 0.1,
+ C: 0.001,
kernel: 'linear'
}
};
@@ -62,17 +62,12 @@ trainSVM(params)
function trainSVM(params) {
var samples = opts.sample ? 1 : 0;
var data = collect.collectData(opts.pos, opts.neg, samples,
- opts.limit, params);
+ opts.limit, params, "svm");
- var inputs = [];
- var labels = [];
+ var inputs = data.inputs;
+ var labels = data.labels;
- for (var i = 0; i < data.length; i++) {
- inputs[i] = data[i].input;
- labels[i] = data[i].output[0] || -1;
- }
-
- console.log("training on", data.length);
+ console.log("training on", inputs.length);
var SVM = new svm.SVM();

0 comments on commit 18ea436

Please sign in to comment.