# Introduction

In this notebook we explore the usage of the swap test to approximate an unknown quantum state. We begin with a brief discussion on single-qubit state creation. We then implement the swap test for approximating a single-qubit. We finish with an extension of the swap test to product states. 

Note: qiskit be throwing deprecation warnings all over the place. I'll do my best to suppress them, but I'm sure some will slip through. My apologies. 

# Imports

In [None]:
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)
from math import pi
from random import uniform

from qiskit import Aer, QuantumCircuit, execute
from qutip import Bloch, Qobj

import numpy as np

# Part 1 - Generating quantum states

To begin, we verify that Qiskit's quantum gate set can generate any single-qubit state starting from |0>. We do this with a u-gate, specified by angles theta and phi. 

The u-gate can be further broken down into a theta rotation about the x-axis of the Bloch sphere followed by a phi-rotation about the z-axis.

In [None]:
### Verification - Does the u-gate cover the Bloch sphere? ###
# Code source: http://qutip.org/docs/latest/guide/guide-bloch.html#animating-with-the-bloch-sphere
import matplotlib.animation as animation
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.pyplot import figure
%matplotlib inline

fig = figure()
ax = Axes3D(fig,azim=-40,elev=30)
sphere = Bloch(axes=ax)
sphere.point_color = ['b']
sphere.point_marker = ['o']
sphere.point_size = [20, 27, 30, 40]
sphere.zlpos = [1.5, -1.5]
sphere.xlpos = [1.5, -1.5]

   
def animate(i, num_iters, backend):
    states = []
    for j in range(num_iters):
        qc = QuantumCircuit(1)
        qc.u(pi*(i/num_iters), 2*pi*(j/num_iters), 0, 0)
        states.append(Qobj(execute(qc, backend).result().get_statevector()))
    sphere.add_states(states, kind='point')
    sphere.make_sphere()
    return ax

num_iters = 30
ani = animation.FuncAnimation(fig, animate, fargs=(num_iters, Aer.get_backend('statevector_simulator')), frames=np.arange(num_iters))
ani.save('bloch_sphere.mp4', fps=15)

Alright, cool, just wanted to make sure. Next up, let's introduce the swap test circuit.

# Part 2 - Swap test implementation, state approximation

Below is the circuit diagram for state creation along with the swap test proper.

In [None]:
### Swap test ###

from qiskit.circuit import Parameter

theta1 = Parameter('t1')
phi1 = Parameter('p1')
theta2 = Parameter('t2')
phi2 = Parameter('p2')

qc = QuantumCircuit(3, 1) # 1 classical register to store the test result
qc.u(theta1, phi1, 0, 1)
qc.u(theta2, phi2, 0, 2)
qc.barrier()
qc.h(0)
qc.cswap(0, 1, 2)
qc.h(0)
qc.measure(0, 0)
qc.draw()

This circuit is parametrized by four angles; two represent the target state, the other two the initial state. In our optimization procedure, we set the target state parameters at the start and vary the initial state parameters until the inner product is close to 1. 

In [None]:
def angles2amplitude(state):
    ''' Convert the (theta, phi) representation to the amplitude representation. '''
    qc = QuantumCircuit(1)
    qc.u(state[0], state[1], 0, 0)
    backend = Aer.get_backend('statevector_simulator')
    job = execute(qc, backend)
    result = job.result()
    return result.get_statevector()

def amplitude2angles(state):
    return

def swap_test(state1, state2, backend, num_shots=1024):
    qc = QuantumCircuit(3, 1)
    qc.h(0)
    qc.u(state1[0], state1[1], 0, 1)
    qc.u(state2[0], state2[1], 0, 2)
    qc.cswap(0, 1, 2)
    qc.h(0)
    qc.measure(0, 0)
        
    job = execute(qc, backend, shots=num_shots)
    counts = job.result().get_counts()
    if '1' in counts:
        return counts['1']
    else:
        return 0
    
def inner_product_approx(state1, state2, backend, num_shots=1024):
    ''' Returns an approximation to the inner product of states given by their Bloch angles. '''
    counts = swap_test(state1, state2, backend, num_shots)
    return 1 - ((2 / num_shots)*counts)

If we were only comparing |0>'s and |1>'s, we could get away with a simpler use of the swap test, declaring the states distinct the moment a 1 appears in our measurement results. For comparing arbitrary qubits, we want to have a little more direction for our optimizer. If the states are different, how different are they?

We use the approxInnerProduct function to give an estimate of the inner product from repeated measurements of the swap circuit, but how many repeats are necessary to get a good enough result? Below we plot convergence properties for this function with a variety of states. As a future improvement, I would like to give a more careful coverage of possible states, but I think this is a good start. 

In [None]:
warnings.filterwarnings('ignore', category=DeprecationWarning)
import matplotlib.pyplot as plt

def inner_product_exact(state1, state2):
    state1 = angles2amplitude(state1)
    state2 = angles2amplitude(state2)
    return np.abs(np.dot(np.conjugate(state1), state2))**2
    
def swap_test_convergence(state1, state2, backend, max_shots=2**11, num_samples=100, num_trials=10):
    ipe = inner_product_exact(state1, state2)
    shots = []
    xrange = np.geomspace(1, max_shots, num_samples)
    for x in xrange:
        trial_results = []
        for i in range(num_trials):
            trial_results.append(inner_product_approx(state1, state2, backend, num_shots=x))
        shots.append(trial_results)
    return ipe, xrange, np.array(shots)

Next comes the challenge. Given a randomly generated single-qubit quantum state, I want to replicate it in a fresh state of my own. I'm assuming I have a lot of copies of this unknown state on hand, because the swap test needs to be made more than once to get any degree of certainty. 

By making multiple runs of the swap test we can approximate the inner product of the two states. This inner product will be used as a cost function for our optimizer, which will tune the (theta, phi) parameters of our input state until we reach the target state. 

We're making the assumption that we don't know the parameters of the generated quantum state. If the parameters were known, we would have a closed form expression for the cost function and could use some form of gradient descent to find the desired parameters. Without it, we have to resort to any kind of minimizing procedure (like annealing). With the framework in place below we'll be in a position to tune our optimizer and our hyperparameters. 

In [None]:
### Approximating a single-qubit quantum state

from scipy.optimize import brute, differential_evolution, shgo, dual_annealing
from qiskit.visualization import plot_bloch_vector, plot_bloch_multivector

# I like the idea of slowing down to approach the point
# Listen, the inner product with fluctuate, but hopefully with enough iterations of swap, this won't be too much
# of a problem

def cost(init_state, *args):
    return -1*inner_product_approx(init_state, args[0], args[1], args[2])

backend = Aer.get_backend('qasm_simulator')
initial_state = [0, 0] # CHANGEME
target_state = [uniform(0, pi), uniform(0, 2*pi)] # CHANGEME

### Brute force minimization ###

parameter_range = ((0, pi), (0, 2*pi))
approx_state_brute = brute(cost, parameter_range, (target_state, backend, 1024), Ns=50)
print("Target state was theta = " + str(target_state[0]) + ", phi = " + str(target_state[1]))
print("State minimising cost: theta = " + str(approx_state_brute[0]) + ", phi = " + str(approx_state_brute[1]))

### Differential evolution ###



### SHG ###

approx_state_shg = shgo(cost, parameter_range, (target_state, backend, 1024))
approx_state_shg = approx_state_shg.x
print("SHG output: theta = " + str(approx_state_shg[0]) + ", phi = " + str(approx_state_shg[1]))

### Dual annealing ###


### Basic minimisation ###




# approx_state_da = dual_annealing(cost, parameter_range, (target_state, backend, 1024))


# Part 3 - Using the swap test for multiple-qubit states 

With the single-qubit case out of the way, we next turn to what ends up being a simpler problem. Consider a multi-qubit state, but take out all the complexity: a product state where all the qubits are either |0> or |1>, like |01001>. Kind of like a quantum bit string. 

If we're given a state like this and asked to match it, we don't have to deal with the optimizing procedure used above. Look at it in a qubit-by-qubit fashion, and consider the first qubit of our target state. If the first qubit of our initial state matches the target qubit, we will never measure 1 off the ancilla qubit (assuming a perfect quantum computer, but let's not go into the weeds here). If they don't match, we'll eventually get a 1, and we modify our qubit to match. 

That's a lot of words to say that we can just iterate through the qubits in our initial state, using grid search until we achieve a perfect match. 

In [None]:
### Constructing a better swap circuit ###

state_size = 5
qc = QuantumCircuit(2*state_size + 1, 1)
qc.h(0)
for i in range(1, state_size + 1):
    qc.cswap(0, i, i + state_size)
qc.h(0)
qc.measure(0, 0)
qc.draw()

This circuit performs the same function as in the single-qubit case, except we now apply a controlled-swap gate on a qubit-by-qubit basis. If you compute the measurement probabilities for the ancilla qubit, you'll find the same expression as the single-qubit swap test.

Since we're dealing with a simple kind of product state, brute force is more valid option here. We search through all possible bit strings until we find a match. If our swap test returns a 1 in N runs of the swap test circuit, we'll know that one of the qubits is incorrectly set. There's a vanishingly small chance that the swap test will return a 0 for several thousand runs.

In [None]:
warnings.filterwarnings('ignore', category=DeprecationWarning)
import numpy as np
from qiskit import QuantumRegister, ClassicalRegister
from itertools import product

initial_state = (0, 0, 0, 0)
target_state  = (1, 1, 0, 1)

def bitlist2int(bit_list):
    s = map(str, bit_list)
    s = ''.join(s)
    return int(s, base=2)

# Swap test for multiple qubits
def swap_test_mq(state1, state2, backend, num_shots=1024):
    n = len(state1)
    q = QuantumRegister(2*n+1)
    c = ClassicalRegister(1)
    qc = QuantumCircuit(q, c)
    
    state1_reg = np.zeros(2**n)
    state1_reg[bitlist2int(state1)] = 1
    state2_reg = np.zeros(2**n)
    state2_reg[bitlist2int(state2)] = 1
    
    qc.initialize(state1_reg, q[1:n+1])
    qc.initialize(state2_reg, q[n+1:])
    qc.h(0)
    for i in range(1, n+1):
        qc.cswap(0, i, i+n)
    qc.h(0)
    qc.measure(0, 0)
    
    job = execute(qc, backend, shots=num_shots)
    counts = job.result().get_counts()
    if '1' in counts:
        return (counts['1'] != 0)
    else:
        return 0

qasm_backend = Aer.get_backend('qasm_simulator')
num_shots=128
def grid_search(target_state, backend, num_shots):
    bitstrings = product([0, 1], repeat=len(target_state))
    for state in bitstrings:
        if swap_test_mq(state, target_state, backend, num_shots) == 0:
            return state
    return "Sorry nothing"

grid_search((1, 1, 1, 1), qasm_backend, num_shots=128)