diff --git a/src/algorithms/imageProcessing/util/MatrixUtil.java b/src/algorithms/imageProcessing/util/MatrixUtil.java index 7497519a1..9a858b86c 100644 --- a/src/algorithms/imageProcessing/util/MatrixUtil.java +++ b/src/algorithms/imageProcessing/util/MatrixUtil.java @@ -417,7 +417,7 @@ public static double[][] multiplyByTranspose(double[][] p, double[][] n) { * adapted from http://sebastianraschka.com/Articles/2014_python_lda.html * * To apply the results to a data matrix, use y = W^T * X where W is - * the returned matrix. + * the returned matrix (i.e., MatrixUtil.dot(w, data)) * Note, X internally has been scaled to unit standard deviation. * @param data * @param classes @@ -453,7 +453,7 @@ public static SimpleMatrix createLDATransformation(SimpleMatrix data, * adapted from http://sebastianraschka.com/Articles/2014_python_lda.html * * To apply the results to a data matrix, use y = W^T * X where W is - * the returned matrix. + * the returned matrix (i.e., MatrixUtil.dot(w, normData)) * Note, X internally has been scaled to unit standard deviation. * @param normData data scaled to mean of 0 and a standard deviation of 1 * @param classes zero based transformed classes diff --git a/tests/algorithms/imageProcessing/util/MatrixUtilTest.java b/tests/algorithms/imageProcessing/util/MatrixUtilTest.java index 7d686ed17..a6abe7b4d 100644 --- a/tests/algorithms/imageProcessing/util/MatrixUtilTest.java +++ b/tests/algorithms/imageProcessing/util/MatrixUtilTest.java @@ -1,5 +1,6 @@ package algorithms.imageProcessing.util; +import algorithms.util.PolygonAndPointPlotter; import algorithms.util.ResourceFinder; import java.io.BufferedInputStream; import java.io.BufferedReader; @@ -347,6 +348,133 @@ public void testCreateLDATrasformation() throws Exception { assertTrue(Math.abs(w2.get(1, 3) - 0.750) < 0.01); + int nr = w2.numRows(); + int nc = w2.numCols(); + int nr2 = normData.numRows(); + int nc2 = normData.numCols(); + // 2 X 150 + SimpleMatrix dataTransformed = new SimpleMatrix(MatrixUtil.dot(w, normData)); + + float minX = Float.MAX_VALUE; + float maxX = Float.MIN_VALUE; + float minY = Float.MAX_VALUE; + float maxY = Float.MIN_VALUE; + int[] countClasses = new int[nClasses]; + for (int col = 0; col < dataTransformed.numCols(); ++col) { + int k = (int)Math.round(classes.get(0, col)); + countClasses[k]++; + float x = (float)dataTransformed.get(0, col); + float y = (float)dataTransformed.get(1, col); + if (x < minX) { + minX = x; + } + if (y < minY) { + minY = y; + } + if (x < maxX) { + maxX = x; + } + if (y > maxY) { + maxY = y; + } + } + maxX = 3; + maxY = 3; + + PolygonAndPointPlotter plotter = new PolygonAndPointPlotter(minX, maxX, + minY, maxY); + + for (int k = 0; k < nClasses; ++k) { + float[] xPoint = new float[countClasses[k]]; + float[] yPoint = new float[countClasses[k]]; + int count = 0; + for (int col = 0; col < dataTransformed.numCols(); ++col) { + if ((int)Math.round(classes.get(0, col)) != k) { + continue; + } + xPoint[count] = (float)dataTransformed.get(0, col); + yPoint[count] = (float)dataTransformed.get(1, col); + count++; + } + float[] xPoly = null; + float[] yPoly = null; + plotter.addPlot(xPoint, yPoint, xPoly, yPoly, "class " + k); + } + String file1 = plotter.writeFile(); + + /*System.out.println(String.format("(%.3f, %.3f)", + (float)dataTransformed.get(0, 0), + (float)dataTransformed.get(1, 0))); + System.out.println(String.format("(%.3f, %.3f)", + (float)dataTransformed.get(0, 1), + (float)dataTransformed.get(1, 1))); + System.out.println(String.format("(%.3f, %.3f)", + (float)dataTransformed.get(0, 2), + (float)dataTransformed.get(1, 2)));*/ + + assertTrue(Math.abs(dataTransformed.get(0, 0) - 1.791) < 0.01); + assertTrue(Math.abs(dataTransformed.get(1, 0) - 0.115) < 0.01); + + assertTrue(Math.abs(dataTransformed.get(0, 1) - 1.583) < 0.01); + assertTrue(Math.abs(dataTransformed.get(1, 1) - -0.265) < 0.01); + + assertTrue(Math.abs(dataTransformed.get(0, 2) - 1.664) < 0.01); + assertTrue(Math.abs(dataTransformed.get(1, 2) - -0.084) < 0.01); + + // --- to make a transformation usable on features not normalized: + SimpleMatrix w3 = MatrixUtil.createLDATransformation( + dataAndClasses[0], dataAndClasses[1]); + SimpleMatrix dataTransformed3 = new SimpleMatrix(MatrixUtil.dot(w3, + dataAndClasses[0])); + + minX = Float.MAX_VALUE; + maxX = Float.MIN_VALUE; + minY = Float.MAX_VALUE; + maxY = Float.MIN_VALUE; + countClasses = new int[nClasses]; + for (int col = 0; col < dataTransformed3.numCols(); ++col) { + int k = (int)Math.round(classes.get(0, col)); + countClasses[k]++; + float x = (float)dataTransformed3.get(0, col); + float y = (float)dataTransformed3.get(1, col); + if (x < minX) { + minX = x; + } + if (y < minY) { + minY = y; + } + if (x < maxX) { + maxX = x; + } + if (y > maxY) { + maxY = y; + } + } + maxX = -1*minX; + maxY = -1*minY; + + PolygonAndPointPlotter plotter2 = new PolygonAndPointPlotter(minX, maxX, + minY, maxY); + + for (int k = 0; k < nClasses; ++k) { + float[] xPoint = new float[countClasses[k]]; + float[] yPoint = new float[countClasses[k]]; + int count = 0; + for (int col = 0; col < dataTransformed3.numCols(); ++col) { + if ((int)Math.round(classes.get(0, col)) != k) { + continue; + } + xPoint[count] = (float)dataTransformed3.get(0, col); + yPoint[count] = (float)dataTransformed3.get(1, col); + count++; + } + float[] xPoly = null; + float[] yPoly = null; + plotter2.addPlot(xPoint, yPoint, xPoly, yPoly, "class " + k); + } + String file2 = plotter2.writeFile2(); + + } private SimpleMatrix[] readIrisDataset() throws Exception {