# 2.5: 受控量子门与 TensorPureSate 类

单体量子门作用于多比特量子态中的某个量子位上

$|\psi'\rangle = \hat{O}|\psi\rangle \Leftrightarrow \psi'_{s_0...s_n...s_{N-1}} = \sum_s O_{s_ns}\psi_{s_0...s...s_{N-1}}$

![single_qubit_gate_tensor_network_representation](./images/single-qubit-gate-tensor-network-repr.png)

量子门作用到多体态的一般情况

![multiple-single-qubit-gate](./images/multiple-single-qubit-gate.png)

![multiple-multi-qubit-gate](./images/multiple-multi-qubit-gate.png)

## 受控量子门 Controlled Gate

**目标比特或目标位**：被幺正变换作用的量子比特

**控制比特或控制位**：决定是否将幺正变换作用到目标位

### 受控非门(CNOT)

若控制位处于|1⟩，则作用泡利x算符翻转目标位的状态，否则不作任何操作；控制位总是不被操作。

$\text{CNOT} = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & 0 & 1 & 0 \end{bmatrix}$

$\text{CNOT}|00\rangle = |00\rangle \qquad \langle00|\text{CNOT}|00\rangle = 1$

$\text{CNOT}|01\rangle = |01\rangle \qquad \langle01|\text{CNOT}|01\rangle = 1$

$\text{CNOT}|10\rangle = |11\rangle \qquad \langle10|\text{CNOT}|11\rangle = 1$

$\text{CNOT}|11\rangle = |10\rangle \qquad \langle11|\text{CNOT}|10\rangle = 1$

![cnot-gate](./images/cnot-gate.png)

### 受控酉门(controlled-U)

若控制位处于$|1\rangle$，则作用$U$算符翻转目标位的状态，否则不作任何操作

$$
\begin{bmatrix}
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
0 & 0 & U_{0,0} & U_{0,1} \\
0 & 0 & U_{1,0} & U_{1,1}
\end{bmatrix}
$$

![controlled-U](./images/controlled-U.png)

### 受控门计算方法

在实际程序计算时，我们并不需要写出整个受控门，而是仅需利用上述$\bar{U}$相关的计算，这样可大幅节约计算成本。

例子：考虑4比特量子态，以量子位1为控制比特，以量子位2为目标比特，作用上CNOT门，设作用后量子态系数张量为$\Psi$

1. 指标交换为(顺序为：目标位、其他、控制位)：

    $\psi_{s_0s_1s_2s_3} \xrightarrow{\text{permute}(2,0,3,1)} \psi'_{s_2s_0s_3s_1}$

    $s_2$ 是目标位，$s_1$ 是控制位，$s_0$ 和 $s_3$ 是其他位

2. 切片+矩阵化+矩阵乘，作用算符$\sigma^x$：

    $\sigma^x(\psi'_{:,:,:,1})_{[0]} \xrightarrow{\text{reshape}(2,2,2)} \phi$

    1. 切片：$\psi'_{:,:,:,1}$ 取出控制位为1的部分
    2. 矩阵化：$\psi'_{:,:,:,1}$ 变为 $2\times4$ 的矩阵，也就是 $[0]$ 的意思，留下第 0 个维度，把其他维度 flatten
    3. 矩阵乘：$\sigma^x$ 作用于 $\psi'_{:,:,:,1}$ 的矩阵上
    4. 矩阵变为张量：$\phi$ 变为 $2\times2\times2$ 的张量

3. 拼接得到新 $\psi'$：

    $\psi_{:,:,:,0}' = \psi_{:,:,:,0}$
    
    $\psi_{:,:,:,1}' = \phi$
4. 指标交换，恢复原来的指标顺序

> 计算复杂度从<2 比特门作用于 3 比特量子态> 变成 <1 比特门作用于 1 比特量子态>，大幅降低了计算复杂度

#### 例子：GHZ 态上作用 CNOT 门

> Greenberger-Horne-Zeilinger (GHZ) 态，又被称为猫态 (cat state)，其定义为
>
> $|\psi\rangle = \frac{1}{\sqrt{2}}(\prod_{\otimes n=0}^{N-1}|0_n\rangle + \prod_{\otimes n=0}^{N-1}|1_n\rangle)$
>
> 4比特GHZ态：$|\psi\rangle = \frac{1}{\sqrt{2}}(|0000\rangle + |1111\rangle)$

以 4 比特 GHZ 态为例，作用 CNOT 门，控制位为 1，目标位为 2

量子线路图如下

![ghz-cnot-example](./images/ghz-cnot-example.png)

用上面的受控门计算方法，和直接把 CNOT 的矩阵作用到 GHZ 态上，得到的结果是一样的


### Code

In [1]:
#|default_exp tensor_gates.functional
#|export
from tensor_network.utils import iterable_have_common, inverse_permutation, check_quantum_gate, check_state_tensor, unify_tensor_dtypes
import torch
from typing import List
from einops import einsum

In [2]:
#|export
def apply_gate(*, quantum_state: torch.Tensor, gate: torch.Tensor, target_qubit: int | List[int], control_qubit: int | List[int] | None = None) -> torch.Tensor:
    check_state_tensor(quantum_state)

    # check types
    assert isinstance(target_qubit, (int, list)), "target qubit must be int or list"
    assert control_qubit is None or isinstance(control_qubit, (int, list)), "control_qubit must be int or list"


    # unify types
    if isinstance(target_qubit, int):
        target_qubit = [target_qubit]
    if control_qubit is None:
        control_qubit = []
    elif isinstance(control_qubit, int):
        control_qubit = [control_qubit]
    assert not iterable_have_common(target_qubit, control_qubit), "target qubit and control qubit must not overlap"

    num_qubits = quantum_state.ndim
    num_target_qubit = len(target_qubit)
    check_quantum_gate(gate, num_target_qubit)

    quantum_state, gate = unify_tensor_dtypes(quantum_state, gate)

    # check indices
    for qidx in target_qubit:
        assert 0 <= qidx < num_qubits, f"target qubit index {qidx} out of range"
    for qidx in control_qubit:
        assert 0 <= qidx < num_qubits, f"control qubit index {qidx} out of range"    
    
    if gate.ndim == 2:
        # if in matrix form, reshape to tensor form
        new_shape = [2] * (num_target_qubit * 2)
        gate = gate.reshape(new_shape)
    
    other_qubits = list(range(num_qubits))
    for qubit_idx in target_qubit:
        other_qubits.remove(qubit_idx)
    for qubit_idx in control_qubit:
        other_qubits.remove(qubit_idx)

    num_other_qubits = len(other_qubits)
    permutation = target_qubit + other_qubits + control_qubit
    state = torch.permute(quantum_state, permutation)
    state_shape = state.shape # (*target_qubit_shapes, *other_qubit_shapes, *control_qubit_shapes)
    # Flatten the state tensor, so that the shape is (target_qubit_shapes, other_qubit_shapes, -1)
    new_shape = [2] * (num_target_qubit + num_other_qubits) + [-1]
    state = state.reshape(new_shape)
    # only when control qubits are 11111... the gate is applied
    unaffected_state = state[..., :-1] # (*target_qubit_shapes, *other_qubit_shapes, flattened_dim-1)
    state_to_apply_gate = state[..., -1] # (*target_qubit_shapes, *other_qubit_shapes)
    # apply gate
    target_qubit_names = [f"t{i}" for i in target_qubit]
    other_qubit_names = [f"o{i}" for i in other_qubits]
    gate_output_qubit_names = [f"g{i}" for i in target_qubit]
    einsum_str = "{gate_dims}, {state_dims} -> {output_dims}".format(
        gate_dims = " ".join(gate_output_qubit_names + target_qubit_names),
        state_dims = " ".join(target_qubit_names + other_qubit_names),
        output_dims = " ".join(gate_output_qubit_names + other_qubit_names)
    )
    new_state = einsum(gate, state_to_apply_gate, einsum_str)
    new_state = new_state.unsqueeze(-1)

    final_state = torch.cat([unaffected_state, new_state], dim=-1) # (*target_qubit_shapes, *other_qubit_shapes, flattened_dim)
    final_state = final_state.reshape(state_shape) # (*target_qubit_shapes, *other_qubit_shapes, *control_qubit_shapes)
    inversed_permutation = inverse_permutation(permutation)
    final_state = final_state.permute(inversed_permutation)
    return final_state


#### Testing

In [3]:
# set up importing for ref code
from tensor_network import setup_ref_code_import

from Library.QuantumState import TensorPureState
from copy import deepcopy

From setup_ref_code_import:
  Added reference_code_path='/Users/zhiqiu/offline_code/personal/tensor_network/reference_code' to sys.path.
  You can import the reference code now.


In [4]:
# initialize a quantum state of 4 qubits
state = TensorPureState(nq=4, dtype=torch.complex128)

# initialize a random quantum gate of 1 qubit
gate1 = torch.randn(2, 2, dtype=torch.complex128)
# initialize a random quantum gate of 2 qubits
gate2_mat = torch.randn(4, 4, dtype=torch.complex128)
gate2_tensor = gate2_mat.reshape(2, 2, 2, 2)
# initialize a random quantum gate of 3 qubits
gate3_mat = torch.randn(8, 8, dtype=torch.complex128)
gate3_tensor = gate3_mat.reshape(2, 2, 2, 2, 2, 2)
# initialize a random quantum gate of 4 qubits
gate4_mat = torch.randn(16, 16, dtype=torch.complex128)
gate4_tensor = gate4_mat.reshape(2, 2, 2, 2, 2, 2, 2, 2)

# tests
for i, (gate_mat, gate_tensor) in enumerate([
        (gate1, gate1),
        (gate2_mat, gate2_tensor), 
        (gate3_mat, gate3_tensor), 
        (gate4_mat, gate4_tensor)
    ]):
    qubit_num = i + 1
    s = deepcopy(state)
    t = s.tensor.clone()
    s.act_single_gate(gate=gate_tensor, pos=list(range(qubit_num)))
    result_ref = s.tensor
    result_mat = apply_gate(quantum_state=t, gate=gate_mat, target_qubit=list(range(qubit_num)))
    result_tensor = apply_gate(quantum_state=t, gate=gate_tensor, target_qubit=list(range(qubit_num)))
    assert torch.allclose(result_mat, result_ref), f"result: {result_mat}, result_ref: {result_ref}"
    assert torch.allclose(result_tensor, result_ref), f"result: {result_tensor}, result_ref: {result_ref}"
