Skip to content

Commit

Permalink
Mappings to common SymPy and SymEngine functions (#54)
Browse files Browse the repository at this point in the history
* added some mappings for SymPy and SymEngine functions

* ensure symbolic functions work on Chain, MultiVector and Simplex objects

* added numeric conversion functions for SymPy and SymEngine

* avoid Union types and type piracy in symbolic function definitions

* remove extra line
  • Loading branch information
micahscopes committed Apr 24, 2020
1 parent 0dcd607 commit 3f38acf
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 2 deletions.
31 changes: 29 additions & 2 deletions src/Grassmann.jl
Expand Up @@ -230,6 +230,22 @@ function generate_algebra(m,t,d=nothing,c=nothing)
generate_inverses(m,t)
!isnothing(d) && generate_derivation(m,t,d,c)
end
function generate_symbolic_methods(mod, symtype, methods_noargs, methods_args)
for method methods_noargs
@eval begin
local apply_symbolic(x) = map(v -> typeof(v) == $mod.$symtype ? $mod.$method(v) : v, x)
$mod.$method(x::T) where T<:TensorGraded = apply_symbolic(x)
$mod.$method(x::T) where T<:TensorMixed = apply_symbolic(x)
end
end
for method methods_args
@eval begin
local apply_symbolic(x, args...) = map(v -> typeof(v) == $mod.$symtype ? $mod.$method(v, args...) : v, x)
$mod.$method(x::T, args...) where T<:TensorGraded = apply_symbolic(x, args...)
$mod.$method(x::T, args...) where T<:TensorMixed = apply_symbolic(x, args...)
end
end
end

function __init__()
@require Reduce="93e0c654-6965-5f22-aba9-9c1ae6b3c259" begin
Expand All @@ -249,8 +265,19 @@ function __init__()
generate_derivation(:(Reduce.Algebra),T,:df,:RExpr)
end
end
@require SymPy="24249f21-da20-56a4-8eb1-6a02cf4ae2e6" generate_algebra(:SymPy,:Sym,:diff,:symbols)
@require SymEngine="123dc426-2d89-5057-bbad-38513e3affd8" generate_algebra(:SymEngine,:Basic,:diff,:symbols)
@require SymPy="24249f21-da20-56a4-8eb1-6a02cf4ae2e6" begin
generate_algebra(:SymPy,:Sym,:diff,:symbols)
generate_symbolic_methods(:SymPy,:Sym, (:expand,:factor,:together,:apart,:cancel), (:N,:subs))
for T ( Chain{V,G,SymPy.Sym} where {V,G},
MultiVector{V,SymPy.Sym} where V,
Simplex{V,G,SymPy.Sym} where {V,G} )
SymPy.collect(x::T, args...) = map(v -> typeof(v) == SymPy.Sym ? SymPy.collect(v, args...) : v, x)
end
end
@require SymEngine="123dc426-2d89-5057-bbad-38513e3affd8" begin
generate_algebra(:SymEngine,:Basic,:diff,:symbols)
generate_symbolic_methods(:SymEngine,:Basic, (:expand,:N), (:subs,:evalf))
end
@require AbstractAlgebra="c3fe647b-3220-5bb0-a1ea-a7954cac585d" generate_algebra(:AbstractAlgebra,:SetElem)
@require GaloisFields="8d0d7f98-d412-5cd4-8397-071c807280aa" generate_algebra(:GaloisFields,:AbstractGaloisField)
@require LightGraphs="093fc24a-ae57-5d10-9952-331d41423f4d" begin
Expand Down
4 changes: 4 additions & 0 deletions src/composite.jl
Expand Up @@ -447,3 +447,7 @@ end
@inbounds return similar_type(a, T, Size($Snew))(tuple($(exprs...)))
end
end

Base.map(fn, x::MultiVector{V}) where V = MultiVector{V}(map(fn, value(x)))
Base.map(fn, x::Chain{V,G}) where {V,G} = Chain{V,G}(map(fn,value(x)))
Base.map(fn, x::Simplex{V,G,B}) where {V,G,B} = fn(value(x))*B
51 changes: 51 additions & 0 deletions test/symbolictests.jl
@@ -0,0 +1,51 @@

"""
Smoke tests for common SymPy and SymEngine functions.
"""
module SymEngineTests
using Grassmann
using Test

@testset "Test SymEngine" begin
using SymEngine
@basis S"+++"
x,y,z = symbols("x y z")

# expanding the symbolic coefficient of a `Simplex`
simp = (x+1)^2*v1
@show expand(simp)

# expansion/substitution on each symbolic coefficient of a `MultiVector`
mv = (x+y)^3 * v12 + (y+z) * v123
@show expand(mv)
@show numeric_mv = N(subs(mv, Dict(x=>2, y=>2, z=>2)))
@show map(typeof, numeric_mv.v)

# expanding each symbolic coefficient of a `Chain`
@show expand((x+1)*(x+2)*(v1+v2))
end
end

module SymPyTests
using Grassmann
using Test

@testset "Test SymPy" begin
using SymPy
@basis S"+++"
x,y,z = symbols("x,y,z")

@show expand((x+y+z)^3*(v1+v12+v123))

expanded = expand((x+1)*(x+y)^5)
@show unwieldly_multivector = z*v1 + expanded*v12
@show clean_multivector = factor(unwieldly_multivector)

mv = (x+y)^3 * v12 + (y+z) * v123

@show numeric_mv = N(subs(mv, Dict(x=>2, y=>2, z=>2)))
@show map(typeof, numeric_mv.v)


end
end

0 comments on commit 3f38acf

Please sign in to comment.