Skip to content
Browse files

improvements to testing

  • Loading branch information...
1 parent 108b736 commit 0b5d6ce98b1be1a93c634d71e2dc3008b19f95e5 @harthur committed Jun 1, 2012
Showing with 141 additions and 58 deletions.
  1. +3 −3 hog-params.json
  2. +39 −6 kittydar.js
  3. +17 −4 test.js
  4. +5 −4 testing/test.js
  5. +6 −6 training/findparams.js
  6. +71 −35 training/train.js
View
6 hog-params.json
@@ -1,7 +1,7 @@
{
- "cellSize": 6,
+ "cellSize": 4,
"blockSize": 2,
- "blockStride": 1,
- "bins": 9,
+ "blockStride": 2,
+ "bins": 6,
"norm": "L2"
}
View
45 kittydar.js
@@ -8,23 +8,46 @@ var net = new brain.NeuralNetwork().fromJSON(network);
var threshold = 0.9;
exports.detectCats = function detectCats(canvas) {
- var found = [];
- var minScale = 48;
- var maxScale = Math.min(canvas.width, canvas.height);
+ var width = canvas.width,
+ height = canvas.height;
+
+ // scale to reduce computation time
+ var scale = 360 / Math.max(width, height);
+ width = width * scale;
+ height = height * scale;
+
+ canvas = resizeCanvas(canvas, width, height);
+
+ var min = 48;
+ var max = Math.min(width, height);
- var step = 14;
var cats = [];
var total = 0;
- for (var scale = minScale; scale < maxScale; scale += step) {
- var info = detectAtScale(canvas, scale, minScale);
+ for (var size = min; size < max; size += 12) {
+ var info = detectAtScale(canvas, size, min);
cats = cats.concat(info.cats);
total += info.total;
}
+ console.log(cats[0])
+ console.log(scale)
+
+ cats = cats.map(function(cat) {
+ return {
+ x: cat.x / scale,
+ y: cat.y / scale,
+ width: cat.width / scale,
+ height: cat.height / scale
+ }
+ });
+
+ console.log(cats[0]);
+
return {cats: cats, total: total};
}
function isCat(canvas) {
var fts = features.extractFeatures(canvas);
+ //console.log(fts.length)
var prob = net.run(fts)[0];
return prob;
}
@@ -49,6 +72,16 @@ function detectAtScale(canvas, scale, resizeTo) {
return {cats: cats, total: count};
}
+function resizeCanvas(canvas, width, height) {
+ var resizeCanvas = new Canvas(width, height);
+ var ctx = resizeCanvas.getContext('2d');
+ ctx.patternQuality = "best";
+
+ ctx.drawImage(canvas, 0, 0, canvas.width, canvas.height,
+ 0, 0, width, height);
+ return resizeCanvas;
+}
+
function cropAndResize(canvas, x, y, size, resizeTo) {
var resizeCanvas = new Canvas(resizeTo, resizeTo);
var ctx = resizeCanvas.getContext('2d');
View
21 test.js
@@ -32,21 +32,34 @@ test(r1, r2, true);
test(r2, r1, true);
+r1 = {x: 0, y: 0, width: 10, height: 10};
+r2 = {x: 2, y: 2, width: 8, height: 8};
+
+test(r1, r2, true);
+test(r2, r1, true);
+
+r1 = {x: 0, y: 0, width: 10, height: 10};
+r2 = {x: 2, y: 2, width: 1, height: 1};
+
+test(r1, r2, false);
+test(r2, r1, true);
+
+
function doesOverlap(cat, rect) {
var overlapW, overlapH;
if (cat.x > rect.x) {
- overlapW = (rect.x + rect.width) - cat.x;
+ overlapW = Math.min((rect.x + rect.width) - cat.x, cat.width);
}
else {
- overlapW = (cat.x + cat.width) - rect.x;
+ overlapW = Math.min((cat.x + cat.width) - rect.x, rect.width);
}
if (cat.y > rect.y) {
- overlapH = (rect.y + rect.height) - cat.y;
+ overlapH = Math.min((rect.y + rect.height) - cat.y, cat.height);
}
else {
- overlapH = (cat.y + cat.height) - rect.y;
+ overlapH = Math.min((cat.y + cat.height) - rect.y, rect.height);
}
if (overlapW > 0 && overlapH > 0) {
View
9 testing/test.js
@@ -40,6 +40,7 @@ function runTest() {
utils.drawImgToCanvas(file, function(canvas) {
console.time("detecting")
+ console.log(canvas)
var info = kittydar.detectCats(canvas);
var cats = info.cats;
console.timeEnd("detecting")
@@ -84,17 +85,17 @@ function doesOverlap(cat, rect) {
var overlapW, overlapH;
if (cat.x > rect.x) {
- overlapW = (rect.x + rect.width) - cat.x;
+ overlapW = Math.min((rect.x + rect.width) - cat.x, cat.width);
}
else {
- overlapW = (cat.x + cat.width) - rect.x;
+ overlapW = Math.min((cat.x + cat.width) - rect.x, rect.width);
}
if (cat.y > rect.y) {
- overlapH = (rect.y + rect.height) - cat.y;
+ overlapH = Math.min((rect.y + rect.height) - cat.y, cat.height);
}
else {
- overlapH = (cat.y + cat.height) - rect.y;
+ overlapH = Math.min((cat.y + cat.height) - rect.y, rect.height);
}
if (overlapW > 0 && overlapH > 0) {
View
12 training/findparams.js
@@ -38,21 +38,21 @@ function getCombos() {
var combos = [{
cellSize: 4,
- blockSize: 2,
- blockStride: 1,
+ blockSize: 3,
+ blockStride: 3,
bins: 6,
norm: "L2"
},
{
cellSize: 4,
- blockSize: 2,
+ blockSize: 3,
blockStride: 2,
bins: 6,
norm: "L2"
},
{
- cellSize: 6,
- blockSize: 2,
+ cellSize: 4,
+ blockSize: 3,
blockStride: 1,
bins: 6,
norm: "L2"
@@ -117,7 +117,7 @@ function testParams(canvases, params) {
})
var opts = {
- hiddenLayers: [40]
+ hiddenLayers: [30]
};
var trainOpts = {
errorThresh: 0.006,
View
106 training/train.js
@@ -1,56 +1,50 @@
-var cradle = require("cradle"),
+var fs = require("fs"),
brain = require("brain"),
- fs = require("fs");
-
-var db = new(cradle.Connection)().database('cats-hog-c6-b9');
-
-db.all({include_docs: true}, function(err, res) {
- if (err) {
- console.log(err);
- }
- else {
- var posData = [];
- var negData = [];
-
- res.rows.forEach(function(row) {
- var doc = row.doc;
- if (doc.output[0]) {
- posData.push(doc);
- }
- else {
- negData.push(doc);
- }
- });
+ path = require("path"),
+ async = require("async"),
+ _ = require("underscore"),
+ utils = require("../utils"),
+ features = require("../features");
- var posSize = 5000;
- var negSize = 5000;
- var data = posData.slice(0, posSize).concat(negData.slice(0, negSize));
- console.log("training with", data.length);
- console.log(posSize, "positives", negSize, "negatives")
+testParams();
+
+function testParams(params) {
+ getCanvases(function(canvases) {
+ var data = canvases.map(function(canvas) {
+ var fts = features.extractFeatures(canvas.canvas, params);
+ return {
+ input: fts,
+ output: [canvas.isCat]
+ };
+ });
var opts = {
- hiddenLayers: [40]
+ hiddenLayers: [30]
};
var trainOpts = {
errorThresh: 0.006,
log: true
};
+
var stats = brain.crossValidate(brain.NeuralNetwork, data, opts, trainOpts);
- console.log("averages:", stats.avgs);
- console.log("parameters:", stats.parameters);
+ stats.featureSize = data[0].input.length;
+
+ console.log("params", stats.params);
+ console.log("stats", stats.stats);
+ console.log("avgs", stats.avgs);
fs.writeFile('misclasses.json', JSON.stringify(stats.misclasses, 4), function (err) {
if (err) throw err;
- console.log('saved misclasses');
+ console.log('saved misclasses to misclasses.json');
});
var minError = 1;
var network;
stats.sets.forEach(function(set) {
- if (set.stats.error < minError) {
- minError = set.stats.error;
+ if (set.error < minError) {
+ minError = set.error;
network = set.network;
}
})
@@ -60,5 +54,47 @@ db.all({include_docs: true}, function(err, res) {
if (err) throw err;
console.log('saved network to cv-network.json');
});
- }
-});
+ })
+}
+
+function getCanvases(callback) {
+ var posDir = __dirname + "/POSITIVES/";
+
+ fs.readdir(posDir, function(err, files) {
+ if (err) throw err;
+
+ getDir(posDir, files, 1, function(posData) {
+ var negsDir = __dirname + "/NEGATIVES/";
+ fs.readdir(negsDir, function(err, files) {
+ if (err) throw err;
+
+ getDir(negsDir, files, 0, function(negData) {
+ var data = posData.concat(negData);
+
+ callback(data);
+ })
+ })
+ })
+ });
+}
+
+function getDir(dir, files, isCat, callback) {
+ var limit = 5000;
+ var images = files.filter(function(file) {
+ return path.extname(file) == ".jpg";
+ });
+ images = images.slice(0, limit);
+
+ var data = [];
+
+ async.map(images, function(file, done) {
+ file = dir + file;
+
+ utils.drawImgToCanvas(file, function(canvas) {
+ done(null, {canvas: canvas, file: file, isCat: isCat});
+ });
+ },
+ function(err, canvases) {
+ callback(canvases);
+ });
+}

0 comments on commit 0b5d6ce

Please sign in to comment.
Something went wrong with that request. Please try again.