In [None]:
from pathlib import Path
from PIL import Image
import numpy as np
from sympy import Matrix

filename = Path("../img/color/Jayhawk_512x512.jpg")

with Image.open(filename, "r") as im:
    display(im)

# True for noisy execution, False for noiseless execution
noisy_execution = False
draw_qc = True

In [None]:
def display_quantum_circuit(qc):
    if draw_qc:
        display(qc.draw('mpl', reverse_bits=True))

In [None]:
from thesis.filters import *

stride = 1
num_layers = 2
kernel_raw = avg_filter(2, dim = 2)
# kernel_raw = sobel_filter(3, axis=0)
# kernel_raw = laplacian()
# kernel_raw = gaussian_blur()

npad = tuple((0, 2 ** int(np.ceil(np.log2(N))) - N) for N in kernel_raw.shape)
kernel = np.pad(kernel_raw, pad_width=npad, mode="constant", constant_values=0)

Matrix(kernel)

In [None]:
from thesis.quantum import *

psi, *dims = flatten_image(filename, pad=True)
psi, mag = normalize(psi, include_magnitude=True)

n_dim = len(dims)

num_qubits = 0
wires = []
for dim in dims:
    root, num_qubits = num_qubits, num_qubits + to_qubits(dim)
    wires.append(list(range(root, num_qubits)))

In [None]:
from qiskit import QuantumCircuit
from qiskit.quantum_info.operators import Operator

qc = QuantumCircuit(num_qubits)
qc.initialize(psi)

display_quantum_circuit(qc)

In [None]:
from thesis.quantum.qiskit import *

def conv_pool_hybrid(
    qc: QuantumCircuit, wires, kernel, stride: int = 1, n_dim: int = None
):
    if n_dim is None:
        n_dim = len(wires)

    for _ in range(kernel.ndim + n_dim - len(wires)):
        wires += [[]]
    kernel_shape_q = [to_qubits(filter_size) for filter_size in kernel.shape]

    params, kernel_mag = normalize(kernel.flatten(order="F"), include_magnitude=True)

    ### Shift operation
    for i, fq in enumerate(kernel_shape_q):
        if len(wires[i]) > fq:
            ctrl_qubits, img_qubits = wires[i][:fq], wires[i][fq:]
            wires[i] = ctrl_qubits + img_qubits[fq:]
            wires[n_dim + i] += img_qubits[:fq]
        else:
            wires[n_dim + i] += wires[i]
            ctrl_qubits = wires[i] = []

        for j, control_qubit in enumerate(ctrl_qubits):
            shift(qc, -stride, targets=img_qubits[j:], control=control_qubit)

    ### Filter using C2Q
    kernel_qubits = [q for w, fq in zip(wires[n_dim:], kernel_shape_q) for q in w[-fq:]]
    c2q(qc, params, targets=kernel_qubits, transpose=True)

    return wires, kernel_mag

In [None]:
wires_old = wires.copy()
kernel_mag = 1
for _ in range(num_layers):
    wires, fm = conv_pool_hybrid(qc, wires, kernel, stride, n_dim)
    kernel_mag *= fm

display_quantum_circuit(qc)

In [None]:
print(wires, n_dim)

### Permutations for qiskit only
kernel_shape_q = [to_qubits(filter_size) for filter_size in kernel.shape]
for i, fq in reversed(list(enumerate(kernel_shape_q))):
    offset = sum([len(w) for w in wires[:n_dim + i]])
    for j, w in reversed(list(enumerate(wires[n_dim + i]))):
        rotate(qc, w, offset+j)
        
display_quantum_circuit(qc)

In [None]:
from qiskit import Aer, execute
from qiskit.tools import job_monitor

backend = Aer.get_backend('aer_simulator')
shots = backend.configuration().max_shots

if noisy_execution:
    qc.measure_all()
else:
    qc.save_statevector()

job = execute(qc, backend=backend, shots=shots)
job_monitor(job)

result = job.result()

if noisy_execution:
    counts = result.get_counts(qc)
    psi_out = from_counts(counts, shots=shots, num_qubits=num_qubits)
else:
    psi_out = result.get_statevector(qc).data

In [None]:
scaling = [2**(len(wo) - len(w)) for w, wo in zip(wires, wires_old)]
dims = [d // s for d, s in zip(dims, scaling)]
num_states = np.prod([2**to_qubits(dim) for dim in dims])

In [None]:
i = 0
img = psi_out.data[i*num_states:(i+1)*num_states]
norm = mag * kernel_mag # * np.sqrt(2**num_ancilla)
img = construct_img(norm*img, dims)
img.save("output.png")

display(img)