Skip to content

Commit

Permalink
Revising entire project to simply for persistence
Browse files Browse the repository at this point in the history
  • Loading branch information
U-osiris\joeo committed Jan 30, 2012
1 parent 38e966c commit 826fb95
Show file tree
Hide file tree
Showing 33 changed files with 1,443 additions and 303 deletions.
37 changes: 37 additions & 0 deletions bayes/pom.xml
@@ -0,0 +1,37 @@
<?xml version="1.0"?>
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"
xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
<modelVersion>4.0.0</modelVersion>
<parent>
<artifactId>ml-parent</artifactId>
<groupId>com.enigmastation.ml</groupId>
<version>4.0-SNAPSHOT</version>
</parent>
<groupId>com.enigmastation.ml</groupId>
<artifactId>bayes</artifactId>
<version>4.0-SNAPSHOT</version>
<name>bayes</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>6.3.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.lucene</groupId>
<artifactId>lucene-analyzers</artifactId>
<version>3.4.0</version>
</dependency>
<dependency>
<groupId>${parent.groupId}</groupId>
<artifactId>common</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
</project>
11 changes: 11 additions & 0 deletions bayes/src/main/java/com/enigmastation/ml/bayes/Classifier.java
@@ -0,0 +1,11 @@
package com.enigmastation.ml.bayes;

import java.util.Map;

public interface Classifier {
Object classify(Object source);
Object classify(Object source, Object defaultClassification);
Object classify(Object source, Object defaultClassification, double strength);
Map<Object, Double> getClassificationProbabilities(Object source);
void train(Object source, Object classification);
}
7 changes: 7 additions & 0 deletions bayes/src/main/java/com/enigmastation/ml/bayes/Tokenizer.java
@@ -0,0 +1,7 @@
package com.enigmastation.ml.bayes;

import java.util.List;

public interface Tokenizer {
List<Object> tokenize(Object source);
}
@@ -0,0 +1,35 @@
package com.enigmastation.ml.bayes.impl;

import com.enigmastation.ml.bayes.Tokenizer;
import org.apache.lucene.analysis.PorterStemFilter;
import org.apache.lucene.analysis.standard.StandardTokenizer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.util.Version;

import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;

public class PorterTokenizer implements Tokenizer {
@Override
public List<Object> tokenize(Object source) {
List<Object> tokens = new ArrayList<>(source.toString().length() / 5);
org.apache.lucene.analysis.Tokenizer tokenizer =
new StandardTokenizer(Version.LUCENE_34,
new StringReader(source.toString()));
CharTermAttribute charTermAttribute = tokenizer.getAttribute(CharTermAttribute.class);
PorterStemFilter filter = new PorterStemFilter(tokenizer);
try {
while (filter.incrementToken()) {
String term = charTermAttribute.toString().toLowerCase();
if (term.length() > 2) {
tokens.add(term);
}
}
} catch (IOException e) {
throw new RuntimeException("Should not happen: " + e.getMessage(), e);
}
return tokens;
}
}
@@ -0,0 +1,171 @@
package com.enigmastation.ml.bayes.impl;

import com.enigmastation.ml.bayes.Classifier;
import com.enigmastation.ml.bayes.Tokenizer;
import com.enigmastation.ml.common.collections.MapBuilder;
import com.enigmastation.ml.common.collections.ValueProvider;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;

public class SimpleClassifier implements Classifier {
Map<Object, Map<Object, Integer>> features;
Map<Object, Integer> categories;
Tokenizer tokenizer = new PorterTokenizer();
Map<Object, Double> thresholds = new MapBuilder().defaultValue(1.0).build();

SimpleClassifier(Map<Object, Map<Object, Integer>> features, Map<Object, Integer> categories) {
this.features = features;
this.categories = categories;
}

public SimpleClassifier() {
features = new MapBuilder().valueProvider(new ValueProvider<Object, Object>() {
@Override
public Object getDefault(Object k) {
return new MapBuilder().defaultValue(0).build();
}
}).build();
categories = new MapBuilder().defaultValue(0).build();
}

@Override
public Object classify(Object source) {
return classify(source, "none");
}

@Override
public Object classify(Object source, Object defaultClassification) {
return classify(source, defaultClassification, 0.0);
}

@Override
public Object classify(Object source, Object defaultClassification, double strength) {
Map<Object, Double> probs = getClassificationProbabilities(source);
double max = 0.0;
Object category = null;
for (Map.Entry<Object, Double> entry : probs.entrySet()) {
if (entry.getValue() > max) {
max = entry.getValue();
category = entry.getKey();
}
}
for (Map.Entry<Object, Double> entry : probs.entrySet()) {
if (entry.getKey().equals(category)) {
continue;
}

if ((entry.getValue() * getThreshold(category)) > probs.get(category)) {
return defaultClassification;
}
}
return category;
}

@Override
public Map<Object, Double> getClassificationProbabilities(Object source) {
Map<Object, Double> probabilities = new HashMap<Object, Double>();
for (Object category : categories()) {
probabilities.put(category, docprob(source, category));
}
return probabilities;
}

@Override
public void train(Object source, Object classification) {
for (Object feature : tokenizer.tokenize(source)) {
incf(feature, classification);
}
incc(classification);
}

private void incf(Object feature, Object category) {
Map<Object, Integer> cat = features.get(feature);
features.put(feature, cat);
cat.put(category, cat.get(category) + 1);
}

private void incc(Object category) {
categories.put(category, categories.get(category) + 1);
}

// the number of times a feature has occurred in a category
int fcount(Object feature, Object category) {
if (features.containsKey(feature) && features.get(feature).containsKey(category)) {
return features.get(feature).get(category);
}
return 0;
}

int catcount(Object category) {
if (categories.containsKey(category)) {
return categories.get(category);
}
return 0;
}

int totalcount() {
int sum = 0;
for (Integer i : categories.values()) {
sum += i;
}
return sum;
}

Set<Object> categories() {
return categories.keySet();
}

double fprob(Object feature, Object category) {
if (catcount(category) == 0) {
return 0.0;
}
return (1.0 * fcount(feature, category)) / catcount(category);
}

double weightedprob(Object feature, Object category, double weight, double assumedProbability) {
double basicProbability = fprob(feature, category);
//System.out.println("basic probability: "+basicProbability);
double totals = 0;
for (Object cat : categories()) {
totals += fcount(feature, cat);
}
//System.out.printf("((%f * %f)+(%f * %f))/(%f + %f) %n", weight,
//assumedProbability, totals, basicProbability, weight, totals);
return ((weight * assumedProbability) + (totals * basicProbability)) / (weight + totals);
}

double weightedprob(Object feature, Object category, double weight) {
return weightedprob(feature, category, weight, 0.5);
}

double weightedprob(Object feature, Object category) {
return weightedprob(feature, category, 1.0);
}

/* naive bayes, very naive - and not what we usually need. */
double docprob(Object corpus, Object category) {
double p = 1.0;
for (Object f : tokenizer.tokenize(corpus)) {
p *= weightedprob(f, category);
}
return p;
}

double prob(Object corpus, Object category) {
double catprob = (1.0 * catcount(category)) / totalcount();
double docprob = docprob(corpus, category);
return docprob * catprob;
}

public void setThreshold(Object category, Double threshold) {
thresholds.put(category, threshold);
}

public double getThreshold(Object category) {
return thresholds.get(category);
}


}
@@ -0,0 +1,20 @@
package com.enigmastation.ml.bayes.impl;

import com.enigmastation.ml.bayes.Tokenizer;

import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

public class SimpleTokenizer implements Tokenizer {
@Override
public List<Object> tokenize(Object source) {
String src=source.toString();
List<Object> tokens=new ArrayList<>(src.length()/6);
Scanner scanner=new Scanner(src);
while(scanner.hasNext("\\S*")) {
tokens.add(scanner.next("\\S*"));
}
return tokens;
}
}
25 changes: 25 additions & 0 deletions bayes/src/test/java/com/enigmastation/ml/bayes/TokenizerTest.java
@@ -0,0 +1,25 @@
package com.enigmastation.ml.bayes;

import com.enigmastation.ml.bayes.impl.PorterTokenizer;
import com.enigmastation.ml.bayes.impl.SimpleTokenizer;
import org.testng.annotations.Test;

import java.util.List;

import static org.testng.Assert.assertEquals;

public class TokenizerTest {
@Test
public void testSimpleTokenizer() {
Tokenizer tokenizer = new SimpleTokenizer();
List<Object> objects = tokenizer.tokenize("1 2 3 4 5 6 7");
assertEquals(objects.size(), 7);
}

@Test
public void testPorterTokenizer() {
Tokenizer porterTokenizer = new PorterTokenizer();
porterTokenizer.tokenize("Now is the time for all good men to come to the aid of their finalizing country.");
assertEquals(porterTokenizer.tokenize("the quick brown fox jumps over the lazy dog's tail").size(), 10);
}
}
@@ -0,0 +1,66 @@
package com.enigmastation.ml.bayes.impl;

import org.testng.annotations.Test;

import static org.testng.Assert.assertEquals;

public class SimpleClassifierTest {
@Test
public void testInternalsOfTrain() {
SimpleClassifier c = new SimpleClassifier();
c.train("the quick brown fox jumps over the lazy dog's tail", "good");
c.train("make quick money in the online casino", "bad");
assertEquals(c.fcount("quick", "good"), 1);
assertEquals(c.fcount("quick", "bad"), 1);
}

@Test
public void testfprob() {
SimpleClassifier cl = getTrainedClassifier();
assertEquals(cl.fprob("quick", "good"), 0.6666, 0.0001);
assertEquals(cl.weightedprob("monei", "good"), 0.25, 0.001);
train(cl);
assertEquals(cl.weightedprob("monei", "good"), 0.166, 0.001);
}

@Test
public void testBayes() {
SimpleClassifier cl = getTrainedClassifier();
assertEquals(cl.prob("quick rabbit", "good"), 0.15624, 0.0001);
assertEquals(cl.prob("quick rabbit", "bad"), 0.05, 0.001);
}

@Test
public void testThreshold() {
SimpleClassifier cl = getTrainedClassifier();
assertEquals(cl.getThreshold("bad"), 1.0, 0.01);
}

@Test
public void testClassification() {
SimpleClassifier cl = getTrainedClassifier();
assertEquals(cl.classify("quick rabbit", "unknown"), "good");
assertEquals(cl.classify("quick money", "unknown"), "bad");
cl.setThreshold("bad", 3.0);
assertEquals(cl.classify("quick money", "unknown"), "unknown");
for (int i = 0; i < 10; i++) {
train(cl);
}
assertEquals(cl.classify("quick money", "unknown"), "bad");
}

private SimpleClassifier getTrainedClassifier() {
SimpleClassifier cl = new SimpleClassifier();

train(cl);
return cl;
}

private void train(SimpleClassifier cl) {
cl.train("Nobody owns the water", "good");
cl.train("The quick rabbit jumps fences", "good");
cl.train("Buy pharmaceuticals now", "bad");
cl.train("make quick money in the online casino", "bad");
cl.train("the quick brown fox jumps", "good");
}
}

0 comments on commit 826fb95

Please sign in to comment.