### **INTEGRATION CHECKS**

In [1]:
import jax
import jax.numpy as jnp
import jax.random as jrand
import pytest

**WaveFunction**

In [2]:
from Qubitly.states import WaveFunction

wf0 = WaveFunction(jnp.array([1, 1, 1, 1], dtype=jnp.complex64), 2)
print(wf0.norm)
print(wf0.amplitudes)
print(wf0)


wf1 = WaveFunction.from_string('11')
print(wf1)

wf2 = WaveFunction.from_string('01')
print(wf2)

wf3 = WaveFunction.from_superposition([wf1, wf2])
print(wf3)

wf4 = WaveFunction.from_superposition([wf1, wf2], [1, 2])
print(wf4)

2.0
[1.+0.j 1.+0.j 1.+0.j 1.+0.j]
WaveFunction: [1.+0.j 1.+0.j 1.+0.j 1.+0.j]
WaveFunction: [0.+0.j 0.+0.j 0.+0.j 1.+0.j]
WaveFunction: [0.+0.j 1.+0.j 0.+0.j 0.+0.j]
WaveFunction: [0.        +0.j 0.70710677+0.j 0.        +0.j 0.70710677+0.j]
WaveFunction: [0.       +0.j 0.8944272+0.j 0.       +0.j 0.4472136+0.j]


**Operator**

In [3]:
from Qubitly.gates import *

_00 = WaveFunction.from_string('00')
_balanced = WaveFunction(amplitudes=[1, 1, 1, 1])
_bell0 = WaveFunction(n_qubits=2, amplitudes=[1, 0, 0, 1])


print(_00)
print(SigmaX(0) * _00)
print(SigmaX(1) * _00)  # What if I apply SigmaX(2)?
print(SigmaX(1) * (SigmaX(0) * _00))

print()

print(_bell0)
print(SigmaX(0) * _bell0)
print(SigmaZ(0) * _bell0)

print()

print(SimpleSigmaX().apply(SimpleSigmaY()).matrix)
print((SimpleSigmaX()*SimpleSigmaY()).matrix)
print((SigmaX(0)*SigmaY(0)).simple_op.matrix)
print((SigmaX(5)*SigmaY(5)).simple_op.matrix)
with pytest.raises(AssertionError):
    print((SigmaX(0)*SigmaY(1)).simple_op.matrix)

WaveFunction: [1.+0.j 0.+0.j 0.+0.j 0.+0.j]
WaveFunction: [0.+0.j 1.+0.j 0.+0.j 0.+0.j]
WaveFunction: [0.+0.j 0.+0.j 1.+0.j 0.+0.j]
WaveFunction: [0.+0.j 0.+0.j 0.+0.j 1.+0.j]

WaveFunction: [1.+0.j 0.+0.j 0.+0.j 1.+0.j]
WaveFunction: [0.+0.j 1.+0.j 1.+0.j 0.+0.j]
WaveFunction: [ 1.+0.j  0.+0.j  0.+0.j -1.+0.j]

[[0.+1.j 0.+0.j]
 [0.+0.j 0.-1.j]]
[[0.+1.j 0.+0.j]
 [0.+0.j 0.-1.j]]
[[0.+1.j 0.+0.j]
 [0.+0.j 0.-1.j]]
[[0.+1.j 0.+0.j]
 [0.+0.j 0.-1.j]]


In [4]:
O = Operator([SimpleSigmaZ(), SimpleSigmaX()], [0,1])
print(O.simple_op.matrix)
psi = WaveFunction.from_string('01')
print(psi)
print(O.simple_op.matrix @ psi.amplitudes)
print(O * psi)
# NOTE that the two lines above are the same since matrix = matrix1 *(kronecker) matrix0 (this depends on how the function _apply_matrix_to_two_sites() is defined); state = state1 *(kronecker) state0

[[ 0.+0.j  0.+0.j  1.+0.j  0.+0.j]
 [ 0.+0.j -0.+0.j  0.+0.j -1.+0.j]
 [ 1.+0.j  0.+0.j  0.+0.j  0.+0.j]
 [ 0.+0.j -1.+0.j  0.+0.j -0.+0.j]]
WaveFunction: [0.+0.j 1.+0.j 0.+0.j 0.+0.j]
[ 0.+0.j  0.+0.j  0.+0.j -1.+0.j]
WaveFunction: [ 0.+0.j  0.+0.j  0.+0.j -1.+0.j]


**QuantumCircuit**

In [5]:
from Qubitly.states import CompBasisMeasurement
from Qubitly.gates import Hadamard, CNOT
from Qubitly.circuits import QuantumCircuit, CircuitLayer, CircuitError

In [6]:
qc = QuantumCircuit(
    CircuitLayer(Hadamard(0)),
    CircuitLayer(CNOT(control=0, target=1)),
    CompBasisMeasurement("m0", 0),
    CompBasisMeasurement("m1", 1),
)

_00 = WaveFunction.from_string('00')

key = jrand.key(5)
_bell0, user_vars = jax.jit(qc)(_00, key)
print(_bell0)

result = jax.jit(qc)(_00, key)
print(result.wf)

_bell0_jit, user_vars_jit = qc.jit_call(_00, key)
print(_bell0_jit)

result_jit = qc.jit_call(_00, key)
print(result.wf)

WaveFunction: [1.+0.j 0.+0.j 0.+0.j 0.+0.j]
WaveFunction: [1.+0.j 0.+0.j 0.+0.j 0.+0.j]
WaveFunction: [1.+0.j 0.+0.j 0.+0.j 0.+0.j]
WaveFunction: [1.+0.j 0.+0.j 0.+0.j 0.+0.j]


In [7]:
import pytest

with pytest.raises(CircuitError):
    qc_wrong_1 = QuantumCircuit(
        CircuitLayer(
            Hadamard(0),
            Hadamard(0),
        ),
    )

with pytest.raises(CircuitError):
    qc_wrong_2 = QuantumCircuit(
        CircuitLayer(
            Hadamard(1),
            CNOT(control=1, target=0),
        ),
    )

with pytest.raises(CircuitError):
    qc_wrong_3 = QuantumCircuit(
        CircuitLayer(
            Hadamard(0),
            CNOT(control=1, target=1),
        ),
    )

**QuantumTeleportation**

In [8]:
import Qubitly.examples as examples

_input = WaveFunction(jnp.array([1, 1], dtype=jnp.complex64))
_input.normalize()
print(_input)

key = jrand.key(2)
_input = examples.prepare_for_teleportation(_input)
_output, user_vars = examples.QuantumTeleportation(_input, key)
_output = examples.extract_teleported_qubit(_output, user_vars)
print(_output)

_output_jit, user_vars_jit = examples.QuantumTeleportation.jit_call(_input, key)
_output_jit = examples.extract_teleported_qubit(_output_jit, user_vars_jit)
print(_output_jit)

QT_jit = jax.jit(examples.QuantumTeleportation)
_output_jit, user_vars_jit = QT_jit(_input, key)
_output_jit = examples.extract_teleported_qubit(_output_jit, user_vars_jit)
print(_output_jit)

# NOTE that the jitted version is 3 o. o. m. faster!

WaveFunction: [0.70710677+0.j 0.70710677+0.j]
WaveFunction: [0.70710677+0.j 0.70710677+0.j]
WaveFunction: [0.70710677+0.j 0.70710677+0.j]
WaveFunction: [0.70710677+0.j 0.70710677+0.j]


**How many times is the initializer of WaveFunction called inside a QuantumCircuit?**
Add a debug print to the initializer of WaveFunction (print(...) or jax.debug.print(...)) to verify the output of the following cells.

In [None]:
_00 = WaveFunction.from_string('00')
qc = QuantumCircuit(
    Hadamard(0)
)
qc_jit = jax.jit(qc)

In [None]:
result = qc(_00)
# One for the non-jitted version, as expected

In [None]:
result = qc_jit(_00)
# Three for the jitted version (copy, build and bind result). In fact, if the circuit performed the same operation but didn't return a WaveFunction, the initializer would ber called only twice.
# NOTE I'm not sure that `binding` is the right word

In [None]:
result = qc_jit(_00)
# Expect only one for the jitted version running the second time (bind result)

In [None]:
# For more complex circuits: 
# - non-jitted version: One call for every layer
# - jitted version: One might expect two calls for every layer and one final call for binding... but actually it's optimized to avoid copying at intermediate steps, so one for copying at the beginning, one for each gate, one for binding at the end.
qc = QuantumCircuit(
    Hadamard(0),
    Hadamard(1),
    CNOT(control=0, target=1)
)

qc_jit = jax.jit(qc)

In [None]:
result = qc(_00) # Three

In [None]:
result = qc_jit(_00) # Five

In [None]:
result = qc_jit(_00) # One

In [None]:
qc = QuantumCircuit(
    Hadamard(0),
    Hadamard(1),
    CNOT(control=0, target=1),
    CZ(control=1, target=0),
    SigmaX(1),
)

qc_jit = jax.jit(qc)

In [None]:
result = qc(_00) # Five

In [None]:
result = qc_jit(_00) # Seven

In [None]:
result = qc_jit(_00) # One

At the end of the day, the current implementation of QuantumCircuit - each layer returning a WaveFunction object rather than an array - leads to an overhead linear with the number of circuit gates.

**Do jitted functions allow raising exceptions in WaveFunction.__init__()?**

In [9]:
# Function that creates a new WaveFunction object with wrong number of qubits
def try_and_change_n_qubits(wf: WaveFunction) -> WaveFunction:
    wrong_n = wf.n_qubits + 1
    new_wf = WaveFunction(wf.amplitudes, wrong_n)
    return new_wf

try_and_change_n_qubits_jit = jax.jit(try_and_change_n_qubits)

In [10]:
wf = WaveFunction.from_string("00")

with pytest.raises(ValueError):
    new_wf = try_and_change_n_qubits(wf)

with pytest.raises(ValueError):
    new_wf = try_and_change_n_qubits_jit(wf)

The jitted version also raises, but note that the problem inside the function does not depend on the input it receives, so it never gets jitted succesfully.
We try now with a function that isn't inherently wrong: we jit compile it and then we check whether a wrong input triggers the exception.

In [11]:
def assert_2_qubit_wf(wf: WaveFunction):
    assert wf.n_qubits == 2
    return

assert_2_qubit_wf_jit = jax.jit(assert_2_qubit_wf)

In [12]:
wf_1 = WaveFunction.from_string("0")
wf_2 = WaveFunction.from_string("00")

with pytest.raises(AssertionError):
    assert_2_qubit_wf(wf_1)

with pytest.raises(AssertionError):
    assert_2_qubit_wf_jit(wf_1)
# This raises too, to no surprise

assert_2_qubit_wf_jit(wf_2)
# This works fine as expected

with pytest.raises(AssertionError):
    assert_2_qubit_wf_jit(wf_1)

Also in the last case the exception is raised: calling the function on an input of different shape triggers recompilation. This is a stupid example!

In [13]:
def assert_is_00(wf: WaveFunction):
    def raise_value_error(_):
        raise ValueError()
        return
    jax.lax.cond(jnp.allclose(wf.amplitudes, jnp.array([1, 0, 0, 0], dtype=jnp.complex64)), lambda _: _, raise_value_error, operand=None)
    return

assert_is_00_jit = jax.jit(assert_is_00)

In [14]:
_00 = WaveFunction.from_string("00")
_11 = WaveFunction.from_string("11")

with pytest.raises(ValueError):
    assert_is_00(_00)

with pytest.raises(ValueError):
    assert_is_00_jit(_00)

Jax follows all paths in jax.lax control flows, hence the exception is always raised, even by the non-jitted function.