Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Version 1.0 Major Commit

I should have commited sooner.  This is the first "working" cut.  Note the entry point class:  LaunchClassifier, and the arguments it takes.  Also note due to the complexity (time) of K-Nearest Neighbor, the test data has been trimmed down a bit more, and the old data that had 200 lines moved to their own respective sub directories in test_data and training_data.  Skeleton code still exists but will be removed.  Trigger LaunchClassifier with no arguments to get usage printout.
  • Loading branch information...
commit a943ee77a95e77f09f40318f3ba09f21bf15aee3 1 parent 3bcbf9c
@jweaver authored
Showing with 8,008 additions and 4,096 deletions.
  1. +6 −0 runKNN8Test.sh
  2. +129 −0 src/com/weaverworx/usc/csci561/assignment3/LaunchClassifier.java
  3. +0 −58 src/com/weaverworx/usc/csci561/assignment3/Launcher.java
  4. +6 −6 src/com/weaverworx/usc/csci561/assignment3/skeleton/KNearestNeighbor_Skeleton.java
  5. +269 −155 src/com/weaverworx/usc/csci561/assignment3/skeleton/NaiveBayes_Skeleton.java
  6. +32 −0 src/com/weaverworx/usc/csci561/assignment3/util/ClassifierTypes.java
  7. +0 −97 src/com/weaverworx/usc/csci561/assignment3/util/KNNUtil.java
  8. +276 −0 src/com/weaverworx/usc/csci561/assignment3/util/LearningUtil.java
  9. +199 −0 test_data/200_lines/test0.txt
  10. +199 −0 test_data/200_lines/test1.txt
  11. +199 −0 test_data/200_lines/test2.txt
  12. +199 −0 test_data/200_lines/test3.txt
  13. +199 −0 test_data/200_lines/test4.txt
  14. +199 −0 test_data/200_lines/test5.txt
  15. +199 −0 test_data/200_lines/test6.txt
  16. +199 −0 test_data/200_lines/test7.txt
  17. +199 −0 test_data/200_lines/test8.txt
  18. +199 −0 test_data/200_lines/test9.txt
  19. +39 −189 test_data/test0.txt
  20. +47 −189 test_data/test1.txt
  21. +42 −189 test_data/test2.txt
  22. +41 −189 test_data/test3.txt
  23. +40 −189 test_data/test4.txt
  24. +35 −189 test_data/test5.txt
  25. +38 −189 test_data/test6.txt
  26. +42 −189 test_data/test7.txt
  27. +39 −189 test_data/test8.txt
  28. +41 −189 test_data/test9.txt
  29. +199 −0 training_data/200_lines/train0.txt
  30. +199 −0 training_data/200_lines/train1.txt
  31. +199 −0 training_data/200_lines/train2.txt
  32. +199 −0 training_data/200_lines/train3.txt
  33. +199 −0 training_data/200_lines/train4.txt
  34. +199 −0 training_data/200_lines/train5.txt
  35. +199 −0 training_data/200_lines/train6.txt
  36. +199 −0 training_data/200_lines/train7.txt
  37. +199 −0 training_data/200_lines/train8.txt
  38. +199 −0 training_data/200_lines/train9.txt
  39. +287 −189 training_data/train0.txt
  40. +328 −189 training_data/train1.txt
  41. +288 −189 training_data/train2.txt
  42. +297 −189 training_data/train3.txt
  43. +283 −189 training_data/train4.txt
  44. +262 −189 training_data/train5.txt
  45. +286 −189 training_data/train6.txt
  46. +304 −189 training_data/train7.txt
  47. +283 −189 training_data/train8.txt
  48. +288 −189 training_data/train9.txt
View
6 runKNN8Test.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+#
+# Run script to test K-Nearest Neighbor (KNN) with K = 8.
+
+
+java -c bin com.weaverworx.usc.csci561.assignment3.LaunchClassifier knn -k=8
View
129 src/com/weaverworx/usc/csci561/assignment3/LaunchClassifier.java
@@ -0,0 +1,129 @@
+/**
+ * File: Launcher.java
+ * Author: Jack Weaver <jhweaver@usc.edu>, <weaver.jack@gmail.com>
+ * Course: CSCI 561, Spring 2012
+ * Assignment: Assignment 3 - Supervised Learning Systems
+ * Target: aludra.usc.edu running Java 1.6.0_23
+ */
+package com.weaverworx.usc.csci561.assignment3;
+
+import com.weaverworx.usc.csci561.assignment3.knn.KNNRecord;
+import com.weaverworx.usc.csci561.assignment3.util.ClassifierTypes;
+import com.weaverworx.usc.csci561.assignment3.util.FileReader;
+import com.weaverworx.usc.csci561.assignment3.util.LearningUtil;
+
+/**
+ * Main class to launch the application.
+ *
+ * @author jw
+ *
+ */
+public class LaunchClassifier {
+ private final static int EX_CLASS_INDEX = 784;
+
+ /**
+ * Entry point.
+ *
+ * @param args
+ */
+ public static void main(String[] args) {
+ ClassifierTypes classifier = LearningUtil.parseClassifierSystem(args);
+
+ if (classifier.name().compareTo(ClassifierTypes.K_NEAREST_NEIGHBOR.name()) == 0) {
+ int k = LearningUtil.parseKArgument(args); // Read k from user input
+ // Get the training data and test data
+ int[][] trainData = FileReader.getTrainingData(LearningUtil.NUMBER_OF_FEATURES,
+ LearningUtil.NUMBER_OF_CLASSES);
+ int[][] testData = FileReader.getTestData(LearningUtil.NUMBER_OF_FEATURES,
+ LearningUtil.NUMBER_OF_CLASSES);
+
+ // Set up the K-Nearest Neighbor Records
+ KNNRecord[] knnRecords = new KNNRecord[trainData.length];
+ for (int i = 0; i < trainData.length; i++) {
+ knnRecords[i] = new KNNRecord();
+ }
+
+ // Set up the arrays for # correct and # incorrect
+ int[] correct = new int[LearningUtil.NUMBER_OF_CLASSES], incorrect =
+ new int[LearningUtil.NUMBER_OF_CLASSES];
+ for (int i = 0; i < testData.length; i++) {
+ for (int j = 0; j < trainData.length; j++) {
+ // Set the distance & the record class
+ knnRecords[j].setExampleClass(trainData[j][EX_CLASS_INDEX]);
+ knnRecords[j].setDistance(LearningUtil
+ .getEuclideanDistance(testData[i], trainData[j]));
+ }
+ int actualClass = testData[i][EX_CLASS_INDEX];
+ int predictedClass = LearningUtil.predictKNN(k, knnRecords);
+ // Set the counters for accuracy, every time we correctly
+ // predict
+ // the numerical class, tally it. Otherwise tally the miss.
+ if (actualClass == predictedClass) {
+ correct[actualClass]++;
+ } else {
+ incorrect[actualClass]++;
+ }
+ }
+ // Display the results to Stdout
+ LearningUtil.outputKNNResultsToStdOut(k, correct, incorrect);
+
+
+ } else if (classifier.name().compareTo(ClassifierTypes.NAIVE_BAYES.name()) == 0) {
+ int threshold = LearningUtil.parseTArgument(args);
+ double[] N_Y = new double[LearningUtil.NUMBER_OF_CLASSES];
+ double[] P_Y = new double[LearningUtil.NUMBER_OF_CLASSES];
+ int[][] trainingData = FileReader.getTrainingData(LearningUtil.NUMBER_OF_FEATURES,
+ LearningUtil.NUMBER_OF_CLASSES);
+ int[][] testData = FileReader.getTestData(LearningUtil.NUMBER_OF_FEATURES,
+ LearningUtil.NUMBER_OF_CLASSES);
+
+ //Setup N_Y: the number of times a given training data entry
+ //appears among ALL the training data, ie- N_Y[1] is # of 1s
+ int totalTrainingSize = 0;
+ for (int i = 0; i < LearningUtil.NUMBER_OF_CLASSES; i++) {
+ for (int j = 0; j < trainingData.length; j++) {
+ //For each "record" in all test data
+ if (trainingData[j][trainingData[j].length -1] == i) {
+ N_Y[i] = N_Y[i] + 1;
+ totalTrainingSize++;
+ }
+ }
+ if (i == LearningUtil.NUMBER_OF_CLASSES - 1) {
+ for (int k = 0; k < P_Y.length; k++) {
+ P_Y[k] = N_Y[k] / totalTrainingSize;
+ }
+ }
+ }
+
+ double[][][] P_X_given_Y =
+ LearningUtil.getBayesConditionalProbabilities(threshold,
+ trainingData, N_Y);
+ int[] correct = new int[LearningUtil.NUMBER_OF_CLASSES];
+ int[] incorrect = new int[LearningUtil.NUMBER_OF_CLASSES];
+ for (int i = 0; i < testData.length; i++) { // for each test example
+ int actual_class = testData[i][EX_CLASS_INDEX];
+ /*
+ * predict by using P_Y and P_X_given_Y parameters threshold used
+ * for converting to binary (1,0) data format.
+ */
+ int predict_class = LearningUtil.predictBayes(testData[i], threshold,
+ P_Y, P_X_given_Y);
+ if (actual_class == predict_class) {
+ /*
+ * if actual_class same as predict_class,
+ * increasing correct of that class
+ * (correct[actual_class])
+ */
+ correct[actual_class]++;
+ } else {
+ incorrect[actual_class]++;
+ }
+ }
+ LearningUtil.outputBayesResultsToStdOut(threshold, correct, incorrect);
+ // End of NAIVE BAYES
+ } else {
+ System.out.println(LearningUtil.getUsage());
+ System.exit(0);
+ }
+ }
+}
View
58 src/com/weaverworx/usc/csci561/assignment3/Launcher.java
@@ -1,58 +0,0 @@
-/**
- * File: Launcher.java
- * Author: Jack Weaver <jhweaver@usc.edu>, <weaver.jack@gmail.com>
- * Course: CSCI 561, Spring 2012
- * Assignment: Assignment 3 - Supervised Learning Systems
- * Target: aludra.usc.edu running Java 1.6.0_23
- */
-package com.weaverworx.usc.csci561.assignment3;
-
-import com.weaverworx.usc.csci561.assignment3.knn.KNNRecord;
-import com.weaverworx.usc.csci561.assignment3.util.FileReader;
-import com.weaverworx.usc.csci561.assignment3.util.KNNUtil;
-
-/**
- * Main class to launch the application.
- *
- * @author jw
- *
- */
-public class Launcher {
- private final static int NUMBER_OF_CLASSES = 10;
- private final static int NUMBER_OF_FEATURES = 784;
- private final static int EX_CLASS_INDEX = 784;
-
- /**
- * Entry point.
- *
- * @param args
- */
- public static void main(String[] args) {
- int k = KNNUtil.parseKArgument(args); //Read k from user input
- //Get the training data and test data
- int[][] trainData = FileReader.getTrainingData(
- NUMBER_OF_FEATURES, NUMBER_OF_CLASSES);
- int[][] testData = FileReader.getTestData(
- NUMBER_OF_FEATURES, NUMBER_OF_CLASSES);
-
- //Set up the K-Nearest Neighbor Records
- KNNRecord[] knnRecords = new KNNRecord[trainData.length];
- for (int i = 0; i < trainData.length; i++) {
- knnRecords[i] = new KNNRecord();
- }
-
- //Set up the arrays for # correct and # incorrect
- int[] correct, incorrect = new int[NUMBER_OF_CLASSES];
- for (int i = 0; i < testData.length; i++) {
- for (int j = 0; j < trainData.length; j++) {
- //Set the distance & the record class
- knnRecords[j].setExampleClass(trainData[j][EX_CLASS_INDEX]);
- knnRecords[j].setDistance(KNNUtil.getEuclideanDistance(
- testData[i], trainData[j]));
- }
-
- }
-
-
- }
-}
View
12 src/com/weaverworx/usc/csci561/assignment3/skeleton/KNearestNeighbor_Skeleton.java
@@ -6,7 +6,7 @@
import com.weaverworx.usc.csci561.assignment3.knn.KNNRecord;
import com.weaverworx.usc.csci561.assignment3.util.FileReader;
-import com.weaverworx.usc.csci561.assignment3.util.KNNUtil;
+import com.weaverworx.usc.csci561.assignment3.util.LearningUtil;
/**
*
@@ -20,12 +20,12 @@
public static void main(String args[]) {
- int K = KNNUtil.parseKArgument(args); // read K from user input
+ int K = LearningUtil.parseKArgument(args); // read K from user input
int[][] train = FileReader.getTrainingData(numberOfFeatures, numberOfClasses); // last column 785 is a class of train
// image
int[][] test = FileReader.getTestData(numberOfFeatures, numberOfClasses); // last column 785 is a class of test
- // imag e
+ // image
KNNRecord[] knn_records = new KNNRecord[train.length]; //
for (int j = 0; j < train.length; j++) {
@@ -43,11 +43,11 @@ public static void main(String args[]) {
// knn_records[j].example_class = train[j][example_class_index]; // class is train class
knn_records[j].setExampleClass(train[j][example_class_index]);
// knn_records[j].distance = getEuclideanDistance(test[i], train[j]); // difference between test i and train j
- knn_records[j].setDistance(KNNUtil.getEuclideanDistance(test[i], train[j]));
+ knn_records[j].setDistance(LearningUtil.getEuclideanDistance(test[i], train[j]));
}
int actual_class = test[i][example_class_index];
- int predict_class = KNNUtil.predict(K, knn_records);
+ int predict_class = LearningUtil.predictKNN(K, knn_records);
if (actual_class == predict_class) {
correct[actual_class]++; // if actual_class same as
@@ -61,7 +61,7 @@ public static void main(String args[]) {
}
// display output
- KNNUtil.outputResultsToStdOut(K, correct, incorrect, numberOfClasses);
+ LearningUtil.outputKNNResultsToStdOut(K, correct, incorrect);
}
View
424 src/com/weaverworx/usc/csci561/assignment3/skeleton/NaiveBayes_Skeleton.java
@@ -6,166 +6,280 @@
import java.io.*;
import java.util.*;
+
+import com.weaverworx.usc.csci561.assignment3.util.FileReader;
+import com.weaverworx.usc.csci561.assignment3.util.LearningUtil;
+
/**
- *
+ *
* @author femto
*/
public class NaiveBayes_Skeleton {
-final static int numberOfClasses = 10;
-final static int numberOffeatures = 784;
-final static int example_class_index = 784;
-
-public static void main(String args[]){
-
-int threshold = readUserInput(); // read K from user input
-
-int [][] train = readTrainFile(); // last column 785 is a class of train image
-int [][] test = readTestFile(); // last column 785 is a class of test image
-
-double[] P_Y = getClassProbabilities(train); // Learn P(Y) parameters
-double[][][] P_X_given_Y = getConditionalProbabilities(threshold ,train); //Learn P(X|Y) parameters
- // threshold used for binarized
-
-int[] correct = new int[numberOfClasses];
-int[] incorrect = new int[numberOfClasses];
-
-
-
-for(int i=0;i<test.length;i++){ // for each test exaple
-
-int actual_class = test[i][example_class_index];
-int predict_class = predict(test[i],threshold,P_Y,P_X_given_Y);// predict by using P_Y and P_X_given_Y parameters
- // threshold used for binarized
-if(actual_class == predict_class){
- correct[actual_class]++; //if actual_class same as predict_class, increasing correct of that class (correct[actual_class])
- }else{
- incorrect[actual_class]++;
- }
-
-}
-
-// display output
-
-System.out.println("Threshold = " + threshold);
-
-for(int i =0;i<10;i++){
- double accuracy = (correct[i]*1.0)/(correct[i] + incorrect[i]);
- System.out.println("Class = " + i + " Correct = " + correct[i] + " Incorrect = " + incorrect[i] + " Accuracy = " + accuracy);
-}
-
-}
-
-static int readUserInput(){
- return 100;
-}
-static int[][] readTrainFile(){
-
- ArrayList<String> tem = new ArrayList<String>();
-
-
- for(int i=0;i<numberOfClasses;i++){
- try{
-
- String filename = "train"+i+".txt";
-
-FileInputStream fstream = new FileInputStream(filename);
- DataInputStream in = new DataInputStream(fstream);
- BufferedReader br = new BufferedReader(new InputStreamReader(in));
- String strLine;
-
- //Read File Line By Line with their class
- while ((strLine = br.readLine()) != null) {
- tem.add(strLine + "," + i);
- }
- //Close the input stream
- in.close();
- }catch (Exception e){//Catch exception if any
- e.printStackTrace();
- }
-
- }
- int[][] train = new int[tem.size()][numberOffeatures+1];
-
- for(int i=0;i<tem.size();i++){
- String[] line = tem.get(i).split(",");
-
- for(int j=0;j<line.length;j++){
- train[i][j] = Integer.parseInt(line[j]);
- }
-
- }
-
- return train;
-}
-
-static int[][] readTestFile(){
-
- ArrayList<String> tem = new ArrayList<String>();
-
-
- for(int i=0;i<numberOfClasses;i++){
- try{
-
- String filename = "test"+i+".txt";
-
-
- FileInputStream fstream = new FileInputStream(filename);
- DataInputStream in = new DataInputStream(fstream);
- BufferedReader br = new BufferedReader(new InputStreamReader(in));
- String strLine;
-
- //Read File Line By Line with their class
- while ((strLine = br.readLine()) != null) {
- tem.add(strLine + "," + i);
-
- }
-
-
- //Close the input stream
- in.close();
- }catch (Exception e){//Catch exception if any
- e.printStackTrace();
- }
-
- }
- int[][] test = new int[tem.size()][numberOffeatures+1];
-
- for(int i=0;i<tem.size();i++){
- String[] line = tem.get(i).split(",");
-
- for(int j=0;j<line.length;j++){
- test[i][j] = Integer.parseInt(line[j]);
- }
-
- }
-
- return test;
-}
-
-static double[] getClassProbabilities(int [][] train){
-
- double[] P_Y = new double[numberOfClasses];
-
- return P_Y;
-}
-
-static double[][][] getConditionalProbabilities(int threshold,int [][] train){
- double[][][] P_X_given_Y = new double[numberOfClasses][numberOffeatures][2];
-
- return P_X_given_Y;
-}
-
-static int predict(int[] input_image,double threshold,double[] P_Y,double[][][] P_X_given_Y){
-
-
- int max_index =0;
-
-
- return max_index;
-}
-
-
+ final static int numberOfClasses = 10;
+ final static int numberOfFeatures = 784;
+ final static int exampleClassIndex = 784;
+
+ private static double[] N_Y = new double[numberOfClasses];
+ private static double[] P_Y = new double[numberOfClasses];
+
+
+ public static void main(String args[]) {
+
+ int threshold = LearningUtil.parseKArgument(args); // read threshold
+
+ int[][] trainingData = FileReader.getTrainingData(numberOfFeatures, numberOfClasses); // last column 785 is a class of train
+ // image
+ int[][] testData = FileReader.getTestData(numberOfFeatures, numberOfClasses); // last column 785 is a class of test
+ // image
+
+ //Setup N_Y: the number of times a given training data appears amongst ALL the training data.
+ //eg - number of 1's in all training data: N_Y[1]
+ int totalTrainingSize = 0;
+ for (int i = 0; i < numberOfClasses; i++) {
+
+ for (int j = 0; j < trainingData.length; j++) {
+ //For each "record" in all test data
+ //
+ if (trainingData[j][trainingData[j].length -1] == i) {
+ N_Y[i] = N_Y[i] + 1;
+ totalTrainingSize++;
+ }
+ }
+ if (i == numberOfClasses - 1) {
+ for (int k = 0; k < P_Y.length; k++) {
+ P_Y[k] = N_Y[k] / totalTrainingSize;
+ }
+ }
+ }
+
+// double[] P_Y = getClassProbabilities(train); // Learn P(Y) parameters
+ double[][][] P_X_given_Y = getConditionalProbabilities(threshold, trainingData); // Learn
+ // P(X|Y)
+ // parameters
+ // threshold
+ // used
+ // for
+ // binarized
+
+ int[] correct = new int[numberOfClasses];
+ int[] incorrect = new int[numberOfClasses];
+
+ for (int i = 0; i < testData.length; i++) { // for each test example
+
+ int actual_class = testData[i][exampleClassIndex];
+ int predict_class = predict(testData[i], threshold, P_Y, P_X_given_Y);// predict
+ // by
+ // using
+ // P_Y
+ // and
+ // P_X_given_Y
+ // parameters
+ // threshold used for binarized
+ if (actual_class == predict_class) {
+ correct[actual_class]++; // if actual_class same as
+ // predict_class, increasing correct
+ // of that class
+ // (correct[actual_class])
+ } else {
+ incorrect[actual_class]++;
+ }
+
+ }
+
+ // display output
+
+ System.out.println("Threshold = " + threshold);
+
+ for (int i = 0; i < 10; i++) {
+ double accuracy = (correct[i] * 1.0) / (correct[i] + incorrect[i]);
+ System.out.println("Class = " + i + " Correct = " + correct[i]
+ + " Incorrect = " + incorrect[i] + " Accuracy = "
+ + accuracy);
+ }
+
+ }
+
+ static int readUserInput() {
+ return 100;
+ }
+
+ static int[][] readTrainFile() {
+
+ ArrayList<String> tem = new ArrayList<String>();
+
+ for (int i = 0; i < numberOfClasses; i++) {
+ try {
+
+ String filename = "train" + i + ".txt";
+
+ FileInputStream fstream = new FileInputStream(filename);
+ DataInputStream in = new DataInputStream(fstream);
+ BufferedReader br = new BufferedReader(
+ new InputStreamReader(in));
+ String strLine;
+
+ // Read File Line By Line with their class
+ while ((strLine = br.readLine()) != null) {
+ tem.add(strLine + "," + i);
+ }
+ // Close the input stream
+ in.close();
+ } catch (Exception e) {// Catch exception if any
+ e.printStackTrace();
+ }
+
+ }
+ int[][] train = new int[tem.size()][numberOfFeatures + 1];
+
+ for (int i = 0; i < tem.size(); i++) {
+ String[] line = tem.get(i).split(",");
+
+ for (int j = 0; j < line.length; j++) {
+ train[i][j] = Integer.parseInt(line[j]);
+ }
+
+ }
+
+ return train;
+ }
+
+ static int[][] readTestFile() {
+
+ ArrayList<String> tem = new ArrayList<String>();
+
+ for (int i = 0; i < numberOfClasses; i++) {
+ try {
+
+ String filename = "test" + i + ".txt";
+
+ FileInputStream fstream = new FileInputStream(filename);
+ DataInputStream in = new DataInputStream(fstream);
+ BufferedReader br = new BufferedReader(
+ new InputStreamReader(in));
+ String strLine;
+
+ // Read File Line By Line with their class
+ while ((strLine = br.readLine()) != null) {
+ tem.add(strLine + "," + i);
+
+ }
+
+ // Close the input stream
+ in.close();
+ } catch (Exception e) {// Catch exception if any
+ e.printStackTrace();
+ }
+
+ }
+ int[][] test = new int[tem.size()][numberOfFeatures + 1];
+
+ for (int i = 0; i < tem.size(); i++) {
+ String[] line = tem.get(i).split(",");
+
+ for (int j = 0; j < line.length; j++) {
+ test[i][j] = Integer.parseInt(line[j]);
+ }
+
+ }
+
+ return test;
+ }
+
+ static double[] getClassProbabilities(int[][] train) {
+
+ double[] P_Y = new double[numberOfClasses];
+
+ return P_Y;
+ }
+
+ static double[][][] getConditionalProbabilities(int threshold, int[][] trainingData) {
+ //Convert training data to 'binary training data'
+ int[][] binaryTrainingData = new int[trainingData.length][trainingData[trainingData.length - 1].length];
+ for (int i = 0; i < trainingData.length; i++) {
+ for (int j = 0; j < trainingData[trainingData.length - 1].length; j++) {
+ if (trainingData[i][j] <= threshold) {
+ binaryTrainingData[i][j] = 0;
+ }
+ else {
+ binaryTrainingData[i][j] = 1;
+ }
+ }
+ }
+
+
+ double[][][] P_X_given_Y = new double[numberOfClasses][numberOfFeatures][2];
+ double[][][] N_XY = new double[numberOfClasses][numberOfFeatures][2];
+
+// int c0 = 0, c1 = 0;
+ for (int i = 0; i < binaryTrainingData.length; i++) { //For each image in training data...
+ for (int j = 0; j < numberOfFeatures; j++) {
+ //Iterate over every image in training data, incrementing it's
+ //corresponding slot in the N_XY array
+ if (binaryTrainingData[i][j] == 0) {
+ N_XY[trainingData[i][trainingData[i].length -1]][j][0]++;
+ } else {
+ N_XY[trainingData[i][trainingData[i].length -1]][j][1]++;
+ }
+ }
+ }
+
+// for (int i = 0; i < numberOfClasses; i++) {
+// for (int j = 0; j < numberOfFeatures; j++) {
+//
+//
+// if (binaryTrainingData[i][j] == 0) {
+// N_XY[i][j][0] = c0;
+// c0 = 0;
+// } else {
+// N_XY[i][j][1] = c1;
+// c1 = 0;
+// }
+// }
+// }
+
+
+ for(int i = 0; i < numberOfClasses; i++) {
+ for(int j = 0; j < numberOfFeatures; j++) {
+ P_X_given_Y[i][j][0] = (N_XY[i][j][0] + 1) / (N_Y[i] + 2);
+ P_X_given_Y[i][j][1] = (N_XY[i][j][1] + 1) / (N_Y[i] + 2);
+ }
+ }
+
+
+ return P_X_given_Y;
+ }
+
+ static int predict(int[] inputImage, double threshold, double[] P_Y, double[][][] P_X_given_Y) {
+ int[] binaryInputImage = new int[inputImage.length];
+ for (int i = 0; i < binaryInputImage.length; i++) {
+ if (inputImage[i] > threshold) {
+ binaryInputImage[i] = 1;
+ } else {
+ binaryInputImage[i] = 0;
+ }
+ }
+
+
+ double[] results = new double[numberOfClasses];
+ for (int i = 0; i < results.length; i++) {
+ results[i] = Math.log10(P_Y[i] * 100);
+ for (int j = 0; j < numberOfFeatures; j++) {
+ results[i] = results[i] + Math.log10(P_X_given_Y[i][j][binaryInputImage[j]] * 100);
+ }
+ }
+
+ //Find the max index
+ int max_index = 0;
+ double max_num = 0;
+ for (int i = 0; i < results.length; i++) {
+ if (results[i] > max_num) {
+ max_num = results[i];
+ max_index = i;
+ }
+ }
+ return max_index;
+ }
}
View
32 src/com/weaverworx/usc/csci561/assignment3/util/ClassifierTypes.java
@@ -0,0 +1,32 @@
+/**
+ * File: CLASSIFIERS.java
+ * Author: Jack Weaver <jhweaver@usc.edu>
+ * Course: CSCI 561, Spring 2012
+ * Assignment: Assignment 3 - Supervised Learning Systems
+ * Target: aludra.usc.edu running Java 1.6.0_23
+ */
+package com.weaverworx.usc.csci561.assignment3.util;
+
+/**
+ * Standard list of input the user provides to internal mapping of the given
+ * classifier.
+ *
+ * @author jw
+ *
+ */
+public enum ClassifierTypes {
+ K_NEAREST_NEIGHBOR("knn"),
+ NAIVE_BAYES("nb");
+
+ private final String text;
+
+ private ClassifierTypes(String name) {
+ this.text = name;
+ }
+
+ @Override
+ public String toString() {
+ return this.text;
+ }
+
+}
View
97 src/com/weaverworx/usc/csci561/assignment3/util/KNNUtil.java
@@ -1,97 +0,0 @@
-/**
- * File: KNNUtil.java
- * Author: Jack Weaver <jhweaver@usc.edu>
- * Course: CSCI 561, Spring 2012
- * Assignment: Assignment 3 - Supervised Learning Systems
- * Target: aludra.usc.edu running Java 1.6.0_23
- */
-package com.weaverworx.usc.csci561.assignment3.util;
-
-import java.text.DecimalFormat;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-
-import com.weaverworx.usc.csci561.assignment3.knn.KNNRecord;
-
-/**
- * @author jw
- *
- */
-public class KNNUtil {
- private static final String TEN_SPACES = " ";
-
- private KNNUtil() { }
-
- public static int parseKArgument(String arguments[]) {
- int k = 0;
- String kString = arguments[0];
- k = Integer.parseInt(kString);
- return k;
- }
-
- public static void outputResultsToStdOut(int k, int[] correct, int[] incorrect,
- int numberOfClasses) {
- StringBuilder sb = new StringBuilder();
- sb.append("K = " + k + "\n");
- //Build the table column headers
- sb.append("Class").append(TEN_SPACES).append("Correct").append(TEN_SPACES)
- .append("Incorrect").append(TEN_SPACES).append("Accuracy");
- sb.append("\n");
- //Fill in the table
- DecimalFormat accuracyFormatter = new DecimalFormat("#.#####");
- for (int i = 0; i < numberOfClasses; i++) {
- double accuracy = (correct[i] * 1.0) / (correct[i] + incorrect[i]);
- sb.append(i).append(" ").append(TEN_SPACES);//Class
- sb.append(correct[i]).append(" ").append(TEN_SPACES);//# Correct
- sb.append(incorrect[i]).append(" ").append(TEN_SPACES);//# Incorrect
- sb.append(accuracyFormatter.format(accuracy)).append("\n");
- }
- System.out.println(sb.toString());
- }
-
- public static double getEuclideanDistance(int[] v1, int[] v2) {
- double distance = 0; // set distance to 0
- double runningTotal = 0;
- double delta = 0;
-
- for (int i : v1) {
- delta = (v1[i] - v2[i]);
- delta = Math.pow(delta, 2d);
- runningTotal += delta;
- }
- distance = Math.sqrt(runningTotal);
- return distance;
- }
-
- /**
- *
- * @param K
- * @param knn
- * @return
- */
- public static int predict(int K, KNNRecord[] knn) {
- //Prep the data
- List<KNNRecord> recordList = Arrays.asList(knn);
- Collections.sort(recordList);
-
- //Do the voting
- int[] voteArray = new int[10]; //index position represents the numerical class
- for(int i = 0; i < K; i++) {
- KNNRecord voter = recordList.get(i);
- int number_class = voter.getExampleClass();
- voteArray[number_class]++;
- }
-
- //Find out who got the most votes
- int winner = -1; //winner is the index with the most votes in the array
- int winnerVotes = 0; //the actual *count* of the most votes
- for (int i = 0; i < voteArray.length; i++) {
- if (voteArray[i] > winnerVotes) {
- winnerVotes = voteArray[i];
- winner = i;
- }
- }
- return winner;
- }
-}
View
276 src/com/weaverworx/usc/csci561/assignment3/util/LearningUtil.java
@@ -0,0 +1,276 @@
+/**
+ * File: KNNUtil.java
+ * Author: Jack Weaver <jhweaver@usc.edu>
+ * Course: CSCI 561, Spring 2012
+ * Assignment: Assignment 3 - Supervised Learning Systems
+ * Target: aludra.usc.edu running Java 1.6.0_23
+ */
+package com.weaverworx.usc.csci561.assignment3.util;
+
+import java.text.DecimalFormat;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Scanner;
+
+import com.weaverworx.usc.csci561.assignment3.LaunchClassifier;
+import com.weaverworx.usc.csci561.assignment3.knn.KNNRecord;
+
+/**
+ * Provides for common utility across both K-Nearest Neighbor (KNN) and Naive
+ * Bayes classifier algorithms.
+ *
+ * @author jw
+ *
+ */
+public class LearningUtil {
+ /*
+ * Hard-coded value of the number of classes (record types) for this
+ * classifier.
+ */
+ public static final int NUMBER_OF_CLASSES = 10;
+ /*
+ * Hard-coded value of the number of 'features' each record/class has in
+ * this classifier. This is how many data points there are per record/
+ * class.
+ */
+ public static final int NUMBER_OF_FEATURES = 784;
+ private static final String TEN_SPACES = " ";
+
+
+ private LearningUtil() { }
+
+ public static int parseTArgument(String arguments[]) {
+ int threshold = 0;
+ String tString = arguments[1];
+ Scanner scanner = new Scanner(tString).useDelimiter("-t=");
+ threshold = scanner.nextInt();
+ return threshold;
+ }
+
+ public static int parseKArgument(String arguments[]) {
+ int k = 0;
+ String kString = arguments[1];
+ Scanner scanner = new Scanner(kString).useDelimiter("-k=");
+ k = scanner.nextInt();
+ return k;
+ }
+
+ public static void outputBayesResultsToStdOut(int t, int[] correct,
+ int[] incorrect) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("T = " + t + "\n");
+ outputResults(sb, correct, incorrect);
+ }
+
+ public static void outputKNNResultsToStdOut(int k, int[] correct,
+ int[] incorrect) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("K = " + k + "\n");
+ outputResults(sb, correct, incorrect);
+ }
+
+ private static void outputResults(StringBuilder sb, int[] correct,
+ int[] incorrect) {
+ //Build the table column headers
+ sb.append("Class").append(TEN_SPACES).append("Correct").append(TEN_SPACES)
+ .append("Incorrect").append(TEN_SPACES).append("Accuracy");
+ sb.append("\n");
+ //Fill in the table
+ DecimalFormat accuracyFormatter = new DecimalFormat("#.#####");
+ for (int i = 0; i < NUMBER_OF_CLASSES; i++) {
+ double accuracy = (correct[i] * 1.0) / (correct[i] + incorrect[i]);
+ sb.append(i).append(" ").append(TEN_SPACES);//Class
+ sb.append(correct[i]).append(" ").append(TEN_SPACES);//# Correct
+ sb.append(incorrect[i]).append(" ").append(TEN_SPACES);//# Incorrect
+ sb.append(accuracyFormatter.format(accuracy)).append("\n");
+ }
+ System.out.println(sb.toString());
+ }
+
+ public static double getEuclideanDistance(int[] v1, int[] v2) {
+ double distance = 0; // set distance to 0
+ double runningTotal = 0;
+ double delta = 0;
+
+ for (int i = 0; i < v1.length; i++) {
+ delta = (v1[i] - v2[i]);
+ delta = Math.pow(delta, 2d);
+ runningTotal += delta;
+ }
+ distance = Math.sqrt(runningTotal);
+ return distance;
+ }
+
+ /**
+ *
+ * @param K
+ * @param knn
+ * @return
+ */
+ public static int predictKNN(int K, KNNRecord[] knn) {
+ //Prep the data
+ List<KNNRecord> recordList = Arrays.asList(knn);
+ Collections.sort(recordList);
+
+ //Do the voting
+ int[] voteArray = new int[10]; //index position represents the numerical class
+ for(int i = 0; i < K; i++) {
+ KNNRecord voter = recordList.get(i);
+ int number_class = voter.getExampleClass();
+ voteArray[number_class]++;
+ }
+
+ //Find out who got the most votes
+ int winner = -1; //winner is the index with the most votes in the array
+ int winnerVotes = 0; //the actual *count* of the most votes
+ for (int i = 0; i < voteArray.length; i++) {
+ if (voteArray[i] > winnerVotes) {
+ winnerVotes = voteArray[i];
+ winner = i;
+ }
+ }
+ return winner;
+ }
+
+ /**
+ *
+ * @return the usage (arguments, commands, etc) used on command line to
+ * operate the program.
+ */
+ public static String getUsage() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("Usage as follows:\n");
+ sb.append(LaunchClassifier.class.getName());
+ sb.append(" "); sb.append("[ classifier_type ]");
+ sb.append(" [ classifier_argument ]" + "\n");
+ sb.append("Where classifier_type is one of the following:\n\n");
+ for (ClassifierTypes classifierType : ClassifierTypes.values()) {
+ sb.append(classifierType.toString()); sb.append(" for ");
+ sb.append(classifierType.name());
+ sb.append("\n");
+ }
+ sb.append("\nClassifier Arguments:\n");
+ sb.append(ClassifierTypes.K_NEAREST_NEIGHBOR.name() + " -k[=value]\n");
+ sb.append(ClassifierTypes.NAIVE_BAYES.name() + " -t[=value]\n");
+ sb.append("\nNote: Each value is an integer, whether specifying k for" +
+ " neighbor count, or t for a Threshold.");
+ sb.append("\n\nExamples:\n");
+ sb.append(LaunchClassifier.class.getName() + " " +
+ ClassifierTypes.K_NEAREST_NEIGHBOR.toString() + " -k=10\n");
+ sb.append(LaunchClassifier.class.getName() + " " +
+ ClassifierTypes.NAIVE_BAYES.toString() + " -t=130\n");
+
+ return sb.toString();
+ }
+
+ /**
+ *
+ * @param threshold
+ * @param trainingData
+ * @param N_Y
+ * @return
+ */
+ public static double[][][] getBayesConditionalProbabilities(int threshold,
+ int[][] trainingData, double[] N_Y) {
+ //Convert training data to 'binary training data'
+ int[][] binaryTrainingData = new int[trainingData.length][trainingData[trainingData.length - 1].length];
+ for (int i = 0; i < trainingData.length; i++) {
+ for (int j = 0; j < trainingData[trainingData.length - 1].length; j++) {
+ if (trainingData[i][j] <= threshold) {
+ binaryTrainingData[i][j] = 0;
+ }
+ else {
+ binaryTrainingData[i][j] = 1;
+ }
+ }
+ }
+
+ double[][][] P_X_given_Y = new double[NUMBER_OF_CLASSES][NUMBER_OF_FEATURES][2];
+ double[][][] N_XY = new double[NUMBER_OF_CLASSES][NUMBER_OF_FEATURES][2];
+ for (int i = 0; i < binaryTrainingData.length; i++) {
+ //For each image in training data...
+ for (int j = 0; j < NUMBER_OF_FEATURES; j++) {
+ //Iterate over every image in training data, incrementing it's
+ //corresponding slot in the N_XY array
+ if (binaryTrainingData[i][j] == 0) {
+ N_XY[trainingData[i][trainingData[i].length -1]][j][0]++;
+ } else {
+ N_XY[trainingData[i][trainingData[i].length -1]][j][1]++;
+ }
+ }
+ }
+ for(int i = 0; i < NUMBER_OF_CLASSES; i++) {
+ for(int j = 0; j < NUMBER_OF_FEATURES; j++) {
+ P_X_given_Y[i][j][0] = (N_XY[i][j][0] + 1) / (N_Y[i] + 2);
+ P_X_given_Y[i][j][1] = (N_XY[i][j][1] + 1) / (N_Y[i] + 2);
+ }
+ }
+
+ return P_X_given_Y;
+ }
+
+ /**
+ *
+ * @param inputImage
+ * @param threshold
+ * @param P_Y
+ * @param P_X_given_Y
+ * @return
+ */
+ public static int predictBayes(int[] inputImage, double threshold,
+ double[] P_Y, double[][][] P_X_given_Y) {
+ int[] binaryInputImage = new int[inputImage.length];
+ for (int i = 0; i < binaryInputImage.length; i++) {
+ if (inputImage[i] > threshold) {
+ binaryInputImage[i] = 1;
+ } else {
+ binaryInputImage[i] = 0;
+ }
+ }
+
+
+ double[] results = new double[NUMBER_OF_CLASSES];
+ for (int i = 0; i < results.length; i++) {
+ results[i] = Math.log10(P_Y[i] * 100);
+ for (int j = 0; j < NUMBER_OF_FEATURES; j++) {
+ results[i] = results[i] + Math.log10(P_X_given_Y[i][j][binaryInputImage[j]] * 100);
+ }
+ }
+
+ //Find the max index
+ int max_index = 0;
+ double max_num = 0;
+ for (int i = 0; i < results.length; i++) {
+ if (results[i] > max_num) {
+ max_num = results[i];
+ max_index = i;
+ }
+ }
+ return max_index;
+ }
+
+
+ /**
+ * Given a list of classifiers, parses the input on args to determine which
+ * classifier is being tested/used. See
+ * @param args
+ * @param classifiers
+ */
+ public static ClassifierTypes parseClassifierSystem(String[] args) {
+ if (args.length < 2) {
+ System.out.println(LearningUtil.getUsage());
+ System.exit(0);
+ }
+ if (args[0].compareTo(ClassifierTypes.K_NEAREST_NEIGHBOR.toString()) == 0) {
+ return ClassifierTypes.K_NEAREST_NEIGHBOR;
+
+ } else if(args[0].compareTo(ClassifierTypes.NAIVE_BAYES.toString()) == 0) {
+ return ClassifierTypes.NAIVE_BAYES;
+ } else {
+ System.out.println(LearningUtil.getUsage());
+ System.exit(0);
+ }
+ return null;
+ }
+}
View
199 test_data/200_lines/test0.txt
199 additions, 0 deletions not shown
View
199 test_data/200_lines/test1.txt
199 additions, 0 deletions not shown
View
199 test_data/200_lines/test2.txt
199 additions, 0 deletions not shown
View
199 test_data/200_lines/test3.txt
199 additions, 0 deletions not shown
View
199 test_data/200_lines/test4.txt
199 additions, 0 deletions not shown
View
199 test_data/200_lines/test5.txt
199 additions, 0 deletions not shown
View
199 test_data/200_lines/test6.txt
199 additions, 0 deletions not shown
View
199 test_data/200_lines/test7.txt
199 additions, 0 deletions not shown
View
199 test_data/200_lines/test8.txt
199 additions, 0 deletions not shown
View
199 test_data/200_lines/test9.txt
199 additions, 0 deletions not shown
View
228 test_data/test0.txt
39 additions, 189 deletions not shown
View
236 test_data/test1.txt
47 additions, 189 deletions not shown
View
231 test_data/test2.txt
42 additions, 189 deletions not shown
View
230 test_data/test3.txt
41 additions, 189 deletions not shown
View
229 test_data/test4.txt
40 additions, 189 deletions not shown
View
224 test_data/test5.txt
35 additions, 189 deletions not shown
View
227 test_data/test6.txt
38 additions, 189 deletions not shown
View
231 test_data/test7.txt
42 additions, 189 deletions not shown
View
228 test_data/test8.txt
39 additions, 189 deletions not shown
View
230 test_data/test9.txt
41 additions, 189 deletions not shown
View
199 training_data/200_lines/train0.txt
199 additions, 0 deletions not shown
View
199 training_data/200_lines/train1.txt
199 additions, 0 deletions not shown
View
199 training_data/200_lines/train2.txt
199 additions, 0 deletions not shown
View
199 training_data/200_lines/train3.txt
199 additions, 0 deletions not shown
View
199 training_data/200_lines/train4.txt
199 additions, 0 deletions not shown
View
199 training_data/200_lines/train5.txt
199 additions, 0 deletions not shown
View
199 training_data/200_lines/train6.txt
199 additions, 0 deletions not shown
View
199 training_data/200_lines/train7.txt
199 additions, 0 deletions not shown
View
199 training_data/200_lines/train8.txt
199 additions, 0 deletions not shown
View
199 training_data/200_lines/train9.txt
199 additions, 0 deletions not shown
View
476 training_data/train0.txt
287 additions, 189 deletions not shown
View
517 training_data/train1.txt
328 additions, 189 deletions not shown
View
477 training_data/train2.txt
288 additions, 189 deletions not shown
View
486 training_data/train3.txt
297 additions, 189 deletions not shown
View
472 training_data/train4.txt
283 additions, 189 deletions not shown
View
451 training_data/train5.txt
262 additions, 189 deletions not shown
View
475 training_data/train6.txt
286 additions, 189 deletions not shown
View
493 training_data/train7.txt
304 additions, 189 deletions not shown
View
472 training_data/train8.txt
283 additions, 189 deletions not shown
View
477 training_data/train9.txt
288 additions, 189 deletions not shown

0 comments on commit a943ee7

Please sign in to comment.
Something went wrong with that request. Please try again.