Jisp, a dynamic Pythonic low-level IR
-------------------------------------

Within this notebook we demonstrate the latest feature of the Jax Integration.

We introduce a Jisp, a new IR that represents hybrid programs embedded into the Jaxpr IR.

Creating a Jisp program is simple:

In [13]:
from qrisp import *
from jax import make_jaxpr

def test_f(i):
    a = QuantumFloat(i)
    with invert():
        x(a[0])
        cx(a[0], a[2])
    return 

jisp_program = make_jaxpr(test_f)(5)


print(jisp_program)
print(collect_environments(jisp_program.jaxpr))
#print(to_qc(jisp_program)(5)[0])

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i32[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:QuantumCircuit[39m = qdef 
    c[35m:QuantumCircuit[39m d[35m:QubitArray[39m = create_qubits b a
    e[35m:QuantumCircuit[39m = q_env[stage=enter type=inversionenvironment] c
    f[35m:Qubit[39m = get_qubit d 0
    g[35m:QuantumCircuit[39m = x e f
    h[35m:Qubit[39m = get_qubit d 0
    i[35m:Qubit[39m = get_qubit d 2
    j[35m:QuantumCircuit[39m = cx g h i
    _[35m:QuantumCircuit[39m = q_env[stage=exit type=inversionenvironment] j
  [34m[22m[1min [39m[22m[22m() }
{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i32[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:QuantumCircuit[39m = qdef 
    c[35m:QuantumCircuit[39m d[35m:QubitArray[39m = create_qubits b a
    _[35m:QuantumCircuit[39m = q_env[
      jaxpr={ [34m[22m[1mlambda [39m[22m[22md[35m:QubitArray[39m; e[35m:QuantumCircuit[39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:Qubit

Jisp programs can be executed with the Jisp interpreter:

In [2]:
compiled_test_f = jisp_interpreter(jisp_program)

print(compiled_test_f(5))

5                                                                                    [2K


One of the most powerful features of this IR is that it is fully dynamic, allowing many functions to be cached and reused.

In [4]:
import time

@qache
def inner_function(qv, i):
    cx(qv[0], qv[1])
    h(qv[i])
    # Complicated compilation, that takes a lot of time
    time.sleep(1)

def outer_function():
    qv = QuantumFloat(5)

    inner_function(qv, 0)
    inner_function(qv, 1)
    inner_function(qv, 2)

    return measure(qv)

t0 = time.time()
jisp_program = make_jaxpr(outer_function)()
print(time.time()- t0)

print(jisp_program)


1.0111033916473389
let inner_function = { [34m[22m[1mlambda [39m[22m[22m; a[35m:QuantumCircuit[39m b[35m:QubitArray[39m c[35m:i32[][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:Qubit[39m = get_qubit b 0
    e[35m:Qubit[39m = get_qubit b 1
    f[35m:QuantumCircuit[39m = cx a d e
    g[35m:Qubit[39m = get_qubit b c
    h[35m:QuantumCircuit[39m = h f g
  [34m[22m[1min [39m[22m[22m(h,) } in
{ [34m[22m[1mlambda [39m[22m[22m; . [34m[22m[1mlet
    [39m[22m[22mi[35m:QuantumCircuit[39m = qdef 
    j[35m:QuantumCircuit[39m k[35m:QubitArray[39m = create_qubits i 5
    l[35m:QuantumCircuit[39m = pjit[name=inner_function jaxpr=inner_function] j k 0
    m[35m:QuantumCircuit[39m = pjit[name=inner_function jaxpr=inner_function] l k 1
    n[35m:QuantumCircuit[39m = pjit[name=inner_function jaxpr=inner_function] m k 2
    _[35m:QuantumCircuit[39m o[35m:i32[][39m = measure n k
    p[35m:i32[][39m = mul o 1
  [34m[22m[1min [39m[22m[22m(p,)

Furthermore Jisp programs are seamlessly hybrid:

In [15]:

def test_f():

    a = QuantumFloat(6, -2)
    a[:] = 5.5

    b = measure(a)

    b += 10

    return b

jisp_program = make_jaxpr(test_f)()

print(jisp_interpreter(jisp_program)())
print(jisp_program)

25.5                                                                                 [2K
{ [34m[22m[1mlambda [39m[22m[22m; . [34m[22m[1mlet
    [39m[22m[22ma[35m:QuantumCircuit[39m = qdef 
    b[35m:QuantumCircuit[39m c[35m:QubitArray[39m = create_qubits a 6
    d[35m:i32[][39m = get_size c
    _[35m:i32[][39m _[35m:i32[][39m _[35m:QubitArray[39m e[35m:QuantumCircuit[39m = while[
      body_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; f[35m:i32[][39m g[35m:i32[][39m h[35m:QubitArray[39m i[35m:QuantumCircuit[39m. [34m[22m[1mlet
          [39m[22m[22mj[35m:i32[][39m = add f 1
          k[35m:i32[][39m = shift_left 1 f
          l[35m:i32[][39m = and k 22
          m[35m:Qubit[39m = get_qubit h f
          n[35m:bool[][39m = ne l 0
          o[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] n
          p[35m:QuantumCircuit[39m = cond[
            branches=(
              { [34m[22m[1mlambda [39m[22m[22m;