Skip to content

Commit

Permalink
SAMME implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
chen0040 committed Jun 9, 2017
1 parent e5d7808 commit e12ffee
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
@@ -1,6 +1,8 @@
.idea/
*.iml

target/

# Compiled class file
*.class

Expand Down
35 changes: 35 additions & 0 deletions README.md
Expand Up @@ -143,6 +143,41 @@ for(int i=0; i < crossValidationData.rowCount(); ++i) {
evaluator.report();
```

### Classification via Ensemble (SAMME)

```java
InputStream irisStream = new FileInputStream("iris.data");
DataFrame irisData = DataQuery.csv(",")
.from(irisStream)
.selectColumn(0).asNumeric().asInput("Sepal Length")
.selectColumn(1).asNumeric().asInput("Sepal Width")
.selectColumn(2).asNumeric().asInput("Petal Length")
.selectColumn(3).asNumeric().asInput("Petal Width")
.selectColumn(4).asCategory().asOutput("Iris Type")
.build();

TupleTwo<DataFrame, DataFrame> parts = irisData.shuffle().split(0.9);

DataFrame trainingData = parts._1();
DataFrame crossValidationData = parts._2();

System.out.println(crossValidationData.head(10));

SAMME multiClassClassifier = new SAMME();
multiClassClassifier.fit(trainingData);

ClassifierEvaluator evaluator = new ClassifierEvaluator();

for(int i=0; i < crossValidationData.rowCount(); ++i) {
String predicted = multiClassClassifier.classify(crossValidationData.row(i));
String actual = crossValidationData.row(i).categoricalTarget();
System.out.println("predicted: " + predicted + "\tactual: " + actual);
evaluator.evaluate(actual, predicted);
}

evaluator.report();
```

To create and train a Bagging ensemble classifier:

### Anomaly Detection
Expand Down
133 changes: 133 additions & 0 deletions src/main/java/com/github/chen0040/trees/ensembles/SAMME.java
@@ -0,0 +1,133 @@
package com.github.chen0040.trees.ensembles;


import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.TupleTwo;
import com.github.chen0040.data.utils.discretizers.KMeansDiscretizer;
import com.github.chen0040.trees.id3.ID3;
import lombok.Getter;
import lombok.Setter;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;


/**
* Created by xschen on 9/6/2017.
*/
public class SAMME {
private final List<ID3> classifiers = new ArrayList<>();
private final List<TupleTwo<Integer, Double>> model = new ArrayList<>();

@Getter
@Setter
private int treeCount = 100;

private KMeansDiscretizer discretizer=new KMeansDiscretizer();

@Getter
private final List<String> classLabels = new ArrayList<>();

@Setter
@Getter
public double dataSampleRate = 0.2; // value between 0 and 1

public SAMME(){
}

public void fit(DataFrame frame){

frame = discretizer.fitAndTransform(frame);

classifiers.clear();
classLabels.clear();
for(int m = 0; m < treeCount; ++m) {
ID3 classifier = new ID3(false);
classifier.fit(frame.shuffle().split(0.2)._1());
classifiers.add(classifier);
}

final int N = frame.rowCount();
double[] weights = new double[N];
Set<String> labels = new HashSet<>();
for(int i=0; i < N; ++i){
weights[i] = 1.0 / N;
labels.add(frame.row(i).categoricalTarget());
}
classLabels.addAll(labels);
int K = classLabels.size();

for(int t = 0; t < treeCount; ++t) {

double min_err = Double.MAX_VALUE;
int M = -1;
for (int m = 0; m < treeCount; ++m) {
ID3 classifier_m = classifiers.get(m);
double err_m = 0;
for (int i = 0; i < N; ++i) {
DataRow row = frame.row(i);
String predicted = classifier_m.classify(row);

if (!predicted.equals(row.categoricalTarget())) {
err_m += weights[i];
}
}

if (min_err > err_m) {
min_err = err_m;
M = m;
}
}

// Add next classifier
ID3 classifier_t = classifiers.get(M);
double alpha_t = 0.5 * Math.log((1-min_err) / min_err) + Math.log(K - 1);
model.add(new TupleTwo<>(M, alpha_t));

// Update weight
double sum = 0;
for(int i=0; i < N; ++i){
DataRow row_i = frame.row(i);
String predicted = classifier_t.classify(row_i);
double II = predicted.equals(row_i.categoricalTarget()) ? 0 : 1;
weights[i] = weights[i] * Math.exp(alpha_t * II);
sum += weights[i];
}

// Normalize weight
for(int i=0; i < N; ++i) {
weights[i] /= sum;
}
}
}

public String classify(DataRow row) {
row = discretizer.transform(row);

double max_sum_k = Double.NEGATIVE_INFINITY;
int K = -1;
for(int k =0; k < classLabels.size(); ++k){
String candidate = classLabels.get(k);

double sum_k = 0;
for(int m = 0; m < treeCount; ++m) {
TupleTwo<Integer, Double> t = model.get(m);
ID3 classifier_t = classifiers.get(t._1());
double alpha_t = t._2();
String predicted = classifier_t.classify(row);
double II = predicted.equals(candidate) ? 1 : 0;
sum_k += alpha_t * II;
}

if(sum_k > max_sum_k) {
max_sum_k =sum_k;
K = k;
}
}

return classLabels.get(K);
}
}
@@ -0,0 +1,54 @@
package com.github.chen0040.trees.ensembles;


import com.github.chen0040.data.evaluators.ClassifierEvaluator;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataQuery;
import com.github.chen0040.data.utils.TupleTwo;
import com.github.chen0040.trees.utils.FileUtils;
import org.testng.annotations.Test;

import java.io.IOException;
import java.io.InputStream;

import static org.testng.Assert.*;


/**
* Created by xschen on 9/6/2017.
*/
public class SAMMEUnitTest {
@Test
public void test_iris() throws IOException {
InputStream irisStream = FileUtils.getResource("iris.data");
DataFrame irisData = DataQuery.csv(",")
.from(irisStream)
.selectColumn(0).asNumeric().asInput("Sepal Length")
.selectColumn(1).asNumeric().asInput("Sepal Width")
.selectColumn(2).asNumeric().asInput("Petal Length")
.selectColumn(3).asNumeric().asInput("Petal Width")
.selectColumn(4).asCategory().asOutput("Iris Type")
.build();

TupleTwo<DataFrame, DataFrame> parts = irisData.shuffle().split(0.9);

DataFrame trainingData = parts._1();
DataFrame crossValidationData = parts._2();

System.out.println(crossValidationData.head(10));

SAMME multiClassClassifier = new SAMME();
multiClassClassifier.fit(trainingData);

ClassifierEvaluator evaluator = new ClassifierEvaluator();

for(int i=0; i < crossValidationData.rowCount(); ++i) {
String predicted = multiClassClassifier.classify(crossValidationData.row(i));
String actual = crossValidationData.row(i).categoricalTarget();
System.out.println("predicted: " + predicted + "\tactual: " + actual);
evaluator.evaluate(actual, predicted);
}

evaluator.report();
}
}

0 comments on commit e12ffee

Please sign in to comment.