Skip to content

Commit

Permalink
extended SparseArray multiplication syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
chakravala committed Mar 25, 2020
1 parent b8219e1 commit e242eba
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -5,7 +5,7 @@ os:
- osx
julia:
- 1.0
- 1.3
- 1.4
- nightly
jobs:
allow_failures:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,7 +1,7 @@
name = "Grassmann"
uuid = "4df31cd9-4c27-5bea-88d0-e6a7146666d8"
authors = ["Michael Reed"]
version = "0.5.5"
version = "0.5.6"

[deps]
AbstractTensors = "a8e43f4a-99b7-5565-8bf1-0165161caaea"
Expand Down
2 changes: 1 addition & 1 deletion appveyor.yml
@@ -1,7 +1,7 @@
environment:
matrix:
- julia_version: 1
- julia_version: 1.3
- julia_version: 1.4
- julia_version: nightly

platform:
Expand Down
10 changes: 1 addition & 9 deletions src/Grassmann.jl
Expand Up @@ -354,17 +354,9 @@ function __init__()
P = ChainBundle([Chain{V,1,Float64}(vcat(1,p[:,k])) for k 1:size(p,2)])
E = ChainBundle([Chain{P(2:s...),1,Int}(Int.(e[1:s-1,k])) for k 1:size(e,2)])
T = ChainBundle([Chain{P,1,Int}(Int.(t[1:s,k])) for k 1:size(t,2)])
matlab(p,bundle(P)); matlab(e,bundle(E)); matlab(t,bundle(T))
#matlab(p,bundle(P)); matlab(e,bundle(E)); matlab(t,bundle(T))
return (P,E,T)
end
export pdegrad
pdegrad(t::ChainBundle,Φ) = pdegrad(points(t),t,Φ)
function pdegrad(p,t,Φ)
P,T = matlab(p),matlab(t)
u,v = MATLAB.mxcall.(:pdeprtni,1,Ref(P),Ref(T),MATLAB.mxcall(:pdegrad,2,P,T,Φ))
V = DirectSum.parent(p)(2,3)
[Chain{V,1}(SVector(u[k],v[k])) for k 1:length(p)]
end
end
end

Expand Down
15 changes: 15 additions & 0 deletions src/composite.jl
Expand Up @@ -381,3 +381,18 @@ Base.isfinite(a::MultiVector) = prod(isfinite.(value(a)))
Base.rationalize(t::Type,a::Chain{V,G,T};tol::Real=eps(T)) where {V,G,T} = Chain{V,G}(rationalize.(t,value(a),tol))
Base.rationalize(t::Type,a::MultiVector{V,T};tol::Real=eps(T)) where {V,T} = MultiVector{V}(rationalize.(t,value(a),tol))
Base.rationalize(t::T;kvs...) where T<:TensorAlgebra = rationalize(Int,t;kvs...)

*(A::SparseMatrixCSC{TA,S}, x::StridedVector{Chain{V,G,𝕂,X}}) where {TA,S,V,G,𝕂,X} =
(T = promote_type(TA, Chain{V,G,𝕂,X}); SparseArrays.mul!(similar(x, T, A.m), A, x, 1, 0))
*(A::SparseMatrixCSC{TA,S}, B::StridedMatrix{Chain{V,G,𝕂,X}}) where {TA,S,V,G,𝕂,X} =
(T = promote_type(TA, Chain{V,G,𝕂,X}); mul!(similar(B, T, (A.m, size(B, 2))), A, B, 1, 0))
*(adjA::LinearAlgebra.Adjoint{<:Any,<:SparseMatrixCSC{TA,S}}, x::StridedVector{Chain{V,G,𝕂,X}}) where {TA,S,V,G,𝕂,X} =
(T = promote_type(TA, Chain{V,G,𝕂,X}); mul!(similar(x, T, size(adjA, 1)), adjA, x, 2, 0))
*(transA::LinearAlgebra.Transpose{<:Any,<:SparseMatrixCSC{TA,S}}, x::StridedVector{Chain{V,G,𝕂,X}}) where {TA,S,V,G,𝕂,X} =
(T = promote_type(TA, Chain{V,G,𝕂,X}); mul!(similar(x, T, size(transA, 1)), transA, x, 1, 0))
if VERSION >= v"1.4"
*(adjA::LinearAlgebra.Adjoint{<:Any,<:SparseMatrixCSC{TA,S}}, B::SparseArrays.AdjOrTransStridedOrTriangularMatrix{Chain{V,G,𝕂,X}}) where {TA,S,V,G,𝕂,X} =
(T = promote_type(TA, Chain{V,G,𝕂,X}); mul!(similar(B, T, (size(adjA, 1), size(B, 2))), adjA, B, 1, 0))
*(transA::LinearAlgebra.Transpose{<:Any,<:SparseMatrixCSC{TA,S}}, B::SparseArrays.AdjOrTransStridedOrTriangularMatrix{Chain{V,G,𝕂,X}}) where {TA,S,V,G,𝕂,X} =
(T = promote_type(TA, Chain{V,G,𝕂,X}); mul!(similar(B, T, (size(transA, 1), size(B, 2))), transA, B, 1, 0))
end
5 changes: 5 additions & 0 deletions src/multivectors.jl
Expand Up @@ -41,6 +41,11 @@ Chain(v::Chain{V,G,𝕂}) where {V,G,𝕂} = Chain{V,G}(SVector{binomial(ndims(V
Chain{𝕂}(v::SubManifold{V,G}) where {V,G,𝕂} = Chain(one(𝕂),v)
Chain{𝕂}(v::Simplex{V,G,B}) where {V,G,B,𝕂} = Chain{𝕂}(v.v,basis(v))
Chain{𝕂}(v::Chain{V,G}) where {V,G,𝕂} = Chain{V,G}(SVector{binomial(ndims(V),G),𝕂}(v.v))
Chain{V,G,T,X}(x::Simplex{V,0}) where {V,G,T,X} = Chain{V,G}(zeros(mvec(ndims(V),G,T)))
function Chain{V,0,T,X}(x::Simplex{V,0,v}) where {V,T,X,v}
N = ndims(V)
Chain{V,0}(setblade!(zeros(mvec(N,0,T)),value(x),bits(v),Val{N}()))
end

export Chain
getindex(m::Chain,i::Int) = m.v[i]
Expand Down

0 comments on commit e242eba

Please sign in to comment.