Skip to content

Commit

Permalink
more towards LDA tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nking committed Feb 8, 2016
1 parent 18b8053 commit 175d2c8
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/algorithms/imageProcessing/util/MatrixUtil.java
Expand Up @@ -417,7 +417,7 @@ public static double[][] multiplyByTranspose(double[][] p, double[][] n) {
* adapted from http://sebastianraschka.com/Articles/2014_python_lda.html
*
* To apply the results to a data matrix, use y = W^T * X where W is
* the returned matrix.
* the returned matrix (i.e., MatrixUtil.dot(w, data))
* Note, X internally has been scaled to unit standard deviation.
* @param data
* @param classes
Expand Down Expand Up @@ -453,7 +453,7 @@ public static SimpleMatrix createLDATransformation(SimpleMatrix data,
* adapted from http://sebastianraschka.com/Articles/2014_python_lda.html
*
* To apply the results to a data matrix, use y = W^T * X where W is
* the returned matrix.
* the returned matrix (i.e., MatrixUtil.dot(w, normData))
* Note, X internally has been scaled to unit standard deviation.
* @param normData data scaled to mean of 0 and a standard deviation of 1
* @param classes zero based transformed classes
Expand Down
128 changes: 128 additions & 0 deletions tests/algorithms/imageProcessing/util/MatrixUtilTest.java
@@ -1,5 +1,6 @@
package algorithms.imageProcessing.util;

import algorithms.util.PolygonAndPointPlotter;
import algorithms.util.ResourceFinder;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
Expand Down Expand Up @@ -347,6 +348,133 @@ public void testCreateLDATrasformation() throws Exception {
assertTrue(Math.abs(w2.get(1, 3) - 0.750) < 0.01);


int nr = w2.numRows();
int nc = w2.numCols();
int nr2 = normData.numRows();
int nc2 = normData.numCols();
// 2 X 150
SimpleMatrix dataTransformed = new SimpleMatrix(MatrixUtil.dot(w, normData));

float minX = Float.MAX_VALUE;
float maxX = Float.MIN_VALUE;
float minY = Float.MAX_VALUE;
float maxY = Float.MIN_VALUE;
int[] countClasses = new int[nClasses];
for (int col = 0; col < dataTransformed.numCols(); ++col) {
int k = (int)Math.round(classes.get(0, col));
countClasses[k]++;
float x = (float)dataTransformed.get(0, col);
float y = (float)dataTransformed.get(1, col);
if (x < minX) {
minX = x;
}
if (y < minY) {
minY = y;
}
if (x < maxX) {
maxX = x;
}
if (y > maxY) {
maxY = y;
}
}
maxX = 3;
maxY = 3;

PolygonAndPointPlotter plotter = new PolygonAndPointPlotter(minX, maxX,
minY, maxY);

for (int k = 0; k < nClasses; ++k) {
float[] xPoint = new float[countClasses[k]];
float[] yPoint = new float[countClasses[k]];
int count = 0;
for (int col = 0; col < dataTransformed.numCols(); ++col) {
if ((int)Math.round(classes.get(0, col)) != k) {
continue;
}
xPoint[count] = (float)dataTransformed.get(0, col);
yPoint[count] = (float)dataTransformed.get(1, col);
count++;
}
float[] xPoly = null;
float[] yPoly = null;
plotter.addPlot(xPoint, yPoint, xPoly, yPoly, "class " + k);
}
String file1 = plotter.writeFile();

/*System.out.println(String.format("(%.3f, %.3f)",
(float)dataTransformed.get(0, 0),
(float)dataTransformed.get(1, 0)));
System.out.println(String.format("(%.3f, %.3f)",
(float)dataTransformed.get(0, 1),
(float)dataTransformed.get(1, 1)));
System.out.println(String.format("(%.3f, %.3f)",
(float)dataTransformed.get(0, 2),
(float)dataTransformed.get(1, 2)));*/

assertTrue(Math.abs(dataTransformed.get(0, 0) - 1.791) < 0.01);
assertTrue(Math.abs(dataTransformed.get(1, 0) - 0.115) < 0.01);

assertTrue(Math.abs(dataTransformed.get(0, 1) - 1.583) < 0.01);
assertTrue(Math.abs(dataTransformed.get(1, 1) - -0.265) < 0.01);

assertTrue(Math.abs(dataTransformed.get(0, 2) - 1.664) < 0.01);
assertTrue(Math.abs(dataTransformed.get(1, 2) - -0.084) < 0.01);

// --- to make a transformation usable on features not normalized:
SimpleMatrix w3 = MatrixUtil.createLDATransformation(
dataAndClasses[0], dataAndClasses[1]);
SimpleMatrix dataTransformed3 = new SimpleMatrix(MatrixUtil.dot(w3,
dataAndClasses[0]));

minX = Float.MAX_VALUE;
maxX = Float.MIN_VALUE;
minY = Float.MAX_VALUE;
maxY = Float.MIN_VALUE;
countClasses = new int[nClasses];
for (int col = 0; col < dataTransformed3.numCols(); ++col) {
int k = (int)Math.round(classes.get(0, col));
countClasses[k]++;
float x = (float)dataTransformed3.get(0, col);
float y = (float)dataTransformed3.get(1, col);
if (x < minX) {
minX = x;
}
if (y < minY) {
minY = y;
}
if (x < maxX) {
maxX = x;
}
if (y > maxY) {
maxY = y;
}
}
maxX = -1*minX;
maxY = -1*minY;

PolygonAndPointPlotter plotter2 = new PolygonAndPointPlotter(minX, maxX,
minY, maxY);

for (int k = 0; k < nClasses; ++k) {
float[] xPoint = new float[countClasses[k]];
float[] yPoint = new float[countClasses[k]];
int count = 0;
for (int col = 0; col < dataTransformed3.numCols(); ++col) {
if ((int)Math.round(classes.get(0, col)) != k) {
continue;
}
xPoint[count] = (float)dataTransformed3.get(0, col);
yPoint[count] = (float)dataTransformed3.get(1, col);
count++;
}
float[] xPoly = null;
float[] yPoly = null;
plotter2.addPlot(xPoint, yPoint, xPoly, yPoly, "class " + k);
}
String file2 = plotter2.writeFile2();


}

private SimpleMatrix[] readIrisDataset() throws Exception {
Expand Down

0 comments on commit 175d2c8

Please sign in to comment.