diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..466e248
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+out/
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
Classification of float values in Java
+ + +## Description + + +## Getting Started + + +### Dependencies + + +### Installing + + +### Executing program + + +### Error Codes + + +## Scripts + + +## Help + + +## Authors + +Contributors names and contact info + +* Max Wenk + * [@max-acc](https://github.com/max-acc) + +## Version History + + +## License + +This project is licensed under the "GNU Affero General Public License v3.0" License - see the LICENSE.md file for details. + +## Acknowledgments + +* [README Template](https://gist.github.com/DomPizzie/7a5ff55ffa9081f2de27c315f5018afc) \ No newline at end of file diff --git a/src/Main.java b/src/Main.java index ea04f70..177ea7a 100644 --- a/src/Main.java +++ b/src/Main.java @@ -21,6 +21,8 @@ public static void main (String[] args) throws Exception { ob.dataSubdivision(); //System.out.println(ob.feedback()[0][2]); ob.distanceClassification(); + + ob.evaluateResults(); } } diff --git a/src/classification/ClassificationOfFloatValues.java b/src/classification/ClassificationOfFloatValues.java index 7ffc9b3..20557b7 100644 --- a/src/classification/ClassificationOfFloatValues.java +++ b/src/classification/ClassificationOfFloatValues.java @@ -27,6 +27,10 @@ public class ClassificationOfFloatValues { private boolean dataSubdivisionBool = false; // Data has been divided into training and test data private String MLAlgorithm; // Variable for saving which machine learning algorithm has been used + // --- Classification result data variables + private String[][][] predictedTestData; + private int[][] sortedProbability; + // Function to add the members of the class public float[][] output() { return this.predictorData; } @@ -83,6 +87,13 @@ public void dataValidation (float trainingData) { // --- Functions for evaluating the machine learning results ------------------------------------------------------- + public void evaluateResults() { + DATA_evaluation evaluationObject = new DATA_evaluation(this.testDataResults, + this.columnCount - this.numberOfTrainingData, + this.predictedTestData, + this.sortedProbability); + evaluationObject.confusionMatrix(); + } public void confusionMatrix() { if (this.MLAlgorithm == "DistanceClassification") { System.out.println("nice confusion"); @@ -134,6 +145,10 @@ public void distanceClassification (){ // Testing the distance classification model classificationObject.setTestData(this.testDataPredictors, this.testDataResults, this.rowCount, this.columnCount - this.numberOfTrainingData); classificationObject.testModel(); + + // Get the test data + this.predictedTestData = classificationObject.getPredictedTestData(); + this.sortedProbability = classificationObject.getSortedProbability(); } } diff --git a/src/classification/DATA_evaluation.java b/src/classification/DATA_evaluation.java new file mode 100644 index 0000000..f0b2649 --- /dev/null +++ b/src/classification/DATA_evaluation.java @@ -0,0 +1,59 @@ +package classification; + +public class DATA_evaluation { + private String[][][] predictedTestData; + private int[][] sortedProbability; + + private int columnCount; + private int numberOfClasses; + private String [] testDataResults; + private int[][] confustionMatrix = new int[3][2]; + + protected DATA_evaluation(String[] testDataResults, int columnCount, String[][][] predictedTestData, int[][] sortedProbability) { + this.testDataResults = testDataResults; + this.columnCount = columnCount; + this.predictedTestData = predictedTestData; + this.sortedProbability = sortedProbability; + } + + protected void confusionMatrix() { + System.out.println(this.testDataResults[0]); + System.out.println(this.columnCount); + System.out.println(this.predictedTestData[0][0][0]); + System.out.println(this.predictedTestData[0][0][1]); + System.out.println(this.predictedTestData[0][1][0]); + System.out.println(this.predictedTestData[0][1][1]); + System.out.println(this.predictedTestData[0][2][0]); + System.out.println(this.predictedTestData[0][2][1]); + System.out.println(this.sortedProbability[0][0]); + System.out.println(this.sortedProbability[0][1]); + System.out.println(this.sortedProbability[0][2]); + + // Resetting the confusion matrix + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 2; j++) { + this.confustionMatrix[i][j] = 0; + } + } + + for (int i = 0; i < this.columnCount; i++) { + if (this.testDataResults[i].equals(this.predictedTestData[i][this.sortedProbability[i][0]][0])) { + this.confustionMatrix[this.sortedProbability[i][0]][0]++; + } + else { + this.confustionMatrix[this.sortedProbability[i][0]][1]++; + } + } + + + + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 2; j++) { + System.out.print(this.confustionMatrix[i][j] + " "); + } + System.out.println(); + } + + + } +} \ No newline at end of file diff --git a/src/classification/DistanceClassification.java b/src/classification/DistanceClassification.java index 33678c2..2ee4009 100644 --- a/src/classification/DistanceClassification.java +++ b/src/classification/DistanceClassification.java @@ -21,7 +21,8 @@ public class DistanceClassification { private int testDataRowCount; private int testDataColumnCount; - private String[][] predictedTestData; + private String[][][] predictedTestData; + private int[][] sortedProbability; protected DistanceClassification(float [][] trainingDataPredictors, String [] trainingDataResults, int rowCount, int columnCount, float density) { @@ -42,6 +43,14 @@ protected float[][][] getSortedClassificationData() { return this.sortedClassificationData; } + protected String[][][] getPredictedTestData() { + return this.predictedTestData; + } + + protected int[][] getSortedProbability() { + return this.sortedProbability; + } + protected float[][] getFeatureMean() { return this.featureMean; } @@ -54,7 +63,7 @@ protected void setTestData(float[][] testDataPredictors, String[] testDataResult } protected void testModel() { - testClassifcationModel(); + testClassificationModel(); } private void getClassificationClasses() { @@ -141,10 +150,12 @@ private void calcFeatureMean() { } } - private void testClassifcationModel() { - this.predictedTestData = new String[testDataColumnCount][2]; + private void testClassificationModel() { + this.predictedTestData = new String[this.testDataColumnCount][this.numberOfClasses][2]; + this.sortedProbability = new int[this.testDataColumnCount][this.numberOfClasses]; // Check the distance for every class + for (int i = 0; i < this.testDataColumnCount; i++) { float[][] tempDelta = new float[this.numberOfClasses][this.rowCount]; for (int j = 0; j < this.numberOfClasses; j++) { @@ -152,36 +163,55 @@ private void testClassifcationModel() { tempDelta[j][k] = this.featureMean[j][k] - this.testDataPredictors[i][k]; } - tempDelta[j][this.rowCount -1] = (float) Math.sqrt( - Math.pow(tempDelta[j][0], 2) + - Math.pow(tempDelta[j][1], 2) + - Math.pow(tempDelta[j][2], 2) + - Math.pow(tempDelta[j][3], 2)); - System.out.print(classes[j] + " "); - System.out.println(tempDelta[j][this.rowCount -1]); - } - + this.predictedTestData[i][j][0] = this.classes[j]; + float tempCalcDistance = 0; - /* - for (int j = 0; j < this.numberOfClasses; j++) { for (int k = 0; k < this.rowCount -1; k++) { - System.out.print(tempDelta[j][k]); - System.out.print(" "); + tempCalcDistance += (float) Math.pow(tempDelta[j][k], 2); } - System.out.println(); - } - */ - this.predictedTestData[i][0] = "best result"; - this.predictedTestData[i][1] = "probability"; + tempDelta[j][this.rowCount -1] = (float) Math.sqrt(tempCalcDistance); + this.predictedTestData[i][j][1] = Float.toString(tempDelta[j][this.rowCount -1]); - //this.predictedTestData[i] = "best result"; - } - + System.out.print(" " + this.predictedTestData[i][j][0] + " "); + System.out.println(this.predictedTestData[i][j][1]); + } + float min = Float.valueOf(this.predictedTestData[i][0][1]); + int tempIndex = 0; + this.sortedProbability[i][0] = tempIndex; + for (int j = 0; j < this.numberOfClasses; j++) { + if (min > Float.valueOf(this.predictedTestData[i][j][1])) { + min = Float.valueOf(this.predictedTestData[i][j][1]); + tempIndex = j; + } + } + this.sortedProbability[i][0] = tempIndex; + for (int j = 1; j < this.numberOfClasses; j++) { + tempIndex = 0; + this.sortedProbability[i][j] = tempIndex; + min = Float.valueOf(this.predictedTestData[i][this.sortedProbability[i][j-1]][1]); + for (int k = 0; k < this.numberOfClasses; k++) { + if ((min > Float.valueOf(this.predictedTestData[i][k][1]) && + Float.valueOf(this.predictedTestData[i][k][1]) > Float.valueOf(this.predictedTestData[i][this.sortedProbability[i][j-1]][1])) || + (Float.valueOf(this.predictedTestData[i][k][1]) > Float.valueOf(this.predictedTestData[i][this.sortedProbability[i][j-1]][1]) && + min == Float.valueOf(this.predictedTestData[i][this.sortedProbability[i][j-1]][1]))) { + if (min > Float.valueOf(this.predictedTestData[i][k][1]) && + Float.valueOf(this.predictedTestData[i][k][1]) > Float.valueOf(this.predictedTestData[i][this.sortedProbability[i][j-1]][1])) { + } + if (Float.valueOf(this.predictedTestData[i][k][1]) > Float.valueOf(this.predictedTestData[i][this.sortedProbability[i][j-1]][1]) && + k != this.sortedProbability[i][j-1]) { + } + min = Float.valueOf(this.predictedTestData[i][k][1]); + tempIndex = k; + } + } + this.sortedProbability[i][j] = tempIndex; + } + } } }