In [2]:
#|default_exp utils
#|export
from typing import List, Tuple, Union
import torch

In [3]:
#|export
def iterable_have_common(a: List[int] | Tuple[int], b: List[int] | Tuple[int]) -> bool:
    """
    Check if two iterables have any common elements.

    Args:
        a: First iterable.
        b: Second iterable.

    Returns:
        True if there are common elements, False otherwise.
    """
    assert isinstance(a, (list, tuple)), "First argument must be a list or tuple."
    assert isinstance(b, (list, tuple)), "Second argument must be a list or tuple."
    return len(set(a) & set(b)) > 0

def inverse_permutation(permutation: List[int]) -> List[int]:
    permutation = torch.tensor(permutation, dtype=torch.long)
    inv = torch.empty_like(permutation)
    inv[permutation] = torch.arange(permutation.size(0))
    return inv.tolist()

In [None]:
#|export
def check_state_tensor(tensor: torch.Tensor):
    assert isinstance(tensor, torch.Tensor), "quantum_state must be a torch.Tensor"
    assert tensor.dtype in [torch.float32, torch.float64, torch.complex64, torch.complex128], \
        "quantum_state must be a float or complex tensor"
    assert all(x == 2 for x in tensor.shape), "quantum_state must be a tensor with all dimensions of size 2"
    assert tensor.ndim > 0, "quantum_state must be a tensor with at least one dimension"

def check_quantum_gate(tensor: torch.Tensor, num_qubits: int):
    assert isinstance(tensor, torch.Tensor), "quantum_gate must be a torch.Tensor"
    assert tensor.dtype in [torch.float32, torch.float64, torch.complex64, torch.complex128], \
        "quantum_gate must be a float or complex tensor"
    assert tensor.ndim >= 2, "quantum_gate must be a tensor with at least two dimensions"
    assert tensor.ndim % 2 == 0, "quantum_gate must have an even number of dimensions"

    if tensor.ndim == 2:
        # in matrix form
        assert tensor.shape[0] == tensor.shape[1] == 2 ** num_qubits, f"gate must be a square matrix with dimensions 2^num_qubits, got {tensor.shape}"
    else:
        assert tensor.ndim == 2 * num_qubits, f"gate must have 2 * num_qubits dimensions, got {tensor.ndim}"

def unify_tensor_dtypes(t1: torch.Tensor, t2: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert t1.dtype in [torch.float32, torch.float64, torch.complex64, torch.complex128], \
        "quantum_state must be a float or complex tensor"
    assert t2.dtype in [torch.float32, torch.float64, torch.complex64, torch.complex128], \
        "quantum_state must be a float or complex tensor"
    if t1.dtype == t2.dtype:
        return t1, t2
    convert_dtypes = [
        (torch.float32, torch.complex64, torch.complex64),
        (torch.float64, torch.complex64, torch.complex128),
        (torch.float32, torch.complex128, torch.complex128),
        (torch.float64, torch.complex128, torch.complex128),
    ]
    for d1, d2, td in convert_dtypes:
        if (t1.dtype == d1 and t2.dtype == d2) or (t1.dtype == d2 and t2.dtype == d1):
            return t1.to(td), t2.to(td)
    raise_dtypes = [
        (torch.float32, torch.float64),
        (torch.complex64, torch.complex128),
    ]
    for d1, d2 in raise_dtypes:
        if (t1.dtype == d1 and t2.dtype == d2) or (t1.dtype == d2 and t2.dtype == d1):
            return t1.to(d2), t2.to(d2)
    
    raise Exception("Unreachable code in unify_tensor_dtypes")

def map_float_to_complex(*, tensor: torch.Tensor | None = None, dtype: torch.dtype | None = None) -> torch.Tensor | torch.dtype:
    assert tensor is not None or dtype is not None, "Either tensor or dtype must be provided"
    original_dtype = tensor.dtype if tensor is not None else dtype
    assert original_dtype in [torch.float32, torch.float64], "dtype must be float32 or float64"
    to_dtype = torch.complex64 if original_dtype == torch.float32 else torch.complex128
    if tensor is not None:
        return tensor.to(to_dtype)
    return to_dtype

In [None]:
# test unify_tensor_dtypes
for d1 in [torch.float32, torch.float64, torch.complex64, torch.complex128]:
    t1 = torch.tensor([1, 2], dtype=d1)
    for d2 in [torch.float32, torch.float64, torch.complex64, torch.complex128]:
        t2 = torch.tensor([1, 2], dtype=d2)
        unify_tensor_dtypes(t1, t2)