-
Notifications
You must be signed in to change notification settings - Fork 23
/
LogisticTest.java
112 lines (95 loc) · 4.26 KB
/
LogisticTest.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package regression;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.kernel.impl.proc.Procedures;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.test.TestGraphDatabaseFactory;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;
import static org.neo4j.helpers.collection.MapUtil.map;
import org.apache.mahout.common.RandomUtils;
public class LogisticTest {
private static GraphDatabaseService db;
//TODO: larger data set, correct random function
@BeforeClass
public static void setUp() throws Exception {
db = new TestGraphDatabaseFactory().newImpermanentDatabase();
Procedures procedures = ((GraphDatabaseAPI) db).getDependencyResolver().resolveDependency(Procedures.class);
procedures.registerProcedure(Logistic.class);
procedures.registerFunction(Logistic.class);
}
@AfterClass
public static void tearDown() throws Exception {
db.shutdown();
}
@Test
public void makeModel() throws Exception {
String csvFile = "/Users/laurenshin/documents/linreg-graph-analytics/src/test/resources/iris-full.csv";
String line = "";
String csvSplitBy = ",";
List<Map<String,Double>> data = new ArrayList<>();
List<String> target = new ArrayList<>();
List<Integer> order = new ArrayList<>();
/*Map<String, Integer> stringToInt = new HashMap<>();
Map<Integer, String> intToString = new HashMap<>();
stringToInt.put("Iris-setosa", 0);
stringToInt.put("Iris-versicolor", 1);
stringToInt.put("Iris-virginica", 2);
intToString.put(0, "Iris-setosa");
intToString.put(1, "Iris-versicolor");
intToString.put(2, "Iris-virginica");*/
try (BufferedReader br = new BufferedReader(new FileReader(csvFile))){
br.readLine(); //skip headers
int i = 0;
while ((line = br.readLine()) != null) {
String[] flower = line.split(csvSplitBy);
Map<String, Double> v = new HashMap<>(4);
v.put("sepallength", Double.parseDouble(flower[1])); //sepal length
v.put("sepalwidth", Double.parseDouble(flower[2])); //sepal width
v.put("petallength", Double.parseDouble(flower[3])); //petal length
v.put("petalwidth", Double.parseDouble(flower[4])); //petal width
data.add(v);
target.add(flower[5]); //class
order.add(i++);
}
} catch (IOException e) {
e.printStackTrace();
Assert.fail("unable to read csv file for test data");
}
RandomUtils.useTestSeed();
Random random = RandomUtils.getRandom();
Collections.shuffle(order, random);
List<Integer> train = order.subList(0, 100);
List<Integer> test = order.subList(100, 150);
db.execute("CALL regression.logistic.create('model', ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'], " +
"{sepallength:'float', sepalwidth:'float', petallength:'float', petalwidth:'float'}, {prior:'L2'})").close();
for (int pass = 0; pass < 30; pass++) {
Collections.shuffle(train, random);
for (int j : train) {
db.execute("CALL regression.logistic.add('model', {output}, {inputs})", map("inputs", data.get(j), "output", target.get(j)));
}
}
int successes = 0;
int failures = 0;
for (int k : test) {
String t;
String guess = ((String) db.execute("RETURN regression.logistic.predict('model', {inputs}) as prediction", map("inputs", data.get(k))).next().get("prediction"));
if (guess.equals(target.get(k))) {
t = "SUCCESS!";
successes++;
} else {
t = "FAIL!";
failures++;
}
System.out.format("Expected: %s, Actual: %s %s%n", target.get(k), guess, t);
}
System.out.format("SUCCESSES: %d%n", successes);
System.out.format("FAILURES: %d%n", failures);
db.execute("CALL regression.logistic.delete('model')");
}
}