Skip to content

Commit

Permalink
Add TensorCast support (#88)
Browse files Browse the repository at this point in the history
* move to @cast

* partial transpose in tensor cast

* use not-@cast as default

* restore tests, fix checks in ptrace
  • Loading branch information
lpawela committed Jan 24, 2021
1 parent 1762f4e commit a0195df
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 34 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"

[compat]
Expand Down
1 change: 1 addition & 0 deletions src/QuantumInformation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module QuantumInformation
using LinearAlgebra
using DocStringExtensions
using TensorOperations
using TensorCast
using Convex, SCS
using Random: AbstractRNG, GLOBAL_RNG

Expand Down
33 changes: 9 additions & 24 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ function ket(::Type{T}, val::Int, dim::Int) where T<:AbstractVector{<:Number}
ψ
end

function ket(::Type{T}, val::Int, dim::Int) where T<:Number
@warn "This method is deprecated and will be removed. Use calls like `ket(Matrix{ComplexF64}, 1, 2)`."
ket(Vector{T}, val, dim)
end
ket(::Type{T}, val::Int, dim::Int) where T<:Number = ket(Vector{T}, val, dim)

"""
$(SIGNATURES)
Expand All @@ -22,14 +19,11 @@ $(SIGNATURES)
Return complex column vector \$|val\\rangle\$ of unit norm describing quantum state.
"""
ket(val::Int, dim::Int) = ket(Vector{ComplexF64}, val, dim)
ket(val::Int, dim::Int) = ket(ComplexF64, val, dim)

function bra(::Type{T}, val::Int, dim::Int) where T<:Number
@warn "This method is deprecated and will be removed. Use calls like `bra(Matrix{ComplexF64}, 1, 2)`."
bra(Vector{T}, val, dim)
end

bra(::Type{T}, val::Int, dim::Int) where T<:AbstractVector{<:Number} = ket(T, val, dim)'
bra(::Type{T}, val::Int, dim::Int) where T<:Number = bra(Vector{T}, val, dim)

"""
$(SIGNATURES)
Expand All @@ -38,7 +32,7 @@ $(SIGNATURES)
Return Hermitian conjugate \$\\langle val| = |val\\rangle^\\dagger\$ of the ket with the same label.
"""
bra(val::Int, dim::Int) = bra(Vector{ComplexF64}, val, dim)
bra(val::Int, dim::Int) = bra(ComplexF64, val, dim)

function ketbra(::Type{T}, valk::Int, valb::Int, idim::Int, odim::Int) where T<:AbstractMatrix{<:Number}
idim > 0 && odim > 0 ? () : throw(ArgumentError("Matrix dimension has to be nonnegative"))
Expand All @@ -50,11 +44,7 @@ function ketbra(::Type{T}, valk::Int, valb::Int, idim::Int, odim::Int) where T<:
end

ketbra(::Type{T}, valk::Int, valb::Int, dim::Int) where T<:AbstractMatrix{<:Number} = ketbra(T, valk, valb, dim, dim)

function ketbra(::Type{T}, valk::Int, valb::Int, dim::Int) where T<:Number
@warn "This method is deprecated and will be removed. Use calls like `ketbra(Matrix{ComplexF64}, 1, 1, 2)`."
ketbra(Matrix{T}, valk, valb, dim)
end
ketbra(::Type{T}, valk::Int, valb::Int, dim::Int) where T<:Number = ketbra(Matrix{T}, valk, valb, dim)

"""
$(SIGNATURES)
Expand All @@ -64,7 +54,7 @@ $(SIGNATURES)
# Return outer product \$|valk\\rangle\\langle vakb|\$ of states \$|valk\\rangle\$ and \$|valb\\rangle\$.
"""
ketbra(valk::Int, valb::Int, dim::Int) = ketbra(Matrix{ComplexF64}, valk, valb, dim)
ketbra(valk::Int, valb::Int, dim::Int) = ketbra(ComplexF64, valk, valb, dim)


"""
Expand Down Expand Up @@ -94,14 +84,9 @@ $(SIGNATURES)
Returns `vec(ρ.T)`. Reshaping maps
matrix `ρ` into a vector row by row.
"""
res::AbstractMatrix{<:Number}) = vec(transpose(ρ))
res::AbstractMatrix{<:Number}) = @cast x[(j, i)] := ρ[i, j]

function unres::AbstractVector{<:Number}, cols::Int)
dim = length(ϕ)
rows = div(dim, cols)
rows*cols == dim ? () : throw(ArgumentError("Wrong number of columns"))
transpose(reshape(ϕ, cols, rows))
end
unres::AbstractVector{<:Number}, cols::Int) = @cast x[i, j] := ϕ[(j, i)] j:cols

"""
$(SIGNATURES)
Expand All @@ -122,7 +107,7 @@ $(SIGNATURES)
Return maximally mixed state \$\\frac{1}{d}\\sum_{i=0}^{d-1}|i\\rangle\\langle i |\$ of length \$d\$.
"""
max_mixed(d::Int) = Matrix(I/d, d, d) # eye(ComplexF64, d, d)/d
max_mixed(d::Int) = I(d)/d

"""
$(SIGNATURES)
Expand Down
4 changes: 2 additions & 2 deletions src/ptrace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ $(SIGNATURES)
"""
function ptrace::AbstractVector{<:Number}, idims::Vector{Int}, sys::Int)
# TODO : Allow mutlipartite systems
length(idims) == 2 ? () : throw(ArgumentError("idims has to be of length 2"))
_, cols = idims
m = unres(ψ, cols)
length(idims) == 2 ? () : throw(ArgumentError("idims has to be of length 2"))
if sys == 1
return transpose(m) * conj.(m)
elseif sys == 2
return m * transpose(conj.(m))
return m * m'
else
throw(ArgumentError("sys must be 1 or 2"))
end
Expand Down
31 changes: 29 additions & 2 deletions src/ptranspose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ function ptranspose(ρ::AbstractMatrix{<:Number}, idims::Vector{Int}, isystems::
tensor = reshape(ρ, [dims; dims]...)
perm = collect(1:(2offset))
for s in systems
idx1 = findall(x->x==s, perm)[1]
idx2 = findall(x->x==(s + offset), perm)[1]
idx1 = findfirst(x->x==s, perm)
idx2 = findfirst(x->x==(s + offset), perm)
perm[idx1], perm[idx2] = perm[idx2], perm[idx1]
end
tensor = permutedims(tensor, invperm(perm))
Expand All @@ -40,3 +40,30 @@ $(SIGNATURES)
- `sys`: transposed subsystem.
"""
ptranspose::AbstractMatrix{<:Number}, idims::Vector{Int}, sys::Int) = ptranspose(ρ, idims, [sys])

function _ptranspose::AbstractMatrix{<:Number}, idims::Vector{Int}, isystems::Vector{Int})
ns = length(idims)

ex1 = Expr(:ref, :x)
ex2 = Expr(:ref, ρ)

I = Expr(:tuple, [gensym() for _=1:ns]...)
J = Expr(:tuple, [gensym() for _=1:ns]...)

K = copy(I)
L = copy(J)

r = Expr(:tuple)
for (k, (i, j)) in enumerate(zip(K.args, L.args))
push!(r.args, :($i:$(idims[k])), :($j:$(idims[k])))
end
for s in isystems
K.args[s], L.args[s] = L.args[s], K.args[s]
end
push!(ex1.args, I, J)
push!(ex2.args, L, K)

ex = Expr(:(:=), ex1, ex2)
ex, r
@eval @cast $ex $r
end
10 changes: 7 additions & 3 deletions test/ptranspose.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
@testset "Partial transpose" begin

@testset "Dense matrices" begin
ρ = ComplexF64[1 2 3 4; 5 6 7 8; 9 10 11 12; 13 14 15 16]
ρ = [1 2 3 4; 5 6 7 8; 9 10 11 12; 13 14 15 16]
trans1 = [1 2 9 10; 5 6 13 14; 3 4 11 12; 7 8 15 16]
trans2 = [1 5 3 7; 2 6 4 8; 9 13 11 15; 10 14 12 16]
@test norm(ptranspose(ρ, [2, 2], [1]) - trans1) 0. atol=1e-15
@test norm(ptranspose(ρ, [2, 2], [2]) - trans2) 0. atol=1e-15

res1 = ptranspose(ρ, [2, 2], [1])
res2 = ptranspose(ρ, [2, 2], [2])

@test norm(res1 - trans1) 0. atol=1e-15
@test norm(res2 - trans2) 0. atol=1e-15

@test_throws ArgumentError ptranspose(ones(2, 3), [2, 2], 1)
@test_throws ArgumentError ptranspose(ones(4, 4), [2, 3], 1)
Expand Down
17 changes: 14 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,20 @@ using LinearAlgebra
# using SparseArrays
using Test

my_tests = ["utils.jl", "base.jl", "ptrace.jl", "ptranspose.jl", "reshuffle.jl",
"channels.jl", "functionals.jl", "gates.jl", "matrixbases.jl",
"permute_systems.jl", "randomqobjects.jl", "convex.jl"]
my_tests = [
"utils.jl",
"base.jl",
"ptrace.jl",
"ptranspose.jl",
"reshuffle.jl",
"channels.jl",
"functionals.jl",
"gates.jl",
"matrixbases.jl",
"permute_systems.jl",
"randomqobjects.jl",
"convex.jl"
]
for my_test in my_tests
include(my_test)
end

0 comments on commit a0195df

Please sign in to comment.