diff --git a/src/main/java/com/cscotta/recordinality/Recordinality.java b/src/main/java/com/cscotta/recordinality/Recordinality.java index 4d5c44f..d40fa4d 100644 --- a/src/main/java/com/cscotta/recordinality/Recordinality.java +++ b/src/main/java/com/cscotta/recordinality/Recordinality.java @@ -25,7 +25,7 @@ import com.google.common.hash.HashFunction; import com.google.common.collect.ImmutableSet; -public class Recordinality { +public class Recordinality { private final int sampleSize; private final int seed = new Random().nextInt(); @@ -46,7 +46,7 @@ public Recordinality(int sampleSize) { /* * Observes a value in a stream. */ - public void observe(String element) { + public void observe(T element) { boolean inserted = insertIfFits(element); if (inserted) modifications.incrementAndGet(); } @@ -84,8 +84,17 @@ public Set getSample() { /* * Inserts a record into our k-set if it fits. */ - private boolean insertIfFits(String element) { - long hashedValue = hash.hashString(element).asLong(); + private boolean insertIfFits(T element) { + long hashedValue = 0L; + + if (element.getClass() == String.class) + hashedValue = hash.hashString((String)element).asLong(); + else if (element.getClass() == Long.class) + hashedValue = hash.hashLong((Long)element).asLong(); + else if (element.getClass() == Integer.class) + hashedValue = hash.hashInt((Integer)element).asLong(); + else + throw new ClassCastException(); // Short-circuit if our k-set is saturated. Common case. if (hashedValue < cachedMin.get() && kMapSize >= sampleSize) @@ -141,10 +150,10 @@ public String toString() { */ public class Element { - public final String value; + public final T value; public final AtomicLong count; - public Element(String value) { + public Element(T value) { this.value = value; this.count = new AtomicLong(1); } diff --git a/src/test/java/com/cscotta/recordinality/RecordinalityTest.java b/src/test/java/com/cscotta/recordinality/RecordinalityTest.java index 0813e07..5cd5ccd 100644 --- a/src/test/java/com/cscotta/recordinality/RecordinalityTest.java +++ b/src/test/java/com/cscotta/recordinality/RecordinalityTest.java @@ -67,7 +67,7 @@ public Result call() throws Exception { long start = System.currentTimeMillis(); final double[] results = new double[numRuns]; for (int i = 0; i < numRuns; i++) { - Recordinality rec = new Recordinality(kSize); + Recordinality rec = new Recordinality(kSize); for (String line : lines) rec.observe(line); results[i] = rec.estimateCardinality(); }