# Partial Trace Methods

In [29]:
using LinearAlgebra
using Random
using QuantumOptics


In [144]:
function partial_trace(rho::Matrix{Float64}, dims::Tuple{Int, Int}, keep::Int)
    dim1, dim2 = dims
    if keep == 0
        y = reshape(sum(reshape(rho, dim1, dim2, dim1, dim2), dims=(2, 4)), dim1, dim1)
    elseif keep == 1
        y = reshape(sum(reshape(rho, dim1, dim2, dim1, dim2), dims=(1, 3)), dim2, dim2)
    else
        throw(ArgumentError("The 'keep' argument must be 0 or 1."))
    end
    return y / tr(y)
end

partial_trace (generic function with 1 method)

In [150]:
A = [1.0 1. 1.; 0. 0. 1.; 0. 0. 0.]
B = [0.5 10. 2.; 10. 0.5 0.; 0. 8. 0.]
display(A)
display(B)

partial_trace(kron(A, B), (3, 3), 0)

3×3 Matrix{Float64}:
 1.0  1.0  1.0
 0.0  0.0  1.0
 0.0  0.0  0.0

3×3 Matrix{Float64}:
  0.5  10.0  2.0
 10.0   0.5  0.0
  0.0   8.0  0.0

3×3 Matrix{Float64}:
  0.5  10.0  2.0
 10.0   0.5  0.0
  0.0   8.0  0.0

In [140]:
basis = QuantumOptics.FockBasis(2)
ρₐ = QuantumOptics.DenseOperator(basis, A)
ρᵦ = QuantumOptics.DenseOperator(basis, B)
display(ρₐ)
display(ρᵦ)

Operator(dim=3x3)
  basis: Fock(cutoff=2)
 1.0  1.0  1.0
 0.0  0.0  1.0
 0.0  0.0  0.0

Operator(dim=3x3)
  basis: Fock(cutoff=2)
 0.5  0.0  2.0
 0.0  0.5  0.0
 0.0  0.0  0.0

In [141]:
ρ = ρₐ ⊗ ρᵦ
display(ρ)
QuantumOptics.ptrace(ρ, 2)

Operator(dim=9x9)
  basis: [Fock(cutoff=2) ⊗ Fock(cutoff=2)]
 0.5  0.5  0.5  0.0  0.0  0.0  2.0  2.0  2.0
 0.0  0.0  0.5  0.0  0.0  0.0  0.0  0.0  2.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.5  0.5  0.5  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.5  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0

Operator(dim=3x3)
  basis: Fock(cutoff=2)
 1.0  1.0  1.0
 0.0  0.0  1.0
 0.0  0.0  0.0

In [76]:
basis = QuantumOptics.FockBasis(20)
a = destroy(basis)
at = create(basis)
n = number(basis)

ω = 0.1
H = ω * n

Ψ₀ = fockstate(basis, 1)
ρ₀ = Ψ₀ ⊗ dagger(Ψ₀)
ρ₁ = Ψ₀ ⊗ dagger(Ψ₀)
ρₜ = thermalstate(H, 0.1)
display(ρₜ)

ρ = ρₜ ⊗ ρₜ

QuantumOptics.ptrace(ρ, 1)

Operator(dim=21x21)
  basis: Fock(cutoff=20)
 0.632121+0.0im       0.0+0.0im  …         0.0+0.0im        0.0+0.0im
      0.0-0.0im  0.232544+0.0im            0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im            0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im            0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im            0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im  …         0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im            0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im            0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im            0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im            0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im  …         0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im            0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im            0.0+0.0im        0.0+0.0im
      0.0-0.0im       0.0-0.0im            0.

Operator(dim=21x21)
  basis: Fock(cutoff=20)
 0.632121+0.0im       0.0+0.0im  …         0.0+0.0im        0.0+0.0im
      0.0+0.0im  0.232544+0.0im            0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im            0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im            0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im            0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im  …         0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im            0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im            0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im            0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im            0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im  …         0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im            0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im            0.0+0.0im        0.0+0.0im
      0.0+0.0im       0.0+0.0im            0.

In [27]:
using Base.Cartesian

N1 = 4
N2 = 4
N3 = 4

x = rand(Float64, N1, N2, N3, N1, N2, N3)
xx = rand(Float64, N1, N2)
a, b = size(r)
rand(Float64, 3, 3, 3)

3×3×3 Array{Float64, 3}:
[:, :, 1] =
 0.356158  0.0833558  0.962663
 0.983849  0.918505   0.746801
 0.677115  0.479717   0.825777

[:, :, 2] =
 0.197625  0.205668  0.165248
 0.729728  0.125497  0.322512
 0.199194  0.182216  0.0243705

[:, :, 3] =
 0.254571  0.762466  0.181103
 0.959139  0.885975  0.360793
 0.527045  0.202909  0.403571

In [None]:
function ptrace_forloops(x)
    n1, n2 = size(x)
    y = zeros(Float64, n2, n3, n2, n3)
    for i5=1:n3
        for i4=1:n2
            for i3=1:n3
                for i2=1:n2
                    for i1=1:n1
                        y[i2,i3,i4,i5] += x[i1,i2,i3,i1,i4,i5]
                    end
                end
            end
        end
    end
    y
end

In [15]:
function ptrace_forloops(x)
    n1, n2, n3 = size(x)
    y = zeros(Float64, n2, n3, n2, n3)
    for i5=1:n3
        for i4=1:n2
            for i3=1:n3
                for i2=1:n2
                    for i1=1:n1
                        y[i2,i3,i4,i5] += x[i1,i2,i3,i1,i4,i5]
                    end
                end
            end
        end
    end
    y
end

function ptrace_slicing(x::Array{Float64, 6})
    n1, n2, n3 = size(x)
    y = zeros(Float64, n2, n3, n2, n3)
    for i1=1:n1
        y += x[i1,:,:,i1,:,:]
    end
    y
end

function ptrace_cartesian(x::Array{Float64, 6})
    n1, n2, n3 = size(x)
    y = zeros(Float64, 1, n2, n3, 1, n2, n3)
    ymax = CartesianIndex(size(y))
    for I in CartesianIndices(size(x))
        if I.I[1] != I.I[4]
            continue
        end
        y[min(ymax, I)] += x[I]
    end
    reshape(y, n2, n3, n2, n3)
end

function ptrace_cartesian2(x::Array{Float64, 6})
    n1, n2, n3 = size(x)
    y = zeros(Float64, 1, n2, n3, 1, n2, n3)
    for I in CartesianIndices(size(y))
        for k in CartesianIndices((n1, 1, 1))
            delta = CartesianIndex(k, k)
            y[I] += x[I+delta-1]
        end
    end
    reshape(y, n2, n3, n2, n3)
end

# Partial trace for dense operators.
function _strides(shape::Vector{Int})
    N = length(shape)
    S = zeros(Int, N)
    S[N] = 1
    for m=N-1:-1:1
        S[m] = S[m+1]*shape[m+1]
    end
    return S
end

@generated function _ptrace(a::Matrix{Float64},
                                  shape_l::Vector{Int}, shape_r::Vector{Int},
                                  indices::Vector{Int})
    return quote
        a_strides_l = _strides(shape_l)
        result_shape_l = deepcopy(shape_l)
        result_shape_l[indices] = 1
        result_strides_l = _strides(result_shape_l)
        a_strides_r = _strides(shape_r)
        result_shape_r = deepcopy(shape_r)
        result_shape_r[indices] = 1
        result_strides_r = _strides(result_shape_r)
        N_result_l = prod(result_shape_l)
        N_result_r = prod(result_shape_r)
        result = zeros(Float64, N_result_l, N_result_r)
        @nexprs 1 (d->(Jr_{3}=1;Ir_{3}=1))
        @nloops 3 ir (d->1:shape_r[d]) (d->(Ir_{d-1}=Ir_d; Jr_{d-1}=Jr_d)) (d->(Ir_d+=a_strides_r[d]; if !(d in indices) Jr_d+=result_strides_r[d] end)) begin
            @nexprs 1 (d->(Jl_{3}=1;Il_{3}=1))
            @nloops 3 il (k->1:shape_l[k]) (k->(Il_{k-1}=Il_k; Jl_{k-1}=Jl_k; if (k in indices && il_k!=ir_k) Il_k+=a_strides_l[k]; continue end)) (k->(Il_k+=a_strides_l[k]; if !(k in indices) Jl_k+=result_strides_l[k] end)) begin
                #println("Jl_0: ", Jl_0, "; Jr_0: ", Jr_0, "; Il_0: ", Il_0, "; Ir_0: ", Ir_0)
                result[Jl_0, Jr_0] += a[Il_0, Ir_0]
            end
        end
        return result
    end
end

function ptrace_nloop(x)
    n1, n2, n3 = size(x)
    n = n1*n2*n3
    x = reshape(x, n, n)
    y = _ptrace(x, [n3,n2,n1], [n3,n2,n1], [3])
    reshape(y, n2, n3, n2, n3)
end

ptrace_nloop (generic function with 1 method)

In [17]:
result = ptrace_forloops(xx)

LoadError: BoundsError: attempt to access Tuple{Int64, Int64} at index [3]

In [None]:
println(dist(result, ptrace_slicing(x)))
println(dist(result, ptrace_cartesian(x)))
println(dist(result, ptrace_cartesian2(x)))
println(dist(result, ptrace_nloop(x)))


println("Explicit loops")
@time ptrace_forloops(x)
@time ptrace_forloops(x)

println("Slicing")
@time ptrace_slicing(x)
@time ptrace_slicing(x)

println("Cartesian Index")
@time ptrace_cartesian(x)
@time ptrace_cartesian(x)

println("Cartesian Index 2")
@time ptrace_cartesian2(x)
@time ptrace_cartesian2(x)

println("nloop")
@time ptrace_nloop(x)
@time ptrace_nloop(x)

0.0
0.0


LoadError: MethodError: no method matching -(::CartesianIndex{6}, ::Int64)
The function `-` exists, but no method is defined for this combination of argument types.

[0mClosest candidates are:
[0m  -([91m::Complex{Bool}[39m, ::Real)
[0m[90m   @[39m [90mBase[39m [90m[4mcomplex.jl:329[24m[39m
[0m  -([91m::Base.CoreLogging.LogLevel[39m, ::Integer)
[0m[90m   @[39m [90mBase[39m [90mlogging\[39m[90m[4mlogging.jl:133[24m[39m
[0m  -([91m::Missing[39m, ::Number)
[0m[90m   @[39m [90mBase[39m [90m[4mmissing.jl:123[24m[39m
[0m  ...


In [19]:
a = [1.0 1.0]
abs(a)

LoadError: MethodError: no method matching abs(::Matrix{Float64})
The function `abs` exists, but no method is defined for this combination of argument types.

[0mClosest candidates are:
[0m  abs([91m::Bool[39m)
[0m[90m   @[39m [90mBase[39m [90m[4mbool.jl:153[24m[39m
[0m  abs([91m::Pkg.Resolve.FieldValue[39m)
[0m[90m   @[39m [36mPkg[39m [90mC:\Users\fedes\.julia\juliaup\julia-1.11.1+0.x64.w64.mingw32\share\julia\stdlib\v1.11\Pkg\src\Resolve\[39m[90m[4mfieldvalues.jl:51[24m[39m
[0m  abs([91m::Missing[39m)
[0m[90m   @[39m [90mBase[39m [90m[4mmissing.jl:101[24m[39m
[0m  ...


In [65]:
partial_trace(r, 3, 1)

3×3 Matrix{Float64}:
 18.0  18.0  18.0
 18.0  18.0  18.0
 18.0  18.0  18.0