Skip to content

Commit

Permalink
add sparse bind
Browse files Browse the repository at this point in the history
  • Loading branch information
kailaix committed Oct 7, 2020
1 parent 2350c96 commit d9c9159
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/TagBot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ jobs:
- uses: julia-actions/setup-julia@latest
with:
version: 1.4
- name: Install ADCME
run: julia --color=yes -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate(); Pkg.build("ADCME"); using ADCME'
- name: Install-Dependencies
run: julia --project=docs/ --color=yes -e 'using Pkg; Pkg.add("Documenter"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate(); Pkg.build("ADCME"); using ADCME; install_adept(); ADCME.precompile()'
run: julia --project=docs/ --color=yes -e 'using Pkg; Pkg.add("Documenter"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate(); Pkg.build("ADCME"); using ADCME'
- name: Build and deploy
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token
Expand Down
7 changes: 6 additions & 1 deletion src/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -736,4 +736,9 @@ function trisolve(a::Union{PyObject, Array{Float64,1}},b::Union{PyObject, Array{
a,b,c,d = convert_to_tensor(Any[a,b,c,d], [Float64,Float64,Float64,Float64])
out = tri_solve_(a,b,c,d)
set_shape(out, (n,))
end
end


function Base.:bind(op::SparseTensor, ops...)
bind(op.o.values, ops...)
end
21 changes: 21 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ reset_default_graph(); sess = Session()
@test run(sess, b_)2.0


a = Variable(ones(10))
b = Variable(ones(10))
for i = 1:10
control_dependencies(a) do
a = scatter_add(a, i, b[i])
end
end

a_ = spdiag(ones(10))
b_ = Variable(0.0)
op = assign(b_, 2.0)
a_ = a_*2
a_ = bind(a_, op)


init(sess)
run(sess, a_)
@test run(sess, a)ones(10)*2
@test run(sess, b_)2.0


end

@testset "while loop" begin
Expand Down

0 comments on commit d9c9159

Please sign in to comment.