Permalink
Browse files

Merge branch 'memoization'

  • Loading branch information...
Logan McGrath
Logan McGrath committed Jun 17, 2013
2 parents de6c6da + e392a2f commit 7d69d49a911d2d916701fa973e02ffabe82afe9d
@@ -8,12 +8,10 @@
public class BoundExpression extends Expression {
- private final Expression expression;
- private final Argument argument;
+ private Expression state;
public BoundExpression(Expression expression, Argument argument) {
- this.expression = expression;
- this.argument = argument;
+ state = new UnboundState(this, expression, argument);
}
@Override
@@ -33,42 +31,108 @@ public Expression apply(Expression argument) throws SterlingException {
@Override
public boolean equals(Object o) {
- if (o == this) {
- return true;
- } else if (o instanceof BoundExpression) {
- BoundExpression other = (BoundExpression) o;
- return Objects.equals(expression, other.expression)
- && Objects.equals(argument, other.argument);
- } else {
- return false;
- }
- }
-
- public Argument getArgument() {
- return argument;
- }
-
- public Expression getExpression() {
- return expression;
+ return o == this || o instanceof BoundExpression && Objects.equals(state, ((BoundExpression) o).state);
}
@Override
public int hashCode() {
- return Objects.hash(expression, argument);
+ return Objects.hash(state);
}
@Override
public Expression reduce() throws SterlingException {
- return bindArgument(expression, argument);
+ return state.reduce();
}
@Override
public String toString() {
- return stringify(this, expression, argument);
+ return stringify(this, state);
}
@Override
protected boolean isReducible() {
- return true;
+ return state.isReducible();
+ }
+
+ private static final class ReducedState extends Expression {
+
+ private final BoundExpression parent;
+ private final Expression expression;
+
+ public ReducedState(BoundExpression parent, Expression expression) {
+ this.parent = parent;
+ this.expression = expression;
+ }
+
+ @Override
+ public Expression reduce() throws SterlingException {
+ Expression expression = this.expression;
+ if (expression.isReducible()) {
+ expression = expression.reduce();
+ parent.state = new ReducedState(parent, expression);
+ }
+ return expression;
+ }
+
+ @Override
+ protected boolean isReducible() {
+ return expression.isReducible();
+ }
+ }
+
+ private static final class UnboundState extends Expression {
+
+ private final BoundExpression parent;
+ private final Expression expression;
+ private final Argument argument;
+
+ public UnboundState(BoundExpression parent, Expression expression, Argument argument) {
+ this.parent = parent;
+ this.expression = expression;
+ this.argument = argument;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == this) {
+ return true;
+ } else if (o instanceof UnboundState) {
+ UnboundState other = (UnboundState) o;
+ return Objects.equals(expression, other.expression)
+ && Objects.equals(argument, other.argument);
+ } else {
+ return false;
+ }
+ }
+
+ public Argument getArgument() {
+ return argument;
+ }
+
+ public Expression getExpression() {
+ return expression;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(expression, argument);
+ }
+
+ @Override
+ public Expression reduce() throws SterlingException {
+ Expression reduction = bindArgument(expression, argument);
+ parent.state = new ReducedState(parent, reduction);
+ return reduction;
+ }
+
+ @Override
+ public String toString() {
+ return stringify(this, expression, argument);
+ }
+
+ @Override
+ protected boolean isReducible() {
+ return true;
+ }
}
}
@@ -51,8 +51,7 @@ public Void visitArgumentExpression(Argument expression, PrinterState state) thr
@Override
public Void visitBindExpression(BoundExpression expression, PrinterState state) throws SterlingException {
state.begin(expression);
- visit(expression.getExpression(), state);
- visit(expression.getArgument(), state);
+ // TODO how to visit bound expression?
state.end();
return null;
}
@@ -4,17 +4,21 @@
import static org.sterling.runtime.expression.ExpressionFactory.bind;
import static org.sterling.util.StringUtil.stringify;
+import java.util.Map;
import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
import org.sterling.SterlingException;
public class Lambda extends Expression {
private final Expression expression;
private final Variable variable;
+ private final Map<Expression, Expression> memos;
public Lambda(Variable variable, Expression expression) {
this.expression = expression;
this.variable = variable;
+ this.memos = new ConcurrentHashMap<>();
}
@Override
@@ -24,7 +28,10 @@ public Lambda(Variable variable, Expression expression) {
@Override
public Expression apply(Expression argument) throws SterlingException {
- return bind(expression, argument(variable, argument));
+ if (!memos.containsKey(argument)) {
+ memos.put(argument, bind(expression, argument(variable, argument)));
+ }
+ return memos.get(argument);
}
@Override
@@ -30,7 +30,7 @@ public void setUp() {
@Test
public void shouldShowLambdaAndArgument_whenToString() {
- assertThat(expression.toString(), equalTo("(BoundExpression LAMBDA ARGUMENT)"));
+ assertThat(expression.toString(), equalTo("(BoundExpression (UnboundState LAMBDA ARGUMENT))"));
}
@Test
@@ -0,0 +1,124 @@
+package sterling.math;
+
+import static java.lang.System.currentTimeMillis;
+import static java.lang.System.out;
+import static org.sterling.runtime.GlobalModule.global;
+import static org.sterling.runtime.expression.ExpressionFactory.constant;
+
+import java.text.NumberFormat;
+import java.util.ArrayList;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.sterling.SterlingException;
+import org.sterling.runtime.expression.Expression;
+import org.sterling.runtime.expression.ExpressionLoader;
+import org.sterling.runtime.expression.IntegerConstant;
+
+@Ignore
+public class FibonacciBenchmarkTest {
+
+ private IntegerConstant input;
+ private ExpressionLoader loader;
+ private String fibonacci;
+ private int executions;
+ private int iterations;
+
+ @Before
+ public void setUp() {
+ input = constant(20);
+ loader = global();
+ fibonacci = "sterling/math/fibonacci";
+ executions = 100;
+ iterations = 10;
+ }
+
+ @Test
+ public void testBenchmark() throws SterlingException {
+ javaBenchmark();
+ sterlingBenchmark();
+ }
+
+ private void javaBenchmark() {
+ List<Interval> intervals = new ArrayList<>(iterations);
+ int value = input.getValue();
+ out.println("Java Benchmark");
+ out.println("--------------");
+ for (int i = 0; i < iterations; i++) {
+ long startTime = currentTimeMillis();
+ for (int j = 0; j < executions; j++) {
+ fibonacci(value);
+ }
+ intervals.add(printIteration(i, startTime, currentTimeMillis()));
+ }
+ out.println("--------------");
+ printAverage(intervals);
+ out.println();
+ }
+
+ private void sterlingBenchmark() throws SterlingException {
+ Expression expression = loader.load(fibonacci);
+ List<Interval> intervals = new ArrayList<>(iterations);
+ out.println("Sterling Benchmark");
+ out.println("------------------");
+ for (int i = 0; i < iterations; i++) {
+ long startTime = currentTimeMillis();
+ for (int j = 0; j < executions; j++) {
+ expression.apply(input).evaluate();
+ }
+ intervals.add(printIteration(i, startTime, currentTimeMillis()));
+ }
+ out.println("------------------");
+ printAverage(intervals);
+ out.println();
+ }
+
+ private void printAverage(List<Interval> intervals) {
+ long sum = 0;
+ for (Interval interval : intervals) {
+ sum += interval.getDifference();
+ }
+ out.println("Average for " + iterations + " iterations X " + executions + " executions: " + format(sum / intervals.size()));
+ }
+
+ private Interval printIteration(int iteration, long startTime, long endTime) {
+ Interval interval = new Interval(startTime, endTime);
+ out.println("Iteration " + iteration + ": executions = " + executions + "; elapsed = " + interval);
+ return interval;
+ }
+
+ private static int fibonacci(int input) {
+ if (input == 0) {
+ return 0;
+ } else if (input == 1) {
+ return input;
+ } else {
+ return fibonacci(input - 1) + fibonacci(input - 2);
+ }
+ }
+
+ private static String format(long millis) {
+ return NumberFormat.getInstance().format(millis) + " milliseconds";
+ }
+
+ private static final class Interval {
+
+ private final long start;
+ private final long end;
+
+ public Interval(long start, long end) {
+ this.start = start;
+ this.end = end;
+ }
+
+ public long getDifference() {
+ return end - start;
+ }
+
+ @Override
+ public String toString() {
+ return format(getDifference());
+ }
+ }
+}

0 comments on commit 7d69d49

Please sign in to comment.