Skip to content

Commit

Permalink
cleaned up the polymorphism a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
horeilly1101 committed Jun 16, 2019
1 parent 5be6176 commit b1f0694
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 78 deletions.
8 changes: 4 additions & 4 deletions src/main/java/com/deriv/expression/Constant.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ public class Constant implements Expression {
/**
* Singleton instance of constant 0.
*/
private static final Expression ADD_ID = constant(0);
private static final Constant ADD_ID = new Constant(0);

/**
* Singleton instance of constant 1.
*/
private static final Expression MULT_ID = constant(1);
private static final Constant MULT_ID = new Constant(1);

/**
* Singleton instance of constant e.
*/
private static final Expression E = new Variable("e");
private static final Variable E = new Variable("e");

/**
* Singleton instance of constant pi.
*/
private static final Expression PI = new Variable("蟺");
private static final Variable PI = new Variable("蟺");

/**
* The integer value of a constant.
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/deriv/expression/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private Tensor(List<Expression> lines) {
* @param lines input
* @return boolean
*/
private static boolean isValid(List<Expression> lines) {
private static boolean isValid(List<? extends Expression> lines) {
// it doesn't make sense to have an empty tensor
if (lines.size() < 1)
return false;
Expand Down Expand Up @@ -69,7 +69,7 @@ public static Expression of(Expression... lines) {
* @param lines input
* @return tensor
*/
public static Expression of(List<Expression> lines) {
public static Expression of(List<? extends Expression> lines) {
if(!isValid(lines))
throw new RuntimeException("Each Expression must have the same depth!");

Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/deriv/expression/Variable.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class Variable implements Expression {
/**
* Singleton instance of variable x.
*/
private static final Expression X = new Variable("x");
private static final Variable X = new Variable("x");

/**
* String that represents the variable. (e.g. x)
Expand Down Expand Up @@ -49,7 +49,7 @@ public static Expression var(String var) {
* Static constructor for a variable with name "x".
* @return variable
*/
public static Expression x() {
public static Variable x() {
return X;
}

Expand Down
14 changes: 7 additions & 7 deletions src/test/java/com/deriv/expression/AddTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,39 +96,39 @@ void divAddTest() {
void evaluateTest() {
// x + x + 2
Expression ex = add(x(), x(), constant(2));
Optional<Expression> eval = ex.evaluate(x().asVariable(), constant(2));
Optional<Expression> eval = ex.evaluate(x(), constant(2));
assertTrue(eval.isPresent());
assertEquals(constant(6), eval.get());

// x + a + 3, where a is a constant
Expression ex2 = add(x(), constant("a"), constant(3));
Optional<Expression> eval2 = ex2.evaluate(x().asVariable(), constant(3));
Optional<Expression> eval2 = ex2.evaluate(x(), constant(3));
assertTrue(eval2.isPresent());
assertEquals(add(constant("a"), constant(6)), eval2.get());

// x + 1/x
// can't divide by 0!
Expression ex3 = add(x(), poly(x(), -1));
assertEquals(Optional.empty(), ex3.evaluate(x().asVariable(), addID()));
assertEquals(Optional.empty(), ex3.evaluate(x(), addID()));
}

@Test
void derivativeTest() {
// x + x + 2
Expression ex = add(x(), x(), constant(2));
assertEquals(constant(2), ex.differentiate(x().asVariable()).get());
assertEquals(constant(2), ex.differentiate(x()).get());

// a * x + 3, where a is a constant
Expression ex2 = add(mult(x(), constant("a")), constant(3));
assertEquals(constant("a"), ex2.differentiate(x().asVariable()).get());
assertEquals(constant("a"), ex2.differentiate(x()).get());

// x + ln(x)
Expression ex3 = add(x(), ln(x()));
assertEquals(add(multID(), poly(x(), -1)), ex3.differentiate(x().asVariable()).get());
assertEquals(add(multID(), poly(x(), -1)), ex3.differentiate(x()).get());

// x + sin(x)
Expression ex4 = add(x(), sin(x()));
assertEquals(add(multID(), cos(x())), ex4.differentiate(x().asVariable()).get());
assertEquals(add(multID(), cos(x())), ex4.differentiate(x()).get());
}

@Test
Expand Down
8 changes: 4 additions & 4 deletions src/test/java/com/deriv/expression/ConstantTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ void derivativeTest() {
// of completeness

Expression c = constant("c");
assertEquals(addID(), c.differentiate(x().asVariable()).get());
assertEquals(addID(), c.differentiate(x()).get());

Expression one = multID();
assertEquals(addID(), one.differentiate(x().asVariable()).get());
assertEquals(addID(), one.differentiate(x()).get());

Expression e = e();
assertEquals(addID(), e.differentiate(x().asVariable()).get());
assertEquals(addID(), e.differentiate(x()).get());

Expression pi = pi();
assertEquals(addID(), pi.differentiate(x().asVariable()).get());
assertEquals(addID(), pi.differentiate(x()).get());
}

@Test
Expand Down
6 changes: 3 additions & 3 deletions src/test/java/com/deriv/expression/DivTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ void addTest() {
void evaluateTest() {
// x / 3
Expression ex = div(x(), constant(3));
assertEquals(constant(27), ex.evaluate(x().asVariable(), constant(81)).get());
assertEquals(constant(27), ex.evaluate(x(), constant(81)).get());

// (x + 1) / 3
Expression ex2 = div(add(x(), multID()), constant(3));
assertEquals(constant(27), ex2.evaluate(x().asVariable(), constant(80)).get());
assertEquals(constant(27), ex2.evaluate(x(), constant(80)).get());

// ln(x + 1) / x
Expression ex3 = div(ln(add(x(), multID())), x());
assertEquals(ln(constant(2)), ex3.evaluate(x().asVariable(), multID()).get());
assertEquals(ln(constant(2)), ex3.evaluate(x(), multID()).get());
}

@Test
Expand Down
12 changes: 6 additions & 6 deletions src/test/java/com/deriv/expression/LogTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,33 +41,33 @@ void evaluateTest() {
Expression lg = log(constant(2), x());

// evaluate at 2
Optional<Expression> eval = lg.evaluate(x().asVariable(), constant(2));
Optional<Expression> eval = lg.evaluate(x(), constant(2));
assertTrue(eval.isPresent());
assertEquals(multID(), eval.get());

// evaluate at 5
Optional<Expression> eval2 = lg.evaluate(x().asVariable(), constant(5));
Optional<Expression> eval2 = lg.evaluate(x(), constant(5));
assertTrue(eval2.isPresent());
assertEquals(log(constant(2), constant(5)), eval2.get());

// evaluate at -1
Optional<Expression> eval3 = lg.evaluate(x().asVariable(), negate(multID()));
Optional<Expression> eval3 = lg.evaluate(x(), negate(multID()));
assertFalse(eval3.isPresent());
}

@Test
void differentiateTest() {
// ln(x)
Expression ln = ln(x());
assertEquals(div(multID(), x()), ln.differentiate(x().asVariable()).get());
assertEquals(div(multID(), x()), ln.differentiate(x()).get());

// log(e, x)
Expression ln2 = log(e(), x());
assertEquals(ln.differentiate(x().asVariable()).get(), ln2.differentiate(x().asVariable()).get());
assertEquals(ln.differentiate(x()).get(), ln2.differentiate(x()).get());

// log(2, x)
Expression lg = log(constant(2), x());
assertEquals(div(multID(), mult(ln(constant(2)), x())), lg.differentiate(x().asVariable()).get());
assertEquals(div(multID(), mult(ln(constant(2)), x())), lg.differentiate(x()).get());
}

@Test
Expand Down
14 changes: 7 additions & 7 deletions src/test/java/com/deriv/expression/MultTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,39 +95,39 @@ void distributeTest() {
void evaluateTest() {
// x * x * 2
Expression ex = mult(x(), x(), constant(2));
Optional<Expression> eval = ex.evaluate(x().asVariable(), constant(2));
Optional<Expression> eval = ex.evaluate(x(), constant(2));
assertTrue(eval.isPresent());
assertEquals(constant(8), eval.get());

// x * a * 3, where a is a constant
Expression ex2 = mult(x(), constant("a"), constant(3));
Optional<Expression> eval2 = ex2.evaluate(x().asVariable(), constant(3));
Optional<Expression> eval2 = ex2.evaluate(x(), constant(3));
assertTrue(eval2.isPresent());
assertEquals(mult(constant("a"), constant(9)), eval2.get());

// 3 / x
Expression ex3 = div(constant(3), x());
Optional<Expression> eval3 = ex3.evaluate(x().asVariable(), addID());
Optional<Expression> eval3 = ex3.evaluate(x(), addID());
assertFalse(eval3.isPresent());
}

@Test
void derivativeTest() {
// x * x * 2
Expression ex = mult(x(), x(), constant(2));
assertEquals(mult(constant(4), x()), ex.differentiate(x().asVariable()).get());
assertEquals(mult(constant(4), x()), ex.differentiate(x()).get());

// x * a * 3, where a is a constant
Expression ex2 = mult(x(), constant("a"), constant(3));
assertEquals(mult(constant("a"), constant(3)), ex2.differentiate(x().asVariable()).get());
assertEquals(mult(constant("a"), constant(3)), ex2.differentiate(x()).get());

// x * ln(x)
Expression ex3 = mult(x(), ln(x()));
assertEquals(add(multID(), ln(x())), ex3.differentiate(x().asVariable()).get());
assertEquals(add(multID(), ln(x())), ex3.differentiate(x()).get());

// x * sin(x)
Expression ex4 = mult(x(), sin(x()));
assertEquals(add(sin(x()), mult(x(), cos(x()))), ex4.differentiate(x().asVariable()).get());
assertEquals(add(sin(x()), mult(x(), cos(x()))), ex4.differentiate(x()).get());
}

@Test
Expand Down
24 changes: 12 additions & 12 deletions src/test/java/com/deriv/expression/ParallelTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ void addEvaluateTest() {
runComparison("addEvaluateTest",
() -> ExpressionUtils.linearityHelper(
result.asAdd().getTerms(),
x -> x.evaluate(x().asVariable(), multID()))
x -> x.evaluate(x(), multID()))
.map(Add::add), // add the results

() -> sequentialLinearityHelper(
result.asAdd().getTerms(),
x -> x.evaluate(x().asVariable(), multID()))
x -> x.evaluate(x(), multID()))
.map(Add::add)); // add the results
}

Expand All @@ -152,11 +152,11 @@ void addDerivativeTest() {
runComparison("addDerivativeTest",
() -> ExpressionUtils.linearityHelper(
result.asAdd().getTerms(),
x -> x.differentiate(x().asVariable())).map(Add::add),
x -> x.differentiate(x())).map(Add::add),

() -> sequentialLinearityHelper(
result.asAdd().getTerms(),
x -> x.differentiate(x().asVariable())).map(Add::add));
x -> x.differentiate(x())).map(Add::add));
}

// @Test
Expand All @@ -166,11 +166,11 @@ void addDerivativeTest2() {
runComparison("addDerivativeTest2",
() -> ExpressionUtils.linearityHelper(
result.asAdd().getTerms(),
x -> x.differentiate(x().asVariable())).map(Add::add),
x -> x.differentiate(x())).map(Add::add),

() -> sequentialLinearityHelper(
result.asAdd().getTerms(),
x -> x.differentiate(x().asVariable())).map(Add::add));
x -> x.differentiate(x())).map(Add::add));
}

// @Test
Expand All @@ -180,30 +180,30 @@ void multEvaluateTest() {
runComparison("multEvaluateTest",
() -> ExpressionUtils.linearityHelper(
result.asMult().getFactors(),
x -> x.evaluate(x().asVariable(), multID())).map(Mult::mult),
x -> x.evaluate(x(), multID())).map(Mult::mult),

() -> sequentialLinearityHelper(
result.asMult().getFactors(),
x -> x.evaluate(x().asVariable(), multID())).map(Mult::mult));
x -> x.evaluate(x(), multID())).map(Mult::mult));
}

// @Test
void multDerivativeTest() {
Expression result = mult(sinList(1_000));

runComparison("multDerivativeTest",
() -> result.differentiate(x().asVariable()), // parallel
() -> result.differentiate(x()), // parallel

() -> sequentialMultDerivative(result.asMult().getFactors(), x().asVariable()));
() -> sequentialMultDerivative(result.asMult().getFactors(), x()));
}

// @Test
void multDerivativeTest2() {
Expression result = mult(expoList(500));

runComparison("multDerivativeTest2",
() -> result.differentiate(x().asVariable()), // parallel
() -> result.differentiate(x()), // parallel

() -> sequentialMultDerivative(result.asMult().getFactors(), x().asVariable()));
() -> sequentialMultDerivative(result.asMult().getFactors(), x()));
}
}
18 changes: 9 additions & 9 deletions src/test/java/com/deriv/expression/PowerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,45 +76,45 @@ void simplifyTest() {
void evaluateTest() {
// 5 ^ x
Expression ex = exponential(5, x());
Optional<Expression> eval = ex.evaluate(x().asVariable(), constant(3));
Optional<Expression> eval = ex.evaluate(x(), constant(3));
assertTrue(eval.isPresent());
assertEquals(constant(125), eval.get());

// x ^ 4
Expression ex2 = poly(x(), 4);
Optional<Expression> eval2 = ex2.evaluate(x().asVariable(), constant(2));
Optional<Expression> eval2 = ex2.evaluate(x(), constant(2));
assertTrue(eval2.isPresent());
assertEquals(constant(16), eval2.get());
assertEquals(addID(), ex2.evaluate(x().asVariable(), addID()).get());
assertEquals(addID(), ex2.evaluate(x(), addID()).get());

// x ^ x
Expression ex3 = power(x(), x());
Optional<Expression> eval3 = ex3.evaluate(x().asVariable(), constant(3));
Optional<Expression> eval3 = ex3.evaluate(x(), constant(3));
assertTrue(eval3.isPresent());
assertEquals(constant(27), eval3.get());

// 1 / 0
Expression ex4 = poly(x(), -1);
assertEquals(Optional.empty(), ex4.evaluate(x().asVariable(), addID()));
assertEquals(Optional.empty(), ex4.evaluate(x(), addID()));

// x ^ -2
Expression ex5 = poly(x(), -2);
assertFalse(ex5.evaluate(x().asVariable(), addID()).isPresent());
assertFalse(ex5.evaluate(x(), addID()).isPresent());
}

@Test
void differentiateTest() {
// 5 ^ x
Expression ex = exponential(3, x());
assertEquals(mult(power(constant(3), x()), ln(constant(3))), ex.differentiate(x().asVariable()).get());
assertEquals(mult(power(constant(3), x()), ln(constant(3))), ex.differentiate(x()).get());

// x ^ 4
Expression ex2 = poly(x(), 4);
assertEquals(mult(constant(4), poly(x(), 3)), ex2.differentiate(x().asVariable()).get());
assertEquals(mult(constant(4), poly(x(), 3)), ex2.differentiate(x()).get());

// x ^ x
Expression ex3 = power(x(), x());
assertEquals(mult(add(multID(), ln(x())), power(x(), x())), ex3.differentiate(x().asVariable()).get());
assertEquals(mult(add(multID(), ln(x())), power(x(), x())), ex3.differentiate(x()).get());
}

@Test
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/com/deriv/expression/TensorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ void ofTest() {
Tensor.of(add(x(), constant(2), poly(x(), 2)))
);
System.out.println(ten);
System.out.println(ten.evaluate(x().asVariable(), constant(3)).get());
System.out.println(ten.differentiate(x().asVariable()).get());
System.out.println(ten.evaluate(x(), constant(3)).get());
System.out.println(ten.differentiate(x()).get());
}

@Test
Expand Down

0 comments on commit b1f0694

Please sign in to comment.