Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Mappings to common SymPy and SymEngine functions (#54)
* added some mappings for SymPy and SymEngine functions * ensure symbolic functions work on Chain, MultiVector and Simplex objects * added numeric conversion functions for SymPy and SymEngine * avoid Union types and type piracy in symbolic function definitions * remove extra line
- Loading branch information
1 parent
0dcd607
commit 3f38acf
Showing
3 changed files
with
84 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
|
||
""" | ||
Smoke tests for common SymPy and SymEngine functions. | ||
""" | ||
module SymEngineTests | ||
using Grassmann | ||
using Test | ||
|
||
@testset "Test SymEngine" begin | ||
using SymEngine | ||
@basis S"+++" | ||
x,y,z = symbols("x y z") | ||
|
||
# expanding the symbolic coefficient of a `Simplex` | ||
simp = (x+1)^2*v1 | ||
@show expand(simp) | ||
|
||
# expansion/substitution on each symbolic coefficient of a `MultiVector` | ||
mv = (x+y)^3 * v12 + (y+z) * v123 | ||
@show expand(mv) | ||
@show numeric_mv = N(subs(mv, Dict(x=>2, y=>2, z=>2))) | ||
@show map(typeof, numeric_mv.v) | ||
|
||
# expanding each symbolic coefficient of a `Chain` | ||
@show expand((x+1)*(x+2)*(v1+v2)) | ||
end | ||
end | ||
|
||
module SymPyTests | ||
using Grassmann | ||
using Test | ||
|
||
@testset "Test SymPy" begin | ||
using SymPy | ||
@basis S"+++" | ||
x,y,z = symbols("x,y,z") | ||
|
||
@show expand((x+y+z)^3*(v1+v12+v123)) | ||
|
||
expanded = expand((x+1)*(x+y)^5) | ||
@show unwieldly_multivector = z*v1 + expanded*v12 | ||
@show clean_multivector = factor(unwieldly_multivector) | ||
|
||
mv = (x+y)^3 * v12 + (y+z) * v123 | ||
|
||
@show numeric_mv = N(subs(mv, Dict(x=>2, y=>2, z=>2))) | ||
@show map(typeof, numeric_mv.v) | ||
|
||
|
||
end | ||
end |