Skip to content

Commit

Permalink
Add permutedims(::Tensor, ...)
Browse files Browse the repository at this point in the history
Closes #121
  • Loading branch information
malmaud committed Jan 30, 2017
1 parent 94303d4 commit a6b6561
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ tf_promote(t, x) = Tensor(x)

convert_number(t, n) = n
convert_number(t, x::Number) = t(x)
convert_number(t, x::Union{AbstractArray, Tuple}) = map(t, x)

to_tensor(x::Union{Number, String, AbstractTensor}) = Tensor(x)
to_tensor(x::AbstractArray) = Tensor(x)
Expand Down
5 changes: 5 additions & 0 deletions src/ops/transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -550,11 +550,16 @@ function Base.transpose(n::AbstractTensor, perm=nothing; name="transpose")
r = range(Tensor, 0, limit=rank(n))
perm = reverse(r, [true])
end
perm = convert_number(Int32, perm)
desc = NodeDescription("Transpose")
add_input(desc, Tensor(n))
add_input(desc, Tensor(perm))
end
Tensor(Operation(desc))
end

function Base.permutedims(n::AbstractTensor, perm; name="transpose")
transpose(n, perm.-1; name=name)
end

Base.ctranspose(n::AbstractTensor) = transpose(n)
1 change: 1 addition & 0 deletions test/transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ one_tens = ones(Tensor, (5,5))
@test ones(10,5) == run(sess, tile(one_tens, [2; 1]))

@test ones(Float32, 4,3) == run(sess, transpose(ones(Tensor, (3, 4))))
@test ones(Float32, 4,3,2) == run(sess, permutedims(ones(Tensor, (4, 2, 3)), [1, 3, 2]))

@test hcat(ones(Float32, 5,5), zeros(Float32, 5)) == run(sess, pad(one_tens, [0 0; 0 1]))

Expand Down

0 comments on commit a6b6561

Please sign in to comment.