In [1]:
from typing import List, Tuple, Dict, Optional

from qiskit_dynamics.array import Array
from qiskit.quantum_info import Operator
import jax
from qiskit_dynamics import DynamicsBackend, Solver
from needed_files.jax_solver import JaxSolver
import numpy as np
import copy

jax.config.update("jax_enable_x64", True)
# tell JAX we are using CPU
jax.config.update("jax_platform_name", "cpu")
# import Array and set default backend
Array.set_default_backend("jax")



#### Neighboring qubits for different Topologies

In [2]:
def linear_chain_neighbors(num_qudits):
    neighbors = {}
    for i in range(num_qudits):
        # Each qudit is connected to its immediate predecessor and successor
        neighbors[i] = []
        if i > 0:
            neighbors[i].append(i - 1)
        if i < num_qudits - 1:
            neighbors[i].append(i + 1)
    return neighbors

def ring_neighbors(num_qudits):
    neighbors = {}
    for i in range(num_qudits):
        # Each qudit is connected to its immediate predecessor and successor
        # The first and last qudits are also connected, forming a ring
        neighbors[i] = [(i - 1) % num_qudits, (i + 1) % num_qudits]
    return neighbors

def grid_neighbors(width, height):
    neighbors = {}
    for row in range(height):
        for col in range(width):
            index = row * width + col
            neighbors[index] = []
            if row > 0:
                neighbors[index].append((row - 1) * width + col)  # North
            if row < height - 1:
                neighbors[index].append((row + 1) * width + col)  # South
            if col > 0:
                neighbors[index].append(row * width + (col - 1))  # West
            if col < width - 1:
                neighbors[index].append(row * width + (col + 1))  # East
    return neighbors

def graph_neighbors(adjacency_matrix):
    neighbors = {}
    num_qudits = len(adjacency_matrix)
    for i in range(num_qudits):
        neighbors[i] = [j for j, connected in enumerate(adjacency_matrix[i]) if connected]
    return neighbors

In [3]:
dims = [3, 3]
freqs = [5.0, 5.0]
anharmonicities = [-0.33, -0.33]
rabi_freqs = [0.1, 0.1]

couplings = {(0, 1): 0.1, (1, 0): 0.1}
noise_couplings = {
    (0, 1): 0.02, 
    (1, 0): 0.06             
}
noise_couplings = {
    (0, 1): 0.05,  # Coupling from qubit 0 to qubit 1
    (1, 0): 0.04,  # Coupling from qubit 1 to qubit 0, demonstrating bidirectional specification
    #(1, 2): 0.03   # Coupling from qubit 1 to qubit 2, specified only in one direction
    # Note: (2, 1) is not specified, implying unidirectional or asymmetric coupling in this case
}
solver_options = None

In [4]:
def get_noise_coupling_neighbours(noise_couplings=None):
    # Initialize neighbor lists for each qubit
    # Assuming the maximum qubit index from the adjacency dictionary
    max_qubit_index = max(max(pair) for pair in noise_couplings.keys())
    neighbors = {i: [] for i in range(max_qubit_index + 1)}

    # Populate the neighbor lists based on the adjacency dictionary
    for (qubit1, qubit2), coupling_strength in noise_couplings.items():
        if coupling_strength > 0:  # Assuming a positive coupling strength indicates a connection
            if qubit1 != qubit2:  # Optional: Exclude self-loops if not needed
                # Check if qubit2 is not already a neighbor of qubit1 before adding
                if qubit2 not in neighbors[qubit1]:
                    neighbors[qubit1].append(qubit2)
                # Check if qubit1 is not already a neighbor of qubit2 before adding
                if qubit1 not in neighbors[qubit2]:
                    neighbors[qubit2].append(qubit1)
            else:
                # Handle self-loops if necessary, add only if not already present
                if qubit2 not in neighbors[qubit1]:
                    neighbors[qubit1].append(qubit2)

    return neighbors


In [5]:
get_noise_coupling_neighbours(noise_couplings=noise_couplings)

{0: [1], 1: [0]}

In [6]:
def custom_backend(
    dims: List[int],
    freqs: List[float],
    anharmonicities: List[float],
    rabi_freqs: List[float],
    couplings: Optional[Dict[Tuple[int, int], float]] = None,
    noise_couplings: Optional[Dict[Tuple[int, int], float]] = None,
    solver_options: Optional[Dict] = None,
):
    """
    Custom backend for the dynamics simulation.

    Args:
        dims: The dimensions of the subsystems.
        freqs: The frequencies of the subsystems.
        anharmonicities: The anharmonicities of the subsystems.
        couplings: The coupling constants between the subsystems.
        noise_couplings: The noise coupling constants between (neighbouring) qubits.
        rabi_freqs: The Rabi frequencies of the subsystems.

    """

    assert (
        len(dims) == len(freqs) == len(anharmonicities) == len(rabi_freqs)
    ), "The number of subsystems, frequencies, and anharmonicities must be equal."
    n_qubits = len(dims)
    print('n_qubits', n_qubits)
    a = [Operator(np.diag(np.sqrt(np.arange(1, dim)), 1)) for dim in dims]
    adag = [Operator(np.diag(np.sqrt(np.arange(1, dim)), -1)) for dim in dims]
    N = [Operator(np.diag(np.arange(dim))) for dim in dims]
    ident = [Operator(np.eye(dim, dtype=complex)) for dim in dims]

    full_ident = ident[0]
    for i in range(1, n_qubits):
        full_ident = full_ident.tensor(ident[i])

    N_ops = N
    a_ops = a
    adag_ops = adag
    for i in range(n_qubits):
        for j in range(n_qubits):
            if j > i:
                N_ops[i] = N_ops[i].expand(ident[j])
                a_ops[i] = a_ops[i].expand(ident[j])
                adag_ops[i] = adag_ops[i].expand(ident[j])
            elif j < i:
                N_ops[i] = N_ops[i].tensor(ident[j])
                a_ops[i] = a_ops[i].tensor(ident[j])
                adag_ops[i] = adag_ops[i].tensor(ident[j])

    static_ham = Operator(
        np.zeros((np.prod(dims), np.prod(dims)), dtype=complex),
        input_dims=tuple(dims),
        output_dims=tuple(dims),
    )

    for i in range(n_qubits):
        static_ham += 2 * np.pi * freqs[i] * N_ops[i] + np.pi * anharmonicities[
            i
        ] * N_ops[i] @ (N_ops[i] - full_ident)
    drive_ops = [
        2 * np.pi * rabi_freqs[i] * (a_ops[i] + adag_ops[i]) for i in range(n_qubits)
    ]

    drive_ops_errorfree = copy.deepcopy(drive_ops)

    ### Noise couplings ###
    if noise_couplings is not None:
        keys = list(noise_couplings.keys())
        # Get the neighbours of each qubit based on the noise couplings dictionary provided
        neighbours = get_noise_coupling_neighbours(noise_couplings)
        print('qubit neighbours', neighbours)
        for i, j in keys:
            if (j, i) not in noise_couplings:
                noise_couplings[(j, i)] = noise_couplings[(i, j)] # Make the noise coupling symmetric if not specified otherwise
        print('noise_couplings', noise_couplings)
        
        # Add spill-over error terms to the drive operators
        errors = {qbit: [] for qbit in range(n_qubits)}
        for qbit in range(n_qubits):
            print('qbit', qbit)
            for neighbour in neighbours[qbit]:
                if (qbit, neighbour) in noise_couplings:
                    print('neighbour', neighbour)
                    noise_strength = noise_couplings[(qbit, neighbour)]
                    print('noise_strength', noise_strength)
                    noise_contribution = noise_strength * (2 * np.pi * rabi_freqs[qbit] * (a_ops[neighbour] + adag_ops[neighbour]))
                    errors[qbit].append(noise_contribution)
                    drive_ops[qbit] += noise_contribution

        drive_ops_error = copy.deepcopy(drive_ops)

        for qbit in range(n_qubits):
            # Sum the error contributions from all neighbors for the current qubit
            total_noise_contribution = sum(errors[qbit]) if qbit in errors else 0

            # Calculate the expected total drive operation with noise for the current qubit
            expected_total_with_noise = drive_ops_errorfree[qbit] + total_noise_contribution

            # Check if the expected total matches the actual total drive operation with noise
            if not np.allclose(expected_total_with_noise, drive_ops_error[qbit]):
                raise ValueError(f'Error in adding noise to drive operators for qubit {qbit}')
        

    channels = {f"d{i}": freqs[i] for i in range(n_qubits)}
    ecr_ops = []
    num_controls = 0
    if couplings is not None:
        keys = list(couplings.keys())
        for i, j in keys:
            couplings[(j, i)] = couplings[(i, j)]
        for (i, j), coupling in couplings.items():
            static_ham += (
                2
                * np.pi
                * coupling
                * (a_ops[i] + adag_ops[i])
                @ (a_ops[j] + adag_ops[j])
            )
            channels[f"u{num_controls}"] = freqs[j]
            num_controls += 1
            ecr_ops.append(drive_ops[i])        

    dt = 2.2222e-10

    jax_solver = JaxSolver(
        static_hamiltonian=static_ham,
        hamiltonian_operators=drive_ops + ecr_ops,
        rotating_frame=static_ham,
        hamiltonian_channels=list(channels.keys()),
        channel_carrier_freqs=channels,
        dt=dt,
        evaluation_mode="dense",
    )

    solver = Solver(
        static_hamiltonian=static_ham,
        hamiltonian_operators=drive_ops + ecr_ops,
        rotating_frame=static_ham,
        hamiltonian_channels=list(channels.keys()),
        channel_carrier_freqs=channels,
        dt=dt,
        evaluation_mode="dense",
    )
    if solver_options is None:
        solver_options = {
            "method": "jax_odeint",
            "atol": 1e-5,
            "rtol": 1e-7,
            "hmax": dt,
        }

    jax_backend = DynamicsBackend(
        solver=jax_solver,
        subsystem_dims=dims,  # for computing measurement data
        solver_options=solver_options,  # to be used every time run is called
    )

    dynamics_backend = DynamicsBackend(
        solver=solver,
        subsystem_dims=dims,  # for computing measurement data
        solver_options=solver_options,  # to be used every time run is called
    )
    return jax_backend, dynamics_backend

In [7]:
jax_backend, dynamics_backend = custom_backend(dims, freqs, anharmonicities, rabi_freqs, couplings, noise_couplings, solver_options)

n_qubits 2
qubit neighbours {0: [1], 1: [0]}
noise_couplings {(0, 1): 0.05, (1, 0): 0.04}
qbit 0
neighbour 1
noise_strength 0.05
qbit 1
neighbour 0
noise_strength 0.04


I0000 00:00:1708485562.100495       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
