*Disclaimer*: This notebook borrows from different sources including the IBMs tutorial on [quantum generative adversarial networks](https://learn.qiskit.org/course/machine-learning/quantum-generative-adversarial-networks).

# Quantum Generative Adversarial Networks

In [None]:
import numpy as np
import tensorflow as tf

from qiskit import QuantumCircuit, Aer
from qiskit.visualization import plot_histogram
from qiskit.circuit import ParameterVector
from qiskit.circuit.library import TwoLocal

## Defining our real distribution based on Bell states

In one of our first lectures we got to know to the Bell state – a maximally entangled quantum state. This state can be constructed by appplying a Hadamard gate followed by a CNOT gate. The state we are interested to model with our quantum generative adversarial network is therefore $$|\psi\rangle = \frac{1}{\sqrt{2}}(|00\rangle + |11\rangle).$$

In [None]:
# Defining the constants we will need for the selected example
REAL_DISTRIBUTION_NQUBITS = 2  # Number of qubits needed to model real distribution
DISCRIMINATOR_NQUBITS = REAL_DISTRIBUTION_NQUBITS + 1  # Number of qubits for discriminator
GENERATOR_NLAYERS = 2  # Number of layers used for the generator
EPOCHS = 100  # Number of epochs for training

In [None]:
# Circuit creating our real distribution
real_circuit = QuantumCircuit(REAL_DISTRIBUTION_NQUBITS)
real_circuit.h(0)
real_circuit.cx(0, 1);
real_circuit.draw("mpl")

## Defining the variational quantum generator and discriminator

The research is still ongoing on how to find and define a proper ansatz for generator and discriminator. So most hyperparameters are still driven by heuristics. 

### Variational quantum generator

For our variational quantum generator we need to ensure that ansatz has enough capacity and expressibility to fully reproduce the real quantum state $|\psi\rangle$. So in the following we use an ansatz based on successive $R_Y$ and $R_Z$ gates as well as entangling $CZ$ gates. This ansatz is expressive enough to properly represent the Bell state we are trying to model.

For this, we can use the [`TwoLocal`](https://qiskit.org/documentation/stable/0.39/stubs/qiskit.circuit.library.TwoLocal.html) ansatz provided by Qiskit.

In [None]:
generator = TwoLocal(
    REAL_DISTRIBUTION_NQUBITS,
    ['ry', 'rz'],  # Parameterized single qubit rotations
    'cz',  # Entangling gate
    'full', # Entanglement structure: all to all
    reps=GENERATOR_NLAYERS, # Number of layers
    parameter_prefix='θ_g',
    name='Generator')
generator = generator.decompose() # decompose into standard gates
generator.draw("mpl")

### Variational quantum discriminator

For the ansatz of the discrimator we create a custom ansatz and define our own [`ParameterVector`](https://qiskit.org/documentation/stable/0.39/stubs/qiskit.circuit.ParameterVector.html?highlight=parametervector).

In [None]:
discriminator_weights = ParameterVector('θ_d', 12)
discriminator = QuantumCircuit(3, name="Discriminator")
discriminator.barrier()
discriminator.h(0)

for i in range(DISCRIMINATOR_NQUBITS):
    discriminator.rx(discriminator_weights[3 * i + 0], i)
    discriminator.ry(discriminator_weights[3 * i + 1], i)
    discriminator.rz(discriminator_weights[3 * i + 2], i)

discriminator.cx(0, 2)
discriminator.cx(1, 2)
discriminator.rx(discriminator_weights[9], 2)
discriminator.ry(discriminator_weights[10], 2)
discriminator.rz(discriminator_weights[11], 2)
discriminator.draw("mpl")

## Building the QGAN

We now got all the components we need:

- the real distribution,
- the generator creating our fake data, and
- the discriminator.

We can now construct the two circuits that we presented in the slides already. The first feeds the real data into our discriminator. The second feeds the generated state into the discriminator.

In [None]:
real_discriminator_circuit = QuantumCircuit(DISCRIMINATOR_NQUBITS)
real_discriminator_circuit.compose(real_circuit, inplace=True)
real_discriminator_circuit.compose(discriminator, inplace=True)
real_discriminator_circuit.draw("mpl")

In [None]:
generator_discriminator_circuit = QuantumCircuit(DISCRIMINATOR_NQUBITS)
generator_discriminator_circuit.compose(generator, inplace=True)
generator_discriminator_circuit.compose(discriminator, inplace=True)
generator_discriminator_circuit.draw("mpl")

## Defining the cost function

Remember the minimax decision rule we have presented earlier:

$$\min_{\vec\theta_G}\max_{\vec\theta_D}(\Pr(D(\vec\theta_D, \mathrm{real}=|\mathrm{real}\rangle) + \Pr(D(\vec\theta_D, G(\vec\theta_G)) = |\mathrm{fake}\rangle))).$$

Based on this, we can define a loss function for the discriminator and the generator. For the discriminator we get:

$$\mathrm{Cost}_D = \Pr(D(\vec\theta_D, G(\vec\theta_G)) = |\mathrm{real}\rangle) - \Pr(D(\vec\theta_D, \mathrm{real})=|\mathrm{real}\rangle).$$

Minimising $\mathrm{Cost}_D$ entails maximising the probability of correctly classifying real data while minimising the probability of wrongly classifying fake data as real.

The generator's cost function can be the negative of the cost of the discriminator. However, the optimal strategy is to maximise the probability that the discriminator miclassifies fake data as real. Thus, the term concerning the real quantum state can be omited:

$$\mathrm{Cost}_G=-\Pr(D(\vec\theta_D, G(\vec\theta_G))=|\mathrm{real}\rangle).$$

In [None]:
N_GPARAMS = generator.num_parameters
N_DPARAMS = discriminator.num_parameters

When implementign the cost functions above, please note the order of qubits in qiskit. We want to ensure to sum up all states where the last qubit is one, that is the probability of a sample being classified as $|\mathrm{real}\rangle = |1\rangle$, thus $|XX1\rangle$. This means that we need to consider the states being $|1XX\rangle$ due to the reverse ordering.

In [None]:
# We'll use Statevector to retrieve statevector of given circuit
from qiskit.quantum_info import Statevector

def discriminator_cost(disc_params):
    """Discriminator cost function for the optimizer to minimize."""
    # .numpy() method extracts numpy array from TF tensor
    curr_params = np.append(disc_params.numpy(), gen_params.numpy())
    gendisc_probs = Statevector(
        generator_discriminator_circuit.bind_parameters(curr_params)).probabilities()
    realdisc_probs = Statevector(
        real_discriminator_circuit.bind_parameters(disc_params.numpy())).probabilities()
    # Get total prob of measuring |1> on last qubit
    prob_fake_true = np.sum(gendisc_probs[0b100:])
    # Get total prob of measuring |1> on last qubit
    prob_real_true = np.sum(realdisc_probs[0b100:])
    cost = prob_fake_true - prob_real_true
    return cost

def generator_cost(gen_params):
    """Generator cost function for the optimizer to minimize."""
    # .numpy() method extracts numpy array from TF tensor
    curr_params = np.append(disc_params.numpy(), gen_params.numpy())
    state_probs = Statevector(
        generator_discriminator_circuit.bind_parameters(curr_params)).probabilities()
    # Get total prob of measuring |1> on q2
    prob_fake_true = np.sum(state_probs[0b100:])
    cost = -prob_fake_true
    return cost

### Evaluating the goodness of our whole model

The Kullback-Leibler divergence is a measure that is used to measure the distance between two distributions. We therefore define a helper function to calculate the Kullback-Leibler divergence between the model and target distribution. This is commonly done to track the generator's progress when training GANs. A lower KL divergence indicates that the two distributions are similar, with a score of $0$ implying equivalence.

In [None]:
def calculate_kl_div(model_distribution: dict, target_distribution: dict):
    """Gauge model performance using Kullback Leibler Divergence"""
    kl_div = 0
    for bitstring, p_data in target_distribution.items():
        if np.isclose(p_data, 0, atol=1e-8):
            continue
        if bitstring in model_distribution.keys():
            kl_div += (p_data * np.log(p_data)
                 - p_data * np.log(model_distribution[bitstring]))
        else:
            kl_div += p_data * np.log(p_data) - p_data * np.log(1e-6)
    return kl_div

## Training our quantum GAN

### Compiling our models for training

For simplicity we use the [`CircuitQNN`](https://qiskit.org/documentation/machine-learning/stubs/qiskit_machine_learning.neural_networks.CircuitQNN.html) that compiles the parmaeterised quantum circuit and handles calculation of gradients.

Please note that the ordering of parameters is internally done alphabetically. So in case you renamed the parameter vectors, please ensure that you adapt the ranges of parameters accordingly.

In [None]:
from qiskit.utils import QuantumInstance
from qiskit_machine_learning.neural_networks import CircuitQNN

# define quantum instances (statevector and sample based)
qi_sv = QuantumInstance(Aer.get_backend('aer_simulator_statevector'))

# specify QNN to update generator weights
generator_qnn = CircuitQNN(
    generator_discriminator_circuit,  # parameterized circuit
    # frozen input arguements (discriminator weights)
    generator_discriminator_circuit.parameters[:N_DPARAMS],
    # differentiable weights (generator weights)
    generator_discriminator_circuit.parameters[N_DPARAMS:],
    sparse=True, # returns sparse probability vector
    quantum_instance=qi_sv)

# specify QNNs to update discriminator weights
discriminator_fake_qnn = CircuitQNN(
    generator_discriminator_circuit, # parameterized circuit
    # frozen input arguments (generator weights)
    generator_discriminator_circuit.parameters[N_DPARAMS:],
    # differentiable weights (discriminator weights)
    generator_discriminator_circuit.parameters[:N_DPARAMS],
    sparse=True, # get sparse probability vector
    quantum_instance=qi_sv)

discriminator_real_qnn = CircuitQNN(
    real_discriminator_circuit, # parameterized circuit
    [], # no input parameters
    # differentiable weights (discriminator weights)
    generator_discriminator_circuit.parameters[:N_DPARAMS],
    sparse=True, # get sparse probability vector
    quantum_instance=qi_sv)

For training we use TensorFlow Keras and create an instance of the ADAM optimiser to optimise the models of the generator and discriminator. The ADAM optimiser is widely used in classical machine learning and usally outperforms vanilla gradient descent.

In [None]:
import pickle # to serialize and deserialize variables
# Initialize parameters
init_gen_params = np.random.uniform(
    low=-np.pi,
    high=np.pi,
    size=(N_GPARAMS,))
init_disc_params = np.random.uniform(
    low=-np.pi,
    high=np.pi,
    size=(N_DPARAMS,))
gen_params = tf.Variable(init_gen_params)
disc_params = tf.Variable(init_disc_params)

First, let's look at our starting distribution that is generated by our generator model.

In [None]:
init_gen_circuit = generator.bind_parameters(init_gen_params)
init_prob_dict = Statevector(init_gen_circuit).probabilities_dict()

import matplotlib.pyplot as plt
fig, ax1 = plt.subplots(1, 1, sharey=True)
ax1.set_title("Initial generator distribution")
plot_histogram(init_prob_dict, ax=ax1)

In [None]:
# Initialize Adam optimizer from Keras
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.02)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.02)

In [None]:
# Initialize variables to track metrics while training
best_gen_params = tf.Variable(init_gen_params)
gloss = []
dloss = []
kl_div = []

### Training

While training classical GANs it is not uncommon to train an unbalanced number of steps between the two networks. Currently we use a 5:1 ratio that has been determined through trial and error.

In [None]:
DISCRIMINATOR_UPDATE_STEPS = 5 # N discriminator updates per generator update

In [None]:
TABLE_HEADERS = "Epoch | Generator cost | Discriminator cost | KL Div. |"
print(TABLE_HEADERS)
for epoch in range(EPOCHS):
    #--- Quantum discriminator parameter updates ---#
    for disc_train_step in range(DISCRIMINATOR_UPDATE_STEPS):
        # Partial derivatives wrt θ_D
        d_fake = discriminator_fake_qnn.backward(gen_params, disc_params
                                       )[1].todense()[0, 0b100:]
        d_fake = np.sum(d_fake, axis=0)
        d_real = discriminator_real_qnn.backward([], disc_params
                                       )[1].todense()[0, 0b100:]
        d_real = np.sum(d_real, axis=0)
        # Recall Cost_D structure
        grad_dcost = [d_fake[i] - d_real[i] for i in range(N_DPARAMS)]
        grad_dcost = tf.convert_to_tensor(grad_dcost)
        # Update disc params with gradient
        discriminator_optimizer.apply_gradients(zip([grad_dcost],
                                                    [disc_params]))
        # Track discriminator loss
        if disc_train_step % DISCRIMINATOR_UPDATE_STEPS == 0:
            dloss.append(discriminator_cost(disc_params))

    #--- Quantum generator parameter updates ---#
    for gen_train_step in range(1):
        # Compute partial derivatives of prob(fake|true) wrt each
        # generator weight
        grads = generator_qnn.backward(disc_params, gen_params)
        grads = grads[1].todense()[0][0b100:]
        # Recall Cost_G structure and the linearity of
        # the derivative operation
        grads = -np.sum(grads, axis=0)
        grads = tf.convert_to_tensor(grads)
        # Update gen params with gradient
        generator_optimizer.apply_gradients(zip([grads], [gen_params]))
        gloss.append(generator_cost(gen_params))

    #--- Track KL and save best performing generator weights ---#
    # Create test circuit with updated gen parameters
    gen_checkpoint_circuit = generator.bind_parameters(gen_params.numpy())
    # Retrieve probability distribution of current generator
    gen_prob_dict = Statevector(gen_checkpoint_circuit).probabilities_dict()
    # Constant real probability distribution
    real_prob_dict = Statevector(real_circuit).probabilities_dict()
    current_kl = calculate_kl_div(gen_prob_dict, real_prob_dict)
    kl_div.append(current_kl)
    if np.min(kl_div) == current_kl:
        # New best
        # serialize & deserialize to simply ensure zero links
        best_gen_params = pickle.loads(pickle.dumps(gen_params))
    if epoch % 10 == 0:
        # print table every 10 epochs
        for header, val in zip(TABLE_HEADERS.split('|'),
                              (epoch, gloss[-1], dloss[-1], kl_div[-1])):
            print(f"{val:.3g} ".rjust(len(header)), end="|")
        print()

## Visualising progress and results

In [None]:
fig, (loss, kl) = plt.subplots(2, sharex=True,
                               gridspec_kw={'height_ratios': [0.75, 1]},
                               figsize=(6,4))
fig.suptitle('QGAN training stats')
fig.supxlabel('Training step')
loss.plot(range(len(gloss)), gloss, label="Generator loss")
loss.plot(range(len(dloss)), dloss, label="Discriminator loss",
          color="C3")
loss.legend()
loss.set(ylabel='Loss')
kl.plot(range(len(kl_div)), kl_div, label="KL Divergence (zero is best)",
        color="C1")
kl.set(ylabel='KL Divergence')
kl.legend()
fig.tight_layout();

In [None]:
# Create test circuit with new parameters
gen_checkpoint_circuit = generator.bind_parameters(
    best_gen_params.numpy())
gen_prob_dict = Statevector(gen_checkpoint_circuit).probabilities_dict()
real_prob_dict = Statevector(real_circuit).probabilities_dict() # constant
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,3))
plot_histogram(gen_prob_dict, ax=ax1)
ax1.set_title("Trained generator distribution")
plot_histogram(real_prob_dict, ax=ax2)
ax2.set_title("Real distribution")
fig.tight_layout()

Our model already approximates the Bell state quite well.

However, due to the competing nature of the two models, the training can become quite fragile and might fail to converge. This is especially true for QGANs due to the lack of best practices. This is still ongoing research and we might often suffer from effects such as vanishing gradients caused by a discriminator that is too good.