In [30]:
import numpy as np
from functools import reduce
from scipy.linalg import eigh
import matplotlib.pyplot as plt
from scipy.sparse import kron, identity, csr_matrix, lil_matrix, coo_matrix, issparse
from scipy.sparse.linalg import eigsh, eigs
from qutip import Qobj, ptrace
from qutip import commutator as qt_commutator
from tqdm import tqdm
from itertools import product
import torch

In [19]:
def partial_trace_qubit(rho, keep, dims):
    """Compute the partial trace of a density matrix of qubits."""
    print("keep",keep)
    print("dims",dims)
    keep_dims = np.prod([dims[i] for i in keep])
    print("keep_dims",keep_dims)
    trace_dims = np.prod([dims[i] for i in range(len(dims)) if i not in keep])
    print("trace_dims",trace_dims)
    rho = rho.reshape([keep_dims, trace_dims, keep_dims, trace_dims])
    print("rho",rho)
    return np.trace(rho, axis1=1, axis2=3).reshape([keep_dims, keep_dims])

def partial_trace_bipartite(psi):
    """
    Compute the partial trace of a square matrix over half of its dimensions.

    Parameters:
    - matrix (numpy.ndarray): Input density matrix of shape (d*d, d*d), where d is the subsystem size.
    - keep_dim (int): Dimension of the subsystem to keep.

    Returns:
    - numpy.ndarray: The reduced density matrix of the kept subsystem.
    """

    # Reshape into 4D tensor: (keep_dim, trace_dim, keep_dim, trace_dim)
    reshaped_psi = np.reshape(psi, (len(psi)//2, len(psi) - len(psi)//2))

    # Perform trace over last axis (tracing out the second subsystem)
    reduced_matrix = np.matmul(reshaped_psi, reshaped_psi.T.conj())

    return reduced_matrix

def vnentropy_qubit(psi, subsystem_size, total_size):

    '''Computes the bipartite entanglement entropy of a pure state.
    
    Parameters:
    psi : np.array
        The wavefunction (state vector) of the full system.
    subsystem_size : int
        The number of qubits in subsystem A.
    total_size : int
        The total number of qubits in the system.
    
    Returns:
    float
        The von Neumann entanglement entropy S_A.'''
    
    psi_matrix =  np.outer(psi, psi.conj())

    # Compute the reduced density matrix rho_A = Tr_B(|psi><psi|)
    rho_A = partial_trace_qubit(psi_matrix, list(range(subsystem_size)), [2]*total_size)  # Partial trace over B
    
    # Compute eigenvalues of rho_A
    eigenvalues = np.linalg.eigvalsh(rho_A)
    
    # Filter out zero eigenvalues to avoid numerical issues in log calculation
    eigenvalues = eigenvalues[eigenvalues > 0]
    
    # Compute von Neumann entropy S_A = -Tr(rho_A log rho_A)
    entropy = -np.sum(eigenvalues * np.log2(eigenvalues))
    
    return entropy

In [31]:
def isket_numpy(arr):
    """
    Check if a NumPy array is a ket (column vector).

    Parameters:
    - arr: np.ndarray, the array to check.

    Returns:
    - bool, True if the array is a ket, False otherwise.
    """
    if not isinstance(arr, np.ndarray):
        raise ValueError("Input must be a NumPy array")

    shape = arr.shape

    if len(shape) == 2 and shape[1] == 1:
        return True
    else:
        return False

def ptrace_numpy(Q, sel, dims):
    """
    Compute the partial trace of a density matrix of qubits using NumPy.

    Parameters:
    - Q: numpy object, the quantum object (density matrix or state vector).
    - sel: list of int, indices of the subsystems to keep.
    - dims: list of int, dimensions of the subsystems.

    Returns:
    - numpy object, the reduced density matrix after tracing out the specified subsystems.
    """
    # Get the dimensions of the subsystems
    rd = np.asarray(dims[0], dtype=np.int32).ravel()
    print("rd", rd)
    nd = len(rd)
    print("nd", nd)
    
    # Ensure sel is a sorted array of indices
    if isinstance(sel, int):
        sel = np.array([sel])
    else:
        sel = np.asarray(sel)
    sel = list(np.sort(sel))
    
    # Dimensions of the subsystems to keep
    dkeep = (rd[sel]).tolist()
    
    # Indices of the subsystems to trace out
    qtrace = list(set(np.arange(nd)) - set(sel))
    
    # Dimensions of the subsystems to trace out
    dtrace = (rd[qtrace]).tolist()
    
    # Reshape the density matrix or state vector
    rd = list(rd)
    if isket_numpy(Q):
        # Reshape and transpose for state vector
        reshaped_Q = Q.reshape(rd)
        print("reshaped_Q", reshaped_Q)
        transposed_Q = reshaped_Q.transpose(sel + qtrace)
        print("transposed_Q", transposed_Q)
        vmat = transposed_Q.reshape([np.prod(dkeep), np.prod(dtrace)])
        print("vmat", vmat)
        # Compute the reduced density matrix
        rhomat = vmat.dot(vmat.conj().T)
        print("rhomat", rhomat)
    else:
        # Reshape and transpose for density matrix
        reshaped_Q = Q.reshape(rd + rd)
        print("reshaped_Q", reshaped_Q)
        #print("reshaped_Q", reshaped_Q.shape)
        transposed_Q = reshaped_Q.transpose(qtrace + [nd + q for q in qtrace] + sel + [nd + q for q in sel])
        print("transposed_Q", transposed_Q)
        reshaped_transposed_Q = transposed_Q.reshape([np.prod(dtrace), np.prod(dtrace), np.prod(dkeep), np.prod(dkeep)])
        print("reshaped_transposed_Q", reshaped_transposed_Q)
        rhomat = np.trace(reshaped_transposed_Q)
        print("rhomat", rhomat)
        print("rhomat", rhomat.shape)
    return rhomat

def ptrace_numpy2(psi, keep, dims):
    """
    Compute the partial trace over specified subsystems (general, non-bipartite).

    Args:
        psi (np.ndarray): Full density matrix, shape (D, D), D = product of dims
        keep (list of int): Subsystems to keep (indices, 0-indexed)
        dims (list of int): List of subsystem dimensions (e.g. [2,2,2,...] for qubits)

    Returns:
        np.ndarray: Reduced density matrix of shape (d_keep, d_keep)
    """
    n = len(dims)
    D = np.prod(dims)
    if psi.shape != (D, D):
        raise ValueError("Density matrix shape does not match dims")

    trace = [i for i in range(n) if i not in keep]
    d_keep = np.prod([dims[i] for i in keep])

    # Reshape into 2n-index tensor
    psi_reshaped = psi.reshape(dims + dims)

    # Label axes: [0, ..., n-1, n, ..., 2n-1]
    input_labels = list(range(2 * n))
    output_labels = [i for i in range(n) if i in keep] + [i + n for i in range(n) if i in keep]

    # For traced-out axes, force contraction: make j_i = i_i
    for t in trace:
        input_labels[n + t] = input_labels[t]

    # Einsum using integer index notation
    reduced = np.einsum(psi_reshaped, input_labels, output_labels, optimize=True)

    return reduced.reshape(d_keep, d_keep)


def isket_torch(arr):
    """
    Check if a PyTorch tensor is a ket (column vector).

    Parameters:
    - arr: torch.Tensor, the array to check.

    Returns:
    - bool, True if the array is a ket, False otherwise.
    """
    if not isinstance(arr, torch.Tensor):
        raise ValueError("Input must be a PyTorch tensor")

    shape = arr.shape

    if len(shape) == 2 and shape[1] == 1:
        return True
    else:
        return False

def ptrace_torch(Q, sel, dims):
    """
    Compute the partial trace of a density matrix of qubits using PyTorch.

    Parameters:
    - Q: torch.Tensor, the quantum object (density matrix or state vector).
    - sel: list of int, indices of the subsystems to keep.
    - dims: list of int, dimensions of the subsystems.

    Returns:
    - torch.Tensor, the reduced density matrix after tracing out the specified subsystems.
    """
    # Get the dimensions of the subsystems
    rd = torch.tensor(dims[0], dtype=torch.int32).flatten()
    nd = len(rd)
    print("rd", rd)
    print("nd", nd)
    
    # Ensure sel is a sorted array of indices
    if isinstance(sel, int):
        sel = torch.tensor([sel])
    else:
        sel = torch.tensor(sel)
    sel = torch.sort(sel).values.tolist()
    
    # Dimensions of the subsystems to keep
    dkeep = rd[sel].tolist()
    
    # Indices of the subsystems to trace out
    qtrace = list(set(range(nd)) - set(sel))
    
    # Dimensions of the subsystems to trace out
    dtrace = rd[qtrace].tolist()
    
    # Reshape the density matrix or state vector
    rd = rd.tolist()
    if isket_torch(Q):
        # Reshape and transpose for state vector
        reshaped_Q = Q.reshape(rd)
        print("reshaped_Q", reshaped_Q)
        #print(reshaped_Q.shape)
        transposed_Q = reshaped_Q.permute(sel + qtrace)
        print("transposed_Q", transposed_Q)
        #print(transposed_Q.shape)
        vmat = transposed_Q.reshape([torch.prod(torch.tensor(dkeep)), torch.prod(torch.tensor(dtrace))])
        print("vmat", vmat)
        #print(vmat.shape)
        # Compute the reduced density matrix
        rhomat = vmat @ vmat.conj().T
        print("rhomat", rhomat) 
        #print(rhomat.shape)
    else:
        # Reshape and transpose for density matrix
        reshaped_Q = Q.reshape(rd + rd)
        #print("reshaped_Q", reshaped_Q.shape)
        print("reshaped_Q", reshaped_Q)
        transposed_Q = reshaped_Q.permute(qtrace + [nd + q for q in qtrace] + sel + [nd + q for q in sel])
        #print("transposed_Q", transposed_Q.shape)
        print("transposed_Q", transposed_Q)
        reshaped_transposed_Q = transposed_Q.reshape([torch.prod(torch.tensor(dtrace)), torch.prod(torch.tensor(dtrace)), torch.prod(torch.tensor(dkeep)), torch.prod(torch.tensor(dkeep))])
        #print("reshaped_transposed_Q", reshaped_transposed_Q.shape)
        #rhomat = torch.trace(reshaped_transposed_Q)
        rhomat = torch.einsum('ijkl->kl', reshaped_transposed_Q)
        print("rhomat", rhomat)
        print("rhomat", rhomat.shape)
        # Trace out the first two dimensions
        '''rhomat = torch.zeros((torch.prod(torch.tensor(dkeep)), torch.prod(torch.tensor(dkeep))), dtype=Q.dtype)
        for i in range(reshaped_transposed_Q.shape[0]):
            for j in range(reshaped_transposed_Q.shape[1]):
                rhomat += reshaped_transposed_Q[i, j, :, :]'''
        #print("rhomat", rhomat.shape)
    return rhomat

def ptrace_sparse(psi_sparse, keep, dims):
    """
    Compute the partial trace over arbitrary subsystems using sparse matrix operations.

    Args:
        psi_sparse (scipy.sparse matrix): Full density matrix of shape (D, D), where D = product(dims)
        keep (list of int): Subsystems to keep (indices, 0-indexed)
        dims (list of int): List of subsystem dimensions, e.g., [2]*n for n qubits

    Returns:
        scipy.sparse.csr_matrix: Reduced density matrix over kept subsystems
    """
    if not issparse(psi_sparse):
        raise ValueError("psi_sparse must be a scipy.sparse matrix")

    n = len(dims)
    D = np.prod(dims)
    if psi_sparse.shape != (D, D):
        raise ValueError("Density matrix shape does not match dims")

    trace = [i for i in range(n) if i not in keep]
    d_keep = np.prod([dims[i] for i in keep])

    # Prepare output
    data = []
    row_idx = []
    col_idx = []

    # Precompute bit masks
    def idx_to_bits(idx):
        return np.array(list(np.binary_repr(idx, width=n))).astype(int)

    psi_sparse = psi_sparse.tocoo()
    for i, j, val in zip(psi_sparse.row, psi_sparse.col, psi_sparse.data):
        bi = idx_to_bits(i)
        bj = idx_to_bits(j)

        # Only sum terms where traced-out subsystems agree
        if np.all(bi[trace] == bj[trace]):
            # Extract kept bits and convert to reduced indices
            i_red_bits = bi[keep]
            j_red_bits = bj[keep]
            i_red = int("".join(i_red_bits.astype(str)), 2)
            j_red = int("".join(j_red_bits.astype(str)), 2)

            data.append(val)
            row_idx.append(i_red)
            col_idx.append(j_red)

    return coo_matrix((data, (row_idx, col_idx)), shape=(d_keep, d_keep)).tocsr()


In [21]:
rho_phip = np.array([[1/2,0,0,1/2],[0,0,0,0],[0,0,0,0],[1/2,0,0,1/2]]) # bell state phi+
rho_phipq = Qobj(rho_phip, dims=[[2, 2], [2, 2]])
print(rho_phipq)
ptrace(rho_phipq, [0])
ptrace(rho_phipq, [1])

Quantum object: dims=[[2, 2], [2, 2]], shape=(4, 4), type='oper', dtype=Dense, isherm=True
Qobj data =
[[0.5 0.  0.  0.5]
 [0.  0.  0.  0. ]
 [0.  0.  0.  0. ]
 [0.5 0.  0.  0.5]]


Quantum object: dims=[[2], [2]], shape=(2, 2), type='oper', dtype=Dense, isherm=True
Qobj data =
[[0.5 0. ]
 [0.  0.5]]

In [22]:
rho_phip = np.array([[1/2,0,0,1/2],[0,0,0,0],[0,0,0,0],[1/2,0,0,1/2]]) # bell state phi+
print(rho_phip)
ptrace_numpy(rho_phip, [0], [[2, 2], [2, 2]])
ptrace_numpy(rho_phip, [1], [[2, 2], [2, 2]])

[[0.5 0.  0.  0.5]
 [0.  0.  0.  0. ]
 [0.  0.  0.  0. ]
 [0.5 0.  0.  0.5]]
rd [2 2]
nd 2
reshaped_Q [[[[0.5 0. ]
   [0.  0.5]]

  [[0.  0. ]
   [0.  0. ]]]


 [[[0.  0. ]
   [0.  0. ]]

  [[0.5 0. ]
   [0.  0.5]]]]
transposed_Q [[[[0.5 0. ]
   [0.  0. ]]

  [[0.  0.5]
   [0.  0. ]]]


 [[[0.  0. ]
   [0.5 0. ]]

  [[0.  0. ]
   [0.  0.5]]]]
reshaped_transposed_Q [[[[0.5 0. ]
   [0.  0. ]]

  [[0.  0.5]
   [0.  0. ]]]


 [[[0.  0. ]
   [0.5 0. ]]

  [[0.  0. ]
   [0.  0.5]]]]
rhomat [[0.5 0. ]
 [0.  0.5]]
rhomat (2, 2)
rd [2 2]
nd 2
reshaped_Q [[[[0.5 0. ]
   [0.  0.5]]

  [[0.  0. ]
   [0.  0. ]]]


 [[[0.  0. ]
   [0.  0. ]]

  [[0.5 0. ]
   [0.  0.5]]]]
transposed_Q [[[[0.5 0. ]
   [0.  0. ]]

  [[0.  0.5]
   [0.  0. ]]]


 [[[0.  0. ]
   [0.5 0. ]]

  [[0.  0. ]
   [0.  0.5]]]]
reshaped_transposed_Q [[[[0.5 0. ]
   [0.  0. ]]

  [[0.  0.5]
   [0.  0. ]]]


 [[[0.  0. ]
   [0.5 0. ]]

  [[0.  0. ]
   [0.  0.5]]]]
rhomat [[0.5 0. ]
 [0.  0.5]]
rhomat (2, 2)


array([[0.5, 0. ],
       [0. , 0.5]])

In [25]:
rho_phip = np.array([[1/2,0,0,1/2],[0,0,0,0],[0,0,0,0],[1/2,0,0,1/2]]) # bell state phi+
print(rho_phip)
ptrace_numpy2(rho_phip, [0], [2, 2])
ptrace_numpy2(rho_phip, [1], [2, 2])

[[0.5 0.  0.  0.5]
 [0.  0.  0.  0. ]
 [0.  0.  0.  0. ]
 [0.5 0.  0.  0.5]]


array([[0.5, 0. ],
       [0. , 0.5]])

In [None]:
rho_phip = torch.tensor([[1/2,0,0,1/2],[0,0,0,0],[0,0,0,0],[1/2,0,0,1/2]], dtype=torch.complex64) # bell state phi+
print(rho_phip)
ptrace_torch(rho_phip, [0], [[2, 2], [2, 2]])
ptrace_torch(rho_phip, [1], [[2, 2], [2, 2]])

tensor([[0.5000+0.j, 0.0000+0.j, 0.0000+0.j, 0.5000+0.j],
        [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j],
        [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j],
        [0.5000+0.j, 0.0000+0.j, 0.0000+0.j, 0.5000+0.j]])
rd tensor([2, 2], dtype=torch.int32)
nd 2
reshaped_Q tensor([[[[0.5000+0.j, 0.0000+0.j],
          [0.0000+0.j, 0.5000+0.j]],

         [[0.0000+0.j, 0.0000+0.j],
          [0.0000+0.j, 0.0000+0.j]]],


        [[[0.0000+0.j, 0.0000+0.j],
          [0.0000+0.j, 0.0000+0.j]],

         [[0.5000+0.j, 0.0000+0.j],
          [0.0000+0.j, 0.5000+0.j]]]])
transposed_Q tensor([[[[0.5000+0.j, 0.0000+0.j],
          [0.0000+0.j, 0.0000+0.j]],

         [[0.0000+0.j, 0.5000+0.j],
          [0.0000+0.j, 0.0000+0.j]]],


        [[[0.0000+0.j, 0.0000+0.j],
          [0.5000+0.j, 0.0000+0.j]],

         [[0.0000+0.j, 0.0000+0.j],
          [0.0000+0.j, 0.5000+0.j]]]])
rhomat tensor([[0.5000+0.j, 0.5000+0.j],
        [0.5000+0.j, 0.5000+0.j]])
rhomat torch.Size([2, 2])
r

tensor([[0.5000+0.j, 0.5000+0.j],
        [0.5000+0.j, 0.5000+0.j]])

In [34]:
rho_phip = np.array([[1/2,0,0,1/2],[0,0,0,0],[0,0,0,0],[1/2,0,0,1/2]]) # bell state phi+
rho_phips = csr_matrix(rho_phip)
out0 = ptrace_sparse(rho_phips, [0], [2, 2])
out1 = ptrace_sparse(rho_phips, [1], [2, 2])
print(out0.toarray())
print(out1.toarray())

[[0.5 0. ]
 [0.  0.5]]
[[0.5 0. ]
 [0.  0.5]]


In [None]:
ghz_state = np.array([1/np.sqrt(2), 0, 0, 0, 0, 0, 0, 1/np.sqrt(2)], dtype=complex) # GHZ state
rho_ghz = np.outer(ghz_state, ghz_state.conj())

rho_ghzq = Qobj(rho_ghz, dims=[[2, 2, 2], [2, 2, 2]])
print(rho_ghzq)
print(ptrace(rho_ghzq, [0,1]))
print(ptrace(rho_ghzq, [1,2]))
print(ptrace(rho_ghzq, [0,2]))

Quantum object: dims=[[2, 2, 2], [2, 2, 2]], shape=(8, 8), type='oper', dtype=Dense, isherm=True
Qobj data =
[[0.5 0.  0.  0.  0.  0.  0.  0.5]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.5 0.  0.  0.  0.  0.  0.  0.5]]
Quantum object: dims=[[2, 2], [2, 2]], shape=(4, 4), type='oper', dtype=Dense, isherm=True
Qobj data =
[[0.5 0.  0.  0. ]
 [0.  0.  0.  0. ]
 [0.  0.  0.  0. ]
 [0.  0.  0.  0.5]]
Quantum object: dims=[[2, 2], [2, 2]], shape=(4, 4), type='oper', dtype=Dense, isherm=True
Qobj data =
[[0.5 0.  0.  0. ]
 [0.  0.  0.  0. ]
 [0.  0.  0.  0. ]
 [0.  0.  0.  0.5]]
Quantum object: dims=[[2, 2], [2, 2]], shape=(4, 4), type='oper', dtype=Dense, isherm=True
Qobj data =
[[0.5 0.  0.  0. ]
 [0.  0.  0.  0. ]
 [0.  0.  0.  0. ]
 [0.  0.  0.  0.5]]


In [27]:
ghz_state = np.array([1/np.sqrt(2), 0, 0, 0, 0, 0, 0, 1/np.sqrt(2)], dtype=complex) # GHZ state
rho_ghz = np.outer(ghz_state, ghz_state.conj())

print(rho_ghzq)
print(ptrace_numpy(rho_ghz, [0,1], [[2, 2, 2], [2, 2, 2]]))
#print(ptrace_numpy(rho_ghz, [1,2], [[2, 2, 2], [2, 2, 2]]))
#print(ptrace_numpy(rho_ghz, [0,2], [[2, 2, 2], [2, 2, 2]]))

Quantum object: dims=[[2, 2, 2], [2, 2, 2]], shape=(8, 8), type='oper', dtype=Dense, isherm=True
Qobj data =
[[0.5 0.  0.  0.  0.  0.  0.  0.5]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.5 0.  0.  0.  0.  0.  0.  0.5]]
rd [2 2 2]
nd 3
reshaped_Q [[[[[[0.5+0.j 0. +0.j]
     [0. +0.j 0. +0.j]]

    [[0. +0.j 0. +0.j]
     [0. +0.j 0.5+0.j]]]


   [[[0. +0.j 0. +0.j]
     [0. +0.j 0. +0.j]]

    [[0. +0.j 0. +0.j]
     [0. +0.j 0. +0.j]]]]



  [[[[0. +0.j 0. +0.j]
     [0. +0.j 0. +0.j]]

    [[0. +0.j 0. +0.j]
     [0. +0.j 0. +0.j]]]


   [[[0. +0.j 0. +0.j]
     [0. +0.j 0. +0.j]]

    [[0. +0.j 0. +0.j]
     [0. +0.j 0. +0.j]]]]]




 [[[[[0. +0.j 0. +0.j]
     [0. +0.j 0. +0.j]]

    [[0. +0.j 0. +0.j]
     [0. +0.j 0. +0.j]]]


   [[[0. +0.j 0. +0.j]
     [0. +0.j 0. +0.j]]

    [[0. +0.j 0. +0.j]
     [0. +0.j 0.

In [28]:
ghz_state = np.array([1/np.sqrt(2), 0, 0, 0, 0, 0, 0, 1/np.sqrt(2)], dtype=complex) # GHZ state
rho_ghz = np.outer(ghz_state, ghz_state.conj())

print(rho_ghzq)
print(ptrace_numpy2(rho_ghz, [0,1], [2, 2, 2]))
#print(ptrace_numpy2(rho_ghz, [1,2], [2, 2, 2]))
#print(ptrace_numpy2(rho_ghz, [0,2], [2, 2, 2]))

Quantum object: dims=[[2, 2, 2], [2, 2, 2]], shape=(8, 8), type='oper', dtype=Dense, isherm=True
Qobj data =
[[0.5 0.  0.  0.  0.  0.  0.  0.5]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.5 0.  0.  0.  0.  0.  0.  0.5]]
[[0.5+0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0.5+0.j]]


In [None]:
ghz_state = torch.tensor([1/np.sqrt(2), 0, 0, 0, 0, 0, 0, 1/np.sqrt(2)], dtype=torch.complex64) # GHZ state
rho_ghz = torch.outer(ghz_state, ghz_state.conj())

print(rho_ghz)
print(ptrace_torch(rho_ghz, [0,1], [[2, 2, 2], [2, 2, 2]]))
#print(ptrace_torch(rho_ghz, [1,2], [[2, 2, 2], [2, 2, 2]]))
#print(ptrace_torch(rho_ghz, [0,2], [[2, 2, 2], [2, 2, 2]]))

tensor([[0.5000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j,
         0.5000+0.j],
        [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j,
         0.0000+0.j],
        [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j,
         0.0000+0.j],
        [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j,
         0.0000+0.j],
        [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j,
         0.0000+0.j],
        [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j,
         0.0000+0.j],
        [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j,
         0.0000+0.j],
        [0.5000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j,
         0.5000+0.j]])
rd tensor([2, 2, 2], dtype=torch.int32)
nd 3
reshaped_Q tensor([[[[[[0.5000+0.j