From ef7cce5a017005d4829ddb9e9802a134e52519df Mon Sep 17 00:00:00 2001 From: dmcnelis Date: Sun, 4 Mar 2012 16:50:03 +0000 Subject: [PATCH] Major changes to project structure and usage. --- README.mkd | 17 +- .../me/mcnelis/rudder/data/FeatureType.java | 2 +- .../java/me/mcnelis/rudder/data/Label.java | 5 +- .../me/mcnelis/rudder/data/MockRecord.java | 22 - .../mcnelis/rudder/data/NumericFeature.java | 2 - .../java/me/mcnelis/rudder/data/Record.java | 395 ------------------ .../mcnelis/rudder/data/RecordInterface.java | 133 ------ .../rudder/data/collections/RecordList.java | 126 ------ .../NaiveBayesClassification.java | 8 +- .../regression/MultiLinearRegression.java | 41 +- .../ml/unsupervised/clustering/Cluster.java | 39 +- .../ml/unsupervised/clustering/DBScan.java | 41 +- .../unsupervised/clustering/DensityBased.java | 32 +- .../ml/unsupervised/clustering/KMeans.java | 77 ++-- .../me/mcnelis/rudder/data/MockRecord.java | 72 ++++ .../mcnelis/rudder/data/MockRecordTest.java | 72 ---- .../me/mcnelis/rudder/data/RecordTest.java | 151 ------- .../NaiveBayesClassificationTest.java | 73 ++-- .../regression/MultiLinearRegressionTest.java | 42 +- .../unsupervised/clustering/DBScanTest.java | 15 +- .../unsupervised/clustering/KMeansTest.java | 20 +- 21 files changed, 322 insertions(+), 1063 deletions(-) delete mode 100644 src/main/java/me/mcnelis/rudder/data/MockRecord.java delete mode 100644 src/main/java/me/mcnelis/rudder/data/Record.java delete mode 100644 src/main/java/me/mcnelis/rudder/data/RecordInterface.java delete mode 100644 src/main/java/me/mcnelis/rudder/data/collections/RecordList.java create mode 100644 src/test/java/me/mcnelis/rudder/data/MockRecord.java delete mode 100644 src/test/java/me/mcnelis/rudder/data/MockRecordTest.java delete mode 100644 src/test/java/me/mcnelis/rudder/data/RecordTest.java diff --git a/README.mkd b/README.mkd index 7e64f3f..bde5d23 100644 --- a/README.mkd +++ b/README.mkd @@ -15,25 +15,34 @@ the issue. Example of running clustering - class MyRecordClass extends Record { + class MyRecordClass { @Feature double myFeature1 @Feature double myFeature2 @Feature double myFeature3 - @Label + @Label(setlabel="setLabel") String myLabel; Object nonFeatureItem; String nonFeatureOrLabelString; + + void setLabel(String label); + + //Only use non-primitives in the set method signature + void setLabel(Double doubleLabel); + //Other stuff I want to do + + //This API uses equals often, so your object should override it + boolean equals(Object o); } class MyClusterRunner { public static void main(String[] args) { double epsilon = 4d; //Maximum distance between records in a cluster int minClusterSize = 3; //Minimum number of records in a cluster - RecordList myRecords = new RecordList(); + IRudderList myRecords = new RudderList(); //populate your object DBScan dbscan = new DBScan(epsilon, minClusterSize); db.setSourceData(list); @@ -41,7 +50,7 @@ Example of running clustering //do stuff with your clusters i.e. print record label and cluster for(int i=0; i flags; - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#equals(java.lang.Object) - */ - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - RecordInterface other = (RecordInterface) obj; - double[] orig = this.getFeatureAndLabelDoubleArray(); - double[] test = other.getFeatureAndLabelDoubleArray(); - return MathUtils.equals(orig, test); - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#getAllFeatures(java.lang.Object[]) - */ - public Object[] getAllFeatures() { - ArrayList arr = new ArrayList(); - - Field[] fields = this.getClass().getDeclaredFields(); - for(Field f : fields) { - Annotation annotation = f.getAnnotation(NumericFeature.class); - if(annotation == null) { - annotation = f.getAnnotation(TextFeature.class); - } - if (annotation != null) { - try { - f.setAccessible(true); - arr.add(f.get(this)); - } catch (IllegalArgumentException e) { - - e.printStackTrace(); - } catch (IllegalAccessException e) { - - e.printStackTrace(); - } - } - } - - return arr.toArray(); - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#getDoubleLabel() - */ - public synchronized double getDoubleLabel() { - - Field[] fields = this.getClass().getDeclaredFields(); - for(Field f : fields) { - Annotation annotation = f.getAnnotation(Label.class); - - if (annotation != null) { - try { - f.setAccessible(true); - return (f.getDouble(this)); - } catch (IllegalArgumentException e) { - //Not a double, so just move forward - } catch (IllegalAccessException e) { - - e.printStackTrace(); - } - } - } - return Double.NaN; - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#getFeatureAndLabelArray() - */ - public synchronized double[] getFeatureAndLabelDoubleArray() { - ResizableDoubleArray arr = new ResizableDoubleArray(); - - Field[] fields = this.getClass().getDeclaredFields(); - for(Field f : fields) { - Annotation annotation = f.getAnnotation(NumericFeature.class); - if(annotation == null) - annotation = f.getAnnotation(Label.class); - if (annotation != null) { - try { - f.setAccessible(true); - arr.addElement(f.getDouble(this)); - } catch (IllegalArgumentException e) { - //Field or label isn't a double, so ignore it - } catch (IllegalAccessException e) { - - e.printStackTrace(); - } - } - } - - return arr.getElements(); - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#getFeatureArray() - */ - public synchronized double[] getFeatureDoubleArray() { - ResizableDoubleArray arr = new ResizableDoubleArray(); - - Field[] fields = this.getClass().getDeclaredFields(); - for(Field f : fields) { - Annotation annotation = f.getAnnotation(NumericFeature.class); - if (annotation != null) { - try { - f.setAccessible(true); - double dval = f.getDouble(this); - if(!Double.isNaN(dval)) { - - arr.addElement(f.getDouble(this)); - } - } catch (IllegalArgumentException e) { - - e.printStackTrace(); - } catch (IllegalAccessException e) { - - e.printStackTrace(); - } - } - } - - return arr.getElements(); - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#getLabel() - */ - public synchronized String getLabel() { - String stringLabel = ""; - - Field[] fields = this.getClass().getDeclaredFields(); - for(Field f : fields) { - Annotation annotation = f.getAnnotation(Label.class); - - if (annotation != null) { - try { - f.setAccessible(true); - stringLabel += f.get(this); - } catch (IllegalArgumentException e) { - //Not a double, so just move forward - } catch (IllegalAccessException e) { - - e.printStackTrace(); - } - } - } - return stringLabel.toLowerCase(); - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#hashCode() - */ - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((flags == null) ? 0 : flags.hashCode()); - return result; - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#isAssigned() - */ - public synchronized boolean isAssigned() { - if(this.flags.contains(RecordFlags.ASSIGNED)) - return true; - else - return false; - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#isNoise() - */ - public synchronized boolean isNoise() { - if(this.flags == null) - return false; - else { - if(this.flags.contains(RecordFlags.NOISE)) - return true; - else - return false; - } - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#isVisited() - */ - public synchronized boolean isVisited() { - if(this.flags == null) - return false; - else { - if(this.flags.contains(RecordFlags.VISITED)) - return true; - else - return false; - } - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#setAssigned(boolean) - */ - public synchronized void setAssigned(boolean visited) { - if(this.flags == null) - this.flags = new HashSet(); - if(visited) - this.flags.add(RecordFlags.ASSIGNED); - else if (!visited) - this.flags.remove(RecordFlags.ASSIGNED); - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#setFeature(java.lang.String, double) - */ - public synchronized boolean setFeature(String featureName, double d) - throws FeatureNotFoundException { - - Field f = null; - try { - f = this.getClass().getDeclaredField(featureName); - } catch (SecurityException e) { - throw new FeatureNotFoundException("Security exception"); - } catch (NoSuchFieldException e) { - throw new FeatureNotFoundException(featureName + " is not a class member"); - } - Annotation annotation = f.getAnnotation(NumericFeature.class); - if(annotation == null) - throw new FeatureNotFoundException("There is no annotation here"); - - if(annotation instanceof NumericFeature) { - try { - f.setAccessible(true); - f.setDouble(this, d); - } catch (IllegalArgumentException e) { - throw new FeatureNotFoundException("Illegal argument"); - } catch (IllegalAccessException e) { - // TODO Auto-generated catch block - throw new FeatureNotFoundException("Illegal access"); - } - return true; - } - - - throw new FeatureNotFoundException("Field is not a feature"); - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#setFeature() - */ - public boolean setFeature(String featureName, String featureStr) - throws FeatureNotFoundException { - Field f = null; - try { - - f = this.getClass().getDeclaredField(featureName); - } catch (SecurityException e) { - throw new FeatureNotFoundException("Security exception"); - } catch (NoSuchFieldException e) { - throw new FeatureNotFoundException(featureName + " is not a class member"); - } - Annotation annotation = f.getAnnotation(TextFeature.class); - if(annotation == null) - throw new FeatureNotFoundException("There is no annotation here"); - - if(annotation instanceof TextFeature) { - try { - f.setAccessible(true); - f.set(this, featureStr.toLowerCase()); - } catch (IllegalArgumentException e) { - e.printStackTrace(); - throw new FeatureNotFoundException("Illegal argument"); - } catch (IllegalAccessException e) { - // TODO Auto-generated catch block - throw new FeatureNotFoundException("Illegal access"); - } - return true; - } - - - throw new FeatureNotFoundException("Field is not a feature"); - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#setLabel(java.lang.String, java.lang.Object) - */ - public synchronized boolean setLabel(String labelName, Object o) throws Exception { - - Field f = null; - try { - f = this.getClass().getDeclaredField(labelName); - } catch (SecurityException e) { - throw e; - } catch (NoSuchFieldException e) { - throw e; - } - Annotation annotation = f.getAnnotation(Label.class); - if(annotation == null) - throw new FeatureNotFoundException("There is no annotation here"); - - if(annotation instanceof Label) { - - f.setAccessible(true); - f.set(this, o); - - - return true; - } - return false; - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#setNoise(boolean) - */ - public synchronized void setNoise(boolean noise) { - if(this.flags == null) - this.flags = new HashSet(); - if(noise) - this.flags.add(RecordFlags.NOISE); - else if (!noise) - this.flags.remove(RecordFlags.NOISE); - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#setVisited(boolean) - */ - public synchronized void setVisited(boolean visited) { - if(this.flags == null) - this.flags = new HashSet(); - if(visited) - this.flags.add(RecordFlags.VISITED); - else if (!visited) - this.flags.remove(RecordFlags.VISITED); - } - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.RecordInterface#toString() - */ - @Override - public String toString() { - - Field[] fields = this.getClass().getDeclaredFields(); - StringBuffer sb = new StringBuffer(); - sb.append(this.getClass().getName()); - sb.append(":\n"); - for(Field f : fields) { - Annotation[] as = f.getAnnotations(); - for (Annotation a : as) { - sb.append(a.toString()); - sb.append(" "); - } - sb.append(f.getName()); - sb.append(": "); - try { - f.setAccessible(true); - sb.append(f.get(this)); - } catch (IllegalArgumentException e) { - e.printStackTrace(); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } - sb.append("\n"); - } - sb.append("]"); - return sb.toString(); - } -} diff --git a/src/main/java/me/mcnelis/rudder/data/RecordInterface.java b/src/main/java/me/mcnelis/rudder/data/RecordInterface.java deleted file mode 100644 index c0147a4..0000000 --- a/src/main/java/me/mcnelis/rudder/data/RecordInterface.java +++ /dev/null @@ -1,133 +0,0 @@ -package me.mcnelis.rudder.data; - -import me.mcnelis.rudder.exceptions.FeatureNotFoundException; - -public interface RecordInterface { - - /** - * Set a feature variable for a class, currently all - * features must be double - * @param featureName - * @param d - * @return - * @throws FeatureNotFoundException - */ - public abstract boolean setFeature(String featureName, double d) - throws FeatureNotFoundException; - - /** - * Set a feature variable for a class, currently all - * features must be double - * @param featureName - * @param d - * @return - * @throws FeatureNotFoundException - */ - public abstract boolean setFeature(String featureName, String d) - throws FeatureNotFoundException; - - /** - * Set the label(s) for the record - * @param labelName - * @param Object to set label as - * @return success of lsetting the label - * @throws Exception, usually if you pass an object that can't be cast - */ - public abstract boolean setLabel(String labelName, Object o) - throws Exception; - - /** - * Make it easy to see your records when you pass it out to a string - * @return - */ - public abstract String toString(); - - /** - * - * @return double array of features for processing - */ - public abstract double[] getFeatureDoubleArray(); - - /** - * Return all of the features in their original object form - */ - public abstract Object[] getAllFeatures(); - - /** - * In unsupervised learning the order of your features is irrelevant - * so it doesn't matter what you're going through, as long as the label - * is a number. - * - * If your label is a string, you should handle labeling your data in a little - * different manner (i.e. give your label a double value type and have your - * label be another class member. When you set that class member, it updates - * the label). This is irrelevant if you're not planning on doing any - * unsupervised learning on this dataset - * - * @return array of all your feature and labels for unsupervised learning - */ - public abstract double[] getFeatureAndLabelDoubleArray(); - - /** - * Get a double representation of the label. - * @return - */ - public abstract double getDoubleLabel(); - - /** - * Some algorithms need to be able to test if - * a record is to be considered noise - * @return - */ - public abstract boolean isNoise(); - - /** - * Set noise flag - * @param noise - */ - public abstract void setNoise(boolean noise); - - /** - * Some algorithms need to know if the record - * has been visited / processed yet - * @return - */ - public abstract boolean isVisited(); - - /** - * Set the visited flag - * @param visited - */ - public abstract void setVisited(boolean visited); - - /** - * Set the 'assigned' flag -- mainly for - * clustering and classification purposes - * @param visited - */ - public abstract void setAssigned(boolean visited); - - /** - * Some algorithms need to know if the record has - * already been assigned - * @return - */ - public abstract boolean isAssigned(); - - /** - * Should be a part of your class - * @return - */ - public abstract int hashCode(); - - /** - * Test if one record is a duplicate of another - * record - * @param obj - * @return - */ - public abstract boolean equals(Object obj); - - public abstract String getLabel(); - -} \ No newline at end of file diff --git a/src/main/java/me/mcnelis/rudder/data/collections/RecordList.java b/src/main/java/me/mcnelis/rudder/data/collections/RecordList.java deleted file mode 100644 index e43562c..0000000 --- a/src/main/java/me/mcnelis/rudder/data/collections/RecordList.java +++ /dev/null @@ -1,126 +0,0 @@ -package me.mcnelis.rudder.data.collections; - -import java.util.ArrayList; -import java.util.List; - -import me.mcnelis.rudder.data.RecordInterface; - -/** - * - * @author dmcnelis - * - * @param - */ -@Deprecated -public class RecordList extends ArrayList implements IRudderList -{ - - protected int clusterIdx=-1; - /** - * - */ - private static final long serialVersionUID = 1L; - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.collections.RudderListInterface#getUnsupervisedDoubleDoubleArray() - */ - public double[][] getUnsupervisedDoubleDoubleArray() { - double[][] d = new double[this.size()][]; - int cnt=0; - synchronized (this) { - for (RecordInterface r : this) { - d[cnt] = r.getFeatureAndLabelDoubleArray(); - cnt++; - } - } - return d; - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.collections.RudderListInterface#getSupervisedFeatures() - */ - public double[][] getSupervisedFeatures() { - - double[][] d = new double[this.size()][]; - int cnt=0; - synchronized (this) { - for (RecordInterface r : this) { - d[cnt] = r.getFeatureDoubleArray(); - cnt++; - } - } - return d; - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.collections.RudderListInterface#getSupervisedLabels() - */ - public double[] getSupervisedLabels() { - double[] d = new double[this.size()]; - int cnt=0; - synchronized (this) { - for (RecordInterface r : this) { - d[cnt] = r.getDoubleLabel(); - cnt++; - } - } - - return d; - } - - public int getClusterId() { - return clusterIdx; - } - public void setClusterId(int id) { - this.clusterIdx = id; - } - - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.collections.RudderListInterface#getUnsupervisedSampleDoubleArray() - */ - public double[] getUnsupervisedSampleDoubleArray() { - return this.get(0).getFeatureAndLabelDoubleArray(); - } - /* (non-Javadoc) - * @see me.mcnelis.rudder.data.collections.RudderListInterface#getSupervisedSampleDoubleArray() - */ - public double[] getSupervisedSampleDoubleArray() { - return this.get(0).getFeatureDoubleArray(); - } - - public String getStringLabel(Object r) - { - try - { - RecordInterface rI = (RecordInterface) r; - return rI.getLabel(); - } - catch (ClassCastException cce) - { - cce.printStackTrace(); - } - - return ""; - } - - public List getRecordFeatures(Object r) - { - try - { - RecordInterface rI = (RecordInterface) r; - List list = new ArrayList(); - for(Object o : rI.getAllFeatures()) - { - list.add(o); - } - - return list; - } - catch (ClassCastException cce) - { - cce.printStackTrace(); - } - - return null; - } -} diff --git a/src/main/java/me/mcnelis/rudder/ml/supervised/classification/NaiveBayesClassification.java b/src/main/java/me/mcnelis/rudder/ml/supervised/classification/NaiveBayesClassification.java index b252f13..a3c0cf2 100644 --- a/src/main/java/me/mcnelis/rudder/ml/supervised/classification/NaiveBayesClassification.java +++ b/src/main/java/me/mcnelis/rudder/ml/supervised/classification/NaiveBayesClassification.java @@ -6,13 +6,13 @@ import java.util.Map; import java.util.Map.Entry; -import me.mcnelis.rudder.data.RecordInterface; +import org.apache.log4j.Logger; + import me.mcnelis.rudder.data.collections.IRudderList; -import me.mcnelis.rudder.data.collections.RecordList; public class NaiveBayesClassification { - + private static final Logger LOG = Logger.getLogger(NaiveBayesClassification.class); protected Map> classList = new HashMap>(); protected IRudderList records; @@ -97,7 +97,7 @@ public Map getClassScores(Object r) scores += score; idx++; } - + LOG.debug(label + ": " + scores); labelScores.put(label, scores); } diff --git a/src/main/java/me/mcnelis/rudder/ml/supervised/regression/MultiLinearRegression.java b/src/main/java/me/mcnelis/rudder/ml/supervised/regression/MultiLinearRegression.java index 413859b..c2b9e2d 100644 --- a/src/main/java/me/mcnelis/rudder/ml/supervised/regression/MultiLinearRegression.java +++ b/src/main/java/me/mcnelis/rudder/ml/supervised/regression/MultiLinearRegression.java @@ -1,13 +1,14 @@ package me.mcnelis.rudder.ml.supervised.regression; -import me.mcnelis.rudder.data.RecordInterface; -import me.mcnelis.rudder.data.collections.RecordList; +import me.mcnelis.rudder.data.collections.IRudderList; +import me.mcnelis.rudder.data.collections.RudderList; import org.apache.commons.math.exception.DimensionMismatchException; import org.apache.commons.math.stat.descriptive.SynchronizedMultivariateSummaryStatistics; import org.apache.commons.math.stat.regression.GLSMultipleLinearRegression; import org.apache.commons.math.stat.regression.MultipleLinearRegression; import org.apache.commons.math.stat.regression.OLSMultipleLinearRegression; +import org.apache.log4j.Logger; /** * Wrapper for @link @@ -17,10 +18,10 @@ * @author dmcnelis * */ -public class MultiLinearRegression +public class MultiLinearRegression { - - protected RecordList records; + private static final Logger LOG = Logger.getLogger(MultiLinearRegression.class); + protected IRudderList records; protected double[] betas; protected SynchronizedMultivariateSummaryStatistics stats; protected MultipleLinearRegression ols; @@ -32,13 +33,13 @@ public MultiLinearRegression() } @SuppressWarnings("unchecked") - public MultiLinearRegression(RecordList records) + public MultiLinearRegression(IRudderList records) { try { synchronized (this) { - this.records = (RecordList) records; + this.records = (IRudderList) records; this.stats = new SynchronizedMultivariateSummaryStatistics( this.records.getSupervisedSampleDoubleArray().length, false); @@ -51,7 +52,7 @@ public MultiLinearRegression(RecordList records) } @SuppressWarnings("unchecked") - public MultiLinearRegression(RecordList records, + public MultiLinearRegression(IRudderList records, RegressionTypes type) { @@ -60,7 +61,7 @@ public MultiLinearRegression(RecordList records, synchronized (this) { this.type = type; - this.records = (RecordList) records; + this.records = (IRudderList) records; this.stats = new SynchronizedMultivariateSummaryStatistics( this.records.getSupervisedSampleDoubleArray().length, false); @@ -78,25 +79,35 @@ public MultiLinearRegression(RecordList records, * @param record * @return success on adding record, negative if unable to add */ - public synchronized boolean addRecord(RecordInterface record) + @SuppressWarnings("unchecked") + public synchronized boolean addRecord(Object record) { if (this.records == null) { - this.records = new RecordList(); - this.stats = new SynchronizedMultivariateSummaryStatistics( - record.getFeatureDoubleArray().length, false); + this.records = new RudderList(); + } try { - this.stats.addValue(record.getFeatureDoubleArray()); - return this.records.add(record); + this.records.add((T) record); + + if(this.stats == null) + { + this.stats = new SynchronizedMultivariateSummaryStatistics( + this.records.getSupervisedSampleDoubleArray().length, false); + } + this.stats.addValue(this.records.getNumericFeatureArray(record)); + + return true; } catch (DimensionMismatchException e) { + LOG.error(e); return false; } catch (@SuppressWarnings("deprecation") org.apache.commons.math.DimensionMismatchException e) { + LOG.error(e); // Deprecated, will remove when we move to Commons Math 3.0 return false; } diff --git a/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/Cluster.java b/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/Cluster.java index 5a9ac91..5e96c62 100644 --- a/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/Cluster.java +++ b/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/Cluster.java @@ -3,8 +3,7 @@ import java.util.ArrayList; import java.util.Arrays; -import me.mcnelis.rudder.data.RecordInterface; -import me.mcnelis.rudder.data.collections.RecordList; +import me.mcnelis.rudder.data.collections.RudderList; import org.apache.commons.math.stat.descriptive.SynchronizedSummaryStatistics; @@ -15,7 +14,7 @@ * @author dmcnelis * */ -public class Cluster extends RecordList +public class Cluster extends RudderList { /** * @@ -27,25 +26,26 @@ public Cluster() { } - public void addRecord(RecordInterface r) - { + public void addRecord(T r) + { this.add(r); } - public void combineClusters(Cluster c) + @SuppressWarnings("unchecked") + public void combineClusters(Cluster cluster) { - for (Object o : c.getRecords()) + for (Object o : cluster.getRecords()) { - RecordInterface r = (RecordInterface) o; - if (!this.contains(r)) + + if (!this.contains(o)) { - this.add(r); + this.add((T) o); } } } - public RecordList getRecords() + public Cluster getRecords() { return this; } @@ -77,12 +77,12 @@ public void setCentroid(double[] d) protected synchronized double[] calculateCentroid() { - this.centroid = new double[this.get(0).getFeatureAndLabelDoubleArray().length]; + this.centroid = new double[this.getUnsupervisedDoubleArray(this.get(0)).length]; ArrayList stats = new ArrayList(); - for (RecordInterface elem : this) + for (Object elem : this) { - double[] arr = elem.getFeatureAndLabelDoubleArray(); + double[] arr = this.getUnsupervisedDoubleArray(elem); for (int i = 0; i < arr.length; i++) { @@ -141,11 +141,12 @@ public boolean equals(Object obj) { return false; } - if (!(obj instanceof Cluster)) + if (!(obj instanceof Cluster)) { return false; } - Cluster other = (Cluster) obj; + @SuppressWarnings("unchecked") + Cluster other = (Cluster) obj; if (!Arrays.equals(centroid, other.centroid)) { return false; @@ -153,4 +154,10 @@ public boolean equals(Object obj) return true; } + + public boolean isAssigned(Object o) + { + // TODO Auto-generated method stub + return false; + } } diff --git a/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/DBScan.java b/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/DBScan.java index 3d1001d..137b5df 100644 --- a/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/DBScan.java +++ b/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/DBScan.java @@ -3,40 +3,41 @@ import java.util.ArrayList; import java.util.List; -import me.mcnelis.rudder.data.RecordInterface; +import org.apache.log4j.Logger; public class DBScan extends DensityBased { - + private static final Logger LOG = Logger.getLogger(DBScan.class); + protected DBScan(double epsilon, int minPts) { super(epsilon, minPts); } @Override - protected List cluster() + protected List> cluster() { if (this.clusters == null) { - this.clusters = new ArrayList(); + this.clusters = new ArrayList>(); } - for (RecordInterface r : this.sourceData) + for (Object r : this.sourceData) { - if (!r.isVisited()) + if (!this.sourceData.isVisited(r)) { - r.setVisited(true); - Cluster c = this.rangeQuery(r); + this.sourceData.setVisited(r,true); + Cluster c = this.rangeQuery(r); if (c.getRecords().size() < this.minPts) { - r.setNoise(true); + this.sourceData.setNoise(r,true); } else { - Cluster addCluster = this.expandCluster(r, c); + Cluster addCluster = this.expandCluster(r, c); this.clusters.add(addCluster); } @@ -45,31 +46,31 @@ protected List cluster() return this.clusters; } - protected Cluster expandCluster(RecordInterface r, Cluster c) + protected Cluster expandCluster(Object r, Cluster c) { - - Cluster newCluster = new Cluster(); + LOG.trace("Expanding cluster"); + Cluster newCluster = new Cluster(); newCluster.addRecord(r); for (Object o : c.getRecords()) { - RecordInterface rPrime = (RecordInterface) o; - if (!rPrime.isVisited()) + + if (!c.isVisited(o)) { - rPrime.setVisited(true); - Cluster cluster = this.rangeQuery(rPrime); + c.setVisited(o,true); + Cluster cluster = this.rangeQuery(o); if (cluster.getRecords().size() >= this.minPts) { - cluster.combineClusters(this.expandCluster(rPrime, cluster)); + cluster.combineClusters(this.expandCluster(o, cluster)); } } - if (!rPrime.isAssigned()) + if (!c.isAssigned(o)) { - newCluster.addRecord(rPrime); + newCluster.addRecord(o); } } diff --git a/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/DensityBased.java b/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/DensityBased.java index f43a5b9..fe8a8d5 100644 --- a/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/DensityBased.java +++ b/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/DensityBased.java @@ -2,8 +2,7 @@ import java.util.List; -import me.mcnelis.rudder.data.RecordInterface; -import me.mcnelis.rudder.data.collections.RecordList; +import me.mcnelis.rudder.data.collections.IRudderList; import org.apache.commons.math.stat.descriptive.SynchronizedSummaryStatistics; import org.apache.commons.math.util.MathUtils; @@ -15,13 +14,12 @@ public abstract class DensityBased protected int minPts; protected int minClusters; protected SynchronizedSummaryStatistics distance = new SynchronizedSummaryStatistics(); - protected List clusters; - protected RecordList sourceData; + protected List> clusters; + protected IRudderList sourceData; - @SuppressWarnings("unchecked") - public void setSourceData(RecordList rl) + public void setSourceData(IRudderList rl) { - this.sourceData = (RecordList) rl; + this.sourceData = (IRudderList) rl; } /** @@ -36,7 +34,7 @@ protected DensityBased(double epsilon, int minPts) this.minPts = minPts; } - public List getClusters() + public List> getClusters() { if (this.clusters == null) { @@ -45,7 +43,7 @@ public List getClusters() return this.clusters; } - protected abstract List cluster(); + protected abstract List> cluster(); /** * Find the neighbors of r within range (this.epsilon) @@ -55,24 +53,24 @@ public List getClusters() * @param Record * @return Cluster of nearest neighbors to Record */ - protected Cluster rangeQuery(RecordInterface r) + protected Cluster rangeQuery(Object r) { - Cluster c = new Cluster(); + Cluster c = new Cluster(); c.addRecord(r); - - for (RecordInterface r2 : this.sourceData) + + for (Object r2 : this.sourceData) { if (!r.equals(r2)) { double mDistance = MathUtils.distance( - r2.getFeatureAndLabelDoubleArray(), - r.getFeatureAndLabelDoubleArray()); + this.sourceData.getUnsupervisedDoubleArray(r2), + this.sourceData.getUnsupervisedDoubleArray(r)); this.distance.addValue(mDistance); if (mDistance < this.epsilon) { - r2.setNoise(false); + this.sourceData.setNoise(r2, false); c.addRecord(r2); } @@ -98,7 +96,7 @@ public SynchronizedSummaryStatistics getDistanceStats() return this.distance; } - public RecordList getSourceData() + public IRudderList getSourceData() { return this.sourceData; } diff --git a/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/KMeans.java b/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/KMeans.java index 522a860..31e1e42 100644 --- a/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/KMeans.java +++ b/src/main/java/me/mcnelis/rudder/ml/unsupervised/clustering/KMeans.java @@ -1,31 +1,35 @@ package me.mcnelis.rudder.ml.unsupervised.clustering; +import java.lang.reflect.Array; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Random; -import me.mcnelis.rudder.data.RecordInterface; -import me.mcnelis.rudder.data.collections.RecordList; +import me.mcnelis.rudder.data.collections.IRudderList; +import me.mcnelis.rudder.data.collections.RudderList; import org.apache.commons.math.stat.descriptive.SynchronizedSummaryStatistics; import org.apache.commons.math.util.MathUtils; +import org.apache.log4j.Logger; -public class KMeans +public class KMeans { + private static final Logger LOG = Logger.getLogger(KMeans.class); + protected int k; - protected List clusters; + protected List> clusters; protected double[][] previousCenters; protected long maxIterations; protected long currentIteration = 0l; protected double minMovement = .0000001d; - protected RecordList sourceData; + protected IRudderList sourceData; - @SuppressWarnings("unchecked") - public KMeans(int k, RecordList data) + public KMeans(int k, IRudderList data) { this.k = k; - this.sourceData = (RecordList) data; + this.sourceData = data; this.init(); } @@ -45,7 +49,7 @@ private void init() this.maxIterations = 1000000000000l; this.minMovement = .0000001d; } - public List cluster() + public List> cluster() { // Get the random centers @@ -57,9 +61,13 @@ public List cluster() this.assignClusters(); // This could be parallelized - for (Cluster c : this.clusters) + for (Cluster c : this.clusters) { - c.calculateCentroid(); + LOG.debug("Cluster size: " + c.size()); + if(c.size() > 0) + { + c.calculateCentroid(); + } } } @@ -105,15 +113,23 @@ protected void assignPreviousClusters() protected void assignClusters() { - for (Object o : this.sourceData) + LOG.debug("Source data size: " + this.sourceData.size()); + List> tempClusters = new ArrayList>(); + + for(int i = 0; i < this.clusters.size(); i++) { - RecordInterface r = (RecordInterface) o; + tempClusters.add(new Cluster()); + } + for (T o : this.sourceData) + { + double min = Double.NaN; int clusterIdx = -1; for (int i = 0; i < this.clusters.size(); i++) { double distance = MathUtils.distance(this.clusters.get(i) - .getCentroid(), r.getFeatureAndLabelDoubleArray()); + .getCentroid(), this.sourceData.getUnsupervisedDoubleArray(o)); + LOG.trace("Record: " + Arrays.toString(this.sourceData.getUnsupervisedDoubleArray(o))); if (Double.isNaN(min) || distance < min) { clusterIdx = i; @@ -121,9 +137,9 @@ protected void assignClusters() } } - this.clusters.get(clusterIdx).addRecord(r); + tempClusters.get(clusterIdx).addRecord(o); } - + this.clusters = tempClusters; } /** @@ -134,19 +150,26 @@ protected void assignClusters() * * @return clusters with a center set */ - protected List chooseRandomCenters() + protected List> chooseRandomCenters() { - ArrayList randomCluster = new ArrayList(); + ArrayList> randomCluster = new ArrayList>(); Random generator = new Random(); for (int i = 0; i < this.k; i++) { - Cluster c = new Cluster(); - - c.setCentroid(((RecordInterface) this.sourceData.get(generator - .nextInt(this.sourceData.size()))) - .getFeatureAndLabelDoubleArray()); + Cluster c = new Cluster(); + + c.setCentroid( + this.sourceData.getUnsupervisedDoubleArray( + this.sourceData.get( + generator.nextInt( + this.sourceData.size() + ) + ) + ) + ); + randomCluster.add(c); } this.clusters = randomCluster; @@ -163,12 +186,12 @@ public void setK(int k) this.k = k; } - public List getClusters() + public List> getClusters() { return clusters; } - public void setClusters(List clusters) + public void setClusters(List> clusters) { this.clusters = clusters; } @@ -183,12 +206,12 @@ public void setMinMovement(double minMovement) this.minMovement = minMovement; } - public RecordList getSourceData() + public IRudderList getSourceData() { return sourceData; } - public void setSourceData(RecordList sourceData) + public void setSourceData(IRudderList sourceData) { this.sourceData = sourceData; } diff --git a/src/test/java/me/mcnelis/rudder/data/MockRecord.java b/src/test/java/me/mcnelis/rudder/data/MockRecord.java new file mode 100644 index 0000000..af0aadd --- /dev/null +++ b/src/test/java/me/mcnelis/rudder/data/MockRecord.java @@ -0,0 +1,72 @@ +package me.mcnelis.rudder.data; +import java.lang.reflect.Field; + +import me.mcnelis.rudder.data.Label; + +public class MockRecord{ + + @NumericFeature + public double feature1; + + @NumericFeature + public double feature2; + + @NumericFeature + public double feature3; + + + protected double nonFeature; + + @Label(setlabel="setDoubleLabel") + protected double doubleLabel; + + @Label(setlabel="setStringLabel") + protected String stringLabel; + + public void setLabel(double d) + { + this.doubleLabel = d; + } + + public void setLabel(String s) + { + this.stringLabel = s; + } + + public void setDoubleLabel(double d) + { + this.doubleLabel = d; + } + + public void setStringLabel(String s) + { + this.stringLabel = s; + } + + public void setFeature(String field, double val) + { + Field[] fields = this.getClass().getFields(); + for(Field f : fields) + { + if(f.getName().equalsIgnoreCase(field)) + { + f.setAccessible(true); + try + { + f.set(this, val); + } + catch (IllegalArgumentException e) + { + // TODO Auto-generated catch block + e.printStackTrace(); + } + catch (IllegalAccessException e) + { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + } + } + +} diff --git a/src/test/java/me/mcnelis/rudder/data/MockRecordTest.java b/src/test/java/me/mcnelis/rudder/data/MockRecordTest.java deleted file mode 100644 index 32fbb36..0000000 --- a/src/test/java/me/mcnelis/rudder/data/MockRecordTest.java +++ /dev/null @@ -1,72 +0,0 @@ -package me.mcnelis.rudder.data; - -import static org.junit.Assert.*; - -import me.mcnelis.rudder.exceptions.FeatureNotFoundException; - -import org.junit.Test; - -public class MockRecordTest { - - @Test - public void testSetFeature() { - MockRecord r = new MockRecord(); - try { - assertTrue(r.setFeature("feature1", 4d)); - } catch (FeatureNotFoundException e) { - e.printStackTrace(); - fail("Feature wasn't found!"); - } - } - - @Test - public void testSetFeatureFailureNonFeature() { - MockRecord r = new MockRecord(); - try { - r.setFeature("nonFeature", 4d); - } catch (FeatureNotFoundException e) { - assertTrue(true); - } - } - - @Test - public void testSetFeatureNoField() { - MockRecord r = new MockRecord(); - try { - r.setFeature("bogusFeature", 4d); - } catch (FeatureNotFoundException e) { - assertTrue(true); - } - } - - @Test - public void testSetLabel() { - MockRecord r = new MockRecord(); - - try { - assertTrue(r.setLabel("doubleLabel", 4d)); - } catch (Exception e) { - fail("Should not have exception"); - } - - } - - @Test - public void testGetFeatureArray() { - MockRecord r = new MockRecord(); - try { - r.setFeature("feature1", 1d); - r.setFeature("feature2", 2d); - r.setFeature("feature3", 3d); - } catch (FeatureNotFoundException e) { - fail("Should have executed cleanly"); - } - - double[] arr = r.getFeatureDoubleArray(); - assertEquals(1d, arr[0], .0001); - assertEquals(2d, arr[1], .0001); - assertEquals(3d, arr[2], .0001); - - - } -} diff --git a/src/test/java/me/mcnelis/rudder/data/RecordTest.java b/src/test/java/me/mcnelis/rudder/data/RecordTest.java deleted file mode 100644 index 9380c9a..0000000 --- a/src/test/java/me/mcnelis/rudder/data/RecordTest.java +++ /dev/null @@ -1,151 +0,0 @@ -package me.mcnelis.rudder.data; - -import static org.junit.Assert.*; - -import me.mcnelis.rudder.data.collections.RecordList; -import me.mcnelis.rudder.exceptions.FeatureNotFoundException; - -import org.junit.Test; - -public class RecordTest { - - @SuppressWarnings("unused") - @Test - public void testSetFeature() { - RecordList list = new RecordList(); - for (int i=0; i<10000; i++) { - Record r = new Record() { - - @NumericFeature double feature1; - @NumericFeature double feature2; - @NumericFeature double feature3; - }; - try { - r.setFeature("feature1", Math.pow(i,2)*3.2d); - r.setFeature("feature2", Math.pow(i,3)/2.3d); - r.setFeature("feature3", 2 * Math.sqrt(i+1)); - } catch (FeatureNotFoundException e) { - fail("Should not have exception"); - } - - list.add(r); - } - - assertTrue(true); - } - - @Test - public void testGetLabelOneDoubleField() { - - Record r = new Record() { - - @NumericFeature double feature1; - @NumericFeature double feature2; - @Label double label; - }; - try { - r.setFeature("feature1", Math.pow(2,2)*3.2d); - r.setFeature("feature2", Math.pow(2,3)/2.3d); - r.setLabel("label", 12d); - } catch (FeatureNotFoundException e) { - fail("Should not have exception"); - } catch (Exception e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } - - assertEquals("12.0", r.getLabel()); - } - - @Test - public void testGetLabelOneStringField() { - - Record r = new Record() { - - @NumericFeature double feature1; - @NumericFeature double feature2; - @Label String label; - }; - try { - r.setFeature("feature1", Math.pow(2,2)*3.2d); - r.setFeature("feature2", Math.pow(2,3)/2.3d); - r.setLabel("label", "success"); - } catch (FeatureNotFoundException e) { - fail("Should not have exception"); - } catch (Exception e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } - - assertEquals("success", r.getLabel()); - } - - @Test - public void testGetLabelTwoFields() { - - Record r = new Record() { - - @NumericFeature double feature1; - @NumericFeature double feature2; - @Label String label; - @Label double label2; - - }; - try { - r.setFeature("feature1", Math.pow(2,2)*3.2d); - r.setFeature("feature2", Math.pow(2,3)/2.3d); - r.setLabel("label", "success"); - r.setLabel("label2", 12d); - } catch (FeatureNotFoundException e) { - fail("Should not have exception"); - } catch (Exception e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } - - assertEquals("success12.0", r.getLabel()); - } - - @Test - public void testSetStringFeature() { - - Record r = new Record() { - - @TextFeature String feature1; - @NumericFeature double feature2; - @Label String label; - @Label double label2; - - }; - try { - r.setFeature("feature1", "test"); - r.setFeature("feature2", Math.pow(2,3)/2.3d); - r.setLabel("label", "success"); - r.setLabel("label2", 12d); - } catch (FeatureNotFoundException e) { - e.printStackTrace(); - fail("Should not have exception"); - } catch (Exception e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } - assertEquals("test",(String)r.getAllFeatures()[0]); - - } - - @Test - public void testUndefinedFeatureDropout() { - Record r = new Record() { - @NumericFeature double f1; - @NumericFeature double f2 = Double.NaN; - - }; - try { - r.setFeature("f1", 1d); - } catch (FeatureNotFoundException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } - assertEquals(1, r.getFeatureDoubleArray().length); - } -} diff --git a/src/test/java/me/mcnelis/rudder/ml/supervised/classification/NaiveBayesClassificationTest.java b/src/test/java/me/mcnelis/rudder/ml/supervised/classification/NaiveBayesClassificationTest.java index ae338e8..b09e894 100644 --- a/src/test/java/me/mcnelis/rudder/ml/supervised/classification/NaiveBayesClassificationTest.java +++ b/src/test/java/me/mcnelis/rudder/ml/supervised/classification/NaiveBayesClassificationTest.java @@ -1,14 +1,13 @@ package me.mcnelis.rudder.ml.supervised.classification; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; -import java.util.HashMap; +import java.lang.reflect.Field; +import me.mcnelis.rudder.data.FeatureType; import me.mcnelis.rudder.data.Label; -import me.mcnelis.rudder.data.Record; import me.mcnelis.rudder.data.TextFeature; import me.mcnelis.rudder.data.collections.IRudderList; -import me.mcnelis.rudder.data.collections.RecordList; import me.mcnelis.rudder.data.collections.RudderList; import me.mcnelis.rudder.exceptions.FeatureNotFoundException; @@ -28,7 +27,7 @@ public void testClassScores() { MockTextFeature m = new MockTextFeature(); MockTextFeature m2 = new MockTextFeature(); - try { + m.setFeature("text1", "up"); m.setFeature("text1", "down"); m.setFeature("text1", "left"); @@ -40,10 +39,7 @@ public void testClassScores() { m2.setFeature("text3", "b"); m2.setFeature("text4", "b"); m2.setFeature("text5", "b"); - } catch (FeatureNotFoundException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } + assertEquals("contra", bayes.getLabel(m).toLowerCase()); assertEquals("double dragon", bayes.getLabel(m2).toLowerCase()); @@ -68,7 +64,7 @@ private IRudderList getMockTextFeatures() { mtf1.setFeature("text3", "left"); mtf1.setFeature("text4", "right"); mtf1.setFeature("text5", "up"); - mtf1.setLabel("label", "contra"); + mtf1.setLabel("contra"); list.add(mtf1); mtf2.setFeature("text1", "left"); @@ -76,7 +72,7 @@ private IRudderList getMockTextFeatures() { mtf2.setFeature("text3", "up"); mtf2.setFeature("text4", "down"); mtf2.setFeature("text5", "left"); - mtf2.setLabel("label", "contra"); + mtf2.setLabel("contra"); list.add(mtf2); mtf3.setFeature("text1", "up"); @@ -84,7 +80,7 @@ private IRudderList getMockTextFeatures() { mtf3.setFeature("text3", "left"); mtf3.setFeature("text4", "right"); mtf3.setFeature("text5", "up"); - mtf3.setLabel("label", "contra"); + mtf3.setLabel("contra"); list.add(mtf3); mtf4.setFeature("text1", "left"); @@ -92,7 +88,7 @@ private IRudderList getMockTextFeatures() { mtf4.setFeature("text3", "left"); mtf4.setFeature("text4", "right"); mtf4.setFeature("text5", "up"); - mtf4.setLabel("label", "contra"); + mtf4.setLabel("contra"); list.add(mtf4); mtf5.setFeature("text1", "up"); @@ -100,7 +96,7 @@ private IRudderList getMockTextFeatures() { mtf5.setFeature("text3", "left"); mtf5.setFeature("text4", "right"); mtf5.setFeature("text5", "down"); - mtf5.setLabel("label", "contra"); + mtf5.setLabel("contra"); list.add(mtf5); mtf6.setFeature("text1", "A"); @@ -108,7 +104,7 @@ private IRudderList getMockTextFeatures() { mtf6.setFeature("text3", "left"); mtf6.setFeature("text4", "c"); mtf6.setFeature("text5", "up"); - mtf6.setLabel("label", "Double dragon"); + mtf6.setLabel("Double dragon"); list.add(mtf6); mtf7.setFeature("text1", "b"); @@ -116,7 +112,7 @@ private IRudderList getMockTextFeatures() { mtf7.setFeature("text3", "left"); mtf7.setFeature("text4", "right"); mtf7.setFeature("text5", "c"); - mtf7.setLabel("label", "double dragon"); + mtf7.setLabel("double dragon"); list.add(mtf7); mtf8.setFeature("text1", "a"); @@ -124,7 +120,7 @@ private IRudderList getMockTextFeatures() { mtf8.setFeature("text3", "left"); mtf8.setFeature("text4", "c"); mtf8.setFeature("text5", "left"); - mtf8.setLabel("label", "double dragon"); + mtf8.setLabel("double dragon"); list.add(mtf8); mtf9.setFeature("text1", "a"); @@ -132,7 +128,7 @@ private IRudderList getMockTextFeatures() { mtf9.setFeature("text3", "right"); mtf9.setFeature("text4", "c"); mtf9.setFeature("text5", "up"); - mtf9.setLabel("label", "double dragon"); + mtf9.setLabel("double dragon"); list.add(mtf9); mtf10.setFeature("text1", "b"); @@ -140,13 +136,11 @@ private IRudderList getMockTextFeatures() { mtf10.setFeature("text3", "b"); mtf10.setFeature("text4", "b"); mtf10.setFeature("text5", "b"); - mtf10.setLabel("label", "double dragon"); + mtf10.setLabel("double dragon"); list.add(mtf10); - } catch (FeatureNotFoundException e) { - // TODO Auto-generated catch block - e.printStackTrace(); + } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); @@ -155,7 +149,7 @@ private IRudderList getMockTextFeatures() { return list; } - class MockTextFeature extends Record { + class MockTextFeature { @TextFeature public String text1; @@ -171,7 +165,38 @@ class MockTextFeature extends Record { @TextFeature public String text5; - @Label + @Label(setlabel="setLabel", type=FeatureType.TEXT) public String label; + + public void setLabel(String label) + { + this.label = label; + } + + public void setFeature(String field, String val) + { + Field[] fields = this.getClass().getFields(); + for(Field f : fields) + { + if(f.getName().equalsIgnoreCase(field)) + { + f.setAccessible(true); + try + { + f.set(this, val); + } + catch (IllegalArgumentException e) + { + // TODO Auto-generated catch block + e.printStackTrace(); + } + catch (IllegalAccessException e) + { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + } + } } } diff --git a/src/test/java/me/mcnelis/rudder/ml/supervised/regression/MultiLinearRegressionTest.java b/src/test/java/me/mcnelis/rudder/ml/supervised/regression/MultiLinearRegressionTest.java index d960dd3..21615e0 100644 --- a/src/test/java/me/mcnelis/rudder/ml/supervised/regression/MultiLinearRegressionTest.java +++ b/src/test/java/me/mcnelis/rudder/ml/supervised/regression/MultiLinearRegressionTest.java @@ -3,15 +3,16 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; -import me.mcnelis.rudder.data.NumericFeature; import me.mcnelis.rudder.data.MockRecord; -import me.mcnelis.rudder.data.Record; -import me.mcnelis.rudder.data.collections.RecordList; +import me.mcnelis.rudder.data.NumericFeature; +import me.mcnelis.rudder.data.collections.IRudderList; +import me.mcnelis.rudder.data.collections.RudderList; import org.junit.Test; @@ -21,23 +22,42 @@ public class MultiLinearRegressionTest { @Test public void testAddRecordSuccess() { - MultiLinearRegression ols = new MultiLinearRegression(); - MockRecord r = new MockRecord(); + MultiLinearRegression ols = new MultiLinearRegression(); + MockRecord r = new MockRecord(){ + @SuppressWarnings("unused") + @NumericFeature + double f2 = 0d; + }; assertTrue(ols.addRecord(r)); } @Test public void testAddRecordFailure() { - MultiLinearRegression ols = new MultiLinearRegression(); - MockRecord r = new MockRecord(); + MultiLinearRegression ols = new MultiLinearRegression(); + MockRecord r = new MockRecord(){ + @SuppressWarnings("unused") + @NumericFeature + double f2 = 0d; + @SuppressWarnings("unused") + @NumericFeature + double f3 = 0d; + }; assertTrue(ols.addRecord(r)); - Record r2 = new Record(){ + MockRecord r2 = new MockRecord(){ @SuppressWarnings("unused") @NumericFeature double f2 = 0d; + }; - assertFalse(ols.addRecord(r2)); + try + { + ols.addRecord(r2); + fail("Added records with different dimensions"); + } catch(Exception e) + { + assertTrue(true); + } } @@ -51,14 +71,14 @@ public void testRunRegression() { e.printStackTrace(); } - RecordList list = new RecordList(); + IRudderList list = new RudderList(); String[] line = null; try { while((line = r.readNext())!=null) { MockRecord m = new MockRecord(); try { - m.setLabel("doubleLabel", Double.parseDouble(line[0])); + m.setDoubleLabel(Double.parseDouble(line[0])); m.setFeature("feature1", Double.parseDouble(line[1])); m.setFeature("feature2", Double.parseDouble(line[2])); m.setFeature("feature3", Double.parseDouble(line[3])); diff --git a/src/test/java/me/mcnelis/rudder/ml/unsupervised/clustering/DBScanTest.java b/src/test/java/me/mcnelis/rudder/ml/unsupervised/clustering/DBScanTest.java index 0f1eed3..9d0df82 100644 --- a/src/test/java/me/mcnelis/rudder/ml/unsupervised/clustering/DBScanTest.java +++ b/src/test/java/me/mcnelis/rudder/ml/unsupervised/clustering/DBScanTest.java @@ -6,7 +6,8 @@ import java.util.List; import me.mcnelis.rudder.data.MockRecord; -import me.mcnelis.rudder.data.collections.RecordList; +import me.mcnelis.rudder.data.collections.IRudderList; +import me.mcnelis.rudder.data.collections.RudderList; import me.mcnelis.rudder.exceptions.FeatureNotFoundException; import org.junit.Test; @@ -17,7 +18,7 @@ public class DBScanTest @Test public void testGetClusters() { - RecordList list = new RecordList(); + IRudderList list = new RudderList(); for (int j = 0; j < 50; j = j + 10) { for (int i = 0; i < 10; i++) @@ -28,13 +29,9 @@ public void testGetClusters() r.setFeature("feature1", 1 * j + i * .0000001); r.setFeature("feature2", 1 * j + i * .0000001); r.setFeature("feature3", 1 * j + i * .0000001); - r.setLabel("stringLabel", "label"); - } - catch (FeatureNotFoundException e) - { - e.printStackTrace(); - fail("Should not have exception"); + r.setLabel("label"); } + catch (Exception e) { // TODO Auto-generated catch block @@ -47,7 +44,7 @@ public void testGetClusters() DBScan db = new DBScan(5d, 2); db.setSourceData(list); - List clusters = db.getClusters(); + List> clusters = db.getClusters(); assertEquals(5, clusters.size()); } diff --git a/src/test/java/me/mcnelis/rudder/ml/unsupervised/clustering/KMeansTest.java b/src/test/java/me/mcnelis/rudder/ml/unsupervised/clustering/KMeansTest.java index 4767a9e..822ef03 100644 --- a/src/test/java/me/mcnelis/rudder/ml/unsupervised/clustering/KMeansTest.java +++ b/src/test/java/me/mcnelis/rudder/ml/unsupervised/clustering/KMeansTest.java @@ -3,7 +3,8 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import me.mcnelis.rudder.data.MockRecord; -import me.mcnelis.rudder.data.collections.RecordList; +import me.mcnelis.rudder.data.collections.IRudderList; +import me.mcnelis.rudder.data.collections.RudderList; import me.mcnelis.rudder.exceptions.FeatureNotFoundException; import org.junit.Test; @@ -12,22 +13,19 @@ public class KMeansTest { @Test public void testClustering() { - RecordList list = new RecordList(); + IRudderList list = new RudderList(); for (int i=0; i<10000; i++) { MockRecord r = new MockRecord(); - try { - r.setFeature("feature1", Math.pow(i,2)*3.2d); - r.setFeature("feature2", Math.pow(i,3)/2.3d); - r.setFeature("feature3", 2 * Math.sqrt(i+1)); - } catch (FeatureNotFoundException e) { - e.printStackTrace(); - fail("Should not have exception"); - } + + r.setFeature("feature1", Math.pow(i,2)*3.2d); + r.setFeature("feature2", Math.pow(i,3)/2.3d); + r.setFeature("feature3", 2 * Math.sqrt(i+1)); + list.add(r); } - KMeans k = new KMeans(3, list); + KMeans k = new KMeans(3, list); k.cluster(); assertTrue(true); }